1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

storage: freeze the error type on BoxRepository

This avoids having to deal with traits bounds everywhere. It also moves
the `boxed()` method to the PgRepository, because it was unnecessary to
keep it on the `Repository` trait
This commit is contained in:
Quentin Gliech
2024-07-24 15:26:01 +02:00
parent 48c4c34e88
commit 144de0deb2
12 changed files with 41 additions and 48 deletions

View File

@@ -30,7 +30,7 @@ use mas_matrix::BoxHomeserverConnection;
use mas_matrix_synapse::SynapseConnection;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
use mas_storage_pg::PgRepository;
use mas_templates::Templates;
use opentelemetry::{
@@ -351,8 +351,6 @@ impl FromRequestParts<AppState> for BoxRepository {
histogram.record(duration_ms, &[]);
}
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
Ok(repo.boxed())
}
}

View File

@@ -15,7 +15,7 @@
use std::{collections::HashMap, net::IpAddr};
use chrono::{DateTime, Utc};
use mas_storage::{user::BrowserSessionRepository, Repository, RepositoryAccess};
use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
use opentelemetry::{
metrics::{Counter, Histogram},
Key,

View File

@@ -40,9 +40,7 @@ use mas_axum_utils::{
use mas_data_model::{BrowserSession, Session, SiteConfig, User};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
};
use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
use mas_storage_pg::PgRepository;
use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
use rand::{thread_rng, SeedableRng};
@@ -82,7 +80,7 @@ impl state::State for GraphQLState {
.await
.map_err(RepositoryError::from_error)?;
Ok(repo.map_err(RepositoryError::from_error).boxed())
Ok(repo.boxed())
}
async fn policy(&self) -> Result<Policy, InstantiateError> {

View File

@@ -43,7 +43,7 @@ use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_matrix::{BoxHomeserverConnection, HomeserverConnection, MockHomeserverConnection};
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::{SimpleRoute, UrlBuilder};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng};
use mas_storage_pg::{DatabaseError, PgRepository};
use mas_templates::{SiteConfigExt, Templates};
use rand::SeedableRng;
@@ -272,9 +272,7 @@ impl TestState {
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
let repo = PgRepository::from_pool(&self.pool).await?;
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
Ok(repo.boxed())
}
/// Returns a new random number generator.
@@ -330,9 +328,7 @@ impl graphql::State for TestGraphQLState {
.await
.map_err(mas_storage::RepositoryError::from_error)?;
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
Ok(repo.boxed())
}
async fn policy(&self) -> Result<Policy, InstantiateError> {
@@ -500,9 +496,7 @@ impl FromRequestParts<TestState> for BoxRepository {
state: &TestState,
) -> Result<Self, Self::Rejection> {
let repo = PgRepository::from_pool(&state.pool).await?;
Ok(repo
.map_err(mas_storage::RepositoryError::from_error)
.boxed())
Ok(repo.boxed())
}
}

View File

@@ -36,7 +36,7 @@ mod tests {
CompatSessionRepository, CompatSsoLoginFilter,
},
user::UserRepository,
Clock, Pagination, Repository, RepositoryAccess,
Clock, Pagination, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;

View File

@@ -36,7 +36,7 @@ mod tests {
use mas_storage::{
clock::MockClock,
oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository},
Clock, Pagination, Repository,
Clock, Pagination,
};
use oauth2_types::{
requests::{GrantType, ResponseMode},

View File

@@ -31,7 +31,7 @@ use mas_storage::{
UpstreamOAuthSessionRepository,
},
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, RepositoryAccess, RepositoryTransaction,
BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
};
use sqlx::{PgConnection, PgPool, Postgres, Transaction};
use tracing::Instrument;
@@ -76,6 +76,11 @@ impl PgRepository {
let txn = pool.begin().await?;
Ok(Self::from_conn(txn))
}
/// Transform the repository into a type-erased [`BoxRepository`]
pub fn boxed(self) -> BoxRepository {
Box::new(MapErr::new(self, RepositoryError::from_error))
}
}
impl<C> PgRepository<C> {

View File

@@ -19,7 +19,7 @@ use mas_storage::{
BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository,
UserFilter, UserPasswordRepository, UserRepository,
},
Pagination, Repository, RepositoryAccess,
Pagination, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;

View File

@@ -104,6 +104,13 @@ impl Pagination {
self
}
/// Clear the before cursor
#[must_use]
pub const fn clear_before(mut self) -> Self {
self.before = None;
self
}
/// Get items after the given cursor
#[must_use]
pub const fn after(mut self, id: Ulid) -> Self {
@@ -111,6 +118,13 @@ impl Pagination {
self
}
/// Clear the after cursor
#[must_use]
pub const fn clear_after(mut self) -> Self {
self.after = None;
self
}
/// Process a page returned by a paginated query
#[must_use]
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {

View File

@@ -34,7 +34,6 @@ use crate::{
BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
UserRecoveryRepository, UserRepository, UserTermsRepository,
},
MapErr,
};
/// A [`Repository`] helps interacting with the underlying storage backend.
@@ -43,21 +42,6 @@ pub trait Repository<E>:
where
E: std::error::Error + Send + Sync + 'static,
{
/// Construct a (boxed) typed-erased repository
fn boxed(self) -> BoxRepository<E>
where
Self: Sync + Sized + 'static,
{
Box::new(self)
}
/// Map the error type of all the methods of a [`Repository`]
fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper>
where
Self: Sized,
{
MapErr::new(self, mapper)
}
}
/// An opaque, type-erased error
@@ -80,7 +64,7 @@ impl RepositoryError {
}
/// A type-erased [`Repository`]
pub type BoxRepository<E = RepositoryError> = Box<dyn Repository<E> + Send + Sync + 'static>;
pub type BoxRepository = Box<dyn Repository<RepositoryError> + Send + Sync + 'static>;
/// A [`RepositoryTransaction`] can be saved or cancelled, after a series
/// of operations.
@@ -113,7 +97,7 @@ pub trait RepositoryTransaction {
/// repository is used at a time.
///
/// When adding a new repository, you should add a new method to this trait, and
/// update the implementations for [`MapErr`] and [`Box<R>`] below.
/// update the implementations for [`crate::MapErr`] and [`Box<R>`] below.
///
/// Note: this used to have generic associated types to avoid boxing all the
/// repository traits, but that was removed because it made almost impossible to
@@ -218,7 +202,7 @@ pub trait RepositoryAccess: Send {
}
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
/// [`Repository`] for the [`MapErr`] wrapper and [`Box<R>`]
/// [`Repository`] for the [`crate::MapErr`] wrapper and [`Box<R>`]
mod impls {
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};

View File

@@ -26,7 +26,10 @@ pub struct MapErr<R, F> {
}
impl<R, F> MapErr<R, F> {
pub(crate) fn new(inner: R, mapper: F) -> Self {
/// Create a new [`MapErr`] wrapper from an inner repository and a mapper
/// function
#[must_use]
pub fn new(inner: R, mapper: F) -> Self {
Self {
inner,
mapper,

View File

@@ -18,7 +18,7 @@ use apalis_core::{executor::TokioExecutor, layers::extensions::Extension, monito
use mas_email::Mailer;
use mas_matrix::HomeserverConnection;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, Repository, SystemClock};
use mas_storage::{BoxClock, BoxRepository, SystemClock};
use mas_storage_pg::{DatabaseError, PgRepository};
use rand::SeedableRng;
use sqlx::{Pool, Postgres};
@@ -83,10 +83,7 @@ impl State {
}
pub async fn repository(&self) -> Result<BoxRepository, DatabaseError> {
let repo = PgRepository::from_pool(self.pool())
.await?
.map_err(mas_storage::RepositoryError::from_error)
.boxed();
let repo = PgRepository::from_pool(self.pool()).await?.boxed();
Ok(repo)
}