You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: Allow filtering oauth2 sessions by scope
This commit is contained in:
@ -8,8 +8,8 @@ license = "Apache-2.0"
|
||||
[dependencies]
|
||||
async-trait = "0.1.73"
|
||||
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid"] }
|
||||
sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono"] }
|
||||
sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono"] }
|
||||
sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono", "postgres-array"] }
|
||||
sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono", "postgres-array"] }
|
||||
chrono.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
@ -38,7 +38,7 @@ mod tests {
|
||||
};
|
||||
use oauth2_types::{
|
||||
requests::{GrantType, ResponseMode},
|
||||
scope::{Scope, OPENID},
|
||||
scope::{Scope, EMAIL, OPENID, PROFILE},
|
||||
};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
@ -456,7 +456,8 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let scope = Scope::from_iter([OPENID]);
|
||||
let scope = Scope::from_iter([OPENID, EMAIL]);
|
||||
let scope2 = Scope::from_iter([OPENID, PROFILE]);
|
||||
|
||||
// Create two sessions for each user, one with each client
|
||||
// We're moving the clock forward by 1 minute between each session to ensure
|
||||
@ -477,14 +478,14 @@ mod tests {
|
||||
|
||||
let session21 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client2, &user1_session, scope.clone())
|
||||
.add(&mut rng, &clock, &client2, &user1_session, scope2.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
|
||||
let session22 = repo
|
||||
.oauth2_session()
|
||||
.add(&mut rng, &clock, &client2, &user2_session, scope.clone())
|
||||
.add(&mut rng, &clock, &client2, &user2_session, scope2.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
clock.advance(Duration::minutes(1));
|
||||
@ -645,5 +646,48 @@ mod tests {
|
||||
assert_eq!(list.edges[0], session21);
|
||||
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
|
||||
// Try the scope filter. We should get all sessions with the "openid" scope
|
||||
let scope = Scope::from_iter([OPENID]);
|
||||
let filter = OAuth2SessionFilter::new().with_scope(&scope);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 4);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(list.edges[1], session12);
|
||||
assert_eq!(list.edges[2], session21);
|
||||
assert_eq!(list.edges[3], session22);
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
|
||||
|
||||
// We should get all sessions with the "openid" and "email" scope
|
||||
let scope = Scope::from_iter([OPENID, EMAIL]);
|
||||
let filter = OAuth2SessionFilter::new().with_scope(&scope);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!list.has_next_page);
|
||||
assert_eq!(list.edges.len(), 2);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(list.edges[1], session12);
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
|
||||
|
||||
// Try combining the scope filter with the user filter
|
||||
let filter = OAuth2SessionFilter::new()
|
||||
.with_scope(&scope)
|
||||
.for_user(&user1);
|
||||
let list = repo
|
||||
.oauth2_session()
|
||||
.list(filter, pagination)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0], session11);
|
||||
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ use mas_storage::{
|
||||
};
|
||||
use oauth2_types::scope::{Scope, ScopeToken};
|
||||
use rand::RngCore;
|
||||
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
|
||||
use sea_query::{enum_def, extension::postgres::PgExpr, Expr, PostgresQueryBuilder, Query};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
@ -288,6 +288,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.scope().map(|scope| {
|
||||
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
||||
}))
|
||||
.generate_pagination(
|
||||
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
|
||||
pagination,
|
||||
@ -345,6 +349,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.scope().map(|scope| {
|
||||
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
||||
}))
|
||||
.build_sqlx(PostgresQueryBuilder);
|
||||
|
||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||
|
@ -42,6 +42,7 @@ pub struct OAuth2SessionFilter<'a> {
|
||||
user: Option<&'a User>,
|
||||
client: Option<&'a Client>,
|
||||
state: Option<OAuth2SessionState>,
|
||||
scope: Option<&'a Scope>,
|
||||
}
|
||||
|
||||
impl<'a> OAuth2SessionFilter<'a> {
|
||||
@ -102,6 +103,21 @@ impl<'a> OAuth2SessionFilter<'a> {
|
||||
pub fn state(&self) -> Option<OAuth2SessionState> {
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Only return sessions with the given scope
|
||||
#[must_use]
|
||||
pub fn with_scope(mut self, scope: &'a Scope) -> Self {
|
||||
self.scope = Some(scope);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the scope filter
|
||||
///
|
||||
/// Returns [`None`] if no scope filter was set
|
||||
#[must_use]
|
||||
pub fn scope(&self) -> Option<&Scope> {
|
||||
self.scope
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
|
||||
|
Reference in New Issue
Block a user