You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
storage: methods to list and count users with filters and pagination
This commit is contained in:
@@ -16,15 +16,16 @@
|
||||
//! repositories
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::User;
|
||||
use mas_storage::{user::UserRepository, Clock};
|
||||
use rand::RngCore;
|
||||
use sea_query::{Expr, PostgresQueryBuilder, Query};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{tracing::ExecuteExt, DatabaseError};
|
||||
use crate::{iden::Users, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError};
|
||||
|
||||
mod email;
|
||||
mod password;
|
||||
@@ -53,16 +54,29 @@ impl<'c> PgUserRepository<'c> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct UserLookup {
|
||||
user_id: Uuid,
|
||||
username: String,
|
||||
primary_user_email_id: Option<Uuid>,
|
||||
created_at: DateTime<Utc>,
|
||||
locked_at: Option<DateTime<Utc>>,
|
||||
can_request_admin: bool,
|
||||
mod priv_ {
|
||||
// The enum_def macro generates a public enum, which we don't want, because it
|
||||
// triggers the missing docs warning
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use sea_query::enum_def;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
#[enum_def]
|
||||
pub(super) struct UserLookup {
|
||||
pub(super) user_id: Uuid,
|
||||
pub(super) username: String,
|
||||
pub(super) primary_user_email_id: Option<Uuid>,
|
||||
pub(super) created_at: DateTime<Utc>,
|
||||
pub(super) locked_at: Option<DateTime<Utc>>,
|
||||
pub(super) can_request_admin: bool,
|
||||
}
|
||||
}
|
||||
|
||||
use priv_::{UserLookup, UserLookupIden};
|
||||
|
||||
impl From<UserLookup> for User {
|
||||
fn from(value: UserLookup) -> Self {
|
||||
let id = value.user_id.into();
|
||||
@@ -324,4 +338,103 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.list",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list(
|
||||
&mut self,
|
||||
filter: mas_storage::user::UserFilter<'_>,
|
||||
pagination: mas_storage::Pagination,
|
||||
) -> Result<mas_storage::Page<User>, Self::Error> {
|
||||
let (sql, arguments) = Query::select()
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::UserId)),
|
||||
UserLookupIden::UserId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::Username)),
|
||||
UserLookupIden::Username,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
|
||||
UserLookupIden::PrimaryUserEmailId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::CreatedAt)),
|
||||
UserLookupIden::CreatedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::LockedAt)),
|
||||
UserLookupIden::LockedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::CanRequestAdmin)),
|
||||
UserLookupIden::CanRequestAdmin,
|
||||
)
|
||||
.from(Users::Table)
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_locked() {
|
||||
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
|
||||
} else {
|
||||
Expr::col((Users::Table, Users::LockedAt)).is_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.can_request_admin().map(|can_request_admin| {
|
||||
Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
|
||||
}))
|
||||
.generate_pagination((Users::Table, Users::UserId), pagination)
|
||||
.build_sqlx(PostgresQueryBuilder);
|
||||
|
||||
let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let page = pagination.process(edges).map(User::from);
|
||||
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.user.count",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn count(
|
||||
&mut self,
|
||||
filter: mas_storage::user::UserFilter<'_>,
|
||||
) -> Result<usize, Self::Error> {
|
||||
let (sql, arguments) = Query::select()
|
||||
.expr(Expr::col((Users::Table, Users::UserId)).count())
|
||||
.from(Users::Table)
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_locked() {
|
||||
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
|
||||
} else {
|
||||
Expr::col((Users::Table, Users::LockedAt)).is_null()
|
||||
}
|
||||
}))
|
||||
.and_where_option(filter.can_request_admin().map(|can_request_admin| {
|
||||
Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
|
||||
}))
|
||||
.build_sqlx(PostgresQueryBuilder);
|
||||
|
||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
count
|
||||
.try_into()
|
||||
.map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
}
|
||||
|
@@ -17,7 +17,7 @@ use mas_storage::{
|
||||
clock::MockClock,
|
||||
user::{
|
||||
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
|
||||
UserPasswordRepository, UserRepository,
|
||||
UserFilter, UserPasswordRepository, UserRepository,
|
||||
},
|
||||
Pagination, Repository, RepositoryAccess,
|
||||
};
|
||||
@@ -36,6 +36,12 @@ async fn test_user_repo(pool: PgPool) {
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
let clock = MockClock::default();
|
||||
|
||||
let all = UserFilter::new();
|
||||
let admin = all.can_request_admin_only();
|
||||
let non_admin = all.cannot_request_admin_only();
|
||||
let active = all.active_only();
|
||||
let locked = all.locked_only();
|
||||
|
||||
// Initially, the user shouldn't exist
|
||||
assert!(!repo.user().exists(USERNAME).await.unwrap());
|
||||
assert!(repo
|
||||
@@ -45,6 +51,12 @@ async fn test_user_repo(pool: PgPool) {
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
assert_eq!(repo.user().count(all).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(active).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||
|
||||
// Adding the user should work
|
||||
let user = repo
|
||||
.user()
|
||||
@@ -62,6 +74,12 @@ async fn test_user_repo(pool: PgPool) {
|
||||
.is_some());
|
||||
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
|
||||
|
||||
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(active).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||
|
||||
// Adding a second time should give a conflict
|
||||
// It should not poison the transaction though
|
||||
assert!(repo
|
||||
@@ -75,6 +93,12 @@ async fn test_user_repo(pool: PgPool) {
|
||||
let user = repo.user().lock(&clock, user).await.unwrap();
|
||||
assert!(!user.is_valid());
|
||||
|
||||
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(active).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(locked).await.unwrap(), 1);
|
||||
|
||||
// Check that the property is retrieved on lookup
|
||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||
assert!(!user.is_valid());
|
||||
@@ -99,6 +123,12 @@ async fn test_user_repo(pool: PgPool) {
|
||||
let user = repo.user().set_can_request_admin(user, true).await.unwrap();
|
||||
assert!(user.can_request_admin);
|
||||
|
||||
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(admin).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(active).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||
|
||||
// Check that the property is retrieved on lookup
|
||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||
assert!(user.can_request_admin);
|
||||
@@ -115,6 +145,47 @@ async fn test_user_repo(pool: PgPool) {
|
||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||
assert!(!user.can_request_admin);
|
||||
|
||||
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(active).await.unwrap(), 1);
|
||||
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||
|
||||
// Check the list method
|
||||
let list = repo.user().list(all, Pagination::first(10)).await.unwrap();
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0].id, user.id);
|
||||
|
||||
let list = repo
|
||||
.user()
|
||||
.list(admin, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(list.edges.len(), 0);
|
||||
|
||||
let list = repo
|
||||
.user()
|
||||
.list(non_admin, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0].id, user.id);
|
||||
|
||||
let list = repo
|
||||
.user()
|
||||
.list(active, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(list.edges.len(), 1);
|
||||
assert_eq!(list.edges[0].id, user.id);
|
||||
|
||||
let list = repo
|
||||
.user()
|
||||
.list(locked, Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(list.edges.len(), 0);
|
||||
|
||||
repo.save().await.unwrap();
|
||||
}
|
||||
|
||||
|
@@ -19,7 +19,7 @@ use mas_data_model::User;
|
||||
use rand_core::RngCore;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::{repository_impl, Clock};
|
||||
use crate::{repository_impl, Clock, Page, Pagination};
|
||||
|
||||
mod email;
|
||||
mod password;
|
||||
@@ -35,6 +35,94 @@ pub use self::{
|
||||
terms::UserTermsRepository,
|
||||
};
|
||||
|
||||
/// The state of a user account
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum UserState {
|
||||
/// The account is locked, it has the `locked_at` timestamp set
|
||||
Locked,
|
||||
|
||||
/// The account is active
|
||||
Active,
|
||||
}
|
||||
|
||||
impl UserState {
|
||||
/// Returns `true` if the user state is [`Locked`].
|
||||
///
|
||||
/// [`Locked`]: UserState::Locked
|
||||
#[must_use]
|
||||
pub fn is_locked(&self) -> bool {
|
||||
matches!(self, Self::Locked)
|
||||
}
|
||||
|
||||
/// Returns `true` if the user state is [`Active`].
|
||||
///
|
||||
/// [`Active`]: UserState::Active
|
||||
#[must_use]
|
||||
pub fn is_active(&self) -> bool {
|
||||
matches!(self, Self::Active)
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter parameters for listing users
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub struct UserFilter<'a> {
|
||||
state: Option<UserState>,
|
||||
can_request_admin: Option<bool>,
|
||||
_phantom: std::marker::PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
impl<'a> UserFilter<'a> {
|
||||
/// Create a new [`UserFilter`] with default values
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Filter for active users
|
||||
#[must_use]
|
||||
pub fn active_only(mut self) -> Self {
|
||||
self.state = Some(UserState::Active);
|
||||
self
|
||||
}
|
||||
|
||||
/// Filter for locked users
|
||||
#[must_use]
|
||||
pub fn locked_only(mut self) -> Self {
|
||||
self.state = Some(UserState::Locked);
|
||||
self
|
||||
}
|
||||
|
||||
/// Filter for users that can request admin privileges
|
||||
#[must_use]
|
||||
pub fn can_request_admin_only(mut self) -> Self {
|
||||
self.can_request_admin = Some(true);
|
||||
self
|
||||
}
|
||||
|
||||
/// Filter for users that can't request admin privileges
|
||||
#[must_use]
|
||||
pub fn cannot_request_admin_only(mut self) -> Self {
|
||||
self.can_request_admin = Some(false);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the state filter
|
||||
///
|
||||
/// Returns [`None`] if no state filter was set
|
||||
#[must_use]
|
||||
pub fn state(&self) -> Option<UserState> {
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Get the can request admin filter
|
||||
///
|
||||
/// Returns [`None`] if no can request admin filter was set
|
||||
#[must_use]
|
||||
pub fn can_request_admin(&self) -> Option<bool> {
|
||||
self.can_request_admin
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`UserRepository`] helps interacting with [`User`] saved in the storage
|
||||
/// backend
|
||||
#[async_trait]
|
||||
@@ -144,6 +232,33 @@ pub trait UserRepository: Send + Sync {
|
||||
user: User,
|
||||
can_request_admin: bool,
|
||||
) -> Result<User, Self::Error>;
|
||||
|
||||
/// List [`User`] with the given filter and pagination
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `filter`: The filter parameters
|
||||
/// * `pagination`: The pagination parameters
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn list(
|
||||
&mut self,
|
||||
filter: UserFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<User>, Self::Error>;
|
||||
|
||||
/// Count the [`User`] with the given filter
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `filter`: The filter parameters
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(UserRepository:
|
||||
@@ -163,4 +278,10 @@ repository_impl!(UserRepository:
|
||||
user: User,
|
||||
can_request_admin: bool,
|
||||
) -> Result<User, Self::Error>;
|
||||
async fn list(
|
||||
&mut self,
|
||||
filter: UserFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<User>, Self::Error>;
|
||||
async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error>;
|
||||
);
|
||||
|
Reference in New Issue
Block a user