diff --git a/clippy.toml b/clippy.toml index 61c5c04f..a300ae05 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,4 +1,4 @@ -doc-valid-idents = ["OpenID", "OAuth", ".."] +doc-valid-idents = ["OpenID", "OAuth", "..", "PostgreSQL"] disallowed-methods = [ { path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" }, diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 67930c3f..6f5b3e27 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,7 +31,7 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, RepositoryAccess}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use thiserror::Error; @@ -74,7 +74,7 @@ pub enum Credentials { impl Credentials { pub async fn fetch( &self, - repo: &mut (impl Repository + ?Sized), + repo: &mut (impl RepositoryAccess + ?Sized), ) -> Result, E> { let client_id = match self { Credentials::None { client_id } diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index 5e966152..c4fece7b 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -14,7 +14,7 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use mas_data_model::BrowserSession; -use mas_storage::{user::BrowserSessionRepository, Repository}; +use mas_storage::{user::BrowserSessionRepository, RepositoryAccess}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -45,7 +45,7 @@ impl SessionInfo { /// Load the [`BrowserSession`] from database pub async fn load_session( &self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result, E> { let session_id = if let Some(id) = self.current { id diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 2d37c40c..c9bc537c 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -29,7 +29,7 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode use mas_data_model::Session; use mas_storage::{ oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, - Clock, Repository, + Clock, RepositoryAccess, }; use serde::{de::DeserializeOwned, Deserialize}; use thiserror::Error; @@ -53,7 +53,7 @@ enum AccessToken { impl AccessToken { async fn fetch( &self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, @@ -86,7 +86,7 @@ impl UserAuthorization { // TODO: take scopes to validate as parameter pub async fn protected_form( self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, clock: &impl Clock, ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { @@ -106,7 +106,7 @@ impl UserAuthorization { // TODO: take scopes to validate as parameter pub async fn protected( self, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, clock: &impl Clock, ) -> Result> { let (token, session) = self.access_token.fetch(repo).await?; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 4e74569a..b685a167 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,7 +21,7 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, SystemClock, + Repository, RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; use oauth2_types::scope::Scope; diff --git a/crates/data-model/src/compat/device.rs b/crates/data-model/src/compat/device.rs index 84bdd067..eebfd9ed 100644 --- a/crates/data-model/src/compat/device.rs +++ b/crates/data-model/src/compat/device.rs @@ -15,7 +15,7 @@ use oauth2_types::scope::ScopeToken; use rand::{ distributions::{Alphanumeric, DistString}, - Rng, + RngCore, }; use serde::Serialize; use thiserror::Error; @@ -48,7 +48,7 @@ impl Device { } /// Generate a random device ID - pub fn generate(rng: &mut R) -> Self { + pub fn generate(rng: &mut R) -> Self { let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH); Self { id } } diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 83038337..ad997e0d 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,7 +28,9 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository}; +use mas_storage::{ + user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, +}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; @@ -71,7 +73,7 @@ async fn render( templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); @@ -88,7 +90,7 @@ async fn render( async fn start_email_verification( mailer: &Mailer, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, mut rng: impl Rng + Send, clock: &impl Clock, user: &User, diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 2836eee3..3083eae0 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRepository, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, @@ -161,7 +161,7 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, mut rng: impl Rng + CryptoRng + Send, clock: &impl Clock, username: &str, @@ -235,7 +235,7 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index ad1ff378..64e30af7 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,7 +33,7 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRepository, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, @@ -233,7 +233,7 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut (impl Repository + ?Sized), + repo: &mut impl RepositoryAccess, templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index b2946084..69fdf901 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -18,7 +18,7 @@ use mas_storage::{ compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, - Repository, + RepositoryAccess, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; @@ -42,7 +42,7 @@ impl OptionalPostAuthAction { pub async fn load_context<'a>( &'a self, - repo: &'a mut (impl Repository + ?Sized), + repo: &'a mut impl RepositoryAccess, ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 9b340756..5d99f332 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -32,7 +32,7 @@ mod tests { CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, }, user::UserRepository, - Clock, Repository, + Clock, Repository, RepositoryAccess, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 459c8c3b..08c89db2 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Interactions with the database +//! An implementation of the storage traits for a PostgreSQL database #![forbid(unsafe_code)] #![deny( diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 6448b61a..0f1fdfb4 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -27,7 +27,7 @@ use mas_storage::{ UpstreamOAuthSessionRepository, }, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + Repository, RepositoryAccess, RepositoryTransaction, }; use sqlx::{PgPool, Postgres, Transaction}; @@ -62,7 +62,9 @@ impl PgRepository { } } -impl Repository for PgRepository { +impl Repository for PgRepository {} + +impl RepositoryTransaction for PgRepository { type Error = DatabaseError; fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { @@ -72,6 +74,10 @@ impl Repository for PgRepository { fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { self.txn.rollback().map_err(DatabaseError::from).boxed() } +} + +impl RepositoryAccess for PgRepository { + type Error = DatabaseError; fn upstream_oauth_link<'c>( &'c mut self, diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index af631f15..9ff7699e 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -31,7 +31,7 @@ mod tests { UpstreamOAuthSessionRepository, }, user::UserRepository, - Pagination, Repository, + Pagination, RepositoryAccess, }; use oauth2_types::scope::{Scope, OPENID}; use rand::SeedableRng; diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 7c3eab37..29f828ab 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -16,7 +16,7 @@ use chrono::Duration; use mas_storage::{ clock::MockClock, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + Repository, RepositoryAccess, }; use rand::SeedableRng; use rand_chacha::ChaChaRng; diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index aa1db0af..69bc2881 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Interactions with the database +//! Interactions with the storage backend #![forbid(unsafe_code)] #![deny( @@ -42,20 +42,25 @@ pub mod user; pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, - repository::{BoxRepository, Repository, RepositoryError}, + repository::{ + BoxRepository, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction, + }, }; -pub struct MapErr { - inner: Repository, - mapper: Mapper, +/// A wrapper which is used to map the error type of a repository to another +pub struct MapErr { + inner: R, + mapper: F, } -impl MapErr { - fn new(inner: Repository, mapper: Mapper) -> Self { +impl MapErr { + fn new(inner: R, mapper: F) -> Self { Self { inner, mapper } } } +/// A macro to implement a repository trait for the [`MapErr`] wrapper and for +/// [`Box`] #[macro_export] macro_rules! repository_impl { ($repo_trait:ident: diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index d6772b9a..f023e469 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; +use futures_util::future::BoxFuture; use thiserror::Error; use crate::{ @@ -32,83 +32,27 @@ use crate::{ MapErr, }; -pub trait Repository: Send { - type Error: std::error::Error + Send + Sync + 'static; +/// A [`Repository`] helps interacting with the underlying storage backend. +pub trait Repository: + RepositoryAccess + RepositoryTransaction + Send +where + E: std::error::Error + Send + Sync + 'static, +{ + /// Construct a (boxed) typed-erased repository + fn boxed(self) -> BoxRepository + where + Self: Sync + Sized + 'static, + { + Box::new(self) + } + /// Map the error type of all the methods of a [`Repository`] fn map_err(self, mapper: Mapper) -> MapErr where Self: Sized, { MapErr::new(self, mapper) } - - fn boxed(self) -> BoxRepository - where - Self: Sized + Sync + 'static, - { - Box::new(self) - } - - fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; - fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; - - fn upstream_oauth_link<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn upstream_oauth_provider<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn upstream_oauth_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn user<'c>(&'c mut self) -> Box + 'c>; - - fn user_email<'c>(&'c mut self) -> Box + 'c>; - - fn user_password<'c>(&'c mut self) - -> Box + 'c>; - - fn browser_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_client<'c>(&'c mut self) - -> Box + 'c>; - - fn oauth2_authorization_grant<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_access_token<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn oauth2_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_session<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_sso_login<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_access_token<'c>( - &'c mut self, - ) -> Box + 'c>; - - fn compat_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c>; } /// An opaque, type-erased error @@ -119,6 +63,7 @@ pub struct RepositoryError { } impl RepositoryError { + /// Construct a [`RepositoryError`] from any error kind pub fn from_error(value: E) -> Self where E: std::error::Error + Send + Sync + 'static, @@ -129,251 +74,386 @@ impl RepositoryError { } } -pub type BoxRepository = - Box + Send + Sync + 'static>; +/// A type-erased [`Repository`] +pub type BoxRepository = Box + Send + Sync + 'static>; -impl Repository for crate::MapErr -where - R: Repository, - R::Error: 'static, - F: FnMut(R::Error) -> E + Send + Sync + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - type Error = E; +/// A [`RepositoryTransaction`] can be saved or cancelled, after a series +/// of operations. +pub trait RepositoryTransaction { + /// The error type used by the [`Self::save`] and [`Self::cancel`] functions + type Error; - fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { - Box::new(self.inner).save().map_err(self.mapper).boxed() - } + /// Commit the transaction + /// + /// # Errors + /// + /// Returns an error if the underlying storage backend failed to commit the + /// transaction. + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; - fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { - Box::new(self.inner).cancel().map_err(self.mapper).boxed() - } - - fn upstream_oauth_link<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.upstream_oauth_link(), - &mut self.mapper, - )) - } - - fn upstream_oauth_provider<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.upstream_oauth_provider(), - &mut self.mapper, - )) - } - - fn upstream_oauth_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.upstream_oauth_session(), - &mut self.mapper, - )) - } - - fn user<'c>(&'c mut self) -> Box + 'c> { - Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) - } - - fn user_email<'c>(&'c mut self) -> Box + 'c> { - Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) - } - - fn user_password<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) - } - - fn browser_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) - } - - fn oauth2_client<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) - } - - fn oauth2_authorization_grant<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.oauth2_authorization_grant(), - &mut self.mapper, - )) - } - - fn oauth2_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) - } - - fn oauth2_access_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.oauth2_access_token(), - &mut self.mapper, - )) - } - - fn oauth2_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.oauth2_refresh_token(), - &mut self.mapper, - )) - } - - fn compat_session<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) - } - - fn compat_sso_login<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) - } - - fn compat_access_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.compat_access_token(), - &mut self.mapper, - )) - } - - fn compat_refresh_token<'c>( - &'c mut self, - ) -> Box + 'c> { - Box::new(MapErr::new( - self.inner.compat_refresh_token(), - &mut self.mapper, - )) - } + /// Rollback the transaction + /// + /// # Errors + /// + /// Returns an error if the underlying storage backend failed to rollback + /// the transaction. + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; } -impl Repository for Box { - type Error = R::Error; - - fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> - where - Self: Sized, - { - // This shouldn't be callable? - unimplemented!() - } - - fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> - where - Self: Sized, - { - // This shouldn't be callable? - unimplemented!() - } +/// Access the various repositories the backend implements. +pub trait RepositoryAccess: Send { + /// The backend-specific error type used by each repository. + type Error: std::error::Error + Send + Sync + 'static; + /// Get an [`UpstreamOAuthLinkRepository`] fn upstream_oauth_link<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).upstream_oauth_link() - } + ) -> Box + 'c>; + /// Get an [`UpstreamOAuthProviderRepository`] fn upstream_oauth_provider<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).upstream_oauth_provider() - } + ) -> Box + 'c>; + /// Get an [`UpstreamOAuthSessionRepository`] fn upstream_oauth_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).upstream_oauth_session() - } + ) -> Box + 'c>; - fn user<'c>(&'c mut self) -> Box + 'c> { - (**self).user() - } + /// Get an [`UserRepository`] + fn user<'c>(&'c mut self) -> Box + 'c>; - fn user_email<'c>(&'c mut self) -> Box + 'c> { - (**self).user_email() - } + /// Get an [`UserEmailRepository`] + fn user_email<'c>(&'c mut self) -> Box + 'c>; - fn user_password<'c>( - &'c mut self, - ) -> Box + 'c> { - (**self).user_password() - } + /// Get an [`UserPasswordRepository`] + fn user_password<'c>(&'c mut self) + -> Box + 'c>; + /// Get a [`BrowserSessionRepository`] fn browser_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).browser_session() - } + ) -> Box + 'c>; - fn oauth2_client<'c>( - &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_client() - } + /// Get an [`OAuth2ClientRepository`] + fn oauth2_client<'c>(&'c mut self) + -> Box + 'c>; + /// Get an [`OAuth2AuthorizationGrantRepository`] fn oauth2_authorization_grant<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_authorization_grant() - } + ) -> Box + 'c>; + /// Get an [`OAuth2SessionRepository`] fn oauth2_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_session() - } + ) -> Box + 'c>; + /// Get an [`OAuth2AccessTokenRepository`] fn oauth2_access_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_access_token() - } + ) -> Box + 'c>; + /// Get an [`OAuth2RefreshTokenRepository`] fn oauth2_refresh_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).oauth2_refresh_token() - } + ) -> Box + 'c>; + /// Get a [`CompatSessionRepository`] fn compat_session<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_session() - } + ) -> Box + 'c>; + /// Get a [`CompatSsoLoginRepository`] fn compat_sso_login<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_sso_login() - } + ) -> Box + 'c>; + /// Get a [`CompatAccessTokenRepository`] fn compat_access_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_access_token() - } + ) -> Box + 'c>; + /// Get a [`CompatRefreshTokenRepository`] fn compat_refresh_token<'c>( &'c mut self, - ) -> Box + 'c> { - (**self).compat_refresh_token() + ) -> Box + 'c>; +} + +/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and +/// [`Repository`] for the [`MapErr`] wrapper and [`Box`] +mod impls { + use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; + + use super::RepositoryAccess; + use crate::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, + OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{ + BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository, + }, + MapErr, Repository, RepositoryTransaction, + }; + + // --- Repository --- + impl Repository for MapErr + where + R: Repository + RepositoryAccess + RepositoryTransaction, + F: FnMut(E1) -> E2 + Send + Sync + 'static, + E1: std::error::Error + Send + Sync + 'static, + E2: std::error::Error + Send + Sync + 'static, + { + } + + // --- RepositoryTransaction -- + impl RepositoryTransaction for MapErr + where + R: RepositoryTransaction, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, + E: std::error::Error, + { + type Error = E; + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).save().map_err(self.mapper).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).cancel().map_err(self.mapper).boxed() + } + } + + // --- RepositoryAccess -- + impl RepositoryAccess for MapErr + where + R: RepositoryAccess, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, + { + type Error = E; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_link(), + &mut self.mapper, + )) + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_provider(), + &mut self.mapper, + )) + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_session(), + &mut self.mapper, + )) + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_authorization_grant(), + &mut self.mapper, + )) + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_access_token(), + &mut self.mapper, + )) + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_refresh_token(), + &mut self.mapper, + )) + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_access_token(), + &mut self.mapper, + )) + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_refresh_token(), + &mut self.mapper, + )) + } + } + + impl RepositoryAccess for Box { + type Error = R::Error; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_link() + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_provider() + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).upstream_oauth_session() + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + (**self).user() + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + (**self).user_email() + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).user_password() + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).browser_session() + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_client() + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_authorization_grant() + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_session() + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_access_token() + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).oauth2_refresh_token() + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_session() + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_sso_login() + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_access_token() + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).compat_refresh_token() + } } } diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index e7947ce2..ebade53a 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,7 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Repository, SystemClock}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, RepositoryAccess, SystemClock}; use mas_storage_pg::PgRepository; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info};