diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml index 3ea8ab81..b39dca7b 100644 --- a/crates/storage-pg/Cargo.toml +++ b/crates/storage-pg/Cargo.toml @@ -8,8 +8,8 @@ license = "Apache-2.0" [dependencies] async-trait = "0.1.73" sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid"] } -sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono"] } -sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono"] } +sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono", "postgres-array"] } +sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono", "postgres-array"] } chrono.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index 11c28fd4..1b00c8e2 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -38,7 +38,7 @@ mod tests { }; use oauth2_types::{ requests::{GrantType, ResponseMode}, - scope::{Scope, OPENID}, + scope::{Scope, EMAIL, OPENID, PROFILE}, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; @@ -456,7 +456,8 @@ mod tests { .await .unwrap(); - let scope = Scope::from_iter([OPENID]); + let scope = Scope::from_iter([OPENID, EMAIL]); + let scope2 = Scope::from_iter([OPENID, PROFILE]); // Create two sessions for each user, one with each client // We're moving the clock forward by 1 minute between each session to ensure @@ -477,14 +478,14 @@ mod tests { let session21 = repo .oauth2_session() - .add(&mut rng, &clock, &client2, &user1_session, scope.clone()) + .add(&mut rng, &clock, &client2, &user1_session, scope2.clone()) .await .unwrap(); clock.advance(Duration::minutes(1)); let session22 = repo .oauth2_session() - .add(&mut rng, &clock, &client2, &user2_session, scope.clone()) + .add(&mut rng, &clock, &client2, &user2_session, scope2.clone()) .await .unwrap(); clock.advance(Duration::minutes(1)); @@ -645,5 +646,48 @@ mod tests { assert_eq!(list.edges[0], session21); assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1); + + // Try the scope filter. We should get all sessions with the "openid" scope + let scope = Scope::from_iter([OPENID]); + let filter = OAuth2SessionFilter::new().with_scope(&scope); + let list = repo + .oauth2_session() + .list(filter, pagination) + .await + .unwrap(); + assert!(!list.has_next_page); + assert_eq!(list.edges.len(), 4); + assert_eq!(list.edges[0], session11); + assert_eq!(list.edges[1], session12); + assert_eq!(list.edges[2], session21); + assert_eq!(list.edges[3], session22); + assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4); + + // We should get all sessions with the "openid" and "email" scope + let scope = Scope::from_iter([OPENID, EMAIL]); + let filter = OAuth2SessionFilter::new().with_scope(&scope); + let list = repo + .oauth2_session() + .list(filter, pagination) + .await + .unwrap(); + assert!(!list.has_next_page); + assert_eq!(list.edges.len(), 2); + assert_eq!(list.edges[0], session11); + assert_eq!(list.edges[1], session12); + assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2); + + // Try combining the scope filter with the user filter + let filter = OAuth2SessionFilter::new() + .with_scope(&scope) + .for_user(&user1); + let list = repo + .oauth2_session() + .list(filter, pagination) + .await + .unwrap(); + assert_eq!(list.edges.len(), 1); + assert_eq!(list.edges[0], session11); + assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1); } } diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index ad144f06..6a7713da 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -21,7 +21,7 @@ use mas_storage::{ }; use oauth2_types::scope::{Scope, ScopeToken}; use rand::RngCore; -use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; +use sea_query::{enum_def, extension::postgres::PgExpr, Expr, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; use sqlx::PgConnection; use ulid::Ulid; @@ -288,6 +288,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() } })) + .and_where_option(filter.scope().map(|scope| { + let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) + })) .generate_pagination( (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId), pagination, @@ -345,6 +349,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() } })) + .and_where_option(filter.scope().map(|scope| { + let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) + })) .build_sqlx(PostgresQueryBuilder); let count: i64 = sqlx::query_scalar_with(&sql, arguments) diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index ad20710c..a3247f55 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -42,6 +42,7 @@ pub struct OAuth2SessionFilter<'a> { user: Option<&'a User>, client: Option<&'a Client>, state: Option, + scope: Option<&'a Scope>, } impl<'a> OAuth2SessionFilter<'a> { @@ -102,6 +103,21 @@ impl<'a> OAuth2SessionFilter<'a> { pub fn state(&self) -> Option { self.state } + + /// Only return sessions with the given scope + #[must_use] + pub fn with_scope(mut self, scope: &'a Scope) -> Self { + self.scope = Some(scope); + self + } + + /// Get the scope filter + /// + /// Returns [`None`] if no scope filter was set + #[must_use] + pub fn scope(&self) -> Option<&Scope> { + self.scope + } } /// An [`OAuth2SessionRepository`] helps interacting with [`Session`]