diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index a6b3ca55..fe66cff4 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -709,6 +709,37 @@ mod tests { assert_eq!(list.edges.len(), 1); assert_eq!(list.edges[0], session11); assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1); + + // Finish all sessions of a client in batch + let affected = repo + .oauth2_session() + .finish_bulk( + &clock, + OAuth2SessionFilter::new() + .for_client(&client1) + .active_only(), + ) + .await + .unwrap(); + assert_eq!(affected, 1); + + // We should have 3 finished sessions + assert_eq!( + repo.oauth2_session() + .count(OAuth2SessionFilter::new().finished_only()) + .await + .unwrap(), + 3 + ); + + // We should have 1 active sessions + assert_eq!( + repo.oauth2_session() + .count(OAuth2SessionFilter::new().active_only()) + .await + .unwrap(), + 1 + ); } /// Test the [`OAuth2DeviceCodeGrantRepository`] implementation diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 663ecd35..9eb5f1a4 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -206,6 +206,51 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { }) } + #[tracing::instrument( + name = "db.oauth2_session.finish_bulk", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: OAuth2SessionFilter<'_>, + ) -> Result { + let finished_at = clock.now(); + let (sql, arguments) = Query::update() + .table(OAuth2Sessions::Table) + .value(OAuth2Sessions::FinishedAt, finished_at) + .and_where_option(filter.user().map(|user| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.client().map(|client| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)) + .eq(Uuid::from(client.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() + } else { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() + } + })) + .and_where_option(filter.scope().map(|scope| { + let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) + })) + .build_sqlx(PostgresQueryBuilder); + + let res = sqlx::query_with(&sql, arguments) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } + #[tracing::instrument( name = "db.oauth2_session.finish", skip_all, diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 60a130e0..46e96938 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// Copyright 2022-2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -245,6 +245,24 @@ pub trait OAuth2SessionRepository: Send + Sync { async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; + /// Mark all the [`Session`] matching the given filter as finished + /// + /// Returns the number of sessions affected + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `filter`: The filter parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: OAuth2SessionFilter<'_>, + ) -> Result; + /// List [`Session`]s matching the given filter and pagination parameters /// /// # Parameters @@ -333,6 +351,12 @@ repository_impl!(OAuth2SessionRepository: async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: OAuth2SessionFilter<'_>, + ) -> Result; + async fn list( &mut self, filter: OAuth2SessionFilter<'_>,