1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: Allow filtering oauth2 sessions by scope

This commit is contained in:
Quentin Gliech
2023-08-29 14:26:09 +02:00
parent 1826120f10
commit d7abdccc0a
4 changed files with 75 additions and 7 deletions

View File

@ -8,8 +8,8 @@ license = "Apache-2.0"
[dependencies]
async-trait = "0.1.73"
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid"] }
sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono"] }
sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono"] }
sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono", "postgres-array"] }
sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono", "postgres-array"] }
chrono.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@ -38,7 +38,7 @@ mod tests {
};
use oauth2_types::{
requests::{GrantType, ResponseMode},
scope::{Scope, OPENID},
scope::{Scope, EMAIL, OPENID, PROFILE},
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
@ -456,7 +456,8 @@ mod tests {
.await
.unwrap();
let scope = Scope::from_iter([OPENID]);
let scope = Scope::from_iter([OPENID, EMAIL]);
let scope2 = Scope::from_iter([OPENID, PROFILE]);
// Create two sessions for each user, one with each client
// We're moving the clock forward by 1 minute between each session to ensure
@ -477,14 +478,14 @@ mod tests {
let session21 = repo
.oauth2_session()
.add(&mut rng, &clock, &client2, &user1_session, scope.clone())
.add(&mut rng, &clock, &client2, &user1_session, scope2.clone())
.await
.unwrap();
clock.advance(Duration::minutes(1));
let session22 = repo
.oauth2_session()
.add(&mut rng, &clock, &client2, &user2_session, scope.clone())
.add(&mut rng, &clock, &client2, &user2_session, scope2.clone())
.await
.unwrap();
clock.advance(Duration::minutes(1));
@ -645,5 +646,48 @@ mod tests {
assert_eq!(list.edges[0], session21);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
// Try the scope filter. We should get all sessions with the "openid" scope
let scope = Scope::from_iter([OPENID]);
let filter = OAuth2SessionFilter::new().with_scope(&scope);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 4);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session12);
assert_eq!(list.edges[2], session21);
assert_eq!(list.edges[3], session22);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
// We should get all sessions with the "openid" and "email" scope
let scope = Scope::from_iter([OPENID, EMAIL]);
let filter = OAuth2SessionFilter::new().with_scope(&scope);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert!(!list.has_next_page);
assert_eq!(list.edges.len(), 2);
assert_eq!(list.edges[0], session11);
assert_eq!(list.edges[1], session12);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
// Try combining the scope filter with the user filter
let filter = OAuth2SessionFilter::new()
.with_scope(&scope)
.for_user(&user1);
let list = repo
.oauth2_session()
.list(filter, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0], session11);
assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
}
}

View File

@ -21,7 +21,7 @@ use mas_storage::{
};
use oauth2_types::scope::{Scope, ScopeToken};
use rand::RngCore;
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sea_query::{enum_def, extension::postgres::PgExpr, Expr, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
@ -288,6 +288,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
.generate_pagination(
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
pagination,
@ -345,6 +349,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
.build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)

View File

@ -42,6 +42,7 @@ pub struct OAuth2SessionFilter<'a> {
user: Option<&'a User>,
client: Option<&'a Client>,
state: Option<OAuth2SessionState>,
scope: Option<&'a Scope>,
}
impl<'a> OAuth2SessionFilter<'a> {
@ -102,6 +103,21 @@ impl<'a> OAuth2SessionFilter<'a> {
pub fn state(&self) -> Option<OAuth2SessionState> {
self.state
}
/// Only return sessions with the given scope
#[must_use]
pub fn with_scope(mut self, scope: &'a Scope) -> Self {
self.scope = Some(scope);
self
}
/// Get the scope filter
///
/// Returns [`None`] if no scope filter was set
#[must_use]
pub fn scope(&self) -> Option<&Scope> {
self.scope
}
}
/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]