1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Use dynamic filters on app sessions by reusing the OAuth/compat sessions filters

This commit is contained in:
Quentin Gliech
2024-07-16 17:54:53 +02:00
parent 12d2f1f827
commit e89a818ff2

View File

@ -1,4 +1,4 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,20 +17,22 @@
use async_trait::async_trait;
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
use mas_storage::{
app_session::{AppSession, AppSessionFilter, AppSessionRepository},
app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
compat::CompatSessionFilter,
oauth2::OAuth2SessionFilter,
Page, Pagination,
};
use oauth2_types::scope::{Scope, ScopeToken};
use sea_query::{
Alias, ColumnRef, CommonTableExpression, Expr, PgFunc, PostgresQueryBuilder, Query, UnionType,
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
errors::DatabaseInconsistencyError,
filter::StatementExt,
iden::{CompatSessions, OAuth2Sessions},
pagination::QueryBuilderExt,
DatabaseError, ExecuteExt,
@ -202,6 +204,44 @@ impl TryFrom<AppSessionLookup> for AppSession {
}
}
/// Split a [`AppSessionFilter`] into two separate filters: a
/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
fn split_filter(
filter: AppSessionFilter<'_>,
) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
let mut compat_filter = CompatSessionFilter::new();
let mut oauth2_filter = OAuth2SessionFilter::new();
if let Some(user) = filter.user() {
compat_filter = compat_filter.for_user(user);
oauth2_filter = oauth2_filter.for_user(user);
}
match filter.state() {
Some(AppSessionState::Active) => {
compat_filter = compat_filter.active_only();
oauth2_filter = oauth2_filter.active_only();
}
Some(AppSessionState::Finished) => {
compat_filter = compat_filter.finished_only();
oauth2_filter = oauth2_filter.finished_only();
}
None => {}
}
if let Some(device) = filter.device() {
compat_filter = compat_filter.for_device(device);
oauth2_filter = oauth2_filter.for_device(device);
}
if let Some(browser_session) = filter.browser_session() {
compat_filter = compat_filter.for_browser_session(browser_session);
oauth2_filter = oauth2_filter.for_browser_session(browser_session);
}
(compat_filter, oauth2_filter)
}
#[async_trait]
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
type Error = DatabaseError;
@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
filter: AppSessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<AppSession>, Self::Error> {
let (compat_filter, oauth2_filter) = split_filter(filter);
let mut oauth2_session_select = Query::select()
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
AppSessionLookupIden::LastActiveIp,
)
.from(OAuth2Sessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
}))
.apply_filter(oauth2_filter)
.clone();
let compat_session_select = Query::select()
@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
AppSessionLookupIden::LastActiveIp,
)
.from(CompatSessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
}))
.apply_filter(compat_filter)
.clone();
let common_table_expression = CommonTableExpression::new()
@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
err,
)]
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
let (compat_filter, oauth2_filter) = split_filter(filter);
let mut oauth2_session_select = Query::select()
.expr(Expr::cust("1"))
.from(OAuth2Sessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
}))
.apply_filter(oauth2_filter)
.clone();
let compat_session_select = Query::select()
.expr(Expr::cust("1"))
.from(CompatSessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.browser_session().map(|browser_session| {
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
.eq(Uuid::from(browser_session.id))
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
}))
.apply_filter(compat_filter)
.clone();
let common_table_expression = CommonTableExpression::new()