diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 092b8c8b..45073667 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -288,6 +288,16 @@ mod tests { .unwrap(), 1 ); + + // Check that we can batch finish sessions + let affected = repo + .compat_session() + .finish_bulk(&clock, all.sso_login_only().active_only()) + .await + .unwrap(); + assert_eq!(affected, 1); + assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2); + assert_eq!(repo.compat_session().count(active).await.unwrap(), 0); } #[sqlx::test(migrator = "crate::MIGRATOR")] diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 2b183253..b8580108 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -271,7 +271,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { sqlx::query!( r#" - INSERT INTO compat_sessions + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, user_session_id, created_at, is_synapse_admin) VALUES ($1, $2, $3, $4, $5, $6) @@ -341,6 +341,64 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { Ok(compat_session) } + #[tracing::instrument( + name = "db.compat_session.finish_bulk", + skip_all, + fields(db.statement), + err, + )] + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: CompatSessionFilter<'_>, + ) -> Result { + let finished_at = clock.now(); + let (sql, arguments) = Query::update() + .table(CompatSessions::Table) + .value(CompatSessions::FinishedAt, finished_at) + .and_where_option(filter.user().map(|user| { + Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() + } else { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null() + } + })) + .and_where_option(filter.auth_type().map(|auth_type| { + // This builds either a: + // `WHERE compat_session_id = ANY(...)` + // or a `WHERE compat_session_id <> ALL(...)` + let compat_sso_logins = Query::select() + .expr(Expr::col(( + CompatSsoLogins::Table, + CompatSsoLogins::CompatSessionId, + ))) + .from(CompatSsoLogins::Table) + .take(); + + if auth_type.is_sso_login() { + Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)) + .eq(Expr::any(compat_sso_logins)) + } else { + Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)) + .ne(Expr::all(compat_sso_logins)) + } + })) + .and_where_option(filter.device().map(|device| { + Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str()) + })) + .build_sqlx(PostgresQueryBuilder); + + let res = sqlx::query_with(&sql, arguments) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } + #[tracing::instrument( name = "db.compat_session.list", skip_all, @@ -545,7 +603,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip) FROM ( SELECT * - FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) AS t(compat_session_id, last_active_at, last_active_ip) ) AS t WHERE compat_sessions.compat_session_id = t.compat_session_id diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 6a0b4ab5..227399ac 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -209,6 +209,24 @@ pub trait CompatSessionRepository: Send + Sync { compat_session: CompatSession, ) -> Result; + /// Mark all the [`CompatSession`] matching the given filter as finished + /// + /// Returns the number of sessions affected + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: CompatSessionFilter<'_>, + ) -> Result; + /// List [`CompatSession`] with the given filter and pagination /// /// Returns a page of compat sessions, with the associated SSO logins if any @@ -289,6 +307,12 @@ repository_impl!(CompatSessionRepository: compat_session: CompatSession, ) -> Result; + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: CompatSessionFilter<'_>, + ) -> Result; + async fn list( &mut self, filter: CompatSessionFilter<'_>,