// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use futures_util::future::BoxFuture; use thiserror::Error; use crate::{ app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::{ BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRecoveryRepository, UserRepository, UserTermsRepository, }, MapErr, }; /// A [`Repository`] helps interacting with the underlying storage backend. pub trait Repository: RepositoryAccess + RepositoryTransaction + Send where E: std::error::Error + Send + Sync + 'static, { /// Construct a (boxed) typed-erased repository fn boxed(self) -> BoxRepository where Self: Sync + Sized + 'static, { Box::new(self) } /// Map the error type of all the methods of a [`Repository`] fn map_err(self, mapper: Mapper) -> MapErr where Self: Sized, { MapErr::new(self, mapper) } } /// An opaque, type-erased error #[derive(Debug, Error)] #[error(transparent)] pub struct RepositoryError { source: Box, } impl RepositoryError { /// Construct a [`RepositoryError`] from any error kind pub fn from_error(value: E) -> Self where E: std::error::Error + Send + Sync + 'static, { Self { source: Box::new(value), } } } /// A type-erased [`Repository`] pub type BoxRepository = Box + Send + Sync + 'static>; /// A [`RepositoryTransaction`] can be saved or cancelled, after a series /// of operations. pub trait RepositoryTransaction { /// The error type used by the [`Self::save`] and [`Self::cancel`] functions type Error; /// Commit the transaction /// /// # Errors /// /// Returns an error if the underlying storage backend failed to commit the /// transaction. fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; /// Rollback the transaction /// /// # Errors /// /// Returns an error if the underlying storage backend failed to rollback /// the transaction. fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; } /// Access the various repositories the backend implements. /// /// All the methods return a boxed trait object, which can be used to access a /// particular repository. The lifetime of the returned object is bound to the /// lifetime of the whole repository, so that only one mutable reference to the /// repository is used at a time. /// /// When adding a new repository, you should add a new method to this trait, and /// update the implementations for [`MapErr`] and [`Box`] below. /// /// Note: this used to have generic associated types to avoid boxing all the /// repository traits, but that was removed because it made almost impossible to /// box the trait object. This might be a shortcoming of the initial /// implementation of generic associated types, and might be fixed in the /// future. pub trait RepositoryAccess: Send { /// The backend-specific error type used by each repository. type Error: std::error::Error + Send + Sync + 'static; /// Get an [`UpstreamOAuthLinkRepository`] fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`UpstreamOAuthProviderRepository`] fn upstream_oauth_provider<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`UpstreamOAuthSessionRepository`] fn upstream_oauth_session<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`UserRepository`] fn user<'c>(&'c mut self) -> Box + 'c>; /// Get an [`UserEmailRepository`] fn user_email<'c>(&'c mut self) -> Box + 'c>; /// Get an [`UserPasswordRepository`] fn user_password<'c>(&'c mut self) -> Box + 'c>; /// Get an [`UserRecoveryRepository`] fn user_recovery<'c>(&'c mut self) -> Box + 'c>; /// Get an [`UserTermsRepository`] fn user_terms<'c>(&'c mut self) -> Box + 'c>; /// Get a [`BrowserSessionRepository`] fn browser_session<'c>( &'c mut self, ) -> Box + 'c>; /// Get a [`AppSessionRepository`] fn app_session<'c>(&'c mut self) -> Box + 'c>; /// Get an [`OAuth2ClientRepository`] fn oauth2_client<'c>(&'c mut self) -> Box + 'c>; /// Get an [`OAuth2AuthorizationGrantRepository`] fn oauth2_authorization_grant<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`OAuth2SessionRepository`] fn oauth2_session<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`OAuth2AccessTokenRepository`] fn oauth2_access_token<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`OAuth2RefreshTokenRepository`] fn oauth2_refresh_token<'c>( &'c mut self, ) -> Box + 'c>; /// Get an [`OAuth2DeviceCodeGrantRepository`] fn oauth2_device_code_grant<'c>( &'c mut self, ) -> Box + 'c>; /// Get a [`CompatSessionRepository`] fn compat_session<'c>( &'c mut self, ) -> Box + 'c>; /// Get a [`CompatSsoLoginRepository`] fn compat_sso_login<'c>( &'c mut self, ) -> Box + 'c>; /// Get a [`CompatAccessTokenRepository`] fn compat_access_token<'c>( &'c mut self, ) -> Box + 'c>; /// Get a [`CompatRefreshTokenRepository`] fn compat_refresh_token<'c>( &'c mut self, ) -> Box + 'c>; /// Get a [`JobRepository`] fn job<'c>(&'c mut self) -> Box + 'c>; } /// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and /// [`Repository`] for the [`MapErr`] wrapper and [`Box`] mod impls { use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use super::RepositoryAccess; use crate::{ app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, job::JobRepository, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, upstream_oauth2::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::{ BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository, UserTermsRepository, }, MapErr, Repository, RepositoryTransaction, }; // --- Repository --- impl Repository for MapErr where R: Repository + RepositoryAccess + RepositoryTransaction, F: FnMut(E1) -> E2 + Send + Sync + 'static, E1: std::error::Error + Send + Sync + 'static, E2: std::error::Error + Send + Sync + 'static, { } // --- RepositoryTransaction -- impl RepositoryTransaction for MapErr where R: RepositoryTransaction, R::Error: 'static, F: FnMut(R::Error) -> E + Send + Sync + 'static, E: std::error::Error, { type Error = E; fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { Box::new(self.inner).save().map_err(self.mapper).boxed() } fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { Box::new(self.inner).cancel().map_err(self.mapper).boxed() } } // --- RepositoryAccess -- impl RepositoryAccess for MapErr where R: RepositoryAccess, R::Error: 'static, F: FnMut(R::Error) -> E + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, { type Error = E; fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.upstream_oauth_link(), &mut self.mapper, )) } fn upstream_oauth_provider<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.upstream_oauth_provider(), &mut self.mapper, )) } fn upstream_oauth_session<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.upstream_oauth_session(), &mut self.mapper, )) } fn user<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) } fn user_email<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) } fn user_password<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) } fn user_recovery<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.user_recovery(), &mut self.mapper)) } fn user_terms<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.user_terms(), &mut self.mapper)) } fn browser_session<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) } fn app_session<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.app_session(), &mut self.mapper)) } fn oauth2_client<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) } fn oauth2_authorization_grant<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.oauth2_authorization_grant(), &mut self.mapper, )) } fn oauth2_session<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) } fn oauth2_access_token<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.oauth2_access_token(), &mut self.mapper, )) } fn oauth2_refresh_token<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.oauth2_refresh_token(), &mut self.mapper, )) } fn oauth2_device_code_grant<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.oauth2_device_code_grant(), &mut self.mapper, )) } fn compat_session<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) } fn compat_sso_login<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) } fn compat_access_token<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.compat_access_token(), &mut self.mapper, )) } fn compat_refresh_token<'c>( &'c mut self, ) -> Box + 'c> { Box::new(MapErr::new( self.inner.compat_refresh_token(), &mut self.mapper, )) } fn job<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.job(), &mut self.mapper)) } } impl RepositoryAccess for Box { type Error = R::Error; fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { (**self).upstream_oauth_link() } fn upstream_oauth_provider<'c>( &'c mut self, ) -> Box + 'c> { (**self).upstream_oauth_provider() } fn upstream_oauth_session<'c>( &'c mut self, ) -> Box + 'c> { (**self).upstream_oauth_session() } fn user<'c>(&'c mut self) -> Box + 'c> { (**self).user() } fn user_email<'c>(&'c mut self) -> Box + 'c> { (**self).user_email() } fn user_password<'c>( &'c mut self, ) -> Box + 'c> { (**self).user_password() } fn user_recovery<'c>( &'c mut self, ) -> Box + 'c> { (**self).user_recovery() } fn user_terms<'c>(&'c mut self) -> Box + 'c> { (**self).user_terms() } fn browser_session<'c>( &'c mut self, ) -> Box + 'c> { (**self).browser_session() } fn app_session<'c>( &'c mut self, ) -> Box + 'c> { (**self).app_session() } fn oauth2_client<'c>( &'c mut self, ) -> Box + 'c> { (**self).oauth2_client() } fn oauth2_authorization_grant<'c>( &'c mut self, ) -> Box + 'c> { (**self).oauth2_authorization_grant() } fn oauth2_session<'c>( &'c mut self, ) -> Box + 'c> { (**self).oauth2_session() } fn oauth2_access_token<'c>( &'c mut self, ) -> Box + 'c> { (**self).oauth2_access_token() } fn oauth2_refresh_token<'c>( &'c mut self, ) -> Box + 'c> { (**self).oauth2_refresh_token() } fn oauth2_device_code_grant<'c>( &'c mut self, ) -> Box + 'c> { (**self).oauth2_device_code_grant() } fn compat_session<'c>( &'c mut self, ) -> Box + 'c> { (**self).compat_session() } fn compat_sso_login<'c>( &'c mut self, ) -> Box + 'c> { (**self).compat_sso_login() } fn compat_access_token<'c>( &'c mut self, ) -> Box + 'c> { (**self).compat_access_token() } fn compat_refresh_token<'c>( &'c mut self, ) -> Box + 'c> { (**self).compat_refresh_token() } fn job<'c>(&'c mut self) -> Box + 'c> { (**self).job() } } }