You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-06 06:02:40 +03:00
Use dynamic filters on app sessions by reusing the OAuth/compat sessions filters
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -17,20 +17,22 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
|
use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
app_session::{AppSession, AppSessionFilter, AppSessionRepository},
|
app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
|
||||||
|
compat::CompatSessionFilter,
|
||||||
|
oauth2::OAuth2SessionFilter,
|
||||||
Page, Pagination,
|
Page, Pagination,
|
||||||
};
|
};
|
||||||
use oauth2_types::scope::{Scope, ScopeToken};
|
use oauth2_types::scope::{Scope, ScopeToken};
|
||||||
use sea_query::{
|
use sea_query::{
|
||||||
Alias, ColumnRef, CommonTableExpression, Expr, PgFunc, PostgresQueryBuilder, Query, UnionType,
|
Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
|
||||||
};
|
};
|
||||||
use sea_query_binder::SqlxBinder;
|
use sea_query_binder::SqlxBinder;
|
||||||
use sqlx::PgConnection;
|
use sqlx::PgConnection;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
errors::DatabaseInconsistencyError,
|
errors::DatabaseInconsistencyError,
|
||||||
|
filter::StatementExt,
|
||||||
iden::{CompatSessions, OAuth2Sessions},
|
iden::{CompatSessions, OAuth2Sessions},
|
||||||
pagination::QueryBuilderExt,
|
pagination::QueryBuilderExt,
|
||||||
DatabaseError, ExecuteExt,
|
DatabaseError, ExecuteExt,
|
||||||
@@ -202,6 +204,44 @@ impl TryFrom<AppSessionLookup> for AppSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Split a [`AppSessionFilter`] into two separate filters: a
|
||||||
|
/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
|
||||||
|
fn split_filter(
|
||||||
|
filter: AppSessionFilter<'_>,
|
||||||
|
) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
|
||||||
|
let mut compat_filter = CompatSessionFilter::new();
|
||||||
|
let mut oauth2_filter = OAuth2SessionFilter::new();
|
||||||
|
|
||||||
|
if let Some(user) = filter.user() {
|
||||||
|
compat_filter = compat_filter.for_user(user);
|
||||||
|
oauth2_filter = oauth2_filter.for_user(user);
|
||||||
|
}
|
||||||
|
|
||||||
|
match filter.state() {
|
||||||
|
Some(AppSessionState::Active) => {
|
||||||
|
compat_filter = compat_filter.active_only();
|
||||||
|
oauth2_filter = oauth2_filter.active_only();
|
||||||
|
}
|
||||||
|
Some(AppSessionState::Finished) => {
|
||||||
|
compat_filter = compat_filter.finished_only();
|
||||||
|
oauth2_filter = oauth2_filter.finished_only();
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(device) = filter.device() {
|
||||||
|
compat_filter = compat_filter.for_device(device);
|
||||||
|
oauth2_filter = oauth2_filter.for_device(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(browser_session) = filter.browser_session() {
|
||||||
|
compat_filter = compat_filter.for_browser_session(browser_session);
|
||||||
|
oauth2_filter = oauth2_filter.for_browser_session(browser_session);
|
||||||
|
}
|
||||||
|
|
||||||
|
(compat_filter, oauth2_filter)
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
||||||
type Error = DatabaseError;
|
type Error = DatabaseError;
|
||||||
@@ -220,6 +260,8 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
|||||||
filter: AppSessionFilter<'_>,
|
filter: AppSessionFilter<'_>,
|
||||||
pagination: Pagination,
|
pagination: Pagination,
|
||||||
) -> Result<Page<AppSession>, Self::Error> {
|
) -> Result<Page<AppSession>, Self::Error> {
|
||||||
|
let (compat_filter, oauth2_filter) = split_filter(filter);
|
||||||
|
|
||||||
let mut oauth2_session_select = Query::select()
|
let mut oauth2_session_select = Query::select()
|
||||||
.expr_as(
|
.expr_as(
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
|
||||||
@@ -269,26 +311,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
|||||||
AppSessionLookupIden::LastActiveIp,
|
AppSessionLookupIden::LastActiveIp,
|
||||||
)
|
)
|
||||||
.from(OAuth2Sessions::Table)
|
.from(OAuth2Sessions::Table)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(oauth2_filter)
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
|
|
||||||
.eq(Uuid::from(browser_session.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.device().map(|device| {
|
|
||||||
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
|
|
||||||
OAuth2Sessions::Table,
|
|
||||||
OAuth2Sessions::ScopeList,
|
|
||||||
))))
|
|
||||||
}))
|
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
let compat_session_select = Query::select()
|
let compat_session_select = Query::select()
|
||||||
@@ -340,23 +363,7 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
|||||||
AppSessionLookupIden::LastActiveIp,
|
AppSessionLookupIden::LastActiveIp,
|
||||||
)
|
)
|
||||||
.from(CompatSessions::Table)
|
.from(CompatSessions::Table)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(compat_filter)
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
|
|
||||||
.eq(Uuid::from(browser_session.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.device().map(|device| {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
|
|
||||||
}))
|
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
let common_table_expression = CommonTableExpression::new()
|
let common_table_expression = CommonTableExpression::new()
|
||||||
@@ -397,51 +404,17 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
|
|||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
|
async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
|
||||||
|
let (compat_filter, oauth2_filter) = split_filter(filter);
|
||||||
let mut oauth2_session_select = Query::select()
|
let mut oauth2_session_select = Query::select()
|
||||||
.expr(Expr::cust("1"))
|
.expr(Expr::cust("1"))
|
||||||
.from(OAuth2Sessions::Table)
|
.from(OAuth2Sessions::Table)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(oauth2_filter)
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
|
|
||||||
.eq(Uuid::from(browser_session.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.device().map(|device| {
|
|
||||||
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
|
|
||||||
OAuth2Sessions::Table,
|
|
||||||
OAuth2Sessions::ScopeList,
|
|
||||||
))))
|
|
||||||
}))
|
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
let compat_session_select = Query::select()
|
let compat_session_select = Query::select()
|
||||||
.expr(Expr::cust("1"))
|
.expr(Expr::cust("1"))
|
||||||
.from(CompatSessions::Table)
|
.from(CompatSessions::Table)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(compat_filter)
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.browser_session().map(|browser_session| {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
|
|
||||||
.eq(Uuid::from(browser_session.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.device().map(|device| {
|
|
||||||
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.to_string())
|
|
||||||
}))
|
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
let common_table_expression = CommonTableExpression::new()
|
let common_table_expression = CommonTableExpression::new()
|
||||||
|
Reference in New Issue
Block a user