1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

storage: freeze the error type on BoxRepository

This avoids having to deal with traits bounds everywhere. It also moves
the `boxed()` method to the PgRepository, because it was unnecessary to
keep it on the `Repository` trait
This commit is contained in:
Quentin Gliech
2024-07-24 15:26:01 +02:00
parent 48c4c34e88
commit 144de0deb2
12 changed files with 41 additions and 48 deletions

View File

@@ -30,7 +30,7 @@ use mas_matrix::BoxHomeserverConnection;
use mas_matrix_synapse::SynapseConnection; use mas_matrix_synapse::SynapseConnection;
use mas_policy::{Policy, PolicyFactory}; use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use mas_templates::Templates; use mas_templates::Templates;
use opentelemetry::{ use opentelemetry::{
@@ -351,8 +351,6 @@ impl FromRequestParts<AppState> for BoxRepository {
histogram.record(duration_ms, &[]); histogram.record(duration_ms, &[]);
} }
Ok(repo Ok(repo.boxed())
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
} }
} }

View File

@@ -15,7 +15,7 @@
use std::{collections::HashMap, net::IpAddr}; use std::{collections::HashMap, net::IpAddr};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{user::BrowserSessionRepository, Repository, RepositoryAccess}; use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
use opentelemetry::{ use opentelemetry::{
metrics::{Counter, Histogram}, metrics::{Counter, Histogram},
Key, Key,

View File

@@ -40,9 +40,7 @@ use mas_axum_utils::{
use mas_data_model::{BrowserSession, Session, SiteConfig, User}; use mas_data_model::{BrowserSession, Session, SiteConfig, User};
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{ use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME}; use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
use rand::{thread_rng, SeedableRng}; use rand::{thread_rng, SeedableRng};
@@ -82,7 +80,7 @@ impl state::State for GraphQLState {
.await .await
.map_err(RepositoryError::from_error)?; .map_err(RepositoryError::from_error)?;
Ok(repo.map_err(RepositoryError::from_error).boxed()) Ok(repo.boxed())
} }
async fn policy(&self) -> Result<Policy, InstantiateError> { async fn policy(&self) -> Result<Policy, InstantiateError> {

View File

@@ -43,7 +43,7 @@ use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{BoxHomeserverConnection, HomeserverConnection, MockHomeserverConnection}; use mas_matrix::{BoxHomeserverConnection, HomeserverConnection, MockHomeserverConnection};
use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::{SimpleRoute, UrlBuilder}; use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::{DatabaseError, PgRepository}; use mas_storage_pg::{DatabaseError, PgRepository};
use mas_templates::{SiteConfigExt, Templates}; use mas_templates::{SiteConfigExt, Templates};
use rand::SeedableRng; use rand::SeedableRng;
@@ -272,9 +272,7 @@ impl TestState {
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> { pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
let repo = PgRepository::from_pool(&self.pool).await?; let repo = PgRepository::from_pool(&self.pool).await?;
Ok(repo Ok(repo.boxed())
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
} }
/// Returns a new random number generator. /// Returns a new random number generator.
@@ -330,9 +328,7 @@ impl graphql::State for TestGraphQLState {
.await .await
.map_err(mas_storage::RepositoryError::from_error)?; .map_err(mas_storage::RepositoryError::from_error)?;
Ok(repo Ok(repo.boxed())
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
} }
async fn policy(&self) -> Result<Policy, InstantiateError> { async fn policy(&self) -> Result<Policy, InstantiateError> {
@@ -500,9 +496,7 @@ impl FromRequestParts<TestState> for BoxRepository {
state: &TestState, state: &TestState,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let repo = PgRepository::from_pool(&state.pool).await?; let repo = PgRepository::from_pool(&state.pool).await?;
Ok(repo Ok(repo.boxed())
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
} }
} }

View File

@@ -36,7 +36,7 @@ mod tests {
CompatSessionRepository, CompatSsoLoginFilter, CompatSessionRepository, CompatSsoLoginFilter,
}, },
user::UserRepository, user::UserRepository,
Clock, Pagination, Repository, RepositoryAccess, Clock, Pagination, RepositoryAccess,
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;

View File

@@ -36,7 +36,7 @@ mod tests {
use mas_storage::{ use mas_storage::{
clock::MockClock, clock::MockClock,
oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository}, oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository},
Clock, Pagination, Repository, Clock, Pagination,
}; };
use oauth2_types::{ use oauth2_types::{
requests::{GrantType, ResponseMode}, requests::{GrantType, ResponseMode},

View File

@@ -31,7 +31,7 @@ use mas_storage::{
UpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository,
}, },
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, RepositoryAccess, RepositoryTransaction, BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
}; };
use sqlx::{PgConnection, PgPool, Postgres, Transaction}; use sqlx::{PgConnection, PgPool, Postgres, Transaction};
use tracing::Instrument; use tracing::Instrument;
@@ -76,6 +76,11 @@ impl PgRepository {
let txn = pool.begin().await?; let txn = pool.begin().await?;
Ok(Self::from_conn(txn)) Ok(Self::from_conn(txn))
} }
/// Transform the repository into a type-erased [`BoxRepository`]
pub fn boxed(self) -> BoxRepository {
Box::new(MapErr::new(self, RepositoryError::from_error))
}
} }
impl<C> PgRepository<C> { impl<C> PgRepository<C> {

View File

@@ -19,7 +19,7 @@ use mas_storage::{
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository, BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
UserFilter, UserPasswordRepository, UserRepository, UserFilter, UserPasswordRepository, UserRepository,
}, },
Pagination, Repository, RepositoryAccess, Pagination, RepositoryAccess,
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;

View File

@@ -104,6 +104,13 @@ impl Pagination {
self self
} }
/// Clear the before cursor
#[must_use]
pub const fn clear_before(mut self) -> Self {
self.before = None;
self
}
/// Get items after the given cursor /// Get items after the given cursor
#[must_use] #[must_use]
pub const fn after(mut self, id: Ulid) -> Self { pub const fn after(mut self, id: Ulid) -> Self {
@@ -111,6 +118,13 @@ impl Pagination {
self self
} }
/// Clear the after cursor
#[must_use]
pub const fn clear_after(mut self) -> Self {
self.after = None;
self
}
/// Process a page returned by a paginated query /// Process a page returned by a paginated query
#[must_use] #[must_use]
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> { pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {

View File

@@ -34,7 +34,6 @@ use crate::{
BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
UserRecoveryRepository, UserRepository, UserTermsRepository, UserRecoveryRepository, UserRepository, UserTermsRepository,
}, },
MapErr,
}; };
/// A [`Repository`] helps interacting with the underlying storage backend. /// A [`Repository`] helps interacting with the underlying storage backend.
@@ -43,21 +42,6 @@ pub trait Repository<E>:
where where
E: std::error::Error + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static,
{ {
/// Construct a (boxed) typed-erased repository
fn boxed(self) -> BoxRepository<E>
where
Self: Sync + Sized + 'static,
{
Box::new(self)
}
/// Map the error type of all the methods of a [`Repository`]
fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper>
where
Self: Sized,
{
MapErr::new(self, mapper)
}
} }
/// An opaque, type-erased error /// An opaque, type-erased error
@@ -80,7 +64,7 @@ impl RepositoryError {
} }
/// A type-erased [`Repository`] /// A type-erased [`Repository`]
pub type BoxRepository<E = RepositoryError> = Box<dyn Repository<E> + Send + Sync + 'static>; pub type BoxRepository = Box<dyn Repository<RepositoryError> + Send + Sync + 'static>;
/// A [`RepositoryTransaction`] can be saved or cancelled, after a series /// A [`RepositoryTransaction`] can be saved or cancelled, after a series
/// of operations. /// of operations.
@@ -113,7 +97,7 @@ pub trait RepositoryTransaction {
/// repository is used at a time. /// repository is used at a time.
/// ///
/// When adding a new repository, you should add a new method to this trait, and /// When adding a new repository, you should add a new method to this trait, and
/// update the implementations for [`MapErr`] and [`Box<R>`] below. /// update the implementations for [`crate::MapErr`] and [`Box<R>`] below.
/// ///
/// Note: this used to have generic associated types to avoid boxing all the /// Note: this used to have generic associated types to avoid boxing all the
/// repository traits, but that was removed because it made almost impossible to /// repository traits, but that was removed because it made almost impossible to
@@ -218,7 +202,7 @@ pub trait RepositoryAccess: Send {
} }
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and /// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
/// [`Repository`] for the [`MapErr`] wrapper and [`Box<R>`] /// [`Repository`] for the [`crate::MapErr`] wrapper and [`Box<R>`]
mod impls { mod impls {
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};

View File

@@ -26,7 +26,10 @@ pub struct MapErr<R, F> {
} }
impl<R, F> MapErr<R, F> { impl<R, F> MapErr<R, F> {
pub(crate) fn new(inner: R, mapper: F) -> Self { /// Create a new [`MapErr`] wrapper from an inner repository and a mapper
/// function
#[must_use]
pub fn new(inner: R, mapper: F) -> Self {
Self { Self {
inner, inner,
mapper, mapper,

View File

@@ -18,7 +18,7 @@ use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monito
use mas_email::Mailer; use mas_email::Mailer;
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock}; use mas_storage::{BoxClock, BoxRepository, SystemClock};
use mas_storage_pg::{DatabaseError, PgRepository}; use mas_storage_pg::{DatabaseError, PgRepository};
use rand::SeedableRng; use rand::SeedableRng;
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
@@ -83,10 +83,7 @@ impl State {
} }
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> { pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
let repo = PgRepository::from_pool(self.pool()) let repo = PgRepository::from_pool(self.pool()).await?.boxed();
.await?
.map_err(mas_storage::RepositoryError::from_error)
.boxed();
Ok(repo) Ok(repo)
} }