1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Batch finish OAuth 2.0 sessions

This commit is contained in:
Quentin Gliech
2024-07-16 11:28:47 +02:00
parent f8d12cc305
commit 04b96b87b8
3 changed files with 101 additions and 1 deletions

View File

@@ -709,6 +709,37 @@ mod tests {
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session11);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
// Finish all sessions of a client in batch
let affected = repo
.oauth2_session()
.finish_bulk(
&clock,
OAuth2SessionFilter::new()
.for_client(&client1)
.active_only(),
)
.await
.unwrap();
assert_eq!(affected, 1);
// We should have 3 finished sessions
assert_eq!(
repo.oauth2_session()
.count(OAuth2SessionFilter::new().finished_only())
.await
.unwrap(),
3
);
// We should have 1 active sessions
assert_eq!(
repo.oauth2_session()
.count(OAuth2SessionFilter::new().active_only())
.await
.unwrap(),
1
);
}
/// Test the [`OAuth2DeviceCodeGrantRepository`] implementation

View File

@@ -206,6 +206,51 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
})
}
#[tracing::instrument(
name = "db.oauth2_session.finish_bulk",
skip_all,
fields(
db.statement,
),
err,
)]
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: OAuth2SessionFilter<'_>,
) -> Result<usize, Self::Error> {
let finished_at = clock.now();
let (sql, arguments) = Query::update()
.table(OAuth2Sessions::Table)
.value(OAuth2Sessions::FinishedAt, finished_at)
.and_where_option(filter.user().map(|user| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.client().map(|client| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
.eq(Uuid::from(client.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
.build_sqlx(PostgresQueryBuilder);
let res = sqlx::query_with(&sql, arguments)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
#[tracing::instrument(
name = "db.oauth2_session.finish",
skip_all,

View File

@@ -1,4 +1,4 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -245,6 +245,24 @@ pub trait OAuth2SessionRepository: Send + Sync {
async fn finish(&mut self, clock: &dyn Clock, session: Session)
-> Result<Session, Self::Error>;
/// Mark all the [`Session`] matching the given filter as finished
///
/// Returns the number of sessions affected
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `filter`: The filter parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: OAuth2SessionFilter<'_>,
) -> Result<usize, Self::Error>;
/// List [`Session`]s matching the given filter and pagination parameters
///
/// # Parameters
@@ -333,6 +351,12 @@ repository_impl!(OAuth2SessionRepository:
async fn finish(&mut self, clock: &dyn Clock, session: Session)
-> Result<Session, Self::Error>;
async fn finish_bulk(
&mut self,
clock: &dyn Clock,
filter: OAuth2SessionFilter<'_>,
) -> Result<usize, Self::Error>;
async fn list(
&mut self,
filter: OAuth2SessionFilter<'_>,