1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Better user emails pagination and filtering

This commit is contained in:
Quentin Gliech
2023-07-21 15:31:55 +02:00
parent 12ad572db8
commit a75a53cc24
9 changed files with 266 additions and 74 deletions

View File

@ -14,14 +14,14 @@
use async_graphql::{ use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor}, connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, Object, ID, Context, Description, Enum, Object, ID,
}; };
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{ use mas_storage::{
compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository}, compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository},
oauth2::OAuth2SessionRepository, oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository, upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository},
Pagination, RepositoryAccess, Pagination, RepositoryAccess,
}; };
@ -300,13 +300,16 @@ impl User {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only emails in the given state.")]
state_param: Option<UserEmailState>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")] #[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>, after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")] #[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>, before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> { ) -> Result<Connection<Cursor, UserEmail, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
@ -324,17 +327,29 @@ impl User {
.transpose()?; .transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?; let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let page = repo let filter = UserEmailFilter::new().for_user(&self.0);
.user_email()
.list_paginated(&self.0, pagination) let filter = match state_param {
.await?; Some(UserEmailState::Pending) => filter.pending_only(),
Some(UserEmailState::Confirmed) => filter.verified_only(),
None => filter,
};
let page = repo.user_email().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.user_email().count(filter).await?)
} else {
None
};
repo.cancel().await?; repo.cancel().await?;
let mut connection = Connection::with_additional_fields( let mut connection = Connection::with_additional_fields(
page.has_previous_page, page.has_previous_page,
page.has_next_page, page.has_next_page,
UserEmailsPagination(self.0.clone()), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|u| { connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new( Edge::new(
@ -493,16 +508,12 @@ impl UserEmail {
} }
} }
pub struct UserEmailsPagination(mas_data_model::User); /// The state of a compatibility session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum UserEmailState {
/// The email address is pending confirmation.
Pending,
#[Object] /// The email address has been confirmed.
impl UserEmailsPagination { Confirmed,
/// Identifies the total count of items in the connection.
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let count = repo.user_email().count(&self.0).await?;
repo.cancel().await?;
Ok(count)
}
} }

View File

@ -31,6 +31,16 @@ pub enum Users {
PrimaryUserEmailId, PrimaryUserEmailId,
} }
#[derive(sea_query::Iden)]
pub enum UserEmails {
Table,
UserEmailId,
UserId,
Email,
CreatedAt,
ConfirmedAt,
}
#[derive(sea_query::Iden)] #[derive(sea_query::Iden)]
pub enum CompatSessions { pub enum CompatSessions {
Table, Table,

View File

@ -15,15 +15,20 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState};
use mas_storage::{user::UserEmailRepository, Clock, Page, Pagination}; use mas_storage::{
user::{UserEmailFilter, UserEmailRepository},
Clock, Page, Pagination,
};
use rand::RngCore; use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder}; use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, iden::UserEmails, pagination::QueryBuilderExt, sea_query_sqlx::map_values, tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError,
}; };
/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection /// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
@ -40,6 +45,7 @@ impl<'c> PgUserEmailRepository<'c> {
} }
#[derive(Debug, Clone, sqlx::FromRow)] #[derive(Debug, Clone, sqlx::FromRow)]
#[enum_def]
struct UserEmailLookup { struct UserEmailLookup {
user_email_id: Uuid, user_email_id: Uuid,
user_id: Uuid, user_id: Uuid,
@ -225,42 +231,65 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
} }
#[tracing::instrument( #[tracing::instrument(
name = "db.user_email.list_paginated", name = "db.user_email.list",
skip_all, skip_all,
fields( fields(
db.statement, db.statement,
%user.id,
), ),
err, err,
)] )]
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: UserEmailFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<UserEmail>, DatabaseError> { ) -> Result<Page<UserEmail>, DatabaseError> {
let mut query = QueryBuilder::new( let (sql, values) = Query::select()
r#" .expr_as(
SELECT user_email_id Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
, user_id UserEmailLookupIden::UserEmailId,
, email )
, created_at .expr_as(
, confirmed_at Expr::col((UserEmails::Table, UserEmails::UserId)),
FROM user_emails UserEmailLookupIden::UserId,
"#, )
); .expr_as(
Expr::col((UserEmails::Table, UserEmails::Email)),
UserEmailLookupIden::Email,
)
.expr_as(
Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
UserEmailLookupIden::CreatedAt,
)
.expr_as(
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)),
UserEmailLookupIden::ConfirmedAt,
)
.from(UserEmails::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_verified() {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_not_null()
} else {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_null()
}
}))
.generate_pagination(
(UserEmails::Table, UserEmails::UserEmailId).into_column_ref(),
pagination,
)
.build(PostgresQueryBuilder);
query let arguments = map_values(values);
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("user_email_id", pagination);
let edges: Vec<UserEmailLookup> = query let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
.build_query_as()
.traced() .traced()
.fetch_all(&mut *self.conn) .fetch_all(&mut *self.conn)
.await?; .await?;
let page = pagination.process(edges).map(UserEmail::from); let page = pagination.process(edges).map(UserEmail::from);
Ok(page) Ok(page)
} }
@ -269,28 +298,35 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
skip_all, skip_all,
fields( fields(
db.statement, db.statement,
%user.id,
), ),
err, err,
)] )]
async fn count(&mut self, user: &User) -> Result<usize, Self::Error> { async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!( let (sql, values) = Query::select()
r#" .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
SELECT COUNT(*) .from(UserEmails::Table)
FROM user_emails .and_where_option(filter.user().map(|user| {
WHERE user_id = $1 Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
"#, }))
Uuid::from(user.id), .and_where_option(filter.state().map(|state| {
) if state.is_verified() {
.traced() Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_not_null()
.fetch_one(&mut *self.conn) } else {
.await?; Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_null()
}
}))
.build(PostgresQueryBuilder);
let res = res.unwrap_or_default(); let arguments = map_values(values);
Ok(res let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into() .try_into()
.map_err(DatabaseError::to_invalid_operation)?) .map_err(DatabaseError::to_invalid_operation)
} }
#[tracing::instrument( #[tracing::instrument(

View File

@ -16,7 +16,7 @@ use chrono::Duration;
use mas_storage::{ use mas_storage::{
clock::MockClock, clock::MockClock,
user::{ user::{
BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository, BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
UserPasswordRepository, UserRepository, UserPasswordRepository, UserRepository,
}, },
Pagination, Repository, RepositoryAccess, Pagination, Repository, RepositoryAccess,
@ -98,7 +98,14 @@ async fn test_user_email_repo(pool: PgPool) {
.unwrap() .unwrap()
.is_none()); .is_none());
assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); let all = UserEmailFilter::new().for_user(&user);
let pending = all.pending_only();
let verified = all.verified_only();
// Check the counts
assert_eq!(repo.user_email().count(all).await.unwrap(), 0);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 0);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 0);
let user_email = repo let user_email = repo
.user_email() .user_email()
@ -110,7 +117,10 @@ async fn test_user_email_repo(pool: PgPool) {
assert_eq!(user_email.email, EMAIL); assert_eq!(user_email.email, EMAIL);
assert!(user_email.confirmed_at.is_none()); assert!(user_email.confirmed_at.is_none());
assert_eq!(repo.user_email().count(&user).await.unwrap(), 1); // Check the counts
assert_eq!(repo.user_email().count(all).await.unwrap(), 1);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 1);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 0);
assert!(repo assert!(repo
.user_email() .user_email()
@ -181,6 +191,11 @@ async fn test_user_email_repo(pool: PgPool) {
.await .await
.unwrap(); .unwrap();
// Check the counts
assert_eq!(repo.user_email().count(all).await.unwrap(), 1);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 0);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 1);
// Reload the user_email // Reload the user_email
let user_email = repo let user_email = repo
.user_email() .user_email()
@ -236,16 +251,35 @@ async fn test_user_email_repo(pool: PgPool) {
// Listing the user emails should work // Listing the user emails should work
let emails = repo let emails = repo
.user_email() .user_email()
.list_paginated(&user, Pagination::first(10)) .list(all, Pagination::first(10))
.await .await
.unwrap(); .unwrap();
assert!(!emails.has_next_page); assert!(!emails.has_next_page);
assert_eq!(emails.edges.len(), 1); assert_eq!(emails.edges.len(), 1);
assert_eq!(emails.edges[0], user_email); assert_eq!(emails.edges[0], user_email);
let emails = repo
.user_email()
.list(verified, Pagination::first(10))
.await
.unwrap();
assert!(!emails.has_next_page);
assert_eq!(emails.edges.len(), 1);
assert_eq!(emails.edges[0], user_email);
let emails = repo
.user_email()
.list(pending, Pagination::first(10))
.await
.unwrap();
assert!(!emails.has_next_page);
assert!(emails.edges.is_empty());
// Deleting the user email should work // Deleting the user email should work
repo.user_email().remove(user_email).await.unwrap(); repo.user_email().remove(user_email).await.unwrap();
assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); assert_eq!(repo.user_email().count(all).await.unwrap(), 0);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 0);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 0);
// Reload the user // Reload the user
let user = repo let user = repo

View File

@ -19,6 +19,76 @@ use ulid::Ulid;
use crate::{pagination::Page, repository_impl, Clock, Pagination}; use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum UserEmailState {
Pending,
Verified,
}
impl UserEmailState {
/// Returns true if the filter should only return non-verified emails
pub fn is_pending(self) -> bool {
matches!(self, Self::Pending)
}
/// Returns true if the filter should only return verified emails
pub fn is_verified(self) -> bool {
matches!(self, Self::Verified)
}
}
/// Filter parameters for listing user emails
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UserEmailFilter<'a> {
user: Option<&'a User>,
state: Option<UserEmailState>,
}
impl<'a> UserEmailFilter<'a> {
/// Create a new [`UserEmailFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Filter for emails of a specific user
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
///
/// Returns [`None`] if no user filter is set
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// Filter for emails that are verified
#[must_use]
pub fn verified_only(mut self) -> Self {
self.state = Some(UserEmailState::Verified);
self
}
/// Filter for emails that are not verified
#[must_use]
pub fn pending_only(mut self) -> Self {
self.state = Some(UserEmailState::Pending);
self
}
/// Get the state filter
///
/// Returns [`None`] if no state filter is set
#[must_use]
pub fn state(&self) -> Option<UserEmailState> {
self.state
}
}
/// A [`UserEmailRepository`] helps interacting with [`UserEmail`] saved in the /// A [`UserEmailRepository`] helps interacting with [`UserEmail`] saved in the
/// storage backend /// storage backend
#[async_trait] #[async_trait]
@ -77,32 +147,32 @@ pub trait UserEmailRepository: Send + Sync {
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error>; async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error>;
/// List [`UserEmail`] of a [`User`] with the given pagination /// List [`UserEmail`] with the given filter and pagination
/// ///
/// # Parameters /// # Parameters
/// ///
/// * `user`: The [`User`] for whom to lookup the [`UserEmail`] /// * `filter`: The filter parameters
/// * `pagination`: The pagination parameters /// * `pagination`: The pagination parameters
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: UserEmailFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<UserEmail>, Self::Error>; ) -> Result<Page<UserEmail>, Self::Error>;
/// Count the [`UserEmail`] of a [`User`] /// Count the [`UserEmail`] with the given filter
/// ///
/// # Parameters /// # Parameters
/// ///
/// * `user`: The [`User`] for whom to count the [`UserEmail`] /// * `filter`: The filter parameters
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, user: &User) -> Result<usize, Self::Error>; async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error>;
/// Create a new [`UserEmail`] for a [`User`] /// Create a new [`UserEmail`] for a [`User`]
/// ///
@ -235,12 +305,12 @@ repository_impl!(UserEmailRepository:
async fn get_primary(&mut self, user: &User) -> Result<Option<UserEmail>, Self::Error>; async fn get_primary(&mut self, user: &User) -> Result<Option<UserEmail>, Self::Error>;
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error>; async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error>;
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: UserEmailFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<UserEmail>, Self::Error>; ) -> Result<Page<UserEmail>, Self::Error>;
async fn count(&mut self, user: &User) -> Result<usize, Self::Error>; async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error>;
async fn add( async fn add(
&mut self, &mut self,

View File

@ -26,7 +26,7 @@ mod password;
mod session; mod session;
pub use self::{ pub use self::{
email::UserEmailRepository, email::{UserEmailFilter, UserEmailRepository},
password::UserPasswordRepository, password::UserPasswordRepository,
session::{BrowserSessionFilter, BrowserSessionRepository}, session::{BrowserSessionFilter, BrowserSessionRepository},
}; };

View File

@ -937,6 +937,7 @@ type User implements Node {
Get the list of emails, chronologically sorted Get the list of emails, chronologically sorted
""" """
emails( emails(
state: UserEmailState
after: String after: String
before: String before: String
first: Int first: Int
@ -1018,6 +1019,20 @@ type UserEmailEdge {
cursor: String! cursor: String!
} }
"""
The state of a compatibility session.
"""
enum UserEmailState {
"""
The email address is pending confirmation.
"""
PENDING
"""
The email address has been confirmed.
"""
CONFIRMED
}
""" """
The input for the `verifyEmail` mutation The input for the `verifyEmail` mutation
""" """

View File

@ -727,6 +727,7 @@ export type UserEmailsArgs = {
before?: InputMaybe<Scalars["String"]["input"]>; before?: InputMaybe<Scalars["String"]["input"]>;
first?: InputMaybe<Scalars["Int"]["input"]>; first?: InputMaybe<Scalars["Int"]["input"]>;
last?: InputMaybe<Scalars["Int"]["input"]>; last?: InputMaybe<Scalars["Int"]["input"]>;
state?: InputMaybe<UserEmailState>;
}; };
/** A user is an individual's account. */ /** A user is an individual's account. */
@ -783,6 +784,14 @@ export type UserEmailEdge = {
node: UserEmail; node: UserEmail;
}; };
/** The state of a compatibility session. */
export enum UserEmailState {
/** The email address has been confirmed. */
Confirmed = "CONFIRMED",
/** The email address is pending confirmation. */
Pending = "PENDING",
}
/** The input for the `verifyEmail` mutation */ /** The input for the `verifyEmail` mutation */
export type VerifyEmailInput = { export type VerifyEmailInput = {
/** The verification code */ /** The verification code */

View File

@ -2147,6 +2147,13 @@ export default {
name: "Any", name: "Any",
}, },
}, },
{
name: "state",
type: {
kind: "SCALAR",
name: "Any",
},
},
], ],
}, },
{ {