You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Use dynamic filters on app sessions by reusing the OAuth/compat sessions filters
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -17,20 +17,22 @@
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
|
||||
use mas_storage::{
|
||||
app_session::{AppSession, AppSessionFilter, AppSessionRepository},
|
||||
app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
|
||||
compat::CompatSessionFilter,
|
||||
oauth2::OAuth2SessionFilter,
|
||||
Page, Pagination,
|
||||
};
|
||||
use oauth2_types::scope::{Scope, ScopeToken};
|
||||
use sea_query::{
|
||||
Alias, ColumnRef, CommonTableExpression, Expr, PgFunc, PostgresQueryBuilder, Query, UnionType,
|
||||
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
|
||||
};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
errors::DatabaseInconsistencyError,
|
||||
filter::StatementExt,
|
||||
iden::{CompatSessions, OAuth2Sessions},
|
||||
pagination::QueryBuilderExt,
|
||||
DatabaseError, ExecuteExt,
|
||||
@ -202,6 +204,44 @@ impl TryFrom<AppSessionLookup> for AppSession {
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a [`AppSessionFilter`] into two separate filters: a
|
||||
/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
|
||||
fn split_filter(
|
||||
filter: AppSessionFilter<'_>,
|
||||
) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
|
||||
let mut compat_filter = CompatSessionFilter::new();
|
||||
let mut oauth2_filter = OAuth2SessionFilter::new();
|
||||
|
||||
if let Some(user) = filter.user() {
|
||||
compat_filter = compat_filter.for_user(user);
|
||||
oauth2_filter = oauth2_filter.for_user(user);
|
||||
}
|
||||
|
||||
match filter.state() {
|
||||
Some(AppSessionState::Active) => {
|
||||
compat_filter = compat_filter.active_only();
|
||||
oauth2_filter = oauth2_filter.active_only();
|
||||
}
|
||||
Some(AppSessionState::Finished) => {
|
||||
compat_filter = compat_filter.finished_only();
|
||||
oauth2_filter = oauth2_filter.finished_only();
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
if let Some(device) = filter.device() {
|
||||
compat_filter = compat_filter.for_device(device);
|
||||
oauth2_filter = oauth2_filter.for_device(device);
|
||||
}
|
||||
|
||||
if let Some(browser_session) = filter.browser_session() {
|
||||
compat_filter = compat_filter.for_browser_session(browser_session);
|
||||
oauth2_filter = oauth2_filter.for_browser_session(browser_session);
|
||||
}
|
||||
|
||||
(compat_filter, oauth2_filter)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||
filter: AppSessionFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<AppSession>, Self::Error> {
|
||||
let (compat_filter, oauth2_filter) = split_filter(filter);
|
||||
|
||||
let mut oauth2_session_select = Query::select()
|
||||
.expr_as(
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
|
||||
@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||
AppSessionLookupIden::LastActiveIp,
|
||||
)
|
||||
.from(OAuth2Sessions::Table)
|
||||
.and_where_option(filter.user().map(|user| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
||||
}))
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
|
||||
.eq(Uuid::from(browser_session.id))
|
||||
}))
|
||||
.and_where_option(filter.device().map(|device| {
|
||||
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
|
||||
OAuth2Sessions::Table,
|
||||
OAuth2Sessions::ScopeList,
|
||||
))))
|
||||
}))
|
||||
.apply_filter(oauth2_filter)
|
||||
.clone();
|
||||
|
||||
let compat_session_select = Query::select()
|
||||
@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||
AppSessionLookupIden::LastActiveIp,
|
||||
)
|
||||
.from(CompatSessions::Table)
|
||||
.and_where_option(filter.user().map(|user| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
|
||||
}))
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
|
||||
.eq(Uuid::from(browser_session.id))
|
||||
}))
|
||||
.and_where_option(filter.device().map(|device| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
|
||||
}))
|
||||
.apply_filter(compat_filter)
|
||||
.clone();
|
||||
|
||||
let common_table_expression = CommonTableExpression::new()
|
||||
@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||
err,
|
||||
)]
|
||||
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
|
||||
let (compat_filter, oauth2_filter) = split_filter(filter);
|
||||
let mut oauth2_session_select = Query::select()
|
||||
.expr(Expr::cust("1"))
|
||||
.from(OAuth2Sessions::Table)
|
||||
.and_where_option(filter.user().map(|user| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
||||
}))
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
|
||||
.eq(Uuid::from(browser_session.id))
|
||||
}))
|
||||
.and_where_option(filter.device().map(|device| {
|
||||
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
|
||||
OAuth2Sessions::Table,
|
||||
OAuth2Sessions::ScopeList,
|
||||
))))
|
||||
}))
|
||||
.apply_filter(oauth2_filter)
|
||||
.clone();
|
||||
|
||||
let compat_session_select = Query::select()
|
||||
.expr(Expr::cust("1"))
|
||||
.from(CompatSessions::Table)
|
||||
.and_where_option(filter.user().map(|user| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
|
||||
}))
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
|
||||
.eq(Uuid::from(browser_session.id))
|
||||
}))
|
||||
.and_where_option(filter.device().map(|device| {
|
||||
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
|
||||
}))
|
||||
.apply_filter(compat_filter)
|
||||
.clone();
|
||||
|
||||
let common_table_expression = CommonTableExpression::new()
|
||||
|
Reference in New Issue
Block a user