From f4c64c21712829a7c85331f07201e3d7b4797cfd Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 19 Jan 2023 19:10:35 +0100 Subject: [PATCH] storage: ensure the repository trait can be boxed and define some wrappers to map the errors --- crates/graphql/src/model/upstream_oauth.rs | 12 +- crates/graphql/src/model/users.rs | 4 +- crates/handlers/src/oauth2/token.rs | 3 +- crates/storage-pg/src/repository.rs | 126 ++++++---- crates/storage/src/compat/access_token.rs | 26 +- crates/storage/src/compat/refresh_token.rs | 26 +- crates/storage/src/compat/session.rs | 20 +- crates/storage/src/compat/sso_login.rs | 38 ++- crates/storage/src/lib.rs | 54 ++++ crates/storage/src/oauth2/access_token.rs | 28 ++- .../storage/src/oauth2/authorization_grant.rs | 43 +++- crates/storage/src/oauth2/client.rs | 60 ++++- crates/storage/src/oauth2/refresh_token.rs | 26 +- crates/storage/src/oauth2/session.rs | 23 +- crates/storage/src/repository.rs | 236 +++++++++++++----- crates/storage/src/upstream_oauth2/link.rs | 34 ++- .../storage/src/upstream_oauth2/provider.rs | 25 +- crates/storage/src/upstream_oauth2/session.rs | 33 ++- crates/storage/src/user/email.rs | 55 +++- crates/storage/src/user/mod.rs | 14 +- crates/storage/src/user/password.rs | 15 +- crates/storage/src/user/session.rs | 39 ++- crates/tasks/src/database.rs | 3 +- 23 files changed, 801 insertions(+), 142 deletions(-) diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 5767f8d4..d65158c9 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -104,10 +104,12 @@ impl UpstreamOAuth2Link { } else { // Fetch on-the-fly let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - repo.upstream_oauth_provider() + let provider = repo + .upstream_oauth_provider() .lookup(self.link.provider_id) .await? - .context("Upstream OAuth 2.0 provider not found")? + .context("Upstream OAuth 2.0 provider not found")?; + provider }; Ok(UpstreamOAuth2Provider::new(provider)) @@ -121,10 +123,12 @@ impl UpstreamOAuth2Link { } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - repo.user() + let user = repo + .user() .lookup(*user_id) .await? - .context("User not found")? + .context("User not found")?; + user } else { return Ok(None); }; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 3f587eb0..a8036dc8 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -67,7 +67,9 @@ impl User { ) -> Result, async_graphql::Error> { let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - Ok(repo.user_email().get_primary(&self.0).await?.map(UserEmail)) + let mut user_email_repo = repo.user_email(); + + Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) } /// Get the list of compatibility SSO logins, chronologically sorted diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index ed566261..5b6b7565 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -426,7 +426,8 @@ async fn refresh_token_grant( .await?; if let Some(access_token_id) = refresh_token.access_token_id { - if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? { + let access_token = repo.oauth2_access_token().lookup(access_token_id).await?; + if let Some(access_token) = access_token { repo.oauth2_access_token() .revoke(clock, access_token) .await?; diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 288181a6..54002755 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -12,7 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use mas_storage::Repository; +use mas_storage::{ + compat::{ + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, + }, + oauth2::{ + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, + }, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Repository, +}; use sqlx::{PgPool, Postgres, Transaction}; use crate::{ @@ -59,84 +74,95 @@ impl PgRepository { impl Repository for PgRepository { type Error = DatabaseError; - type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; - type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; - type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; - type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; - type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; - type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; - type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; - type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; - type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; - type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; - type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; - type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; - type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; - type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; - type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(&mut self.txn) + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthLinkRepository::new(&mut self.txn)) } - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(&mut self.txn) + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthProviderRepository::new(&mut self.txn)) } - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(&mut self.txn) + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUpstreamOAuthSessionRepository::new(&mut self.txn)) } - fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(&mut self.txn) + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgUserRepository::new(&mut self.txn)) } - fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(&mut self.txn) + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgUserEmailRepository::new(&mut self.txn)) } - fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(&mut self.txn) + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUserPasswordRepository::new(&mut self.txn)) } - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(&mut self.txn) + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgBrowserSessionRepository::new(&mut self.txn)) } - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(&mut self.txn) + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2ClientRepository::new(&mut self.txn)) } - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)) } - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(&mut self.txn) + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2SessionRepository::new(&mut self.txn)) } - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(&mut self.txn) + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2AccessTokenRepository::new(&mut self.txn)) } - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(&mut self.txn) + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgOAuth2RefreshTokenRepository::new(&mut self.txn)) } - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(&mut self.txn) + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatSessionRepository::new(&mut self.txn)) } - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(&mut self.txn) + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatSsoLoginRepository::new(&mut self.txn)) } - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(&mut self.txn) + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatAccessTokenRepository::new(&mut self.txn)) } - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(&mut self.txn) + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgCompatRefreshTokenRepository::new(&mut self.txn)) } } diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs index 32ba1f73..c6d4eb7f 100644 --- a/crates/storage/src/compat/access_token.rs +++ b/crates/storage/src/compat/access_token.rs @@ -18,7 +18,7 @@ use mas_data_model::{CompatAccessToken, CompatSession}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait CompatAccessTokenRepository: Send + Sync { @@ -50,3 +50,27 @@ pub trait CompatAccessTokenRepository: Send + Sync { compat_access_token: CompatAccessToken, ) -> Result; } + +repository_impl!(CompatAccessTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result; + + async fn expire( + &mut self, + clock: &dyn Clock, + compat_access_token: CompatAccessToken, + ) -> Result; +); diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs index 627b59a1..3fd916da 100644 --- a/crates/storage/src/compat/refresh_token.rs +++ b/crates/storage/src/compat/refresh_token.rs @@ -17,7 +17,7 @@ use mas_data_model::{CompatAccessToken, CompatRefreshToken, CompatSession}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait CompatRefreshTokenRepository: Send + Sync { @@ -49,3 +49,27 @@ pub trait CompatRefreshTokenRepository: Send + Sync { compat_refresh_token: CompatRefreshToken, ) -> Result; } + +repository_impl!(CompatRefreshTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result; +); diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 0c5bc125..f867a332 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{CompatSession, Device, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait CompatSessionRepository: Send + Sync { @@ -42,3 +42,21 @@ pub trait CompatSessionRepository: Send + Sync { compat_session: CompatSession, ) -> Result; } + +repository_impl!(CompatSessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + device: Device, + ) -> Result; + + async fn finish( + &mut self, + clock: &dyn Clock, + compat_session: CompatSession, + ) -> Result; +); diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 1ed3e5d8..a6fa0735 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -18,7 +18,7 @@ use rand_core::RngCore; use ulid::Ulid; use url::Url; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait CompatSsoLoginRepository: Send + Sync { @@ -64,3 +64,39 @@ pub trait CompatSsoLoginRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } + +repository_impl!(CompatSsoLoginRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + login_token: String, + redirect_uri: Url, + ) -> Result; + + async fn fulfill( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result; + + async fn exchange( + &mut self, + clock: &dyn Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index d5a45372..0cdc4e39 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -45,5 +45,59 @@ pub use self::{ repository::Repository, }; +pub struct MapErr { + inner: Repository, + mapper: Mapper, +} + +impl MapErr { + fn new(inner: Repository, mapper: Mapper) -> Self { + Self { inner, mapper } + } +} + +#[macro_export] +macro_rules! repository_impl { + ($repo_trait:ident: + $( + async fn $method:ident ( + &mut self + $(, $arg:ident: $arg_ty:ty )* + $(,)? + ) -> Result<$ret_ty:ty, Self::Error>; + )* + ) => { + #[::async_trait::async_trait] + impl $repo_trait for ::std::boxed::Box + where + R: $repo_trait, + { + type Error = ::Error; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + (**self).$method ( $($arg),* ).await + } + )* + } + + #[::async_trait::async_trait] + impl $repo_trait for $crate::MapErr + where + R: $repo_trait, + F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, + E: ::std::error::Error + ::std::marker::Send + ::std::marker::Sync, + { + type Error = E; + + $( + async fn $method (&mut self $(, $arg: $arg_ty)*) -> Result<$ret_ty, Self::Error> { + self.inner.$method ( $($arg),* ).await.map_err(&mut self.mapper) + } + )* + } + }; +} + pub type BoxClock = Box; pub type BoxRng = Box; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 1148136f..8a536243 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -18,7 +18,7 @@ use mas_data_model::{AccessToken, Session}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2AccessTokenRepository: Send + Sync { @@ -53,3 +53,29 @@ pub trait OAuth2AccessTokenRepository: Send + Sync { /// Cleanup expired access tokens async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; } + +repository_impl!(OAuth2AccessTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: String, + expires_after: Duration, + ) -> Result; + + async fn revoke( + &mut self, + clock: &dyn Clock, + access_token: AccessToken, + ) -> Result; + + async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result; +); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 1130e6a8..8852f796 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -21,7 +21,7 @@ use rand_core::RngCore; use ulid::Ulid; use url::Url; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2AuthorizationGrantRepository: Send + Sync { @@ -67,3 +67,44 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { authorization_grant: AuthorizationGrant, ) -> Result; } + +repository_impl!(OAuth2AuthorizationGrantRepository: + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + redirect_uri: Url, + scope: Scope, + code: Option, + state: Option, + nonce: Option, + max_age: Option, + response_mode: ResponseMode, + response_type_id_token: bool, + requires_consent: bool, + ) -> Result; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_code(&mut self, code: &str) + -> Result, Self::Error>; + + async fn fulfill( + &mut self, + clock: &dyn Clock, + session: &Session, + authorization_grant: AuthorizationGrant, + ) -> Result; + + async fn exchange( + &mut self, + clock: &dyn Clock, + authorization_grant: AuthorizationGrant, + ) -> Result; + + async fn give_consent( + &mut self, + authorization_grant: AuthorizationGrant, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 3c7d7dbb..98acaaf7 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -23,7 +23,7 @@ use rand_core::RngCore; use ulid::Ulid; use url::Url; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2ClientRepository: Send + Sync { @@ -92,3 +92,61 @@ pub trait OAuth2ClientRepository: Send + Sync { scope: &Scope, ) -> Result<(), Self::Error>; } + +repository_impl!(OAuth2ClientRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result; + + async fn add_from_config( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result; + + async fn get_consent_for_user( + &mut self, + client: &Client, + user: &User, + ) -> Result; + + async fn give_consent_for_user( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + client: &Client, + user: &User, + scope: &Scope, + ) -> Result<(), Self::Error>; +); diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 66ec2c32..e8ac63ce 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -17,7 +17,7 @@ use mas_data_model::{AccessToken, RefreshToken, Session}; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait OAuth2RefreshTokenRepository: Send + Sync { @@ -49,3 +49,27 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync { refresh_token: RefreshToken, ) -> Result; } + +repository_impl!(OAuth2RefreshTokenRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + session: &Session, + access_token: &AccessToken, + refresh_token: String, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + refresh_token: RefreshToken, + ) -> Result; +); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 3813810b..f348d9e6 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{AuthorizationGrant, BrowserSession, Session, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait OAuth2SessionRepository: Send + Sync { @@ -42,3 +42,24 @@ pub trait OAuth2SessionRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } + +repository_impl!(OAuth2SessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn create_from_grant( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + grant: &AuthorizationGrant, + user_session: &BrowserSession, + ) -> Result; + + async fn finish(&mut self, clock: &dyn Clock, session: Session) + -> Result; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 55afe41b..085c06ab 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -26,92 +26,192 @@ use crate::{ UpstreamOAuthSessionRepository, }, user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + MapErr, }; pub trait Repository: Send { type Error: std::error::Error + Send + Sync + 'static; - type UpstreamOAuthLinkRepository<'c>: UpstreamOAuthLinkRepository + 'c - where - Self: 'c; + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c>; - type UpstreamOAuthProviderRepository<'c>: UpstreamOAuthProviderRepository - + 'c - where - Self: 'c; + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c>; - type UpstreamOAuthSessionRepository<'c>: UpstreamOAuthSessionRepository - + 'c - where - Self: 'c; + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type UserRepository<'c>: UserRepository + 'c - where - Self: 'c; + fn user<'c>(&'c mut self) -> Box + 'c>; - type UserEmailRepository<'c>: UserEmailRepository + 'c - where - Self: 'c; + fn user_email<'c>(&'c mut self) -> Box + 'c>; - type UserPasswordRepository<'c>: UserPasswordRepository + 'c - where - Self: 'c; + fn user_password<'c>(&'c mut self) + -> Box + 'c>; - type BrowserSessionRepository<'c>: BrowserSessionRepository + 'c - where - Self: 'c; + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2ClientRepository<'c>: OAuth2ClientRepository + 'c - where - Self: 'c; + fn oauth2_client<'c>(&'c mut self) + -> Box + 'c>; - type OAuth2AuthorizationGrantRepository<'c>: OAuth2AuthorizationGrantRepository - + 'c - where - Self: 'c; + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2SessionRepository<'c>: OAuth2SessionRepository + 'c - where - Self: 'c; + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2AccessTokenRepository<'c>: OAuth2AccessTokenRepository + 'c - where - Self: 'c; + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c>; - type OAuth2RefreshTokenRepository<'c>: OAuth2RefreshTokenRepository + 'c - where - Self: 'c; + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatSessionRepository<'c>: CompatSessionRepository + 'c - where - Self: 'c; + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatSsoLoginRepository<'c>: CompatSsoLoginRepository + 'c - where - Self: 'c; + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatAccessTokenRepository<'c>: CompatAccessTokenRepository + 'c - where - Self: 'c; + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c>; - type CompatRefreshTokenRepository<'c>: CompatRefreshTokenRepository + 'c - where - Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; - fn user(&mut self) -> Self::UserRepository<'_>; - fn user_email(&mut self) -> Self::UserEmailRepository<'_>; - fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_>; - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_>; - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_>; - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>; - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>; - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c>; +} + +impl Repository for crate::MapErr +where + R: Repository, + F: FnMut(R::Error) -> E + Send + Sync, + E: std::error::Error + Send + Sync + 'static, +{ + type Error = E; + + fn upstream_oauth_link<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_link(), + &mut self.mapper, + )) + } + + fn upstream_oauth_provider<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_provider(), + &mut self.mapper, + )) + } + + fn upstream_oauth_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.upstream_oauth_session(), + &mut self.mapper, + )) + } + + fn user<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user(), &mut self.mapper)) + } + + fn user_email<'c>(&'c mut self) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper)) + } + + fn user_password<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) + } + + fn browser_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) + } + + fn oauth2_client<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper)) + } + + fn oauth2_authorization_grant<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_authorization_grant(), + &mut self.mapper, + )) + } + + fn oauth2_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper)) + } + + fn oauth2_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_access_token(), + &mut self.mapper, + )) + } + + fn oauth2_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.oauth2_refresh_token(), + &mut self.mapper, + )) + } + + fn compat_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper)) + } + + fn compat_sso_login<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper)) + } + + fn compat_access_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_access_token(), + &mut self.mapper, + )) + } + + fn compat_refresh_token<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new( + self.inner.compat_refresh_token(), + &mut self.mapper, + )) + } } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index bf9e0aad..c5e024af 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -17,11 +17,11 @@ use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { - type Error; + type Error: std::error::Error + Send + Sync; /// Lookup an upstream OAuth link by its ID async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; @@ -56,3 +56,33 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { pagination: Pagination, ) -> Result, Self::Error>; } + +repository_impl!(UpstreamOAuthLinkRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result; + + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error>; + + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; +); diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 521a7e7a..8aaca0da 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -19,7 +19,7 @@ use oauth2_types::scope::Scope; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthProviderRepository: Send + Sync { @@ -51,3 +51,26 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Get all upstream OAuth providers async fn all(&mut self) -> Result, Self::Error>; } + +repository_impl!(UpstreamOAuthProviderRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option + ) -> Result; + + async fn list_paginated( + &mut self, + pagination: Pagination + ) -> Result, Self::Error>; + + async fn all(&mut self) -> Result, Self::Error>; +); diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index f4441b2a..e878444b 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, Upstr use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait UpstreamOAuthSessionRepository: Send + Sync { @@ -56,3 +56,34 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, ) -> Result; } + +repository_impl!(UpstreamOAuthSessionRepository: + async fn lookup( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, + ) -> Result; + + async fn complete_with_link( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + upstream_oauth_link: &UpstreamOAuthLink, + id_token: Option, + ) -> Result; + + async fn consume( + &mut self, + clock: &dyn Clock, + upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, + ) -> Result; +); diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 65ee465b..4c8601c2 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -17,7 +17,7 @@ use mas_data_model::{User, UserEmail, UserEmailVerification}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UserEmailRepository: Send + Sync { @@ -74,3 +74,56 @@ pub trait UserEmailRepository: Send + Sync { verification: UserEmailVerification, ) -> Result; } + +repository_impl!(UserEmailRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find(&mut self, user: &User, email: &str) -> Result, Self::Error>; + async fn get_primary(&mut self, user: &User) -> Result, Self::Error>; + + async fn all(&mut self, user: &User) -> Result, Self::Error>; + async fn list_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + async fn count(&mut self, user: &User) -> Result; + + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + email: String, + ) -> Result; + async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>; + + async fn mark_as_verified( + &mut self, + clock: &dyn Clock, + user_email: UserEmail, + ) -> Result; + + async fn set_as_primary(&mut self, user_email: &UserEmail) -> Result<(), Self::Error>; + + async fn add_verification_code( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_email: &UserEmail, + max_age: chrono::Duration, + code: String, + ) -> Result; + + async fn find_verification_code( + &mut self, + clock: &dyn Clock, + user_email: &UserEmail, + code: &str, + ) -> Result, Self::Error>; + + async fn consume_verification_code( + &mut self, + clock: &dyn Clock, + verification: UserEmailVerification, + ) -> Result; +); diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index b3bd0bc2..49003335 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -17,7 +17,7 @@ use mas_data_model::User; use rand_core::RngCore; use ulid::Ulid; -use crate::Clock; +use crate::{repository_impl, Clock}; mod email; mod password; @@ -41,3 +41,15 @@ pub trait UserRepository: Send + Sync { ) -> Result; async fn exists(&mut self, username: &str) -> Result; } + +repository_impl!(UserRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn find_by_username(&mut self, username: &str) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + username: String, + ) -> Result; + async fn exists(&mut self, username: &str) -> Result; +); diff --git a/crates/storage/src/user/password.rs b/crates/storage/src/user/password.rs index 609198b2..06f03f55 100644 --- a/crates/storage/src/user/password.rs +++ b/crates/storage/src/user/password.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use mas_data_model::{Password, User}; use rand_core::RngCore; -use crate::Clock; +use crate::{repository_impl, Clock}; #[async_trait] pub trait UserPasswordRepository: Send + Sync { @@ -33,3 +33,16 @@ pub trait UserPasswordRepository: Send + Sync { upgraded_from: Option<&Password>, ) -> Result; } + +repository_impl!(UserPasswordRepository: + async fn active(&mut self, user: &User) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + version: u16, + hashed_password: String, + upgraded_from: Option<&Password>, + ) -> Result; +); diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index 5556547c..0dfc581c 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -17,7 +17,7 @@ use mas_data_model::{BrowserSession, Password, UpstreamOAuthLink, User}; use rand_core::RngCore; use ulid::Ulid; -use crate::{pagination::Page, Clock, Pagination}; +use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait BrowserSessionRepository: Send + Sync { @@ -58,3 +58,40 @@ pub trait BrowserSessionRepository: Send + Sync { upstream_oauth_link: &UpstreamOAuthLink, ) -> Result; } + +repository_impl!(BrowserSessionRepository: + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user: &User, + ) -> Result; + async fn finish( + &mut self, + clock: &dyn Clock, + user_session: BrowserSession, + ) -> Result; + async fn list_active_paginated( + &mut self, + user: &User, + pagination: Pagination, + ) -> Result, Self::Error>; + async fn count_active(&mut self, user: &User) -> Result; + + async fn authenticate_with_password( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + user_password: &Password, + ) -> Result; + + async fn authenticate_with_upstream( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_session: BrowserSession, + upstream_oauth_link: &UpstreamOAuthLink, + ) -> Result; +); diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index 9e31880c..e7947ce2 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -35,7 +35,8 @@ impl Task for CleanupExpired { async fn run(&self) { let res = async move { let mut repo = PgRepository::from_pool(&self.0).await?; - repo.oauth2_access_token().cleanup_expired(&self.1).await + let res = repo.oauth2_access_token().cleanup_expired(&self.1).await; + res } .await;