1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Batch finish compatibility sessions

This commit is contained in:
Quentin Gliech
2024-07-15 22:57:24 +02:00
parent fa32387ca5
commit f8d12cc305
3 changed files with 94 additions and 2 deletions

View File

@ -288,6 +288,16 @@ mod tests {
.unwrap(),
1
);
// Check that we can batch finish sessions
let affected = repo
.compat_session()
.finish_bulk(&clock, all.sso_login_only().active_only())
.await
.unwrap();
assert_eq!(affected, 1);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
}
#[sqlx::test(migrator = "crate::MIGRATOR")]

View File

@ -271,7 +271,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
sqlx::query!(
r#"
INSERT INTO compat_sessions
INSERT INTO compat_sessions
(compat_session_id, user_id, device_id,
user_session_id, created_at, is_synapse_admin)
VALUES ($1, $2, $3, $4, $5, $6)
@ -341,6 +341,64 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
Ok(compat_session)
}
#[tracing::instrument(
name = "db.compat_session.finish_bulk",
skip_all,
fields(db.statement),
err,
)]
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: CompatSessionFilter<'_>,
) -> Result<usize, Self::Error> {
let finished_at = clock.now();
let (sql, arguments) = Query::update()
.table(CompatSessions::Table)
.value(CompatSessions::FinishedAt, finished_at)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.auth_type().map(|auth_type| {
// This builds either a:
// `WHERE compat_session_id = ANY(...)`
// or a `WHERE compat_session_id <> ALL(...)`
let compat_sso_logins = Query::select()
.expr(Expr::col((
CompatSsoLogins::Table,
CompatSsoLogins::CompatSessionId,
)))
.from(CompatSsoLogins::Table)
.take();
if auth_type.is_sso_login() {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.eq(Expr::any(compat_sso_logins))
} else {
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.ne(Expr::all(compat_sso_logins))
}
}))
.and_where_option(filter.device().map(|device| {
Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
}))
.build_sqlx(PostgresQueryBuilder);
let res = sqlx::query_with(&sql, arguments)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
#[tracing::instrument(
name = "db.compat_session.list",
skip_all,
@ -545,7 +603,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
, last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
FROM (
SELECT *
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
AS t(compat_session_id, last_active_at, last_active_ip)
) AS t
WHERE compat_sessions.compat_session_id = t.compat_session_id