1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Better user emails pagination and filtering

This commit is contained in:
Quentin Gliech
2023-07-21 15:31:55 +02:00
parent 12ad572db8
commit a75a53cc24
9 changed files with 266 additions and 74 deletions

View File

@ -14,14 +14,14 @@
use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, Object, ID,
Context, Description, Enum, Object, ID,
};
use chrono::{DateTime, Utc};
use mas_storage::{
compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository},
oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository},
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository},
Pagination, RepositoryAccess,
};
@ -300,13 +300,16 @@ impl User {
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only emails in the given state.")]
state_param: Option<UserEmailState>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UserEmail, UserEmailsPagination>, async_graphql::Error> {
) -> Result<Connection<Cursor, UserEmail, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
@ -324,17 +327,29 @@ impl User {
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let page = repo
.user_email()
.list_paginated(&self.0, pagination)
.await?;
let filter = UserEmailFilter::new().for_user(&self.0);
let filter = match state_param {
Some(UserEmailState::Pending) => filter.pending_only(),
Some(UserEmailState::Confirmed) => filter.verified_only(),
None => filter,
};
let page = repo.user_email().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.user_email().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
UserEmailsPagination(self.0.clone()),
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
@ -493,16 +508,12 @@ impl UserEmail {
}
}
pub struct UserEmailsPagination(mas_data_model::User);
/// The state of a compatibility session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum UserEmailState {
/// The email address is pending confirmation.
Pending,
#[Object]
impl UserEmailsPagination {
/// Identifies the total count of items in the connection.
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let count = repo.user_email().count(&self.0).await?;
repo.cancel().await?;
Ok(count)
}
/// The email address has been confirmed.
Confirmed,
}

View File

@ -31,6 +31,16 @@ pub enum Users {
PrimaryUserEmailId,
}
#[derive(sea_query::Iden)]
pub enum UserEmails {
Table,
UserEmailId,
UserId,
Email,
CreatedAt,
ConfirmedAt,
}
#[derive(sea_query::Iden)]
pub enum CompatSessions {
Table,

View File

@ -15,15 +15,20 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState};
use mas_storage::{user::UserEmailRepository, Clock, Page, Pagination};
use mas_storage::{
user::{UserEmailFilter, UserEmailRepository},
Clock, Page, Pagination,
};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
iden::UserEmails, pagination::QueryBuilderExt, sea_query_sqlx::map_values, tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError,
};
/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
@ -40,6 +45,7 @@ impl<'c> PgUserEmailRepository<'c> {
}
#[derive(Debug, Clone, sqlx::FromRow)]
#[enum_def]
struct UserEmailLookup {
user_email_id: Uuid,
user_id: Uuid,
@ -225,42 +231,65 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
}
#[tracing::instrument(
name = "db.user_email.list_paginated",
name = "db.user_email.list",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: UserEmailFilter<'_>,
pagination: Pagination,
) -> Result<Page<UserEmail>, DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT user_email_id
, user_id
, email
, created_at
, confirmed_at
FROM user_emails
"#,
);
let (sql, values) = Query::select()
.expr_as(
Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
UserEmailLookupIden::UserEmailId,
)
.expr_as(
Expr::col((UserEmails::Table, UserEmails::UserId)),
UserEmailLookupIden::UserId,
)
.expr_as(
Expr::col((UserEmails::Table, UserEmails::Email)),
UserEmailLookupIden::Email,
)
.expr_as(
Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
UserEmailLookupIden::CreatedAt,
)
.expr_as(
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)),
UserEmailLookupIden::ConfirmedAt,
)
.from(UserEmails::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_verified() {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_not_null()
} else {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_null()
}
}))
.generate_pagination(
(UserEmails::Table, UserEmails::UserEmailId).into_column_ref(),
pagination,
)
.build(PostgresQueryBuilder);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("user_email_id", pagination);
let arguments = map_values(values);
let edges: Vec<UserEmailLookup> = query
.build_query_as()
let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UserEmail::from);
Ok(page)
}
@ -269,28 +298,35 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn count(&mut self, user: &User) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
FROM user_emails
WHERE user_id = $1
"#,
Uuid::from(user.id),
)
async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
let (sql, values) = Query::select()
.expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
.from(UserEmails::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_verified() {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_not_null()
} else {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_null()
}
}))
.build(PostgresQueryBuilder);
let arguments = map_values(values);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
let res = res.unwrap_or_default();
Ok(res
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)?)
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(

View File

@ -16,7 +16,7 @@ use chrono::Duration;
use mas_storage::{
clock::MockClock,
user::{
BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository,
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
UserPasswordRepository, UserRepository,
},
Pagination, Repository, RepositoryAccess,
@ -98,7 +98,14 @@ async fn test_user_email_repo(pool: PgPool) {
.unwrap()
.is_none());
assert_eq!(repo.user_email().count(&user).await.unwrap(), 0);
let all = UserEmailFilter::new().for_user(&user);
let pending = all.pending_only();
let verified = all.verified_only();
// Check the counts
assert_eq!(repo.user_email().count(all).await.unwrap(), 0);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 0);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 0);
let user_email = repo
.user_email()
@ -110,7 +117,10 @@ async fn test_user_email_repo(pool: PgPool) {
assert_eq!(user_email.email, EMAIL);
assert!(user_email.confirmed_at.is_none());
assert_eq!(repo.user_email().count(&user).await.unwrap(), 1);
// Check the counts
assert_eq!(repo.user_email().count(all).await.unwrap(), 1);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 1);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 0);
assert!(repo
.user_email()
@ -181,6 +191,11 @@ async fn test_user_email_repo(pool: PgPool) {
.await
.unwrap();
// Check the counts
assert_eq!(repo.user_email().count(all).await.unwrap(), 1);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 0);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 1);
// Reload the user_email
let user_email = repo
.user_email()
@ -236,16 +251,35 @@ async fn test_user_email_repo(pool: PgPool) {
// Listing the user emails should work
let emails = repo
.user_email()
.list_paginated(&user, Pagination::first(10))
.list(all, Pagination::first(10))
.await
.unwrap();
assert!(!emails.has_next_page);
assert_eq!(emails.edges.len(), 1);
assert_eq!(emails.edges[0], user_email);
let emails = repo
.user_email()
.list(verified, Pagination::first(10))
.await
.unwrap();
assert!(!emails.has_next_page);
assert_eq!(emails.edges.len(), 1);
assert_eq!(emails.edges[0], user_email);
let emails = repo
.user_email()
.list(pending, Pagination::first(10))
.await
.unwrap();
assert!(!emails.has_next_page);
assert!(emails.edges.is_empty());
// Deleting the user email should work
repo.user_email().remove(user_email).await.unwrap();
assert_eq!(repo.user_email().count(&user).await.unwrap(), 0);
assert_eq!(repo.user_email().count(all).await.unwrap(), 0);
assert_eq!(repo.user_email().count(pending).await.unwrap(), 0);
assert_eq!(repo.user_email().count(verified).await.unwrap(), 0);
// Reload the user
let user = repo

View File

@ -19,6 +19,76 @@ use ulid::Ulid;
use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum UserEmailState {
Pending,
Verified,
}
impl UserEmailState {
/// Returns true if the filter should only return non-verified emails
pub fn is_pending(self) -> bool {
matches!(self, Self::Pending)
}
/// Returns true if the filter should only return verified emails
pub fn is_verified(self) -> bool {
matches!(self, Self::Verified)
}
}
/// Filter parameters for listing user emails
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UserEmailFilter<'a> {
user: Option<&'a User>,
state: Option<UserEmailState>,
}
impl<'a> UserEmailFilter<'a> {
/// Create a new [`UserEmailFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Filter for emails of a specific user
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
///
/// Returns [`None`] if no user filter is set
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// Filter for emails that are verified
#[must_use]
pub fn verified_only(mut self) -> Self {
self.state = Some(UserEmailState::Verified);
self
}
/// Filter for emails that are not verified
#[must_use]
pub fn pending_only(mut self) -> Self {
self.state = Some(UserEmailState::Pending);
self
}
/// Get the state filter
///
/// Returns [`None`] if no state filter is set
#[must_use]
pub fn state(&self) -> Option<UserEmailState> {
self.state
}
}
/// A [`UserEmailRepository`] helps interacting with [`UserEmail`] saved in the
/// storage backend
#[async_trait]
@ -77,32 +147,32 @@ pub trait UserEmailRepository: Send + Sync {
/// Returns [`Self::Error`] if the underlying repository fails
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error>;
/// List [`UserEmail`] of a [`User`] with the given pagination
/// List [`UserEmail`] with the given filter and pagination
///
/// # Parameters
///
/// * `user`: The [`User`] for whom to lookup the [`UserEmail`]
/// * `filter`: The filter parameters
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: UserEmailFilter<'_>,
pagination: Pagination,
) -> Result<Page<UserEmail>, Self::Error>;
/// Count the [`UserEmail`] of a [`User`]
/// Count the [`UserEmail`] with the given filter
///
/// # Parameters
///
/// * `user`: The [`User`] for whom to count the [`UserEmail`]
/// * `filter`: The filter parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, user: &User) -> Result<usize, Self::Error>;
async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error>;
/// Create a new [`UserEmail`] for a [`User`]
///
@ -235,12 +305,12 @@ repository_impl!(UserEmailRepository:
async fn get_primary(&mut self, user: &User) -> Result<Option<UserEmail>, Self::Error>;
async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error>;
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: UserEmailFilter<'_>,
pagination: Pagination,
) -> Result<Page<UserEmail>, Self::Error>;
async fn count(&mut self, user: &User) -> Result<usize, Self::Error>;
async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error>;
async fn add(
&mut self,

View File

@ -26,7 +26,7 @@ mod password;
mod session;
pub use self::{
email::UserEmailRepository,
email::{UserEmailFilter, UserEmailRepository},
password::UserPasswordRepository,
session::{BrowserSessionFilter, BrowserSessionRepository},
};

View File

@ -937,6 +937,7 @@ type User implements Node {
Get the list of emails, chronologically sorted
"""
emails(
state: UserEmailState
after: String
before: String
first: Int
@ -1018,6 +1019,20 @@ type UserEmailEdge {
cursor: String!
}
"""
The state of a compatibility session.
"""
enum UserEmailState {
"""
The email address is pending confirmation.
"""
PENDING
"""
The email address has been confirmed.
"""
CONFIRMED
}
"""
The input for the `verifyEmail` mutation
"""

View File

@ -727,6 +727,7 @@ export type UserEmailsArgs = {
before?: InputMaybe<Scalars["String"]["input"]>;
first?: InputMaybe<Scalars["Int"]["input"]>;
last?: InputMaybe<Scalars["Int"]["input"]>;
state?: InputMaybe<UserEmailState>;
};
/** A user is an individual's account. */
@ -783,6 +784,14 @@ export type UserEmailEdge = {
node: UserEmail;
};
/** The state of a compatibility session. */
export enum UserEmailState {
/** The email address has been confirmed. */
Confirmed = "CONFIRMED",
/** The email address is pending confirmation. */
Pending = "PENDING",
}
/** The input for the `verifyEmail` mutation */
export type VerifyEmailInput = {
/** The verification code */

View File

@ -2147,6 +2147,13 @@ export default {
name: "Any",
},
},
{
name: "state",
type: {
kind: "SCALAR",
name: "Any",
},
},
],
},
{