diff --git a/Cargo.lock b/Cargo.lock index 253809cf..87dd60de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2804,11 +2804,10 @@ dependencies = [ "chrono", "mas-data-model", "mas-storage", - "mas-storage-pg", "oauth2-types", "serde", - "sqlx", "thiserror", + "tokio", "tracing", "ulid", "url", @@ -3101,6 +3100,7 @@ version = "0.1.0" dependencies = [ "async-trait", "chrono", + "futures-util", "mas-data-model", "mas-iana", "mas-jose", @@ -3117,6 +3117,7 @@ version = "0.1.0" dependencies = [ "async-trait", "chrono", + "futures-util", "mas-data-model", "mas-iana", "mas-jose", diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 09090230..67930c3f 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -72,10 +72,10 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result, R::Error> - where - R: Repository, - { + pub async fn fetch( + &self, + repo: &mut (impl Repository + ?Sized), + ) -> Result, E> { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index 71961367..5e966152 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -43,10 +43,10 @@ impl SessionInfo { } /// Load the [`BrowserSession`] from database - pub async fn load_session( + pub async fn load_session( &self, - repo: &mut R, - ) -> Result, R::Error> { + repo: &mut (impl Repository + ?Sized), + ) -> Result, E> { let session_id = if let Some(id) = self.current { id } else { diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 9a5956c9..2d37c40c 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -51,11 +51,10 @@ enum AccessToken { } impl AccessToken { - async fn fetch( + async fn fetch( &self, - repo: &mut R, - ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> - { + repo: &mut (impl Repository + ?Sized), + ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), @@ -85,11 +84,11 @@ pub struct UserAuthorization { impl UserAuthorization { // TODO: take scopes to validate as parameter - pub async fn protected_form( + pub async fn protected_form( self, - repo: &mut R, - clock: &C, - ) -> Result<(Session, F), AuthorizationVerificationError> { + repo: &mut (impl Repository + ?Sized), + clock: &impl Clock, + ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), @@ -105,11 +104,11 @@ impl UserAuthorization { } // TODO: take scopes to validate as parameter - pub async fn protected( + pub async fn protected( self, - repo: &mut R, - clock: &C, - ) -> Result> { + repo: &mut (impl Repository + ?Sized), + clock: &impl Clock, + ) -> Result> { let (token, session) = self.access_token.fetch(repo).await?; if !token.is_valid(clock.now()) || !session.is_valid() { diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 2f3e8852..4e74569a 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -203,7 +203,7 @@ impl Options { let pool = database_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); let user = repo .user() .find_by_username(username) @@ -234,7 +234,7 @@ impl Options { let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); let user = repo .user() @@ -262,7 +262,7 @@ impl Options { let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); - let mut repo = PgRepository::from_pool(&pool).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); for client in config.clients.iter() { let client_id = client.client_id; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 1a7e39e6..00230953 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -102,7 +102,7 @@ impl Options { watch_templates(&templates).await?; } - let graphql_schema = mas_handlers::graphql_schema(&pool); + let graphql_schema = mas_handlers::graphql_schema(); // Maximum 50 outgoing HTTP requests at a time let http_client_factory = HttpClientFactory::new(50); diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index 16f7e5b5..1f8bda4e 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -10,7 +10,7 @@ anyhow = "1.0.68" async-graphql = { version = "5.0.4", features = ["chrono", "url"] } chrono = "0.4.23" serde = { version = "1.0.152", features = ["derive"] } -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } +tokio = { version = "1.23.0", features = ["sync"] } thiserror = "1.0.38" tracing = "0.1.37" ulid = "1.0.0" @@ -19,7 +19,6 @@ url = "2.3.1" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } mas-storage = { path = "../storage" } -mas-storage-pg = { path = "../storage-pg" } [[bin]] name = "schema" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 159387ae..ca374565 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -34,11 +34,10 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, Repository, + BoxRepository, Pagination, }; -use mas_storage_pg::PgRepository; use model::CreationEvent; -use sqlx::PgPool; +use tokio::sync::Mutex; use self::model::{ BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, @@ -94,7 +93,7 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::OAuth2Client.extract_ulid(&id)?; - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let client = repo.oauth2_client().lookup(id).await?; @@ -124,7 +123,7 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::BrowserSession.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -150,7 +149,7 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UserEmail.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -172,7 +171,7 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -192,7 +191,7 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let provider = repo.upstream_oauth_provider().lookup(id).await?; @@ -211,7 +210,7 @@ impl RootQuery { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index e5cd66bc..38fdd4ba 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,9 +15,8 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository}; +use tokio::sync::Mutex; use url::Url; use super::{NodeType, User}; @@ -36,7 +35,7 @@ impl CompatSession { /// The user authorized for this session. async fn user(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let user = repo .user() .lookup(self.0.user_id) @@ -101,7 +100,7 @@ impl CompatSsoLogin { ) -> Result, async_graphql::Error> { let Some(session_id) = self.0.session_id() else { return Ok(None) }; - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let session = repo .compat_session() .lookup(session_id) diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 90a0c6b7..19612f6d 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,10 +14,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository}; use oauth2_types::scope::Scope; -use sqlx::PgPool; +use tokio::sync::Mutex; use ulid::Ulid; use url::Url; @@ -37,7 +36,7 @@ impl OAuth2Session { /// OAuth 2.0 client used by this session. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let client = repo .oauth2_client() .lookup(self.0.client_id) @@ -57,7 +56,7 @@ impl OAuth2Session { &self, ctx: &Context<'_>, ) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let browser_session = repo .browser_session() .lookup(self.0.user_session_id) @@ -69,7 +68,7 @@ impl OAuth2Session { /// User authorized for this session. pub async fn user(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let browser_session = repo .browser_session() .lookup(self.0.user_session_id) @@ -139,7 +138,7 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let client = repo .oauth2_client() .lookup(self.client_id) diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index d65158c9..76b3a44a 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -16,10 +16,9 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository, }; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use tokio::sync::Mutex; use super::{NodeType, User}; @@ -103,7 +102,7 @@ impl UpstreamOAuth2Link { provider.clone() } else { // Fetch on-the-fly - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let provider = repo .upstream_oauth_provider() .lookup(self.link.provider_id) @@ -122,7 +121,7 @@ impl UpstreamOAuth2Link { user.clone() } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let user = repo .user() .lookup(*user_id) diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index a8036dc8..35c2cae4 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -22,10 +22,9 @@ use mas_storage::{ oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Pagination, Repository, + BoxRepository, Pagination, }; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use tokio::sync::Mutex; use super::{ compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, @@ -65,10 +64,9 @@ impl User { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let mut user_email_repo = repo.user_email(); - Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) } @@ -84,7 +82,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -131,7 +129,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -178,7 +176,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -229,7 +227,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -276,7 +274,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; query( after, @@ -350,7 +348,7 @@ pub struct UserEmailsPagination(mas_data_model::User); impl UserEmailsPagination { /// Identifies the total count of items in the connection. async fn total_count(&self, ctx: &Context<'_>) -> Result { - let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let mut repo = ctx.data::>()?.lock().await; let count = repo.user_email().count(&self.0).await?; Ok(count) } diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 4271f896..2e826bad 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -25,7 +25,7 @@ use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRng, SystemClock}; +use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage_pg::PgRepository; use mas_templates::Templates; use rand::SeedableRng; @@ -156,7 +156,7 @@ impl IntoResponse for RepositoryError { } #[async_trait] -impl FromRequestParts for PgRepository { +impl FromRequestParts for BoxRepository { type Rejection = RepositoryError; async fn from_request_parts( @@ -164,6 +164,8 @@ impl FromRequestParts for PgRepository { state: &AppState, ) -> Result { let repo = PgRepository::from_pool(&state.pool).await?; - Ok(repo) + Ok(repo + .map_err(mas_storage::RepositoryError::from_error) + .boxed()) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f76cda71..07077ea1 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,9 +22,8 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; @@ -154,7 +153,7 @@ pub enum RouteError { InvalidLoginToken, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -196,7 +195,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(homeserver): State, Json(input): Json, ) -> Result { @@ -262,7 +261,7 @@ pub(crate) async fn post( } async fn token_login( - repo: &mut PgRepository, + repo: &mut BoxRepository, clock: &dyn Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { @@ -331,7 +330,7 @@ async fn user_password_login( mut rng: &mut (impl RngCore + CryptoRng + Send), clock: &impl Clock, password_manager: &PasswordManager, - repo: &mut PgRepository, + repo: &mut BoxRepository, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 602b4d80..ba3dee13 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,9 +31,8 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; use ulid::Ulid; @@ -55,7 +54,7 @@ pub struct Params { pub async fn get( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, @@ -64,7 +63,7 @@ pub async fn get( let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -117,7 +116,7 @@ pub async fn get( pub async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, @@ -127,7 +126,7 @@ pub async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(&clock, form)?; - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index d8ef0fb2..da013cf7 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,8 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{compat::CompatSsoLoginRepository, BoxClock, BoxRepository, BoxRng}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -48,7 +47,7 @@ pub enum RouteError { InvalidRedirectUrl, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -59,7 +58,7 @@ impl IntoResponse for RouteError { pub async fn get( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(url_builder): State, Query(params): Query, ) -> Result { diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index e1ef02be..096b22de 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,9 +18,8 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - BoxClock, Clock, Repository, + BoxClock, BoxRepository, Clock, }; -use mas_storage_pg::PgRepository; use thiserror::Error; use super::MatrixError; @@ -41,7 +40,7 @@ pub enum RouteError { InvalidAuthorization, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -68,7 +67,7 @@ impl IntoResponse for RouteError { pub(crate) async fn post( clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, maybe_authorization: Option>>, ) -> Result { let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 6b90464e..eb970c57 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,9 +18,8 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; use thiserror::Error; @@ -69,7 +68,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -89,7 +88,7 @@ pub struct ResponseBody { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, Json(input): Json, ) -> Result { let token_type = TokenType::check(&input.refresh_token)?; diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index fcc6aa3c..2d1f7fcc 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -22,20 +22,19 @@ use axum::{ Json, TypedHeader, }; use axum_extra::extract::PrivateCookieJar; -use futures_util::{StreamExt, TryStreamExt}; +use futures_util::TryStreamExt; use headers::{ContentType, HeaderValue}; use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; -use mas_storage_pg::PgRepository; -use sqlx::PgPool; +use mas_storage::BoxRepository; +use tokio::sync::Mutex; use tracing::{info_span, Instrument}; #[must_use] -pub fn schema(pool: &PgPool) -> Schema { +pub fn schema() -> Schema { mas_graphql::schema_builder() - .data(pool.clone()) .extension(Tracing) .extension(ApolloTracing) .finish() @@ -59,8 +58,8 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span { } pub async fn post( - State(pool): State, State(schema): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, content_type: Option>, body: BodyStream, @@ -68,62 +67,46 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut repo = PgRepository::from_pool(&pool).await?; - let maybe_session = session_info.load_session(&mut repo).await?; - repo.cancel().await?; + let maybe_session = session_info.load_session(&mut *repo).await?; - let mut request = async_graphql::http::receive_batch_body( + let mut request = async_graphql::http::receive_body( content_type, body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .into_async_read(), MultipartOptions::default(), ) - .await?; // XXX: this should probably return another error response? + .await? // XXX: this should probably return another error response? + .data(Mutex::new(repo)); if let Some(session) = maybe_session { request = request.data(session); } - let response = match request { - async_graphql::BatchRequest::Single(request) => { - let span = span_for_graphql_request(&request); - let response = schema.execute(request).instrument(span).await; - async_graphql::BatchResponse::Single(response) - } - async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch( - futures_util::stream::iter(requests.into_iter()) - .then(|request| { - let span = span_for_graphql_request(&request); - schema.execute(request).instrument(span) - }) - .collect() - .await, - ), - }; + let span = span_for_graphql_request(&request); + let response = schema.execute(request).instrument(span).await; let cache_control = response - .cache_control() + .cache_control .value() .and_then(|v| HeaderValue::from_str(&v).ok()) .map(|h| [(CACHE_CONTROL, h)]); - let headers = response.http_headers(); + let headers = response.http_headers.clone(); Ok((headers, cache_control, Json(response))) } pub async fn get( - State(pool): State, State(schema): State, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut repo = PgRepository::from_pool(&pool).await?; - let maybe_session = session_info.load_session(&mut repo).await?; - repo.cancel().await?; + let maybe_session = session_info.load_session(&mut *repo).await?; - let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; + let mut request = + async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo)); if let Some(session) = maybe_session { request = request.data(session); diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 48ca5560..30519f42 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -43,8 +43,7 @@ use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; -use mas_storage::{BoxClock, BoxRng}; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; use sqlx::PgPool; @@ -98,7 +97,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, mas_graphql::Schema: FromRef, - PgPool: FromRef, + BoxRepository: FromRequestParts, Encrypter: FromRef, { let mut router = Router::new().route( @@ -158,7 +157,7 @@ where Keystore: FromRef, UrlBuilder: FromRef, Arc: FromRef, - PgRepository: FromRequestParts, + BoxRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, BoxClock: FromRequestParts, @@ -213,7 +212,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - PgRepository: FromRequestParts, + BoxRepository: FromRequestParts, MatrixHomeserver: FromRef, PasswordManager: FromRef, BoxClock: FromRequestParts, @@ -258,7 +257,7 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, Arc: FromRef, - PgRepository: FromRequestParts, + BoxRepository: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Mailer: FromRef, @@ -401,7 +400,7 @@ async fn test_state(pool: sqlx::PgPool) -> Result { let policy_factory = Arc::new(policy_factory); - let graphql_schema = graphql_schema(&pool); + let graphql_schema = graphql_schema(); let http_client_factory = HttpClientFactory::new(10); diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index c17fb9f1..91121df9 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,9 +27,8 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use thiserror::Error; @@ -69,7 +68,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -81,13 +80,13 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let grant = repo .oauth2_authorization_grant() @@ -147,7 +146,7 @@ pub enum GrantCompletionError { NoSuchClient, } -impl_from_error_for_route!(GrantCompletionError: mas_storage_pg::DatabaseError); +impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); @@ -159,7 +158,7 @@ pub(crate) async fn complete( grant: AuthorizationGrant, browser_session: BrowserSession, policy_factory: &PolicyFactory, - mut repo: PgRepository, + mut repo: BoxRepository, ) -> Result>, GrantCompletionError> { // Verify that the grant is in a pending stage if !grant.stage.is_pending() { diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 30efcaa3..4ce10baa 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,9 +27,8 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -90,7 +89,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); @@ -135,7 +134,7 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { @@ -168,7 +167,7 @@ pub(crate) async fn get( let templates = templates.clone(); let callback_destination = callback_destination.clone(); async move { - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); // Check if the request/request_uri/registration params are used. If so, reply diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index c83dca03..f365a9f3 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,9 +30,8 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use thiserror::Error; use ulid::Ulid; @@ -61,7 +60,7 @@ pub enum RouteError { } impl_from_error_for_route!(mas_templates::TemplateError); -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -77,13 +76,13 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let grant = repo .oauth2_authorization_grant() @@ -130,7 +129,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(policy_factory): State>, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, @@ -139,7 +138,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let grant = repo .oauth2_authorization_grant() diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 65e48e06..d0dcd26c 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,9 +25,8 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - BoxClock, Clock, Repository, + BoxClock, BoxRepository, Clock, }; -use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, requests::{IntrospectionRequest, IntrospectionResponse}, @@ -96,7 +95,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -125,13 +124,13 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc pub(crate) async fn post( clock: BoxClock, State(http_client_factory): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { let client = client_authorization .credentials - .fetch(&mut repo) + .fetch(&mut *repo) .await .unwrap() .ok_or(RouteError::ClientNotFound)?; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 129f636f..650a19ab 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,8 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -48,7 +47,7 @@ pub(crate) enum RouteError { PolicyDenied(Vec), } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -108,7 +107,7 @@ impl IntoResponse for RouteError { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(policy_factory): State>, State(encrypter): State, Json(body): Json, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 5b6b7565..76943e7e 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,9 +37,8 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, pkce::CodeChallengeError, @@ -150,7 +149,7 @@ impl IntoResponse for RouteError { } } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::claims::ClaimError); impl_from_error_for_route!(mas_jose::claims::TokenHashError); @@ -163,13 +162,13 @@ pub(crate) async fn post( State(http_client_factory): State, State(key_store): State, State(url_builder): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { let client = client_authorization .credentials - .fetch(&mut repo) + .fetch(&mut *repo) .await? .ok_or(RouteError::ClientNotFound)?; @@ -185,7 +184,7 @@ pub(crate) async fn post( let form = client_authorization.form.ok_or(RouteError::BadRequest)?; - let reply = match form { + let (reply, repo) = match form { AccessTokenRequest::AuthorizationCode(grant) => { authorization_code_grant( &mut rng, @@ -206,6 +205,8 @@ pub(crate) async fn post( } }; + repo.save().await?; + let mut headers = HeaderMap::new(); headers.typed_insert(CacheControl::new().with_no_store()); headers.typed_insert(Pragma::no_cache()); @@ -221,8 +222,8 @@ async fn authorization_code_grant( client: &Client, key_store: &Keystore, url_builder: &UrlBuilder, - mut repo: PgRepository, -) -> Result { + mut repo: BoxRepository, +) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { let authz_grant = repo .oauth2_authorization_grant() .find_by_code(&grant.code) @@ -367,9 +368,7 @@ async fn authorization_code_grant( .exchange(clock, authz_grant) .await?; - repo.save().await?; - - Ok(params) + Ok((params, repo)) } async fn refresh_token_grant( @@ -377,8 +376,8 @@ async fn refresh_token_grant( clock: &impl Clock, grant: &RefreshTokenGrant, client: &Client, - mut repo: PgRepository, -) -> Result { + mut repo: BoxRepository, +) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { let refresh_token = repo .oauth2_refresh_token() .find_by_token(&grant.refresh_token) @@ -439,7 +438,5 @@ async fn refresh_token_grant( .with_refresh_token(new_refresh_token.refresh_token) .with_scope(session.scope); - repo.save().await?; - - Ok(params) + Ok((params, repo)) } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index eb9e1cc2..e56dafbc 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,9 +31,8 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; @@ -65,7 +64,7 @@ pub enum RouteError { #[error("failed to authenticate")] AuthorizationVerificationError( - #[from] AuthorizationVerificationError, + #[from] AuthorizationVerificationError, ), #[error("no suitable key found for signing")] @@ -78,7 +77,7 @@ pub enum RouteError { NoSuchBrowserSession, } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); @@ -100,11 +99,11 @@ pub async fn get( mut rng: BoxRng, clock: BoxClock, State(url_builder): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let session = user_authorization.protected(&mut repo, &clock).await?; + let session = user_authorization.protected(&mut *repo, &clock).await?; let browser_session = repo .browser_session() diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index ff47084b..8da6231a 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,9 +24,8 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use thiserror::Error; use ulid::Ulid; @@ -45,7 +44,7 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -60,7 +59,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(url_builder): State, cookie_jar: PrivateCookieJar, Path(provider_id): Path, diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index b324cfb2..bc24c399 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -30,9 +30,8 @@ use mas_storage::{ UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; use thiserror::Error; @@ -99,7 +98,7 @@ pub(crate) enum RouteError { Internal(Box), } -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::JwksError); @@ -123,7 +122,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(url_builder): State, State(encrypter): State, State(keystore): State, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index bdd5df1f..89614dcb 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -27,9 +27,8 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, @@ -72,7 +71,7 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); -impl_from_error_for_route!(mas_storage_pg::DatabaseError); +impl_from_error_for_route!(mas_storage::RepositoryError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -95,7 +94,7 @@ pub(crate) enum FormData { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, @@ -129,7 +128,7 @@ pub(crate) async fn get( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let maybe_user_session = user_session_info.load_session(&mut repo).await?; + let maybe_user_session = user_session_info.load_session(&mut *repo).await?; let render = match (maybe_user_session, link.user_id) { (Some(session), Some(user_id)) if session.user.id == user_id => { @@ -211,7 +210,7 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Path(link_id): Path, Form(form): Form>, @@ -250,7 +249,7 @@ pub(crate) async fn post( } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_user_session = user_session_info.load_session(&mut repo).await?; + let maybe_user_session = user_session_info.load_session(&mut *repo).await?; let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 64218e3a..7b89b2d8 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,8 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; @@ -41,13 +40,13 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -68,7 +67,7 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, @@ -77,7 +76,7 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -99,7 +98,7 @@ pub(crate) async fn post( }; start_email_verification( &mailer, - &mut repo, + &mut *repo, &mut rng, &clock, &session.user, diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index fd2f2981..251e5adb 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,8 +28,7 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Clock, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, Clock, Repository}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; @@ -51,28 +50,28 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if let Some(session) = maybe_session { - render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await + render(&mut rng, &clock, templates, session, cookie_jar, &mut *repo).await } else { let login = mas_router::Login::default(); Ok((cookie_jar, login.go()).into_response()) } } -async fn render( +async fn render( rng: impl Rng + Send, clock: &impl Clock, templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); @@ -87,9 +86,9 @@ async fn render( Ok((cookie_jar, Html(content)).into_response()) } -async fn start_email_verification( +async fn start_email_verification( mailer: &Mailer, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), mut rng: impl Rng + Send, clock: &impl Clock, user: &User, @@ -124,14 +123,14 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let mut session = if let Some(session) = maybe_session { session @@ -150,7 +149,7 @@ pub(crate) async fn post( .await?; let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email) .await?; repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); @@ -169,7 +168,7 @@ pub(crate) async fn post( } let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut *repo, &mut rng, &clock, &session.user, email) .await?; repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); @@ -212,7 +211,7 @@ pub(crate) async fn post( templates.clone(), session, cookie_jar, - &mut repo, + &mut *repo, ) .await?; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index e330c944..6a701b50 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,8 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::UserEmailRepository, BoxClock, BoxRepository, BoxRng}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use ulid::Ulid; @@ -41,7 +40,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, @@ -49,7 +48,7 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -82,7 +81,7 @@ pub(crate) async fn get( pub(crate) async fn post( clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, @@ -91,7 +90,7 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 76ea5667..8860c43c 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,22 +25,21 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 4fa86eae..ddc47779 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,9 +27,8 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, }; -use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; @@ -48,12 +47,12 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if let Some(session) = maybe_session { render(&mut rng, &clock, templates, session, cookie_jar).await @@ -86,7 +85,7 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { @@ -94,7 +93,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index d4322eef..0cfe0d05 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,8 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; -use mas_storage::{BoxClock, BoxRng}; -use mas_storage_pg::PgRepository; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{IndexContext, TemplateContext, Templates}; pub async fn get( @@ -29,12 +28,12 @@ pub async fn get( clock: BoxClock, State(templates): State, State(url_builder): State, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let session = session_info.load_session(&mut repo).await?; + let session = session_info.load_session(&mut *repo).await?; let ctx = IndexContext::new(url_builder.oidc_discovery()) .maybe_with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index b245b597..4f9ecfd8 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,9 +26,8 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRng, Clock, Repository, + BoxClock, BoxRepository, BoxRng, Clock, Repository, }; -use mas_storage_pg::PgRepository; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; @@ -53,14 +52,14 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -71,7 +70,7 @@ pub(crate) async fn get( LoginContext::default().with_upstrem_providers(providers), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -85,7 +84,7 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -117,7 +116,7 @@ pub(crate) async fn post( .with_upstrem_providers(providers), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -127,7 +126,7 @@ pub(crate) async fn post( match login( password_manager, - &mut repo, + &mut *repo, rng, &clock, &form.username, @@ -149,7 +148,7 @@ pub(crate) async fn post( LoginContext::default().with_form_state(state), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -162,7 +161,7 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), mut rng: impl Rng + CryptoRng + Send, clock: &impl Clock, username: &str, @@ -236,7 +235,7 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 9cdc93f0..189331fd 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -20,12 +20,11 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage::{user::BrowserSessionRepository, BoxClock, BoxRepository}; pub(crate) async fn post( clock: BoxClock, - mut repo: PgRepository, + mut repo: BoxRepository, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { @@ -33,7 +32,7 @@ pub(crate) async fn post( let (session_info, mut cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if let Some(session) = maybe_session { repo.browser_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index ced97902..2750711c 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,9 +26,8 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, }; -use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; use zeroize::Zeroizing; @@ -45,14 +44,14 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session @@ -64,7 +63,7 @@ pub(crate) async fn get( }; let ctx = ReauthContext::default(); - let next = query.load_context(&mut repo).await?; + let next = query.load_context(&mut *repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { @@ -81,7 +80,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -90,7 +89,7 @@ pub(crate) async fn post( let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; let session = if let Some(session) = maybe_session { session diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 68cf5c49..467352af 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,9 +33,8 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - BoxClock, BoxRng, Repository, + BoxClock, BoxRepository, BoxRng, Repository, }; -use mas_storage_pg::PgRepository; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, ToFormState, @@ -63,14 +62,14 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut repo).await?; + let maybe_session = session_info.load_session(&mut *repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -80,7 +79,7 @@ pub(crate) async fn get( RegisterContext::default(), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -97,7 +96,7 @@ pub(crate) async fn post( State(mailer): State, State(policy_factory): State>, State(templates): State, - mut repo: PgRepository, + mut repo: BoxRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, @@ -175,7 +174,7 @@ pub(crate) async fn post( RegisterContext::default().with_form_state(state), query, csrf_token, - &mut repo, + &mut *repo, &templates, ) .await?; @@ -234,7 +233,7 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - repo: &mut impl Repository, + repo: &mut (impl Repository + ?Sized), templates: &Templates, ) -> Result { let next = action.load_context(repo).await?; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index db3c3392..b2946084 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -40,9 +40,9 @@ impl OptionalPostAuthAction { self.go_next_or_default(&mas_router::Index) } - pub async fn load_context( - &self, - repo: &mut R, + pub async fn load_context<'a>( + &'a self, + repo: &'a mut (impl Repository + ?Sized), ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml index fad6e30e..3373a21f 100644 --- a/crates/storage-pg/Cargo.toml +++ b/crates/storage-pg/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.91" thiserror = "1.0.38" tracing = "0.1.37" +futures-util = "0.3.25" rand = "0.8.5" rand_chacha = "0.3.1" diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index dd68e4d5..9b340756 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -103,7 +103,7 @@ mod tests { const SECOND_TOKEN: &str = "second_access_token"; let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Create a user let user = repo @@ -139,7 +139,7 @@ mod tests { repo.save().await.unwrap(); { - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Adding the same token a second time should conflict assert!(repo .compat_access_token() @@ -156,7 +156,7 @@ mod tests { } // Grab a new repo - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Looking up via ID works let token_lookup = repo @@ -223,7 +223,7 @@ mod tests { const REFRESH_TOKEN: &str = "refresh_token"; let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); // Create a user let user = repo diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 54002755..6448b61a 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use mas_storage::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -59,21 +60,19 @@ impl PgRepository { let txn = pool.begin().await?; Ok(PgRepository { txn }) } - - pub async fn save(self) -> Result<(), DatabaseError> { - self.txn.commit().await?; - Ok(()) - } - - pub async fn cancel(self) -> Result<(), DatabaseError> { - self.txn.rollback().await?; - Ok(()) - } } impl Repository for PgRepository { type Error = DatabaseError; + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + self.txn.commit().map_err(DatabaseError::from).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + self.txn.rollback().map_err(DatabaseError::from).boxed() + } + fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index b3b88232..7c3eab37 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -29,7 +29,7 @@ use crate::PgRepository; async fn test_user_repo(pool: PgPool) { const USERNAME: &str = "john"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); @@ -77,7 +77,7 @@ async fn test_user_email_repo(pool: PgPool) { const CODE2: &str = "543210"; const EMAIL: &str = "john@example.com"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); @@ -259,7 +259,7 @@ async fn test_user_password_repo(pool: PgPool) { const FIRST_PASSWORD_HASH: &str = "doesntmatter"; const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index 86ca9f07..cea7b03b 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" async-trait = "0.1.60" chrono = "0.4.23" thiserror = "1.0.38" +futures-util = "0.3.25" rand_core = "0.6.4" url = "2.3.1" diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 0cdc4e39..aa1db0af 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -28,21 +28,21 @@ clippy::module_name_repetitions )] +use rand_core::CryptoRngCore; + pub mod clock; +pub mod pagination; +pub(crate) mod repository; pub mod compat; pub mod oauth2; -pub mod pagination; -pub(crate) mod repository; pub mod upstream_oauth2; pub mod user; -use rand_core::CryptoRngCore; - pub use self::{ clock::{Clock, SystemClock}, pagination::{Page, Pagination}, - repository::Repository, + repository::{BoxRepository, Repository, RepositoryError}, }; pub struct MapErr { @@ -86,7 +86,6 @@ macro_rules! repository_impl { where R: $repo_trait, F: FnMut(::Error) -> E + ::std::marker::Send + ::std::marker::Sync, - E: ::std::error::Error + ::std::marker::Send + ::std::marker::Sync, { type Error = E; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 085c06ab..3da64a8c 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; +use thiserror::Error; + use crate::{ compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, @@ -32,6 +35,23 @@ use crate::{ pub trait Repository: Send { type Error: std::error::Error + Send + Sync + 'static; + fn map_err(self, mapper: Mapper) -> MapErr + where + Self: Sized, + { + MapErr::new(self, mapper) + } + + fn boxed(self) -> BoxRepository + where + Self: Sized + Sync + 'static, + { + Box::new(self) + } + + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>>; + fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c>; @@ -91,14 +111,44 @@ pub trait Repository: Send { ) -> Box + 'c>; } +/// An opaque, type-erased error +#[derive(Debug, Error)] +#[error(transparent)] +pub struct RepositoryError { + source: Box, +} + +impl RepositoryError { + pub fn from_error(value: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + source: Box::new(value), + } + } +} + +pub type BoxRepository = + Box + Send + Sync + 'static>; + impl Repository for crate::MapErr where R: Repository, - F: FnMut(R::Error) -> E + Send + Sync, + R::Error: 'static, + F: FnMut(R::Error) -> E + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, { type Error = E; + fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).save().map_err(self.mapper).boxed() + } + + fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::new(self.inner).cancel().map_err(self.mapper).boxed() + } + fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index c5e024af..0057f2d6 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -21,7 +21,7 @@ use crate::{pagination::Page, repository_impl, Clock, Pagination}; #[async_trait] pub trait UpstreamOAuthLinkRepository: Send + Sync { - type Error: std::error::Error + Send + Sync; + type Error; /// Lookup an upstream OAuth link by its ID async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>;