1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Batch finish browser sessions

This commit is contained in:
Quentin Gliech
2024-07-16 11:53:04 +02:00
parent 04b96b87b8
commit dcaf65e6e7
3 changed files with 119 additions and 10 deletions

View File

@@ -259,6 +259,43 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
Ok(user_session) Ok(user_session)
} }
#[tracing::instrument(
name = "db.browser_session.finish_bulk",
skip_all,
fields(
db.statement,
),
err,
)]
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: mas_storage::user::BrowserSessionFilter<'_>,
) -> Result<usize, Self::Error> {
let finished_at = clock.now();
let (sql, arguments) = sea_query::Query::update()
.table(UserSessions::Table)
.value(UserSessions::FinishedAt, finished_at)
.and_where_option(filter.user().map(|user| {
Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
} else {
Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
}
}))
.build_sqlx(PostgresQueryBuilder);
let res = sqlx::query_with(&sql, arguments)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
#[tracing::instrument( #[tracing::instrument(
name = "db.browser_session.list", name = "db.browser_session.list",
skip_all, skip_all,
@@ -560,7 +597,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
, last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip) , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
FROM ( FROM (
SELECT * SELECT *
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
AS t(user_session_id, last_active_at, last_active_ip) AS t(user_session_id, last_active_at, last_active_ip)
) AS t ) AS t
WHERE user_sessions.user_session_id = t.user_session_id WHERE user_sessions.user_session_id = t.user_session_id

View File

@@ -534,19 +534,23 @@ async fn test_user_password_repo(pool: PgPool) {
#[sqlx::test(migrator = "crate::MIGRATOR")] #[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_user_session(pool: PgPool) { async fn test_user_session(pool: PgPool) {
const USERNAME: &str = "john";
let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut repo = PgRepository::from_pool(&pool).await.unwrap();
let mut rng = ChaChaRng::seed_from_u64(42); let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default(); let clock = MockClock::default();
let user = repo let alice = repo
.user() .user()
.add(&mut rng, &clock, USERNAME.to_owned()) .add(&mut rng, &clock, "alice".to_owned())
.await .await
.unwrap(); .unwrap();
let all = BrowserSessionFilter::default().for_user(&user); let bob = repo
.user()
.add(&mut rng, &clock, "bob".to_owned())
.await
.unwrap();
let all = BrowserSessionFilter::default();
let active = all.active_only(); let active = all.active_only();
let finished = all.finished_only(); let finished = all.finished_only();
@@ -556,10 +560,10 @@ async fn test_user_session(pool: PgPool) {
let session = repo let session = repo
.browser_session() .browser_session()
.add(&mut rng, &clock, &user, None) .add(&mut rng, &clock, &alice, None)
.await .await
.unwrap(); .unwrap();
assert_eq!(session.user.id, user.id); assert_eq!(session.user.id, alice.id);
assert!(session.finished_at.is_none()); assert!(session.finished_at.is_none());
assert_eq!(repo.browser_session().count(all).await.unwrap(), 1); assert_eq!(repo.browser_session().count(all).await.unwrap(), 1);
@@ -584,7 +588,7 @@ async fn test_user_session(pool: PgPool) {
.expect("user session not found"); .expect("user session not found");
assert_eq!(session_lookup.id, session.id); assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user.id, user.id); assert_eq!(session_lookup.user.id, alice.id);
assert!(session_lookup.finished_at.is_none()); assert!(session_lookup.finished_at.is_none());
// Finish the session // Finish the session
@@ -616,9 +620,53 @@ async fn test_user_session(pool: PgPool) {
.expect("user session not found"); .expect("user session not found");
assert_eq!(session_lookup.id, session.id); assert_eq!(session_lookup.id, session.id);
assert_eq!(session_lookup.user.id, user.id); assert_eq!(session_lookup.user.id, alice.id);
// This time the session is finished // This time the session is finished
assert!(session_lookup.finished_at.is_some()); assert!(session_lookup.finished_at.is_some());
// Create a bunch of other sessions
for _ in 0..5 {
for user in &[&alice, &bob] {
repo.browser_session()
.add(&mut rng, &clock, user, None)
.await
.unwrap();
}
}
let all_alice = BrowserSessionFilter::new().for_user(&alice);
let active_alice = BrowserSessionFilter::new().for_user(&alice).active_only();
let all_bob = BrowserSessionFilter::new().for_user(&bob);
let active_bob = BrowserSessionFilter::new().for_user(&bob).active_only();
assert_eq!(repo.browser_session().count(all).await.unwrap(), 11);
assert_eq!(repo.browser_session().count(active).await.unwrap(), 10);
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 1);
assert_eq!(repo.browser_session().count(all_alice).await.unwrap(), 6);
assert_eq!(repo.browser_session().count(active_alice).await.unwrap(), 5);
assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5);
assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 5);
// Finish all the sessions for alice
let affected = repo
.browser_session()
.finish_bulk(&clock, active_alice)
.await
.unwrap();
assert_eq!(affected, 5);
assert_eq!(repo.browser_session().count(all_alice).await.unwrap(), 6);
assert_eq!(repo.browser_session().count(active_alice).await.unwrap(), 0);
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 6);
// Finish all the sessions for bob
let affected = repo
.browser_session()
.finish_bulk(&clock, active_bob)
.await
.unwrap();
assert_eq!(affected, 5);
assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5);
assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 0);
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 11);
} }
#[sqlx::test(migrator = "crate::MIGRATOR")] #[sqlx::test(migrator = "crate::MIGRATOR")]

View File

@@ -148,6 +148,24 @@ pub trait BrowserSessionRepository: Send + Sync {
user_session: BrowserSession, user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error>; ) -> Result<BrowserSession, Self::Error>;
/// Mark all the [`BrowserSession`] matching the given filter as finished
///
/// Returns the number of sessions affected
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `filter`: The filter parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: BrowserSessionFilter<'_>,
) -> Result<usize, Self::Error>;
/// List [`BrowserSession`] with the given filter and pagination /// List [`BrowserSession`] with the given filter and pagination
/// ///
/// # Parameters /// # Parameters
@@ -262,6 +280,12 @@ repository_impl!(BrowserSessionRepository:
user_session: BrowserSession, user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error>; ) -> Result<BrowserSession, Self::Error>;
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: BrowserSessionFilter<'_>,
) -> Result<usize, Self::Error>;
async fn list( async fn list(
&mut self, &mut self,
filter: BrowserSessionFilter<'_>, filter: BrowserSessionFilter<'_>,