You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Batch finish OAuth 2.0 sessions
This commit is contained in:
@@ -709,6 +709,37 @@ mod tests {
|
|||||||
assert_eq!(list.edges.len(), 1);
|
assert_eq!(list.edges.len(), 1);
|
||||||
assert_eq!(list.edges[0], session11);
|
assert_eq!(list.edges[0], session11);
|
||||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||||
|
|
||||||
|
// Finish all sessions of a client in batch
|
||||||
|
let affected = repo
|
||||||
|
.oauth2_session()
|
||||||
|
.finish_bulk(
|
||||||
|
&clock,
|
||||||
|
OAuth2SessionFilter::new()
|
||||||
|
.for_client(&client1)
|
||||||
|
.active_only(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(affected, 1);
|
||||||
|
|
||||||
|
// We should have 3 finished sessions
|
||||||
|
assert_eq!(
|
||||||
|
repo.oauth2_session()
|
||||||
|
.count(OAuth2SessionFilter::new().finished_only())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
3
|
||||||
|
);
|
||||||
|
|
||||||
|
// We should have 1 active sessions
|
||||||
|
assert_eq!(
|
||||||
|
repo.oauth2_session()
|
||||||
|
.count(OAuth2SessionFilter::new().active_only())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
1
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test the [`OAuth2DeviceCodeGrantRepository`] implementation
|
/// Test the [`OAuth2DeviceCodeGrantRepository`] implementation
|
||||||
|
@@ -206,6 +206,51 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.oauth2_session.finish_bulk",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn finish_bulk(
|
||||||
|
&mut self,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
filter: OAuth2SessionFilter<'_>,
|
||||||
|
) -> Result<usize, Self::Error> {
|
||||||
|
let finished_at = clock.now();
|
||||||
|
let (sql, arguments) = Query::update()
|
||||||
|
.table(OAuth2Sessions::Table)
|
||||||
|
.value(OAuth2Sessions::FinishedAt, finished_at)
|
||||||
|
.and_where_option(filter.user().map(|user| {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.client().map(|client| {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
||||||
|
.eq(Uuid::from(client.id))
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_active() {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.scope().map(|scope| {
|
||||||
|
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
||||||
|
}))
|
||||||
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let res = sqlx::query_with(&sql, arguments)
|
||||||
|
.traced()
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "db.oauth2_session.finish",
|
name = "db.oauth2_session.finish",
|
||||||
skip_all,
|
skip_all,
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -245,6 +245,24 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
|||||||
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
||||||
-> Result<Session, Self::Error>;
|
-> Result<Session, Self::Error>;
|
||||||
|
|
||||||
|
/// Mark all the [`Session`] matching the given filter as finished
|
||||||
|
///
|
||||||
|
/// Returns the number of sessions affected
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `clock`: The clock used to generate timestamps
|
||||||
|
/// * `filter`: The filter parameters
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn finish_bulk(
|
||||||
|
&mut self,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
filter: OAuth2SessionFilter<'_>,
|
||||||
|
) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
/// List [`Session`]s matching the given filter and pagination parameters
|
/// List [`Session`]s matching the given filter and pagination parameters
|
||||||
///
|
///
|
||||||
/// # Parameters
|
/// # Parameters
|
||||||
@@ -333,6 +351,12 @@ repository_impl!(OAuth2SessionRepository:
|
|||||||
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
async fn finish(&mut self, clock: &dyn Clock, session: Session)
|
||||||
-> Result<Session, Self::Error>;
|
-> Result<Session, Self::Error>;
|
||||||
|
|
||||||
|
async fn finish_bulk(
|
||||||
|
&mut self,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
filter: OAuth2SessionFilter<'_>,
|
||||||
|
) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
async fn list(
|
async fn list(
|
||||||
&mut self,
|
&mut self,
|
||||||
filter: OAuth2SessionFilter<'_>,
|
filter: OAuth2SessionFilter<'_>,
|
||||||
|
Reference in New Issue
Block a user