1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Better SSO login pagination and filtering

This commit is contained in:
Quentin Gliech
2023-07-21 12:23:44 +02:00
parent 24b29498a7
commit 12ad572db8
10 changed files with 353 additions and 52 deletions

View File

@@ -18,7 +18,7 @@ use async_graphql::{
}; };
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{ use mas_storage::{
compat::{CompatSessionFilter, CompatSsoLoginRepository}, compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository},
oauth2::OAuth2SessionRepository, oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository, upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository},
@@ -98,7 +98,7 @@ impl User {
before: Option<String>, before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSsoLogin>, async_graphql::Error> { ) -> Result<Connection<Cursor, CompatSsoLogin, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
@@ -116,14 +116,24 @@ impl User {
.transpose()?; .transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?; let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let page = repo let filter = CompatSsoLoginFilter::new().for_user(&self.0);
.compat_sso_login()
.list_paginated(&self.0, pagination) let page = repo.compat_sso_login().list(filter, pagination).await?;
.await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.compat_sso_login().count(filter).await?)
} else {
None
};
repo.cancel().await?; repo.cancel().await?;
let mut connection = Connection::new(page.has_previous_page, page.has_next_page); let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|u| { connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)), OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)),

View File

@@ -33,7 +33,7 @@ mod tests {
clock::MockClock, clock::MockClock,
compat::{ compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
CompatSessionRepository, CompatSessionRepository, CompatSsoLoginFilter,
}, },
user::UserRepository, user::UserRepository,
Clock, Pagination, Repository, RepositoryAccess, Clock, Pagination, Repository, RepositoryAccess,
@@ -494,6 +494,19 @@ mod tests {
let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap(); let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
assert_eq!(login, None); assert_eq!(login, None);
let all = CompatSsoLoginFilter::new();
let for_user = all.for_user(&user);
let pending = all.pending_only();
let fulfilled = all.fulfilled_only();
let exchanged = all.exchanged_only();
// Check the initial counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
// Lookup an unknown login token // Lookup an unknown login token
let login = repo let login = repo
.compat_sso_login() .compat_sso_login()
@@ -515,6 +528,13 @@ mod tests {
.unwrap(); .unwrap();
assert!(login.is_pending()); assert!(login.is_pending());
// Check the counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
// Lookup the login by ID // Lookup the login by ID
let login_lookup = repo let login_lookup = repo
.compat_sso_login() .compat_sso_login()
@@ -557,6 +577,13 @@ mod tests {
.unwrap(); .unwrap();
assert!(login.is_fulfilled()); assert!(login.is_fulfilled());
// Check the counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
// Fulfilling again should not work // Fulfilling again should not work
// Note: It should also not poison the SQL transaction // Note: It should also not poison the SQL transaction
let res = repo let res = repo
@@ -573,6 +600,13 @@ mod tests {
.unwrap(); .unwrap();
assert!(login.is_exchanged()); assert!(login.is_exchanged());
// Check the counts
assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
// Exchange again should not work // Exchange again should not work
// Note: It should also not poison the SQL transaction // Note: It should also not poison the SQL transaction
let res = repo let res = repo
@@ -589,13 +623,47 @@ mod tests {
.await; .await;
assert!(res.is_err()); assert!(res.is_err());
let pagination = Pagination::first(10);
// List all logins
let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login.clone()]);
// List the logins for the user // List the logins for the user
let logins = repo let logins = repo
.compat_sso_login() .compat_sso_login()
.list_paginated(&user, Pagination::first(10)) .list(for_user, pagination)
.await .await
.unwrap(); .unwrap();
assert!(!logins.has_next_page); assert!(!logins.has_next_page);
assert_eq!(logins.edges, vec![login]); assert_eq!(logins.edges, &[login.clone()]);
// List only the pending logins for the user
let logins = repo
.compat_sso_login()
.list(for_user.pending_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert!(logins.edges.is_empty());
// List only the fulfilled logins for the user
let logins = repo
.compat_sso_login()
.list(for_user.fulfilled_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert!(logins.edges.is_empty());
// List only the exchanged logins for the user
let logins = repo
.compat_sso_login()
.list(for_user.exchanged_only(), pagination)
.await
.unwrap();
assert!(!logins.has_next_page);
assert_eq!(logins.edges, &[login]);
} }
} }

View File

@@ -323,7 +323,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
filter: CompatSessionFilter<'_>, filter: CompatSessionFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> { ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
let (sql, values) = sea_query::Query::select() let (sql, values) = Query::select()
.expr_as( .expr_as(
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)), Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
CompatSessionAndSsoLoginLookupIden::CompatSessionId, CompatSessionAndSsoLoginLookupIden::CompatSessionId,
@@ -441,7 +441,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
// session. // session.
let exists = Expr::exists( let exists = Expr::exists(
Query::select() Query::select()
.expr(Expr::val(1)) .expr(Expr::cust("1"))
.from(CompatSsoLogins::Table) .from(CompatSsoLogins::Table)
.and_where( .and_where(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)) Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))

View File

@@ -14,16 +14,24 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState};
use mas_storage::{compat::CompatSsoLoginRepository, Clock, Page, Pagination}; use mas_storage::{
compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
Clock, Page, Pagination,
};
use rand::RngCore; use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder}; use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, iden::{CompatSessions, CompatSsoLogins},
pagination::QueryBuilderExt,
sea_query_sqlx::map_values,
tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError,
}; };
/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL /// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL
@@ -41,6 +49,7 @@ impl<'c> PgCompatSsoLoginRepository<'c> {
} }
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
#[enum_def]
struct CompatSsoLoginLookup { struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid, compat_sso_login_id: Uuid,
login_token: String, login_token: String,
@@ -295,49 +304,149 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
} }
#[tracing::instrument( #[tracing::instrument(
name = "db.compat_sso_login.list_paginated", name = "db.compat_sso_login.list",
skip_all, skip_all,
fields( fields(
db.statement, db.statement,
%user.id,
%user.username,
), ),
err err
)] )]
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: CompatSsoLoginFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error> { ) -> Result<Page<CompatSsoLogin>, Self::Error> {
let mut query = QueryBuilder::new( let (sql, values) = Query::select()
r#" .expr_as(
SELECT cl.compat_sso_login_id Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
, cl.login_token CompatSsoLoginLookupIden::CompatSsoLoginId,
, cl.redirect_uri )
, cl.created_at .expr_as(
, cl.fulfilled_at Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
, cl.exchanged_at CompatSsoLoginLookupIden::CompatSessionId,
, cl.compat_session_id )
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
CompatSsoLoginLookupIden::LoginToken,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
CompatSsoLoginLookupIden::RedirectUri,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
CompatSsoLoginLookupIden::CreatedAt,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
CompatSsoLoginLookupIden::FulfilledAt,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
CompatSsoLoginLookupIden::ExchangedAt,
)
.from(CompatSsoLogins::Table)
.and_where_option(filter.user().map(|user| {
Expr::exists(
Query::select()
.expr(Expr::cust("1"))
.from(CompatSessions::Table)
.and_where(
Expr::col((CompatSessions::Table, CompatSessions::UserId))
.eq(Uuid::from(user.id)),
)
.and_where(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))
.equals((CompatSessions::Table, CompatSessions::CompatSessionId)),
)
.take(),
)
}))
.and_where_option(filter.state().map(|state| {
if state.is_exchanged() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
} else if state.is_fulfilled() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
.is_not_null()
.and(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
.is_null(),
)
} else {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
}
}))
.generate_pagination(
(CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId).into_column_ref(),
pagination,
)
.build(PostgresQueryBuilder);
FROM compat_sso_logins cl let arguments = map_values(values);
INNER JOIN compat_sessions cs USING (compat_session_id)
"#,
);
query let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
.push(" WHERE cs.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", pagination);
let edges: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.traced() .traced()
.fetch_all(&mut *self.conn) .fetch_all(&mut *self.conn)
.await?; .await?;
let page = pagination let page = pagination.process(edges).try_map(TryFrom::try_from)?;
.process(edges)
.try_map(CompatSsoLogin::try_from)?;
Ok(page) Ok(page)
} }
#[tracing::instrument(
name = "db.compat_sso_login.count",
skip_all,
fields(
db.statement,
),
err
)]
async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
let (sql, values) = Query::select()
.expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
.from(CompatSsoLogins::Table)
.and_where_option(filter.user().map(|user| {
Expr::exists(
Query::select()
.expr(Expr::cust("1"))
.from(CompatSessions::Table)
.and_where(
Expr::col((CompatSessions::Table, CompatSessions::UserId))
.eq(Uuid::from(user.id)),
)
.and_where(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))
.equals((CompatSessions::Table, CompatSessions::CompatSessionId)),
)
.take(),
)
}))
.and_where_option(filter.state().map(|state| {
if state.is_exchanged() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
} else if state.is_fulfilled() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
.is_not_null()
.and(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
.is_null(),
)
} else {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
}
}))
.build(PostgresQueryBuilder);
let arguments = map_values(values);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
} }

View File

@@ -23,5 +23,5 @@ pub use self::{
access_token::CompatAccessTokenRepository, access_token::CompatAccessTokenRepository,
refresh_token::CompatRefreshTokenRepository, refresh_token::CompatRefreshTokenRepository,
session::{CompatSessionFilter, CompatSessionRepository}, session::{CompatSessionFilter, CompatSessionRepository},
sso_login::CompatSsoLoginRepository, sso_login::{CompatSsoLoginFilter, CompatSsoLoginRepository},
}; };

View File

@@ -59,7 +59,7 @@ impl CompatSessionType {
} }
} }
/// Filter parameters for listing browser sessions /// Filter parameters for listing compatibility sessions
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct CompatSessionFilter<'a> { pub struct CompatSessionFilter<'a> {
user: Option<&'a User>, user: Option<&'a User>,

View File

@@ -20,6 +20,88 @@ use url::Url;
use crate::{pagination::Page, repository_impl, Clock, Pagination}; use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CompatSsoLoginState {
Pending,
Fulfilled,
Exchanged,
}
impl CompatSsoLoginState {
/// Returns [`true`] if we're looking for pending SSO logins
#[must_use]
pub fn is_pending(self) -> bool {
matches!(self, Self::Pending)
}
/// Returns [`true`] if we're looking for fulfilled SSO logins
#[must_use]
pub fn is_fulfilled(self) -> bool {
matches!(self, Self::Fulfilled)
}
/// Returns [`true`] if we're looking for exchanged SSO logins
#[must_use]
pub fn is_exchanged(self) -> bool {
matches!(self, Self::Exchanged)
}
}
/// Filter parameters for listing compat SSO logins
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct CompatSsoLoginFilter<'a> {
user: Option<&'a User>,
state: Option<CompatSsoLoginState>,
}
impl<'a> CompatSsoLoginFilter<'a> {
/// Create a new empty filter
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Set the user who owns the SSO logins sessions
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// Only return pending SSO logins
#[must_use]
pub fn pending_only(mut self) -> Self {
self.state = Some(CompatSsoLoginState::Pending);
self
}
/// Only return fulfilled SSO logins
#[must_use]
pub fn fulfilled_only(mut self) -> Self {
self.state = Some(CompatSsoLoginState::Fulfilled);
self
}
/// Only return exchanged SSO logins
#[must_use]
pub fn exchanged_only(mut self) -> Self {
self.state = Some(CompatSsoLoginState::Exchanged);
self
}
/// Get the state filter
#[must_use]
pub fn state(&self) -> Option<CompatSsoLoginState> {
self.state
}
}
/// A [`CompatSsoLoginRepository`] helps interacting with /// A [`CompatSsoLoginRepository`] helps interacting with
/// [`CompatSsoLoginRepository`] saved in the storage backend /// [`CompatSsoLoginRepository`] saved in the storage backend
#[async_trait] #[async_trait]
@@ -117,21 +199,34 @@ pub trait CompatSsoLoginRepository: Send + Sync {
compat_sso_login: CompatSsoLogin, compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error>; ) -> Result<CompatSsoLogin, Self::Error>;
/// Get a paginated list of compat SSO logins for a user /// List [`CompatSsoLogin`] with the given filter and pagination
///
/// Returns a page of compat SSO logins
/// ///
/// # Parameters /// # Parameters
/// ///
/// * `user`: The user to get the compat SSO logins for /// * `filter`: The filter to apply
/// * `pagination`: The pagination parameters /// * `pagination`: The pagination parameters
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: CompatSsoLoginFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error>; ) -> Result<Page<CompatSsoLogin>, Self::Error>;
/// Count the number of [`CompatSsoLogin`] with the given filter
///
/// # Parameters
///
/// * `filter`: The filter to apply
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error>;
} }
repository_impl!(CompatSsoLoginRepository: repository_impl!(CompatSsoLoginRepository:
@@ -163,9 +258,11 @@ repository_impl!(CompatSsoLoginRepository:
compat_sso_login: CompatSsoLogin, compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error>; ) -> Result<CompatSsoLogin, Self::Error>;
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: CompatSsoLoginFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<CompatSsoLogin>, Self::Error>; ) -> Result<Page<CompatSsoLogin>, Self::Error>;
async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error>;
); );

View File

@@ -273,6 +273,10 @@ type CompatSsoLoginConnection {
A list of nodes. A list of nodes.
""" """
nodes: [CompatSsoLogin!]! nodes: [CompatSsoLogin!]!
"""
Identifies the total count of items in the connection.
"""
totalCount: Int!
} }
""" """

View File

@@ -216,6 +216,8 @@ export type CompatSsoLoginConnection = {
nodes: Array<CompatSsoLogin>; nodes: Array<CompatSsoLogin>;
/** Information to aid in pagination. */ /** Information to aid in pagination. */
pageInfo: PageInfo; pageInfo: PageInfo;
/** Identifies the total count of items in the connection. */
totalCount: Scalars["Int"]["output"];
}; };
/** An edge in a connection. */ /** An edge in a connection. */

View File

@@ -557,6 +557,17 @@ export default {
}, },
args: [], args: [],
}, },
{
name: "totalCount",
type: {
kind: "NON_NULL",
ofType: {
kind: "SCALAR",
name: "Any",
},
},
args: [],
},
], ],
interfaces: [], interfaces: [],
}, },