1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

storage: methods to list and count users with filters and pagination

This commit is contained in:
Quentin Gliech
2024-07-05 11:53:31 +02:00
parent 9486460aae
commit e75df0752d
3 changed files with 317 additions and 12 deletions

View File

@@ -16,15 +16,16 @@
//! repositories
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::User;
use mas_storage::{user::UserRepository, Clock};
use rand::RngCore;
use sea_query::{Expr, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError};
use crate::{iden::Users, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError};
mod email;
mod password;
@@ -53,16 +54,29 @@ impl<'c> PgUserRepository<'c> {
}
}
#[derive(Debug, Clone)]
struct UserLookup {
user_id: Uuid,
username: String,
primary_user_email_id: Option<Uuid>,
created_at: DateTime<Utc>,
locked_at: Option<DateTime<Utc>>,
can_request_admin: bool,
mod priv_ {
// The enum_def macro generates a public enum, which we don't want, because it
// triggers the missing docs warning
#![allow(missing_docs)]
use chrono::{DateTime, Utc};
use sea_query::enum_def;
use uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow)]
#[enum_def]
pub(super) struct UserLookup {
pub(super) user_id: Uuid,
pub(super) username: String,
pub(super) primary_user_email_id: Option<Uuid>,
pub(super) created_at: DateTime<Utc>,
pub(super) locked_at: Option<DateTime<Utc>>,
pub(super) can_request_admin: bool,
}
}
use priv_::{UserLookup, UserLookupIden};
impl From<UserLookup> for User {
fn from(value: UserLookup) -> Self {
let id = value.user_id.into();
@@ -324,4 +338,103 @@ impl<'c> UserRepository for PgUserRepository<'c> {
Ok(user)
}
#[tracing::instrument(
name = "db.user.list",
skip_all,
fields(
db.statement,
),
err,
)]
async fn list(
&mut self,
filter: mas_storage::user::UserFilter<'_>,
pagination: mas_storage::Pagination,
) -> Result<mas_storage::Page<User>, Self::Error> {
let (sql, arguments) = Query::select()
.expr_as(
Expr::col((Users::Table, Users::UserId)),
UserLookupIden::UserId,
)
.expr_as(
Expr::col((Users::Table, Users::Username)),
UserLookupIden::Username,
)
.expr_as(
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
UserLookupIden::PrimaryUserEmailId,
)
.expr_as(
Expr::col((Users::Table, Users::CreatedAt)),
UserLookupIden::CreatedAt,
)
.expr_as(
Expr::col((Users::Table, Users::LockedAt)),
UserLookupIden::LockedAt,
)
.expr_as(
Expr::col((Users::Table, Users::CanRequestAdmin)),
UserLookupIden::CanRequestAdmin,
)
.from(Users::Table)
.and_where_option(filter.state().map(|state| {
if state.is_locked() {
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
} else {
Expr::col((Users::Table, Users::LockedAt)).is_null()
}
}))
.and_where_option(filter.can_request_admin().map(|can_request_admin| {
Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
}))
.generate_pagination((Users::Table, Users::UserId), pagination)
.build_sqlx(PostgresQueryBuilder);
let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(User::from);
Ok(page)
}
#[tracing::instrument(
name = "db.user.count",
skip_all,
fields(
db.statement,
),
err,
)]
async fn count(
&mut self,
filter: mas_storage::user::UserFilter<'_>,
) -> Result<usize, Self::Error> {
let (sql, arguments) = Query::select()
.expr(Expr::col((Users::Table, Users::UserId)).count())
.from(Users::Table)
.and_where_option(filter.state().map(|state| {
if state.is_locked() {
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
} else {
Expr::col((Users::Table, Users::LockedAt)).is_null()
}
}))
.and_where_option(filter.can_request_admin().map(|can_request_admin| {
Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
}))
.build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@@ -17,7 +17,7 @@ use mas_storage::{
clock::MockClock,
user::{
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
UserPasswordRepository, UserRepository,
UserFilter, UserPasswordRepository, UserRepository,
},
Pagination, Repository, RepositoryAccess,
};
@@ -36,6 +36,12 @@ async fn test_user_repo(pool: PgPool) {
let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();
let all = UserFilter::new();
let admin = all.can_request_admin_only();
let non_admin = all.cannot_request_admin_only();
let active = all.active_only();
let locked = all.locked_only();
// Initially, the user shouldn't exist
assert!(!repo.user().exists(USERNAME).await.unwrap());
assert!(repo
@@ -45,6 +51,12 @@ async fn test_user_repo(pool: PgPool) {
.unwrap()
.is_none());
assert_eq!(repo.user().count(all).await.unwrap(), 0);
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
assert_eq!(repo.user().count(active).await.unwrap(), 0);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
// Adding the user should work
let user = repo
.user()
@@ -62,6 +74,12 @@ async fn test_user_repo(pool: PgPool) {
.is_some());
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
assert_eq!(repo.user().count(all).await.unwrap(), 1);
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 1);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
// Adding a second time should give a conflict
// It should not poison the transaction though
assert!(repo
@@ -75,6 +93,12 @@ async fn test_user_repo(pool: PgPool) {
let user = repo.user().lock(&clock, user).await.unwrap();
assert!(!user.is_valid());
assert_eq!(repo.user().count(all).await.unwrap(), 1);
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 0);
assert_eq!(repo.user().count(locked).await.unwrap(), 1);
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
assert!(!user.is_valid());
@@ -99,6 +123,12 @@ async fn test_user_repo(pool: PgPool) {
let user = repo.user().set_can_request_admin(user, true).await.unwrap();
assert!(user.can_request_admin);
assert_eq!(repo.user().count(all).await.unwrap(), 1);
assert_eq!(repo.user().count(admin).await.unwrap(), 1);
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
assert_eq!(repo.user().count(active).await.unwrap(), 1);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
assert!(user.can_request_admin);
@@ -115,6 +145,47 @@ async fn test_user_repo(pool: PgPool) {
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
assert!(!user.can_request_admin);
assert_eq!(repo.user().count(all).await.unwrap(), 1);
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 1);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
// Check the list method
let list = repo.user().list(all, Pagination::first(10)).await.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].id, user.id);
let list = repo
.user()
.list(admin, Pagination::first(10))
.await
.unwrap();
assert_eq!(list.edges.len(), 0);
let list = repo
.user()
.list(non_admin, Pagination::first(10))
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].id, user.id);
let list = repo
.user()
.list(active, Pagination::first(10))
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].id, user.id);
let list = repo
.user()
.list(locked, Pagination::first(10))
.await
.unwrap();
assert_eq!(list.edges.len(), 0);
repo.save().await.unwrap();
}

View File

@@ -19,7 +19,7 @@ use mas_data_model::User;
use rand_core::RngCore;
use ulid::Ulid;
use crate::{repository_impl, Clock};
use crate::{repository_impl, Clock, Page, Pagination};
mod email;
mod password;
@@ -35,6 +35,94 @@ pub use self::{
terms::UserTermsRepository,
};
/// The state of a user account
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum UserState {
/// The account is locked, it has the `locked_at` timestamp set
Locked,
/// The account is active
Active,
}
impl UserState {
/// Returns `true` if the user state is [`Locked`].
///
/// [`Locked`]: UserState::Locked
#[must_use]
pub fn is_locked(&self) -> bool {
matches!(self, Self::Locked)
}
/// Returns `true` if the user state is [`Active`].
///
/// [`Active`]: UserState::Active
#[must_use]
pub fn is_active(&self) -> bool {
matches!(self, Self::Active)
}
}
/// Filter parameters for listing users
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UserFilter<'a> {
state: Option<UserState>,
can_request_admin: Option<bool>,
_phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a> UserFilter<'a> {
/// Create a new [`UserFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Filter for active users
#[must_use]
pub fn active_only(mut self) -> Self {
self.state = Some(UserState::Active);
self
}
/// Filter for locked users
#[must_use]
pub fn locked_only(mut self) -> Self {
self.state = Some(UserState::Locked);
self
}
/// Filter for users that can request admin privileges
#[must_use]
pub fn can_request_admin_only(mut self) -> Self {
self.can_request_admin = Some(true);
self
}
/// Filter for users that can't request admin privileges
#[must_use]
pub fn cannot_request_admin_only(mut self) -> Self {
self.can_request_admin = Some(false);
self
}
/// Get the state filter
///
/// Returns [`None`] if no state filter was set
#[must_use]
pub fn state(&self) -> Option<UserState> {
self.state
}
/// Get the can request admin filter
///
/// Returns [`None`] if no can request admin filter was set
#[must_use]
pub fn can_request_admin(&self) -> Option<bool> {
self.can_request_admin
}
}
/// A [`UserRepository`] helps interacting with [`User`] saved in the storage
/// backend
#[async_trait]
@@ -144,6 +232,33 @@ pub trait UserRepository: Send + Sync {
user: User,
can_request_admin: bool,
) -> Result<User, Self::Error>;
/// List [`User`] with the given filter and pagination
///
/// # Parameters
///
/// * `filter`: The filter parameters
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list(
&mut self,
filter: UserFilter<'_>,
pagination: Pagination,
) -> Result<Page<User>, Self::Error>;
/// Count the [`User`] with the given filter
///
/// # Parameters
///
/// * `filter`: The filter parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error>;
}
repository_impl!(UserRepository:
@@ -163,4 +278,10 @@ repository_impl!(UserRepository:
user: User,
can_request_admin: bool,
) -> Result<User, Self::Error>;
async fn list(
&mut self,
filter: UserFilter<'_>,
pagination: Pagination,
) -> Result<Page<User>, Self::Error>;
async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error>;
);