1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: split the repository trait

This commit is contained in:
Quentin Gliech
2023-01-24 16:04:18 +01:00
parent 6a8c79c497
commit d14ca156ad
18 changed files with 401 additions and 308 deletions

View File

@ -1,4 +1,4 @@
doc-valid-idents = ["OpenID", "OAuth", ".."] doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"]
disallowed-methods = [ disallowed-methods = [
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" }, { path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },

View File

@ -31,7 +31,7 @@ use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess};
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use serde_json::Value; use serde_json::Value;
use thiserror::Error; use thiserror::Error;
@ -74,7 +74,7 @@ pub enum Credentials {
impl Credentials { impl Credentials {
pub async fn fetch<E>( pub async fn fetch<E>(
&self, &self,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut (impl RepositoryAccess<Error = E> + ?Sized),
) -> Result<Option<Client>, E> { ) -> Result<Option<Client>, E> {
let client_id = match self { let client_id = match self {
Credentials::None { client_id } Credentials::None { client_id }

View File

@ -14,7 +14,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
use mas_storage::{user::BrowserSessionRepository, Repository}; use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ulid::Ulid; use ulid::Ulid;
@ -45,7 +45,7 @@ impl SessionInfo {
/// Load the [`BrowserSession`] from database /// Load the [`BrowserSession`] from database
pub async fn load_session<E>( pub async fn load_session<E>(
&self, &self,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<BrowserSession>, E> { ) -> Result<Option<BrowserSession>, E> {
let session_id = if let Some(id) = self.current { let session_id = if let Some(id) = self.current {
id id

View File

@ -29,7 +29,7 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode
use mas_data_model::Session; use mas_data_model::Session;
use mas_storage::{ use mas_storage::{
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
Clock, Repository, Clock, RepositoryAccess,
}; };
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use thiserror::Error; use thiserror::Error;
@ -53,7 +53,7 @@ enum AccessToken {
impl AccessToken { impl AccessToken {
async fn fetch<E>( async fn fetch<E>(
&self, &self,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> { ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
let token = match self { let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::Form(t) | AccessToken::Header(t) => t,
@ -86,7 +86,7 @@ impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter // TODO: take scopes to validate as parameter
pub async fn protected_form<E>( pub async fn protected_form<E>(
self, self,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock, clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<E>> { ) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let form = match self.form { let form = match self.form {
@ -106,7 +106,7 @@ impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter // TODO: take scopes to validate as parameter
pub async fn protected<E>( pub async fn protected<E>(
self, self,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock, clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<E>> { ) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?; let (token, session) = self.access_token.fetch(repo).await?;

View File

@ -21,7 +21,7 @@ use mas_storage::{
oauth2::OAuth2ClientRepository, oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository, upstream_oauth2::UpstreamOAuthProviderRepository,
user::{UserEmailRepository, UserPasswordRepository, UserRepository}, user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, SystemClock, Repository, RepositoryAccess, SystemClock,
}; };
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;

View File

@ -15,7 +15,7 @@
use oauth2_types::scope::ScopeToken; use oauth2_types::scope::ScopeToken;
use rand::{ use rand::{
distributions::{Alphanumeric, DistString}, distributions::{Alphanumeric, DistString},
Rng, RngCore,
}; };
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
@ -48,7 +48,7 @@ impl Device {
} }
/// Generate a random device ID /// Generate a random device ID
pub fn generate<R: Rng + ?Sized>(rng: &mut R) -> Self { pub fn generate<R: RngCore + ?Sized>(rng: &mut R) -> Self {
let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH); let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH);
Self { id } Self { id }
} }

View File

@ -28,7 +28,9 @@ use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer; use mas_email::Mailer;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository}; use mas_storage::{
user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, Rng}; use rand::{distributions::Uniform, Rng};
use serde::Deserialize; use serde::Deserialize;
@ -71,7 +73,7 @@ async fn render<E: std::error::Error>(
templates: Templates, templates: Templates,
session: BrowserSession, session: BrowserSession,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
@ -88,7 +90,7 @@ async fn render<E: std::error::Error>(
async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>( async fn start_email_verification<E: std::error::Error + Send + Sync + 'static>(
mailer: &Mailer, mailer: &Mailer,
repo: &mut (impl Repository<Error = E> + ?Sized), repo: &mut impl RepositoryAccess<Error = E>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &impl Clock, clock: &impl Clock,
user: &User, user: &User,

View File

@ -26,7 +26,7 @@ use mas_keystore::Encrypter;
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, upstream_oauth2::UpstreamOAuthProviderRepository,
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Clock, Repository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
}; };
use mas_templates::{ use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
@ -161,7 +161,7 @@ pub(crate) async fn post(
// TODO: move that logic elsewhere? // TODO: move that logic elsewhere?
async fn login( async fn login(
password_manager: PasswordManager, password_manager: PasswordManager,
repo: &mut (impl Repository + ?Sized), repo: &mut impl RepositoryAccess,
mut rng: impl Rng + CryptoRng + Send, mut rng: impl Rng + CryptoRng + Send,
clock: &impl Clock, clock: &impl Clock,
username: &str, username: &str,
@ -235,7 +235,7 @@ async fn render(
ctx: LoginContext, ctx: LoginContext,
action: OptionalPostAuthAction, action: OptionalPostAuthAction,
csrf_token: CsrfToken, csrf_token: CsrfToken,
repo: &mut (impl Repository + ?Sized), repo: &mut impl RepositoryAccess,
templates: &Templates, templates: &Templates,
) -> Result<String, FancyError> { ) -> Result<String, FancyError> {
let next = action.load_context(repo).await?; let next = action.load_context(repo).await?;

View File

@ -33,7 +33,7 @@ use mas_policy::PolicyFactory;
use mas_router::Route; use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Repository, BoxClock, BoxRepository, BoxRng, RepositoryAccess,
}; };
use mas_templates::{ use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
@ -233,7 +233,7 @@ async fn render(
ctx: RegisterContext, ctx: RegisterContext,
action: OptionalPostAuthAction, action: OptionalPostAuthAction,
csrf_token: CsrfToken, csrf_token: CsrfToken,
repo: &mut (impl Repository + ?Sized), repo: &mut impl RepositoryAccess,
templates: &Templates, templates: &Templates,
) -> Result<String, FancyError> { ) -> Result<String, FancyError> {
let next = action.load_context(repo).await?; let next = action.load_context(repo).await?;

View File

@ -18,7 +18,7 @@ use mas_storage::{
compat::CompatSsoLoginRepository, compat::CompatSsoLoginRepository,
oauth2::OAuth2AuthorizationGrantRepository, oauth2::OAuth2AuthorizationGrantRepository,
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
Repository, RepositoryAccess,
}; };
use mas_templates::{PostAuthContext, PostAuthContextInner}; use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -42,7 +42,7 @@ impl OptionalPostAuthAction {
pub async fn load_context<'a>( pub async fn load_context<'a>(
&'a self, &'a self,
repo: &'a mut (impl Repository + ?Sized), repo: &'a mut impl RepositoryAccess,
) -> anyhow::Result<Option<PostAuthContext>> { ) -> anyhow::Result<Option<PostAuthContext>> {
let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action { let ctx = match action {

View File

@ -32,7 +32,7 @@ mod tests {
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
}, },
user::UserRepository, user::UserRepository,
Clock, Repository, Clock, Repository, RepositoryAccess,
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! Interactions with the database //! An implementation of the storage traits for a PostgreSQL database
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![deny( #![deny(

View File

@ -27,7 +27,7 @@ use mas_storage::{
UpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository,
}, },
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, Repository, RepositoryAccess, RepositoryTransaction,
}; };
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::{PgPool, Postgres, Transaction};
@ -62,7 +62,9 @@ impl PgRepository {
} }
} }
impl Repository for PgRepository { impl Repository<DatabaseError> for PgRepository {}
impl RepositoryTransaction for PgRepository {
type Error = DatabaseError; type Error = DatabaseError;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> { fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
@ -72,6 +74,10 @@ impl Repository for PgRepository {
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> { fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
self.txn.rollback().map_err(DatabaseError::from).boxed() self.txn.rollback().map_err(DatabaseError::from).boxed()
} }
}
impl RepositoryAccess for PgRepository {
type Error = DatabaseError;
fn upstream_oauth_link<'c>( fn upstream_oauth_link<'c>(
&'c mut self, &'c mut self,

View File

@ -31,7 +31,7 @@ mod tests {
UpstreamOAuthSessionRepository, UpstreamOAuthSessionRepository,
}, },
user::UserRepository, user::UserRepository,
Pagination, Repository, Pagination, RepositoryAccess,
}; };
use oauth2_types::scope::{Scope, OPENID}; use oauth2_types::scope::{Scope, OPENID};
use rand::SeedableRng; use rand::SeedableRng;

View File

@ -16,7 +16,7 @@ use chrono::Duration;
use mas_storage::{ use mas_storage::{
clock::MockClock, clock::MockClock,
user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, Repository, RepositoryAccess,
}; };
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaChaRng; use rand_chacha::ChaChaRng;

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! Interactions with the database //! Interactions with the storage backend
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![deny( #![deny(
@ -42,20 +42,25 @@ pub mod user;
pub use self::{ pub use self::{
clock::{Clock, SystemClock}, clock::{Clock, SystemClock},
pagination::{Page, Pagination}, pagination::{Page, Pagination},
repository::{BoxRepository, Repository, RepositoryError}, repository::{
BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
},
}; };
pub struct MapErr<Repository, Mapper> { /// A wrapper which is used to map the error type of a repository to another
inner: Repository, pub struct MapErr<R, F> {
mapper: Mapper, inner: R,
mapper: F,
} }
impl<Repository, Mapper> MapErr<Repository, Mapper> { impl<R, F> MapErr<R, F> {
fn new(inner: Repository, mapper: Mapper) -> Self { fn new(inner: R, mapper: F) -> Self {
Self { inner, mapper } Self { inner, mapper }
} }
} }
/// A macro to implement a repository trait for the [`MapErr`] wrapper and for
/// [`Box<R>`]
#[macro_export] #[macro_export]
macro_rules! repository_impl { macro_rules! repository_impl {
($repo_trait:ident: ($repo_trait:ident:

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use futures_util::future::BoxFuture;
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
@ -32,83 +32,27 @@ use crate::{
MapErr, MapErr,
}; };
pub trait Repository: Send { /// A [`Repository`] helps interacting with the underlying storage backend.
type Error: std::error::Error + Send + Sync + 'static; pub trait Repository<E>:
RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send
where
E: std::error::Error + Send + Sync + 'static,
{
/// Construct a (boxed) typed-erased repository
fn boxed(self) -> BoxRepository<E>
where
Self: Sync + Sized + 'static,
{
Box::new(self)
}
/// Map the error type of all the methods of a [`Repository`]
fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper> fn map_err<Mapper>(self, mapper: Mapper) -> MapErr<Self, Mapper>
where where
Self: Sized, Self: Sized,
{ {
MapErr::new(self, mapper) MapErr::new(self, mapper)
} }
fn boxed(self) -> BoxRepository<Self::Error>
where
Self: Sized + Sync + 'static,
{
Box::new(self)
}
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
fn upstream_oauth_provider<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c>;
fn upstream_oauth_session<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c>;
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c>;
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c>;
fn user_password<'c>(&'c mut self)
-> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c>;
fn browser_session<'c>(
&'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
fn oauth2_client<'c>(&'c mut self)
-> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
fn oauth2_authorization_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c>;
fn oauth2_session<'c>(
&'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c>;
fn oauth2_access_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c>;
fn oauth2_refresh_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c>;
fn compat_sso_login<'c>(
&'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c>;
fn compat_access_token<'c>(
&'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c>;
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
} }
/// An opaque, type-erased error /// An opaque, type-erased error
@ -119,6 +63,7 @@ pub struct RepositoryError {
} }
impl RepositoryError { impl RepositoryError {
/// Construct a [`RepositoryError`] from any error kind
pub fn from_error<E>(value: E) -> Self pub fn from_error<E>(value: E) -> Self
where where
E: std::error::Error + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static,
@ -129,251 +74,386 @@ impl RepositoryError {
} }
} }
pub type BoxRepository<E = RepositoryError> = /// A type-erased [`Repository`]
Box<dyn Repository<Error = E> + Send + Sync + 'static>; pub type BoxRepository<E = RepositoryError> = Box<dyn Repository<E> + Send + Sync + 'static>;
impl<R, F, E> Repository for crate::MapErr<R, F> /// A [`RepositoryTransaction`] can be saved or cancelled, after a series
where /// of operations.
R: Repository, pub trait RepositoryTransaction {
R::Error: 'static, /// The error type used by the [`Self::save`] and [`Self::cancel`] functions
F: FnMut(R::Error) -> E + Send + Sync + 'static, type Error;
E: std::error::Error + Send + Sync + 'static,
{
type Error = E;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> { /// Commit the transaction
Box::new(self.inner).save().map_err(self.mapper).boxed() ///
} /// # Errors
///
/// Returns an error if the underlying storage backend failed to commit the
/// transaction.
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> { /// Rollback the transaction
Box::new(self.inner).cancel().map_err(self.mapper).boxed() ///
} /// # Errors
///
fn upstream_oauth_link<'c>( /// Returns an error if the underlying storage backend failed to rollback
&'c mut self, /// the transaction.
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> { fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
Box::new(MapErr::new(
self.inner.upstream_oauth_link(),
&mut self.mapper,
))
}
fn upstream_oauth_provider<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.upstream_oauth_provider(),
&mut self.mapper,
))
}
fn upstream_oauth_session<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.upstream_oauth_session(),
&mut self.mapper,
))
}
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.user(), &mut self.mapper))
}
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper))
}
fn user_password<'c>(
&'c mut self,
) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper))
}
fn browser_session<'c>(
&'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper))
}
fn oauth2_client<'c>(
&'c mut self,
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper))
}
fn oauth2_authorization_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_authorization_grant(),
&mut self.mapper,
))
}
fn oauth2_session<'c>(
&'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper))
}
fn oauth2_access_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_access_token(),
&mut self.mapper,
))
}
fn oauth2_refresh_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_refresh_token(),
&mut self.mapper,
))
}
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper))
}
fn compat_sso_login<'c>(
&'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper))
}
fn compat_access_token<'c>(
&'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.compat_access_token(),
&mut self.mapper,
))
}
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.compat_refresh_token(),
&mut self.mapper,
))
}
} }
impl<R: Repository + ?Sized> Repository for Box<R> { /// Access the various repositories the backend implements.
type Error = R::Error; pub trait RepositoryAccess: Send {
/// The backend-specific error type used by each repository.
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> type Error: std::error::Error + Send + Sync + 'static;
where
Self: Sized,
{
// This shouldn't be callable?
unimplemented!()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>
where
Self: Sized,
{
// This shouldn't be callable?
unimplemented!()
}
/// Get an [`UpstreamOAuthLinkRepository`]
fn upstream_oauth_link<'c>( fn upstream_oauth_link<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> { ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
(**self).upstream_oauth_link()
}
/// Get an [`UpstreamOAuthProviderRepository`]
fn upstream_oauth_provider<'c>( fn upstream_oauth_provider<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> { ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c>;
(**self).upstream_oauth_provider()
}
/// Get an [`UpstreamOAuthSessionRepository`]
fn upstream_oauth_session<'c>( fn upstream_oauth_session<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> { ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c>;
(**self).upstream_oauth_session()
}
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> { /// Get an [`UserRepository`]
(**self).user() fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c>;
}
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> { /// Get an [`UserEmailRepository`]
(**self).user_email() fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c>;
}
fn user_password<'c>( /// Get an [`UserPasswordRepository`]
&'c mut self, fn user_password<'c>(&'c mut self)
) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> { -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c>;
(**self).user_password()
}
/// Get a [`BrowserSessionRepository`]
fn browser_session<'c>( fn browser_session<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> { ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
(**self).browser_session()
}
fn oauth2_client<'c>( /// Get an [`OAuth2ClientRepository`]
&'c mut self, fn oauth2_client<'c>(&'c mut self)
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> { -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
(**self).oauth2_client()
}
/// Get an [`OAuth2AuthorizationGrantRepository`]
fn oauth2_authorization_grant<'c>( fn oauth2_authorization_grant<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> { ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c>;
(**self).oauth2_authorization_grant()
}
/// Get an [`OAuth2SessionRepository`]
fn oauth2_session<'c>( fn oauth2_session<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> { ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c>;
(**self).oauth2_session()
}
/// Get an [`OAuth2AccessTokenRepository`]
fn oauth2_access_token<'c>( fn oauth2_access_token<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> { ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c>;
(**self).oauth2_access_token()
}
/// Get an [`OAuth2RefreshTokenRepository`]
fn oauth2_refresh_token<'c>( fn oauth2_refresh_token<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> { ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
(**self).oauth2_refresh_token()
}
/// Get a [`CompatSessionRepository`]
fn compat_session<'c>( fn compat_session<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> { ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c>;
(**self).compat_session()
}
/// Get a [`CompatSsoLoginRepository`]
fn compat_sso_login<'c>( fn compat_sso_login<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> { ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c>;
(**self).compat_sso_login()
}
/// Get a [`CompatAccessTokenRepository`]
fn compat_access_token<'c>( fn compat_access_token<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> { ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c>;
(**self).compat_access_token()
}
/// Get a [`CompatRefreshTokenRepository`]
fn compat_refresh_token<'c>( fn compat_refresh_token<'c>(
&'c mut self, &'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> { ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
(**self).compat_refresh_token() }
/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
/// [`Repository`] for the [`MapErr`] wrapper and [`Box<R>`]
mod impls {
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use super::RepositoryAccess;
use crate::{
compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
oauth2::{
OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
},
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
},
user::{
BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository,
},
MapErr, Repository, RepositoryTransaction,
};
// --- Repository ---
impl<R, F, E1, E2> Repository<E2> for MapErr<R, F>
where
R: Repository<E1> + RepositoryAccess<Error = E1> + RepositoryTransaction<Error = E1>,
F: FnMut(E1) -> E2 + Send + Sync + 'static,
E1: std::error::Error + Send + Sync + 'static,
E2: std::error::Error + Send + Sync + 'static,
{
}
// --- RepositoryTransaction --
impl<R, F, E> RepositoryTransaction for MapErr<R, F>
where
R: RepositoryTransaction,
R::Error: 'static,
F: FnMut(R::Error) -> E + Send + Sync + 'static,
E: std::error::Error,
{
type Error = E;
fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).save().map_err(self.mapper).boxed()
}
fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::new(self.inner).cancel().map_err(self.mapper).boxed()
}
}
// --- RepositoryAccess --
impl<R, F, E> RepositoryAccess for MapErr<R, F>
where
R: RepositoryAccess,
R::Error: 'static,
F: FnMut(R::Error) -> E + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
type Error = E;
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.upstream_oauth_link(),
&mut self.mapper,
))
}
fn upstream_oauth_provider<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.upstream_oauth_provider(),
&mut self.mapper,
))
}
fn upstream_oauth_session<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.upstream_oauth_session(),
&mut self.mapper,
))
}
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.user(), &mut self.mapper))
}
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper))
}
fn user_password<'c>(
&'c mut self,
) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper))
}
fn browser_session<'c>(
&'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper))
}
fn oauth2_client<'c>(
&'c mut self,
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper))
}
fn oauth2_authorization_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_authorization_grant(),
&mut self.mapper,
))
}
fn oauth2_session<'c>(
&'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper))
}
fn oauth2_access_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_access_token(),
&mut self.mapper,
))
}
fn oauth2_refresh_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.oauth2_refresh_token(),
&mut self.mapper,
))
}
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper))
}
fn compat_sso_login<'c>(
&'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper))
}
fn compat_access_token<'c>(
&'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.compat_access_token(),
&mut self.mapper,
))
}
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
Box::new(MapErr::new(
self.inner.compat_refresh_token(),
&mut self.mapper,
))
}
}
impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
type Error = R::Error;
fn upstream_oauth_link<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
(**self).upstream_oauth_link()
}
fn upstream_oauth_provider<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
(**self).upstream_oauth_provider()
}
fn upstream_oauth_session<'c>(
&'c mut self,
) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
(**self).upstream_oauth_session()
}
fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
(**self).user()
}
fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
(**self).user_email()
}
fn user_password<'c>(
&'c mut self,
) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
(**self).user_password()
}
fn browser_session<'c>(
&'c mut self,
) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
(**self).browser_session()
}
fn oauth2_client<'c>(
&'c mut self,
) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
(**self).oauth2_client()
}
fn oauth2_authorization_grant<'c>(
&'c mut self,
) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
(**self).oauth2_authorization_grant()
}
fn oauth2_session<'c>(
&'c mut self,
) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
(**self).oauth2_session()
}
fn oauth2_access_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
(**self).oauth2_access_token()
}
fn oauth2_refresh_token<'c>(
&'c mut self,
) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
(**self).oauth2_refresh_token()
}
fn compat_session<'c>(
&'c mut self,
) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
(**self).compat_session()
}
fn compat_sso_login<'c>(
&'c mut self,
) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
(**self).compat_sso_login()
}
fn compat_access_token<'c>(
&'c mut self,
) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
(**self).compat_access_token()
}
fn compat_refresh_token<'c>(
&'c mut self,
) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
(**self).compat_refresh_token()
}
} }
} }

View File

@ -14,7 +14,7 @@
//! Database-related tasks //! Database-related tasks
use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, SystemClock}; use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess, SystemClock};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tracing::{debug, error, info}; use tracing::{debug, error, info};