1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Use dynamic filters on compatibility sessions

This commit is contained in:
Quentin Gliech
2024-07-16 16:07:04 +02:00
parent 452024764a
commit 15c2c740a7
2 changed files with 57 additions and 86 deletions

View File

@@ -12,5 +12,5 @@
}, },
"nullable": [] "nullable": []
}, },
"hash": "d0c02576b1550fe2eb877d24f7cdfc819307ee0c47af9fbbf1a3b484290b321d" "hash": "047990a99794b565c2cad396946299db5b617f52f6c24bcca0a24c0c185c4478"
} }

View File

@@ -33,6 +33,7 @@ use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
filter::{Filter, StatementExt, StatementWithJoinsExt},
iden::{CompatSessions, CompatSsoLogins}, iden::{CompatSessions, CompatSsoLogins},
pagination::QueryBuilderExt, pagination::QueryBuilderExt,
tracing::ExecuteExt, tracing::ExecuteExt,
@@ -203,6 +204,57 @@ impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSs
} }
} }
impl Filter for CompatSessionFilter<'_> {
fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all()
.add_option(self.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.add_option(self.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.add_option(self.auth_type().map(|auth_type| {
// In in the SELECT to list sessions, we can rely on the JOINed table, whereas
// in other queries we need to do a subquery
if has_joins {
if auth_type.is_sso_login() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
.is_not_null()
} else {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
.is_null()
}
} else {
// This builds either a:
// `WHERE compat_session_id = ANY(...)`
// or a `WHERE compat_session_id <> ALL(...)`
let compat_sso_logins = Query::select()
.expr(Expr::col((
CompatSsoLogins::Table,
CompatSsoLogins::CompatSessionId,
)))
.from(CompatSsoLogins::Table)
.take();
if auth_type.is_sso_login() {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.eq(Expr::any(compat_sso_logins))
} else {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.ne(Expr::all(compat_sso_logins))
}
}
}))
.add_option(self.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
}))
}
}
#[async_trait] #[async_trait]
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
type Error = DatabaseError; type Error = DatabaseError;
@@ -356,39 +408,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
let (sql, arguments) = Query::update() let (sql, arguments) = Query::update()
.table(CompatSessions::Table) .table(CompatSessions::Table)
.value(CompatSessions::FinishedAt, finished_at) .value(CompatSessions::FinishedAt, finished_at)
.and_where_option(filter.user().map(|user| { .apply_filter(filter)
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.auth_type().map(|auth_type| {
// This builds either a:
// `WHERE compat_session_id = ANY(...)`
// or a `WHERE compat_session_id <> ALL(...)`
let compat_sso_logins = Query::select()
.expr(Expr::col((
CompatSsoLogins::Table,
CompatSsoLogins::CompatSessionId,
)))
.from(CompatSsoLogins::Table)
.take();
if auth_type.is_sso_login() {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.eq(Expr::any(compat_sso_logins))
} else {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.ne(Expr::all(compat_sso_logins))
}
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
}))
.build_sqlx(PostgresQueryBuilder); .build_sqlx(PostgresQueryBuilder);
let res = sqlx::query_with(&sql, arguments) let res = sqlx::query_with(&sql, arguments)
@@ -483,27 +503,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)) Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)), .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
) )
.and_where_option(filter.user().map(|user| { .apply_filter_with_joins(filter)
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.auth_type().map(|auth_type| {
if auth_type.is_sso_login() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
.is_not_null()
} else {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).is_null()
}
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
}))
.generate_pagination( .generate_pagination(
(CompatSessions::Table, CompatSessions::CompatSessionId), (CompatSessions::Table, CompatSessions::CompatSessionId),
pagination, pagination,
@@ -532,36 +532,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
let (sql, arguments) = sea_query::Query::select() let (sql, arguments) = sea_query::Query::select()
.expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count()) .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
.from(CompatSessions::Table) .from(CompatSessions::Table)
.and_where_option(filter.user().map(|user| { .apply_filter(filter)
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.auth_type().map(|auth_type| {
// Check if it is an SSO login by checking if there is a SSO login for the
// session.
let exists = Expr::exists(
Query::select()
.expr(Expr::cust("1"))
.from(CompatSsoLogins::Table)
.and_where(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))
.equals((CompatSessions::Table, CompatSessions::CompatSessionId)),
)
.take(),
);
if auth_type.is_sso_login() {
exists
} else {
exists.not()
}
}))
.build_sqlx(PostgresQueryBuilder); .build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments) let count: i64 = sqlx::query_scalar_with(&sql, arguments)