diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index 0ec57481..5f9a5da4 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -259,6 +259,43 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { Ok(user_session) } + #[tracing::instrument( + name = "db.browser_session.finish_bulk", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: mas_storage::user::BrowserSessionFilter<'_>, + ) -> Result { + let finished_at = clock.now(); + let (sql, arguments) = sea_query::Query::update() + .table(UserSessions::Table) + .value(UserSessions::FinishedAt, finished_at) + .and_where_option(filter.user().map(|user| { + Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null() + } else { + Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null() + } + })) + .build_sqlx(PostgresQueryBuilder); + + let res = sqlx::query_with(&sql, arguments) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } + #[tracing::instrument( name = "db.browser_session.list", skip_all, @@ -560,7 +597,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip) FROM ( SELECT * - FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) AS t(user_session_id, last_active_at, last_active_ip) ) AS t WHERE user_sessions.user_session_id = t.user_session_id diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 917ebc85..789e9b4c 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -534,19 +534,23 @@ async fn test_user_password_repo(pool: PgPool) { #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_user_session(pool: PgPool) { - const USERNAME: &str = "john"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); - let user = repo + let alice = repo .user() - .add(&mut rng, &clock, USERNAME.to_owned()) + .add(&mut rng, &clock, "alice".to_owned()) .await .unwrap(); - let all = BrowserSessionFilter::default().for_user(&user); + let bob = repo + .user() + .add(&mut rng, &clock, "bob".to_owned()) + .await + .unwrap(); + + let all = BrowserSessionFilter::default(); let active = all.active_only(); let finished = all.finished_only(); @@ -556,10 +560,10 @@ async fn test_user_session(pool: PgPool) { let session = repo .browser_session() - .add(&mut rng, &clock, &user, None) + .add(&mut rng, &clock, &alice, None) .await .unwrap(); - assert_eq!(session.user.id, user.id); + assert_eq!(session.user.id, alice.id); assert!(session.finished_at.is_none()); assert_eq!(repo.browser_session().count(all).await.unwrap(), 1); @@ -584,7 +588,7 @@ async fn test_user_session(pool: PgPool) { .expect("user session not found"); assert_eq!(session_lookup.id, session.id); - assert_eq!(session_lookup.user.id, user.id); + assert_eq!(session_lookup.user.id, alice.id); assert!(session_lookup.finished_at.is_none()); // Finish the session @@ -616,9 +620,53 @@ async fn test_user_session(pool: PgPool) { .expect("user session not found"); assert_eq!(session_lookup.id, session.id); - assert_eq!(session_lookup.user.id, user.id); + assert_eq!(session_lookup.user.id, alice.id); // This time the session is finished assert!(session_lookup.finished_at.is_some()); + + // Create a bunch of other sessions + for _ in 0..5 { + for user in &[&alice, &bob] { + repo.browser_session() + .add(&mut rng, &clock, user, None) + .await + .unwrap(); + } + } + + let all_alice = BrowserSessionFilter::new().for_user(&alice); + let active_alice = BrowserSessionFilter::new().for_user(&alice).active_only(); + let all_bob = BrowserSessionFilter::new().for_user(&bob); + let active_bob = BrowserSessionFilter::new().for_user(&bob).active_only(); + assert_eq!(repo.browser_session().count(all).await.unwrap(), 11); + assert_eq!(repo.browser_session().count(active).await.unwrap(), 10); + assert_eq!(repo.browser_session().count(finished).await.unwrap(), 1); + assert_eq!(repo.browser_session().count(all_alice).await.unwrap(), 6); + assert_eq!(repo.browser_session().count(active_alice).await.unwrap(), 5); + assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5); + assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 5); + + // Finish all the sessions for alice + let affected = repo + .browser_session() + .finish_bulk(&clock, active_alice) + .await + .unwrap(); + assert_eq!(affected, 5); + assert_eq!(repo.browser_session().count(all_alice).await.unwrap(), 6); + assert_eq!(repo.browser_session().count(active_alice).await.unwrap(), 0); + assert_eq!(repo.browser_session().count(finished).await.unwrap(), 6); + + // Finish all the sessions for bob + let affected = repo + .browser_session() + .finish_bulk(&clock, active_bob) + .await + .unwrap(); + assert_eq!(affected, 5); + assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5); + assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 0); + assert_eq!(repo.browser_session().count(finished).await.unwrap(), 11); } #[sqlx::test(migrator = "crate::MIGRATOR")] diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index fa1d1763..a10c65b6 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -148,6 +148,24 @@ pub trait BrowserSessionRepository: Send + Sync { user_session: BrowserSession, ) -> Result; + /// Mark all the [`BrowserSession`] matching the given filter as finished + /// + /// Returns the number of sessions affected + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `filter`: The filter parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: BrowserSessionFilter<'_>, + ) -> Result; + /// List [`BrowserSession`] with the given filter and pagination /// /// # Parameters @@ -262,6 +280,12 @@ repository_impl!(BrowserSessionRepository: user_session: BrowserSession, ) -> Result; + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: BrowserSessionFilter<'_>, + ) -> Result; + async fn list( &mut self, filter: BrowserSessionFilter<'_>,