diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 25e520ad..e0e1dc23 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -18,7 +18,7 @@ use async_graphql::{ }; use chrono::{DateTime, Utc}; use mas_storage::{ - compat::{CompatSessionFilter, CompatSsoLoginRepository}, + compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository}, oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository}, @@ -98,7 +98,7 @@ impl User { before: Option, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, - ) -> Result, async_graphql::Error> { + ) -> Result, async_graphql::Error> { let state = ctx.state(); let mut repo = state.repository().await?; @@ -116,14 +116,24 @@ impl User { .transpose()?; let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let page = repo - .compat_sso_login() - .list_paginated(&self.0, pagination) - .await?; + let filter = CompatSsoLoginFilter::new().for_user(&self.0); + + let page = repo.compat_sso_login().list(filter, pagination).await?; + + // Preload the total count if requested + let count = if ctx.look_ahead().field("totalCount").exists() { + Some(repo.compat_sso_login().count(filter).await?) + } else { + None + }; repo.cancel().await?; - let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + let mut connection = Connection::with_additional_fields( + page.has_previous_page, + page.has_next_page, + PreloadedTotalCount(count), + ); connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)), diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index c926e117..12da6159 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -33,7 +33,7 @@ mod tests { clock::MockClock, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter, - CompatSessionRepository, + CompatSessionRepository, CompatSsoLoginFilter, }, user::UserRepository, Clock, Pagination, Repository, RepositoryAccess, @@ -494,6 +494,19 @@ mod tests { let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap(); assert_eq!(login, None); + let all = CompatSsoLoginFilter::new(); + let for_user = all.for_user(&user); + let pending = all.pending_only(); + let fulfilled = all.fulfilled_only(); + let exchanged = all.exchanged_only(); + + // Check the initial counts + assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0); + // Lookup an unknown login token let login = repo .compat_sso_login() @@ -515,6 +528,13 @@ mod tests { .unwrap(); assert!(login.is_pending()); + // Check the counts + assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0); + // Lookup the login by ID let login_lookup = repo .compat_sso_login() @@ -557,6 +577,13 @@ mod tests { .unwrap(); assert!(login.is_fulfilled()); + // Check the counts + assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0); + // Fulfilling again should not work // Note: It should also not poison the SQL transaction let res = repo @@ -573,6 +600,13 @@ mod tests { .unwrap(); assert!(login.is_exchanged()); + // Check the counts + assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1); + assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0); + assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1); + // Exchange again should not work // Note: It should also not poison the SQL transaction let res = repo @@ -589,13 +623,47 @@ mod tests { .await; assert!(res.is_err()); + let pagination = Pagination::first(10); + + // List all logins + let logins = repo.compat_sso_login().list(all, pagination).await.unwrap(); + assert!(!logins.has_next_page); + assert_eq!(logins.edges, &[login.clone()]); + // List the logins for the user let logins = repo .compat_sso_login() - .list_paginated(&user, Pagination::first(10)) + .list(for_user, pagination) .await .unwrap(); assert!(!logins.has_next_page); - assert_eq!(logins.edges, vec![login]); + assert_eq!(logins.edges, &[login.clone()]); + + // List only the pending logins for the user + let logins = repo + .compat_sso_login() + .list(for_user.pending_only(), pagination) + .await + .unwrap(); + assert!(!logins.has_next_page); + assert!(logins.edges.is_empty()); + + // List only the fulfilled logins for the user + let logins = repo + .compat_sso_login() + .list(for_user.fulfilled_only(), pagination) + .await + .unwrap(); + assert!(!logins.has_next_page); + assert!(logins.edges.is_empty()); + + // List only the exchanged logins for the user + let logins = repo + .compat_sso_login() + .list(for_user.exchanged_only(), pagination) + .await + .unwrap(); + assert!(!logins.has_next_page); + assert_eq!(logins.edges, &[login]); } } diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 006b0560..91b8d33a 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -323,7 +323,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { filter: CompatSessionFilter<'_>, pagination: Pagination, ) -> Result)>, Self::Error> { - let (sql, values) = sea_query::Query::select() + let (sql, values) = Query::select() .expr_as( Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)), CompatSessionAndSsoLoginLookupIden::CompatSessionId, @@ -441,7 +441,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { // session. let exists = Expr::exists( Query::select() - .expr(Expr::val(1)) + .expr(Expr::cust("1")) .from(CompatSsoLogins::Table) .and_where( Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)) diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs index 1687289c..b201ed22 100644 --- a/crates/storage-pg/src/compat/sso_login.rs +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -14,16 +14,24 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; -use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination}; +use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState}; +use mas_storage::{ + compat::{CompatSsoLoginFilter, CompatSsoLoginRepository}, + Clock, Page, Pagination, +}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; +use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sqlx::PgConnection; use ulid::Ulid; use url::Url; use uuid::Uuid; use crate::{ - pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + iden::{CompatSessions, CompatSsoLogins}, + pagination::QueryBuilderExt, + sea_query_sqlx::map_values, + tracing::ExecuteExt, + DatabaseError, DatabaseInconsistencyError, }; /// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL @@ -41,6 +49,7 @@ impl<'c> PgCompatSsoLoginRepository<'c> { } #[derive(sqlx::FromRow)] +#[enum_def] struct CompatSsoLoginLookup { compat_sso_login_id: Uuid, login_token: String, @@ -295,49 +304,149 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { } #[tracing::instrument( - name = "db.compat_sso_login.list_paginated", + name = "db.compat_sso_login.list", skip_all, fields( db.statement, - %user.id, - %user.username, ), err )] - async fn list_paginated( + async fn list( &mut self, - user: &User, + filter: CompatSsoLoginFilter<'_>, pagination: Pagination, ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT cl.compat_sso_login_id - , cl.login_token - , cl.redirect_uri - , cl.created_at - , cl.fulfilled_at - , cl.exchanged_at - , cl.compat_session_id + let (sql, values) = Query::select() + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)), + CompatSsoLoginLookupIden::CompatSsoLoginId, + ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)), + CompatSsoLoginLookupIden::CompatSessionId, + ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)), + CompatSsoLoginLookupIden::LoginToken, + ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)), + CompatSsoLoginLookupIden::RedirectUri, + ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)), + CompatSsoLoginLookupIden::CreatedAt, + ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)), + CompatSsoLoginLookupIden::FulfilledAt, + ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)), + CompatSsoLoginLookupIden::ExchangedAt, + ) + .from(CompatSsoLogins::Table) + .and_where_option(filter.user().map(|user| { + Expr::exists( + Query::select() + .expr(Expr::cust("1")) + .from(CompatSessions::Table) + .and_where( + Expr::col((CompatSessions::Table, CompatSessions::UserId)) + .eq(Uuid::from(user.id)), + ) + .and_where( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)) + .equals((CompatSessions::Table, CompatSessions::CompatSessionId)), + ) + .take(), + ) + })) + .and_where_option(filter.state().map(|state| { + if state.is_exchanged() { + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null() + } else if state.is_fulfilled() { + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)) + .is_not_null() + .and( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)) + .is_null(), + ) + } else { + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null() + } + })) + .generate_pagination( + (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId).into_column_ref(), + pagination, + ) + .build(PostgresQueryBuilder); - FROM compat_sso_logins cl - INNER JOIN compat_sessions cs USING (compat_session_id) - "#, - ); + let arguments = map_values(values); - query - .push(" WHERE cs.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", pagination); - - let edges: Vec = query - .build_query_as() + let edges: Vec = sqlx::query_as_with(&sql, arguments) .traced() .fetch_all(&mut *self.conn) .await?; - let page = pagination - .process(edges) - .try_map(CompatSsoLogin::try_from)?; + let page = pagination.process(edges).try_map(TryFrom::try_from)?; + Ok(page) } + + #[tracing::instrument( + name = "db.compat_sso_login.count", + skip_all, + fields( + db.statement, + ), + err + )] + async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result { + let (sql, values) = Query::select() + .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count()) + .from(CompatSsoLogins::Table) + .and_where_option(filter.user().map(|user| { + Expr::exists( + Query::select() + .expr(Expr::cust("1")) + .from(CompatSessions::Table) + .and_where( + Expr::col((CompatSessions::Table, CompatSessions::UserId)) + .eq(Uuid::from(user.id)), + ) + .and_where( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)) + .equals((CompatSessions::Table, CompatSessions::CompatSessionId)), + ) + .take(), + ) + })) + .and_where_option(filter.state().map(|state| { + if state.is_exchanged() { + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null() + } else if state.is_fulfilled() { + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)) + .is_not_null() + .and( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)) + .is_null(), + ) + } else { + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null() + } + })) + .build(PostgresQueryBuilder); + + let arguments = map_values(values); + + let count: i64 = sqlx::query_scalar_with(&sql, arguments) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + count + .try_into() + .map_err(DatabaseError::to_invalid_operation) + } } diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs index f1f5e2f8..b064d090 100644 --- a/crates/storage/src/compat/mod.rs +++ b/crates/storage/src/compat/mod.rs @@ -23,5 +23,5 @@ pub use self::{ access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository, session::{CompatSessionFilter, CompatSessionRepository}, - sso_login::CompatSsoLoginRepository, + sso_login::{CompatSsoLoginFilter, CompatSsoLoginRepository}, }; diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 6f634d2a..dd2b39cd 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -59,7 +59,7 @@ impl CompatSessionType { } } -/// Filter parameters for listing browser sessions +/// Filter parameters for listing compatibility sessions #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] pub struct CompatSessionFilter<'a> { user: Option<&'a User>, diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 7c823d62..0782f01b 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -20,6 +20,88 @@ use url::Url; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CompatSsoLoginState { + Pending, + Fulfilled, + Exchanged, +} + +impl CompatSsoLoginState { + /// Returns [`true`] if we're looking for pending SSO logins + #[must_use] + pub fn is_pending(self) -> bool { + matches!(self, Self::Pending) + } + + /// Returns [`true`] if we're looking for fulfilled SSO logins + #[must_use] + pub fn is_fulfilled(self) -> bool { + matches!(self, Self::Fulfilled) + } + + /// Returns [`true`] if we're looking for exchanged SSO logins + #[must_use] + pub fn is_exchanged(self) -> bool { + matches!(self, Self::Exchanged) + } +} + +/// Filter parameters for listing compat SSO logins +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub struct CompatSsoLoginFilter<'a> { + user: Option<&'a User>, + state: Option, +} + +impl<'a> CompatSsoLoginFilter<'a> { + /// Create a new empty filter + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Set the user who owns the SSO logins sessions + #[must_use] + pub fn for_user(mut self, user: &'a User) -> Self { + self.user = Some(user); + self + } + + /// Get the user filter + #[must_use] + pub fn user(&self) -> Option<&User> { + self.user + } + + /// Only return pending SSO logins + #[must_use] + pub fn pending_only(mut self) -> Self { + self.state = Some(CompatSsoLoginState::Pending); + self + } + + /// Only return fulfilled SSO logins + #[must_use] + pub fn fulfilled_only(mut self) -> Self { + self.state = Some(CompatSsoLoginState::Fulfilled); + self + } + + /// Only return exchanged SSO logins + #[must_use] + pub fn exchanged_only(mut self) -> Self { + self.state = Some(CompatSsoLoginState::Exchanged); + self + } + + /// Get the state filter + #[must_use] + pub fn state(&self) -> Option { + self.state + } +} + /// A [`CompatSsoLoginRepository`] helps interacting with /// [`CompatSsoLoginRepository`] saved in the storage backend #[async_trait] @@ -117,21 +199,34 @@ pub trait CompatSsoLoginRepository: Send + Sync { compat_sso_login: CompatSsoLogin, ) -> Result; - /// Get a paginated list of compat SSO logins for a user + /// List [`CompatSsoLogin`] with the given filter and pagination + /// + /// Returns a page of compat SSO logins /// /// # Parameters /// - /// * `user`: The user to get the compat SSO logins for + /// * `filter`: The filter to apply /// * `pagination`: The pagination parameters /// /// # Errors /// /// Returns [`Self::Error`] if the underlying repository fails - async fn list_paginated( + async fn list( &mut self, - user: &User, + filter: CompatSsoLoginFilter<'_>, pagination: Pagination, ) -> Result, Self::Error>; + + /// Count the number of [`CompatSsoLogin`] with the given filter + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result; } repository_impl!(CompatSsoLoginRepository: @@ -163,9 +258,11 @@ repository_impl!(CompatSsoLoginRepository: compat_sso_login: CompatSsoLogin, ) -> Result; - async fn list_paginated( + async fn list( &mut self, - user: &User, + filter: CompatSsoLoginFilter<'_>, pagination: Pagination, ) -> Result, Self::Error>; + + async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result; ); diff --git a/frontend/schema.graphql b/frontend/schema.graphql index 44d00611..6a2af24f 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -273,6 +273,10 @@ type CompatSsoLoginConnection { A list of nodes. """ nodes: [CompatSsoLogin!]! + """ + Identifies the total count of items in the connection. + """ + totalCount: Int! } """ diff --git a/frontend/src/gql/graphql.ts b/frontend/src/gql/graphql.ts index 1ed92af0..27f00a45 100644 --- a/frontend/src/gql/graphql.ts +++ b/frontend/src/gql/graphql.ts @@ -216,6 +216,8 @@ export type CompatSsoLoginConnection = { nodes: Array; /** Information to aid in pagination. */ pageInfo: PageInfo; + /** Identifies the total count of items in the connection. */ + totalCount: Scalars["Int"]["output"]; }; /** An edge in a connection. */ diff --git a/frontend/src/gql/schema.ts b/frontend/src/gql/schema.ts index 7bf275f7..382741bc 100644 --- a/frontend/src/gql/schema.ts +++ b/frontend/src/gql/schema.ts @@ -557,6 +557,17 @@ export default { }, args: [], }, + { + name: "totalCount", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + args: [], + }, ], interfaces: [], },