1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-28 11:02:02 +03:00

storage: split the repository trait

This commit is contained in:
Quentin Gliech
2023-01-24 16:04:18 +01:00
parent 6a8c79c497
commit d14ca156ad
18 changed files with 401 additions and 308 deletions

View File

@ -1,4 +1,4 @@
doc-valid-idents = ["OpenID", "OAuth", ".."]
doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"]
disallowed-methods = [
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },

View File

@ -31,7 +31,7 @@ use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter;
use mas_storage::{oauth2::OAuth2ClientRepository, Repository};
use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess};
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value;
use thiserror::Error;
@ -74,7 +74,7 @@ pub enum Credentials {
impl Credentials {
pub async fn fetch<E>(
&self,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut (impl RepositoryAccess<Error = E> + ?Sized),
) -> Result<Option<Client>, E> {
let client_id = match self {
Credentials::None { client_id }

View File

@ -14,7 +14,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession;
use mas_storage::{user::BrowserSessionRepository, Repository};
use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
use serde::{Deserialize, Serialize};
use ulid::Ulid;
@ -45,7 +45,7 @@ impl SessionInfo {
/// Load the [`BrowserSession`] from database
pub async fn load_session<E>(
&self,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<BrowserSession>, E> {
let session_id = if let Some(id) = self.current {
id

View File

@ -29,7 +29,7 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode
use mas_data_model::Session;
use mas_storage::{
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
Clock, Repository,
Clock, RepositoryAccess,
};
use serde::{de::DeserializeOwned, Deserialize};
use thiserror::Error;
@ -53,7 +53,7 @@ enum AccessToken {
impl AccessToken {
async fn fetch<E>(
&self,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t,
@ -86,7 +86,7 @@ impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter
pub async fn protected_form<E>(
self,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let form = match self.form {
@ -106,7 +106,7 @@ impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter
pub async fn protected<E>(
self,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?;

View File

@ -21,7 +21,7 @@ use mas_storage::{
oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, SystemClock,
Repository, RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope;

View File

@ -15,7 +15,7 @@
use oauth2_types::scope::ScopeToken;
use rand::{
distributions::{Alphanumeric, DistString},
Rng,
RngCore,
};
use serde::Serialize;
use thiserror::Error;
@ -48,7 +48,7 @@ impl Device {
}
/// Generate a random device ID
pub fn generate<R: Rng + ?Sized>(rng: &mut R) -> Self {
pub fn generate<R: RngCore + ?Sized>(rng: &mut R) -> Self {
let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH);
Self { id }
}

View File

@ -28,7 +28,9 @@ use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer;
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository};
use mas_storage::{
user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng};
use serde::Deserialize;
@ -71,7 +73,7 @@ async fn render<E: std::error::Error>(
templates: Templates,
session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
@ -88,7 +90,7 @@ async fn render<E: std::error::Error>(
async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>(
mailer: &Mailer,
repo: &mut (impl Repository<Error = E> + ?Sized),
repo: &mut impl RepositoryAccess<Error = E>,
mut rng: impl Rng + Send,
clock: &impl Clock,
user: &User,

View File

@ -26,7 +26,7 @@ use mas_keystore::Encrypter;
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Clock, Repository,
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
@ -161,7 +161,7 @@ pub(crate) async fn post(
// TODO: move that logic elsewhere?
async fn login(
password_manager: PasswordManager,
repo: &mut (impl Repository + ?Sized),
repo: &mut impl RepositoryAccess,
mut rng: impl Rng + CryptoRng + Send,
clock: &impl Clock,
username: &str,
@ -235,7 +235,7 @@ async fn render(
ctx: LoginContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
repo: &mut (impl Repository + ?Sized),
repo: &mut impl RepositoryAccess,
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(repo).await?;

View File

@ -33,7 +33,7 @@ use mas_policy::PolicyFactory;
use mas_router::Route;
use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Repository,
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
@ -233,7 +233,7 @@ async fn render(
ctx: RegisterContext,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
repo: &mut (impl Repository + ?Sized),
repo: &mut impl RepositoryAccess,
templates: &Templates,
) -> Result<String, FancyError> {
let next = action.load_context(repo).await?;

View File

@ -18,7 +18,7 @@ use mas_storage::{
compat::CompatSsoLoginRepository,
oauth2::OAuth2AuthorizationGrantRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
Repository,
RepositoryAccess,
};
use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize};
@ -42,7 +42,7 @@ impl OptionalPostAuthAction {
pub async fn load_context<'a>(
&'a self,
repo: &'a mut (impl Repository + ?Sized),
repo: &'a mut impl RepositoryAccess,
) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action {

View File

@ -32,7 +32,7 @@ mod tests {
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
},
user::UserRepository,
Clock, Repository,
Clock, Repository, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! Interactions with the database
//! An implementation of the storage traits for a PostgreSQL database
#![forbid(unsafe_code)]
#![deny(

View File

@ -27,7 +27,7 @@ use mas_storage::{
UpstreamOAuthSessionRepository,
},
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository,
Repository, RepositoryAccess, RepositoryTransaction,
};
use sqlx::{PgPool, Postgres, Transaction};
@ -62,7 +62,9 @@ impl PgRepository {
}
}
impl Repository for PgRepository {
impl Repository<DatabaseError> for PgRepository {}
impl RepositoryTransaction for PgRepository {
type Error = DatabaseError;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
@ -72,6 +74,10 @@ impl Repository for PgRepository {
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.rollback().map_err(DatabaseError::from).boxed()
}
}
impl RepositoryAccess for PgRepository {
type Error = DatabaseError;
fn upstream_oauth_link<'c>(
&'c mut self,

View File

@ -31,7 +31,7 @@ mod tests {
UpstreamOAuthSessionRepository,
},
user::UserRepository,
Pagination, Repository,
Pagination, RepositoryAccess,
};
use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng;

View File

@ -16,7 +16,7 @@ use chrono::Duration;
use mas_storage::{
clock::MockClock,
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository,
Repository, RepositoryAccess,
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! Interactions with the database
//! Interactions with the storage backend
#![forbid(unsafe_code)]
#![deny(
@ -42,20 +42,25 @@ pub mod user;
pub use self::{
clock::{Clock, SystemClock},
pagination::{Page, Pagination},
repository::{BoxRepository, Repository, RepositoryError},
repository::{
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
},
};
pub struct MapErr<Repository, Mapper> {
inner: Repository,
mapper: Mapper,
/// A wrapper which is used to map the error type of a repository to another
pub struct MapErr<R, F> {
inner: R,
mapper: F,
}
impl<Repository, Mapper> MapErr<Repository, Mapper> {
fn new(inner: Repository, mapper: Mapper) -> Self {
impl<R, F> MapErr<R, F> {
fn new(inner: R, mapper: F) -> Self {
Self { inner, mapper }
}
}
/// A macro to implement a repository trait for the [`MapErr`] wrapper and for
/// [`Box<R>`]
#[macro_export]
macro_rules! repository_impl {
($repo_trait:ident:

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use futures_util::future::BoxFuture;
use thiserror::Error;
use crate::{
@ -32,83 +32,27 @@ use crate::{
MapErr,
};
pub trait Repository: Send {
type Error: std::error::Error + Send + Sync + 'static;
/// A [`Repository`] helps interacting with the underlying storage backend.
pub trait Repository<E>:
RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send
where
E: std::error::Error + Send + Sync + 'static,
{
/// Construct a (boxed) typed-erased repository
fn boxed(self) -> BoxRepository<E>
where
Self: Sync + Sized + 'static,
{
Box::new(self)
}
/// Map the error type of all the methods of a [`Repository`]
fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper>
where
Self: Sized,
{
MapErr::new(self, mapper)
}
fn boxed(self) -> BoxRepository<Self::Error>
where
Self: Sized + Sync + 'static,
{
Box::new(self)
}
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
fn upstream_oauth_provider<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c>;
fn upstream_oauth_session<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c>;
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c>;
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c>;
fn user_password<'c>(&'c mut self)
-> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c>;
fn browser_session<'c>(
&'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
fn oauth2_client<'c>(&'c mut self)
-> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
fn oauth2_authorization_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c>;
fn oauth2_session<'c>(
&'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c>;
fn oauth2_access_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c>;
fn oauth2_refresh_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c>;
fn compat_sso_login<'c>(
&'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c>;
fn compat_access_token<'c>(
&'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c>;
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
}
/// An opaque, type-erased error
@ -119,6 +63,7 @@ pub struct RepositoryError {
}
impl RepositoryError {
/// Construct a [`RepositoryError`] from any error kind
pub fn from_error<E>(value: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
@ -129,16 +74,155 @@ impl RepositoryError {
}
}
pub type BoxRepository<E = RepositoryError> =
Box<dyn Repository<Error = E> + Send + Sync + 'static>;
/// A type-erased [`Repository`]
pub type BoxRepository<E = RepositoryError> = Box<dyn Repository<E> + Send + Sync + 'static>;
impl<R, F, E> Repository for crate::MapErr<R, F>
where
R: Repository,
/// A [`RepositoryTransaction`] can be saved or cancelled, after a series
/// of operations.
pub trait RepositoryTransaction {
/// The error type used by the [`Self::save`] and [`Self::cancel`] functions
type Error;
/// Commit the transaction
///
/// # Errors
///
/// Returns an error if the underlying storage backend failed to commit the
/// transaction.
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
/// Rollback the transaction
///
/// # Errors
///
/// Returns an error if the underlying storage backend failed to rollback
/// the transaction.
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
}
/// Access the various repositories the backend implements.
pub trait RepositoryAccess: Send {
/// The backend-specific error type used by each repository.
type Error: std::error::Error + Send + Sync + 'static;
/// Get an [`UpstreamOAuthLinkRepository`]
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
/// Get an [`UpstreamOAuthProviderRepository`]
fn upstream_oauth_provider<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c>;
/// Get an [`UpstreamOAuthSessionRepository`]
fn upstream_oauth_session<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c>;
/// Get an [`UserRepository`]
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c>;
/// Get an [`UserEmailRepository`]
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c>;
/// Get an [`UserPasswordRepository`]
fn user_password<'c>(&'c mut self)
-> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c>;
/// Get a [`BrowserSessionRepository`]
fn browser_session<'c>(
&'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
/// Get an [`OAuth2ClientRepository`]
fn oauth2_client<'c>(&'c mut self)
-> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
/// Get an [`OAuth2AuthorizationGrantRepository`]
fn oauth2_authorization_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c>;
/// Get an [`OAuth2SessionRepository`]
fn oauth2_session<'c>(
&'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c>;
/// Get an [`OAuth2AccessTokenRepository`]
fn oauth2_access_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c>;
/// Get an [`OAuth2RefreshTokenRepository`]
fn oauth2_refresh_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
/// Get a [`CompatSessionRepository`]
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c>;
/// Get a [`CompatSsoLoginRepository`]
fn compat_sso_login<'c>(
&'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c>;
/// Get a [`CompatAccessTokenRepository`]
fn compat_access_token<'c>(
&'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c>;
/// Get a [`CompatRefreshTokenRepository`]
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
}
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
/// [`Repository`] for the [`MapErr`] wrapper and [`Box<R>`]
mod impls {
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use super::RepositoryAccess;
use crate::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
user::{
BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository,
},
MapErr, Repository, RepositoryTransaction,
};
// --- Repository ---
impl<R, F, E1, E2> Repository<E2> for MapErr<R, F>
where
R: Repository<E1> + RepositoryAccess<Error = E1> + RepositoryTransaction<Error = E1>,
F: FnMut(E1) -> E2 + Send + Sync + 'static,
E1: std::error::Error + Send + Sync + 'static,
E2: std::error::Error + Send + Sync + 'static,
{
}
// --- RepositoryTransaction --
impl<R, F, E> RepositoryTransaction for MapErr<R, F>
where
R: RepositoryTransaction,
R::Error: 'static,
F: FnMut(R::Error) -> E + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
E: std::error::Error,
{
type Error = E;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
@ -148,6 +232,17 @@ where
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).cancel().map_err(self.mapper).boxed()
}
}
// --- RepositoryAccess --
impl<R, F, E> RepositoryAccess for MapErr<R, F>
where
R: RepositoryAccess,
R::Error: 'static,
F: FnMut(R::Error) -> E + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
type Error = E;
fn upstream_oauth_link<'c>(
&'c mut self,
@ -264,27 +359,11 @@ where
&mut self.mapper,
))
}
}
}
impl<R: Repository + ?Sized> Repository for Box<R> {
impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
type Error = R::Error;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>
where
Self: Sized,
{
// This shouldn't be callable?
unimplemented!()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>
where
Self: Sized,
{
// This shouldn't be callable?
unimplemented!()
}
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
@ -376,4 +455,5 @@ impl<R: Repository + ?Sized> Repository for Box<R> {
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
(**self).compat_refresh_token()
}
}
}

View File

@ -14,7 +14,7 @@
//! Database-related tasks
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, SystemClock};
use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess, SystemClock};
use mas_storage_pg::PgRepository;
use sqlx::{Pool, Postgres};
use tracing::{debug, error, info};