You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
storage: methods to list and count users with filters and pagination
This commit is contained in:
@@ -16,15 +16,16 @@
|
|||||||
//! repositories
|
//! repositories
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
use mas_data_model::User;
|
use mas_data_model::User;
|
||||||
use mas_storage::{user::UserRepository, Clock};
|
use mas_storage::{user::UserRepository, Clock};
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
|
use sea_query::{Expr, PostgresQueryBuilder, Query};
|
||||||
|
use sea_query_binder::SqlxBinder;
|
||||||
use sqlx::PgConnection;
|
use sqlx::PgConnection;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{tracing::ExecuteExt, DatabaseError};
|
use crate::{iden::Users, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError};
|
||||||
|
|
||||||
mod email;
|
mod email;
|
||||||
mod password;
|
mod password;
|
||||||
@@ -53,16 +54,29 @@ impl<'c> PgUserRepository<'c> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
mod priv_ {
|
||||||
struct UserLookup {
|
// The enum_def macro generates a public enum, which we don't want, because it
|
||||||
user_id: Uuid,
|
// triggers the missing docs warning
|
||||||
username: String,
|
#![allow(missing_docs)]
|
||||||
primary_user_email_id: Option<Uuid>,
|
|
||||||
created_at: DateTime<Utc>,
|
use chrono::{DateTime, Utc};
|
||||||
locked_at: Option<DateTime<Utc>>,
|
use sea_query::enum_def;
|
||||||
can_request_admin: bool,
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||||
|
#[enum_def]
|
||||||
|
pub(super) struct UserLookup {
|
||||||
|
pub(super) user_id: Uuid,
|
||||||
|
pub(super) username: String,
|
||||||
|
pub(super) primary_user_email_id: Option<Uuid>,
|
||||||
|
pub(super) created_at: DateTime<Utc>,
|
||||||
|
pub(super) locked_at: Option<DateTime<Utc>>,
|
||||||
|
pub(super) can_request_admin: bool,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use priv_::{UserLookup, UserLookupIden};
|
||||||
|
|
||||||
impl From<UserLookup> for User {
|
impl From<UserLookup> for User {
|
||||||
fn from(value: UserLookup) -> Self {
|
fn from(value: UserLookup) -> Self {
|
||||||
let id = value.user_id.into();
|
let id = value.user_id.into();
|
||||||
@@ -324,4 +338,103 @@ impl<'c> UserRepository for PgUserRepository<'c> {
|
|||||||
|
|
||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.user.list",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn list(
|
||||||
|
&mut self,
|
||||||
|
filter: mas_storage::user::UserFilter<'_>,
|
||||||
|
pagination: mas_storage::Pagination,
|
||||||
|
) -> Result<mas_storage::Page<User>, Self::Error> {
|
||||||
|
let (sql, arguments) = Query::select()
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((Users::Table, Users::UserId)),
|
||||||
|
UserLookupIden::UserId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((Users::Table, Users::Username)),
|
||||||
|
UserLookupIden::Username,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
|
||||||
|
UserLookupIden::PrimaryUserEmailId,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((Users::Table, Users::CreatedAt)),
|
||||||
|
UserLookupIden::CreatedAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((Users::Table, Users::LockedAt)),
|
||||||
|
UserLookupIden::LockedAt,
|
||||||
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((Users::Table, Users::CanRequestAdmin)),
|
||||||
|
UserLookupIden::CanRequestAdmin,
|
||||||
|
)
|
||||||
|
.from(Users::Table)
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_locked() {
|
||||||
|
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((Users::Table, Users::LockedAt)).is_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.can_request_admin().map(|can_request_admin| {
|
||||||
|
Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
|
||||||
|
}))
|
||||||
|
.generate_pagination((Users::Table, Users::UserId), pagination)
|
||||||
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
|
||||||
|
.traced()
|
||||||
|
.fetch_all(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let page = pagination.process(edges).map(User::from);
|
||||||
|
|
||||||
|
Ok(page)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.user.count",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn count(
|
||||||
|
&mut self,
|
||||||
|
filter: mas_storage::user::UserFilter<'_>,
|
||||||
|
) -> Result<usize, Self::Error> {
|
||||||
|
let (sql, arguments) = Query::select()
|
||||||
|
.expr(Expr::col((Users::Table, Users::UserId)).count())
|
||||||
|
.from(Users::Table)
|
||||||
|
.and_where_option(filter.state().map(|state| {
|
||||||
|
if state.is_locked() {
|
||||||
|
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((Users::Table, Users::LockedAt)).is_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.and_where_option(filter.can_request_admin().map(|can_request_admin| {
|
||||||
|
Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
|
||||||
|
}))
|
||||||
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||||
|
.traced()
|
||||||
|
.fetch_one(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
count
|
||||||
|
.try_into()
|
||||||
|
.map_err(DatabaseError::to_invalid_operation)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@@ -17,7 +17,7 @@ use mas_storage::{
|
|||||||
clock::MockClock,
|
clock::MockClock,
|
||||||
user::{
|
user::{
|
||||||
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
|
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
|
||||||
UserPasswordRepository, UserRepository,
|
UserFilter, UserPasswordRepository, UserRepository,
|
||||||
},
|
},
|
||||||
Pagination, Repository, RepositoryAccess,
|
Pagination, Repository, RepositoryAccess,
|
||||||
};
|
};
|
||||||
@@ -36,6 +36,12 @@ async fn test_user_repo(pool: PgPool) {
|
|||||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||||
let clock = MockClock::default();
|
let clock = MockClock::default();
|
||||||
|
|
||||||
|
let all = UserFilter::new();
|
||||||
|
let admin = all.can_request_admin_only();
|
||||||
|
let non_admin = all.cannot_request_admin_only();
|
||||||
|
let active = all.active_only();
|
||||||
|
let locked = all.locked_only();
|
||||||
|
|
||||||
// Initially, the user shouldn't exist
|
// Initially, the user shouldn't exist
|
||||||
assert!(!repo.user().exists(USERNAME).await.unwrap());
|
assert!(!repo.user().exists(USERNAME).await.unwrap());
|
||||||
assert!(repo
|
assert!(repo
|
||||||
@@ -45,6 +51,12 @@ async fn test_user_repo(pool: PgPool) {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.is_none());
|
.is_none());
|
||||||
|
|
||||||
|
assert_eq!(repo.user().count(all).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(active).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||||
|
|
||||||
// Adding the user should work
|
// Adding the user should work
|
||||||
let user = repo
|
let user = repo
|
||||||
.user()
|
.user()
|
||||||
@@ -62,6 +74,12 @@ async fn test_user_repo(pool: PgPool) {
|
|||||||
.is_some());
|
.is_some());
|
||||||
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
|
assert!(repo.user().lookup(user.id).await.unwrap().is_some());
|
||||||
|
|
||||||
|
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(active).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||||
|
|
||||||
// Adding a second time should give a conflict
|
// Adding a second time should give a conflict
|
||||||
// It should not poison the transaction though
|
// It should not poison the transaction though
|
||||||
assert!(repo
|
assert!(repo
|
||||||
@@ -75,6 +93,12 @@ async fn test_user_repo(pool: PgPool) {
|
|||||||
let user = repo.user().lock(&clock, user).await.unwrap();
|
let user = repo.user().lock(&clock, user).await.unwrap();
|
||||||
assert!(!user.is_valid());
|
assert!(!user.is_valid());
|
||||||
|
|
||||||
|
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(active).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(locked).await.unwrap(), 1);
|
||||||
|
|
||||||
// Check that the property is retrieved on lookup
|
// Check that the property is retrieved on lookup
|
||||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||||
assert!(!user.is_valid());
|
assert!(!user.is_valid());
|
||||||
@@ -99,6 +123,12 @@ async fn test_user_repo(pool: PgPool) {
|
|||||||
let user = repo.user().set_can_request_admin(user, true).await.unwrap();
|
let user = repo.user().set_can_request_admin(user, true).await.unwrap();
|
||||||
assert!(user.can_request_admin);
|
assert!(user.can_request_admin);
|
||||||
|
|
||||||
|
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(admin).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(active).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||||
|
|
||||||
// Check that the property is retrieved on lookup
|
// Check that the property is retrieved on lookup
|
||||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||||
assert!(user.can_request_admin);
|
assert!(user.can_request_admin);
|
||||||
@@ -115,6 +145,47 @@ async fn test_user_repo(pool: PgPool) {
|
|||||||
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
|
||||||
assert!(!user.can_request_admin);
|
assert!(!user.can_request_admin);
|
||||||
|
|
||||||
|
assert_eq!(repo.user().count(all).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
|
||||||
|
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(active).await.unwrap(), 1);
|
||||||
|
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
|
||||||
|
|
||||||
|
// Check the list method
|
||||||
|
let list = repo.user().list(all, Pagination::first(10)).await.unwrap();
|
||||||
|
assert_eq!(list.edges.len(), 1);
|
||||||
|
assert_eq!(list.edges[0].id, user.id);
|
||||||
|
|
||||||
|
let list = repo
|
||||||
|
.user()
|
||||||
|
.list(admin, Pagination::first(10))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(list.edges.len(), 0);
|
||||||
|
|
||||||
|
let list = repo
|
||||||
|
.user()
|
||||||
|
.list(non_admin, Pagination::first(10))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(list.edges.len(), 1);
|
||||||
|
assert_eq!(list.edges[0].id, user.id);
|
||||||
|
|
||||||
|
let list = repo
|
||||||
|
.user()
|
||||||
|
.list(active, Pagination::first(10))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(list.edges.len(), 1);
|
||||||
|
assert_eq!(list.edges[0].id, user.id);
|
||||||
|
|
||||||
|
let list = repo
|
||||||
|
.user()
|
||||||
|
.list(locked, Pagination::first(10))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(list.edges.len(), 0);
|
||||||
|
|
||||||
repo.save().await.unwrap();
|
repo.save().await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -19,7 +19,7 @@ use mas_data_model::User;
|
|||||||
use rand_core::RngCore;
|
use rand_core::RngCore;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
use crate::{repository_impl, Clock};
|
use crate::{repository_impl, Clock, Page, Pagination};
|
||||||
|
|
||||||
mod email;
|
mod email;
|
||||||
mod password;
|
mod password;
|
||||||
@@ -35,6 +35,94 @@ pub use self::{
|
|||||||
terms::UserTermsRepository,
|
terms::UserTermsRepository,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// The state of a user account
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum UserState {
|
||||||
|
/// The account is locked, it has the `locked_at` timestamp set
|
||||||
|
Locked,
|
||||||
|
|
||||||
|
/// The account is active
|
||||||
|
Active,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserState {
|
||||||
|
/// Returns `true` if the user state is [`Locked`].
|
||||||
|
///
|
||||||
|
/// [`Locked`]: UserState::Locked
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_locked(&self) -> bool {
|
||||||
|
matches!(self, Self::Locked)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the user state is [`Active`].
|
||||||
|
///
|
||||||
|
/// [`Active`]: UserState::Active
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_active(&self) -> bool {
|
||||||
|
matches!(self, Self::Active)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter parameters for listing users
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||||
|
pub struct UserFilter<'a> {
|
||||||
|
state: Option<UserState>,
|
||||||
|
can_request_admin: Option<bool>,
|
||||||
|
_phantom: std::marker::PhantomData<&'a ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> UserFilter<'a> {
|
||||||
|
/// Create a new [`UserFilter`] with default values
|
||||||
|
#[must_use]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter for active users
|
||||||
|
#[must_use]
|
||||||
|
pub fn active_only(mut self) -> Self {
|
||||||
|
self.state = Some(UserState::Active);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter for locked users
|
||||||
|
#[must_use]
|
||||||
|
pub fn locked_only(mut self) -> Self {
|
||||||
|
self.state = Some(UserState::Locked);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter for users that can request admin privileges
|
||||||
|
#[must_use]
|
||||||
|
pub fn can_request_admin_only(mut self) -> Self {
|
||||||
|
self.can_request_admin = Some(true);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter for users that can't request admin privileges
|
||||||
|
#[must_use]
|
||||||
|
pub fn cannot_request_admin_only(mut self) -> Self {
|
||||||
|
self.can_request_admin = Some(false);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the state filter
|
||||||
|
///
|
||||||
|
/// Returns [`None`] if no state filter was set
|
||||||
|
#[must_use]
|
||||||
|
pub fn state(&self) -> Option<UserState> {
|
||||||
|
self.state
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the can request admin filter
|
||||||
|
///
|
||||||
|
/// Returns [`None`] if no can request admin filter was set
|
||||||
|
#[must_use]
|
||||||
|
pub fn can_request_admin(&self) -> Option<bool> {
|
||||||
|
self.can_request_admin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A [`UserRepository`] helps interacting with [`User`] saved in the storage
|
/// A [`UserRepository`] helps interacting with [`User`] saved in the storage
|
||||||
/// backend
|
/// backend
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -144,6 +232,33 @@ pub trait UserRepository: Send + Sync {
|
|||||||
user: User,
|
user: User,
|
||||||
can_request_admin: bool,
|
can_request_admin: bool,
|
||||||
) -> Result<User, Self::Error>;
|
) -> Result<User, Self::Error>;
|
||||||
|
|
||||||
|
/// List [`User`] with the given filter and pagination
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `filter`: The filter parameters
|
||||||
|
/// * `pagination`: The pagination parameters
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn list(
|
||||||
|
&mut self,
|
||||||
|
filter: UserFilter<'_>,
|
||||||
|
pagination: Pagination,
|
||||||
|
) -> Result<Page<User>, Self::Error>;
|
||||||
|
|
||||||
|
/// Count the [`User`] with the given filter
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `filter`: The filter parameters
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
repository_impl!(UserRepository:
|
repository_impl!(UserRepository:
|
||||||
@@ -163,4 +278,10 @@ repository_impl!(UserRepository:
|
|||||||
user: User,
|
user: User,
|
||||||
can_request_admin: bool,
|
can_request_admin: bool,
|
||||||
) -> Result<User, Self::Error>;
|
) -> Result<User, Self::Error>;
|
||||||
|
async fn list(
|
||||||
|
&mut self,
|
||||||
|
filter: UserFilter<'_>,
|
||||||
|
pagination: Pagination,
|
||||||
|
) -> Result<Page<User>, Self::Error>;
|
||||||
|
async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error>;
|
||||||
);
|
);
|
||||||
|
Reference in New Issue
Block a user