diff --git a/Cargo.lock b/Cargo.lock index f299df15..b780f20f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2673,7 +2673,6 @@ dependencies = [ "serde_json", "serde_urlencoded", "serde_with", - "sqlx", "thiserror", "tokio", "tower", diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 0b6572c7..1cbae228 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -21,7 +21,6 @@ serde = "1.0.152" serde_with = "2.1.0" serde_urlencoded = "0.7.1" serde_json = "1.0.91" -sqlx = "0.6.2" thiserror = "1.0.38" tokio = "1.23.0" tower = { version = "0.4.13", features = ["util"] } diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 6f212369..09090230 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,10 +31,9 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::OAuth2ClientRepository, DatabaseError, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; -use sqlx::PgConnection; use thiserror::Error; use tower::{Service, ServiceExt}; @@ -73,7 +72,10 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch(&self, conn: &mut PgConnection) -> Result, DatabaseError> { + pub async fn fetch<'r, R>(&self, repo: &'r mut R) -> Result, R::Error> + where + R: Repository, + { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } @@ -81,7 +83,7 @@ impl Credentials { | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, }; - conn.oauth2_client().find_by_client_id(client_id).await + repo.oauth2_client().find_by_client_id(client_id).await } #[tracing::instrument(skip_all, err)] diff --git a/crates/axum-utils/src/session.rs b/crates/axum-utils/src/session.rs index 64887895..71961367 100644 --- a/crates/axum-utils/src/session.rs +++ b/crates/axum-utils/src/session.rs @@ -14,9 +14,8 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use mas_data_model::BrowserSession; -use mas_storage::{user::BrowserSessionRepository, DatabaseError, Repository}; +use mas_storage::{user::BrowserSessionRepository, Repository}; use serde::{Deserialize, Serialize}; -use sqlx::PgConnection; use ulid::Ulid; use crate::CookieExt; @@ -44,17 +43,17 @@ impl SessionInfo { } /// Load the [`BrowserSession`] from database - pub async fn load_session( + pub async fn load_session( &self, - conn: &mut PgConnection, - ) -> Result, DatabaseError> { + repo: &mut R, + ) -> Result, R::Error> { let session_id = if let Some(id) = self.current { id } else { return Ok(None); }; - let maybe_session = conn + let maybe_session = repo .browser_session() .lookup(session_id) .await? diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index ec60103d..11d79312 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -30,10 +30,9 @@ use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode use mas_data_model::Session; use mas_storage::{ oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository}, - DatabaseError, Repository, + Repository, }; use serde::{de::DeserializeOwned, Deserialize}; -use sqlx::PgConnection; use thiserror::Error; #[derive(Debug, Deserialize)] @@ -53,22 +52,23 @@ enum AccessToken { } impl AccessToken { - async fn fetch( + async fn fetch( &self, - conn: &mut PgConnection, - ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> { + repo: &mut R, + ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError> + { let token = match self { AccessToken::Form(t) | AccessToken::Header(t) => t, AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let token = conn + let token = repo .oauth2_access_token() .find_by_token(token.as_str()) .await? .ok_or(AuthorizationVerificationError::InvalidToken)?; - let session = conn + let session = repo .oauth2_session() .lookup(token.session_id) .await? @@ -86,17 +86,17 @@ pub struct UserAuthorization { impl UserAuthorization { // TODO: take scopes to validate as parameter - pub async fn protected_form( + pub async fn protected_form( self, - conn: &mut PgConnection, + repo: &mut R, now: DateTime, - ) -> Result<(Session, F), AuthorizationVerificationError> { + ) -> Result<(Session, F), AuthorizationVerificationError> { let form = match self.form { Some(f) => f, None => return Err(AuthorizationVerificationError::MissingForm), }; - let (token, session) = self.access_token.fetch(conn).await?; + let (token, session) = self.access_token.fetch(repo).await?; if !token.is_valid(now) || !session.is_valid() { return Err(AuthorizationVerificationError::InvalidToken); @@ -106,12 +106,12 @@ impl UserAuthorization { } // TODO: take scopes to validate as parameter - pub async fn protected( + pub async fn protected( self, - conn: &mut PgConnection, + repo: &mut R, now: DateTime, - ) -> Result { - let (token, session) = self.access_token.fetch(conn).await?; + ) -> Result> { + let (token, session) = self.access_token.fetch(repo).await?; if !token.is_valid(now) || !session.is_valid() { return Err(AuthorizationVerificationError::InvalidToken); @@ -129,7 +129,7 @@ pub enum UserAuthorizationError { } #[derive(Debug, Error)] -pub enum AuthorizationVerificationError { +pub enum AuthorizationVerificationError { #[error("missing token")] MissingToken, @@ -140,7 +140,7 @@ pub enum AuthorizationVerificationError { MissingForm, #[error(transparent)] - Internal(#[from] DatabaseError), + Internal(#[from] E), } enum BearerError { @@ -248,7 +248,10 @@ impl IntoResponse for UserAuthorizationError { } } -impl IntoResponse for AuthorizationVerificationError { +impl IntoResponse for AuthorizationVerificationError +where + E: ToString, +{ fn into_response(self) -> Response { match self { Self::MissingForm | Self::MissingToken => { diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index d159f3e3..15378940 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -21,7 +21,7 @@ use mas_storage::{ oauth2::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use oauth2_types::scope::Scope; use rand::SeedableRng; @@ -202,8 +202,8 @@ impl Options { let pool = database_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; - let mut txn = pool.begin().await?; - let user = txn + let mut repo = PgRepository::from_pool(&pool).await?; + let user = repo .user() .find_by_username(username) .await? @@ -213,12 +213,12 @@ impl Options { let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - txn.user_password() + repo.user_password() .add(&mut rng, &clock, &user, version, hashed_password, None) .await?; info!(%user.id, %user.username, "Password changed"); - txn.commit().await?; + repo.save().await?; Ok(()) } @@ -233,22 +233,22 @@ impl Options { let config: DatabaseConfig = root.load_config()?; let pool = database_from_config(&config).await?; - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let user = txn + let user = repo .user() .find_by_username(username) .await? .context("User not found")?; - let email = txn + let email = repo .user_email() .find(&user, email) .await? .context("Email not found")?; - let email = txn.user_email().mark_as_verified(&clock, email).await?; + let email = repo.user_email().mark_as_verified(&clock, email).await?; - txn.commit().await?; + repo.save().await?; info!(?email, "Email marked as verified"); Ok(()) @@ -261,12 +261,12 @@ impl Options { let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; for client in config.clients.iter() { let client_id = client.client_id; - let existing = txn.oauth2_client().lookup(client_id).await?.is_some(); + let existing = repo.oauth2_client().lookup(client_id).await?.is_some(); if !update && existing { warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); continue; @@ -288,7 +288,7 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - txn.oauth2_client() + repo.oauth2_client() .add_from_config( &mut rng, &clock, @@ -302,7 +302,7 @@ impl Options { .await?; } - txn.commit().await?; + repo.save().await?; Ok(()) } @@ -326,7 +326,7 @@ impl Options { let encrypter = config.secrets.encrypter(); let pool = database_from_config(&config.database).await?; let url_builder = UrlBuilder::new(config.http.public_base); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); @@ -347,7 +347,7 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - let provider = conn + let provider = repo .upstream_oauth_provider() .add( &mut rng, diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index d01be16c..6e58bec7 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -32,9 +32,9 @@ use async_graphql::{ }; use mas_storage::{ oauth2::OAuth2ClientRepository, - upstream_oauth2::UpstreamOAuthProviderRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, UpstreamOAuthLinkRepository, + PgRepository, Repository, }; use model::CreationEvent; use sqlx::PgPool; @@ -93,10 +93,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::OAuth2Client.extract_ulid(&id)?; - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - let client = conn.oauth2_client().lookup(id).await?; + let client = repo.oauth2_client().lookup(id).await?; Ok(client.map(OAuth2Client)) } @@ -124,13 +123,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::BrowserSession.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let browser_session = conn.browser_session().lookup(id).await?; + let browser_session = repo.browser_session().lookup(id).await?; let ret = browser_session.and_then(|browser_session| { if browser_session.user.id == current_user.id { @@ -151,13 +149,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UserEmail.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let user_email = conn + let user_email = repo .user_email() .lookup(id) .await? @@ -174,13 +171,12 @@ impl RootQuery { ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let session = ctx.data_opt::().cloned(); - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let link = conn.upstream_oauth_link().lookup(id).await?; + let link = repo.upstream_oauth_link().lookup(id).await?; // Ensure that the link belongs to the current user let link = link.filter(|link| link.user_id == Some(current_user.id)); @@ -195,10 +191,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - let provider = conn.upstream_oauth_provider().lookup(id).await?; + let provider = repo.upstream_oauth_provider().lookup(id).await?; Ok(provider.map(UpstreamOAuth2Provider::new)) } @@ -215,7 +210,7 @@ impl RootQuery { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -223,7 +218,6 @@ impl RootQuery { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| { x.extract_for_type(NodeType::UpstreamOAuth2Provider) @@ -235,7 +229,7 @@ impl RootQuery { }) .transpose()?; - let page = conn + let page = repo .upstream_oauth_provider() .list_paginated(before_id, after_id, first, last) .await?; diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 3c94c672..a2196e36 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; +use mas_storage::{ + compat::CompatSessionRepository, user::UserRepository, PgRepository, Repository, +}; use sqlx::PgPool; use url::Url; @@ -35,8 +37,8 @@ impl CompatSession { /// The user authorized for this session. async fn user(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let user = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let user = repo .user() .lookup(self.0.user_id) .await? @@ -100,8 +102,8 @@ impl CompatSsoLogin { ) -> Result, async_graphql::Error> { let Some(session_id) = self.0.session_id() else { return Ok(None) }; - let mut conn = ctx.data::()?.acquire().await?; - let session = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let session = repo .compat_session() .lookup(session_id) .await? diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 0ab2bc68..171c800f 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,7 +14,9 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, Repository}; +use mas_storage::{ + oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, PgRepository, Repository, +}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; @@ -36,8 +38,8 @@ impl OAuth2Session { /// OAuth 2.0 client used by this session. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let client = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let client = repo .oauth2_client() .lookup(self.0.client_id) .await? @@ -56,8 +58,8 @@ impl OAuth2Session { &self, ctx: &Context<'_>, ) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let browser_session = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let browser_session = repo .browser_session() .lookup(self.0.user_session_id) .await? @@ -68,8 +70,8 @@ impl OAuth2Session { /// User authorized for this session. pub async fn user(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let browser_session = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let browser_session = repo .browser_session() .lookup(self.0.user_session_id) .await? @@ -138,8 +140,8 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let client = conn + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let client = repo .oauth2_client() .lookup(self.client_id) .await? diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 249a0928..4a4c223b 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -16,7 +16,8 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, Repository, + upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, PgRepository, + Repository, }; use sqlx::PgPool; @@ -102,9 +103,8 @@ impl UpstreamOAuth2Link { provider.clone() } else { // Fetch on-the-fly - let database = ctx.data::()?; - let mut conn = database.acquire().await?; - conn.upstream_oauth_provider() + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + repo.upstream_oauth_provider() .lookup(self.link.provider_id) .await? .context("Upstream OAuth 2.0 provider not found")? @@ -120,9 +120,8 @@ impl UpstreamOAuth2Link { user.clone() } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly - let database = ctx.data::()?; - let mut conn = database.acquire().await?; - conn.user() + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + repo.user() .lookup(*user_id) .await? .context("User not found")? diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index b19a1ae1..9cd8d53b 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -20,8 +20,9 @@ use chrono::{DateTime, Utc}; use mas_storage::{ compat::CompatSsoLoginRepository, oauth2::OAuth2SessionRepository, + upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, UpstreamOAuthLinkRepository, + PgRepository, Repository, }; use sqlx::PgPool; @@ -63,10 +64,9 @@ impl User { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; - let mut conn = database.acquire().await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; - Ok(conn.user_email().get_primary(&self.0).await?.map(UserEmail)) + Ok(repo.user_email().get_primary(&self.0).await?.map(UserEmail)) } /// Get the list of compatibility SSO logins, chronologically sorted @@ -81,7 +81,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -89,7 +89,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; @@ -97,7 +96,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; - let page = conn + let page = repo .compat_sso_login() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -128,7 +127,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -136,7 +135,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; @@ -144,7 +142,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::BrowserSession)) .transpose()?; - let page = conn + let page = repo .browser_session() .list_active_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -175,7 +173,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -183,7 +181,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; @@ -191,7 +188,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) .transpose()?; - let page = conn + let page = repo .user_email() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -226,7 +223,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -234,7 +231,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; @@ -242,7 +238,7 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::OAuth2Session)) .transpose()?; - let page = conn + let page = repo .oauth2_session() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -273,7 +269,7 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; query( after, @@ -281,7 +277,6 @@ impl User { first, last, |after, before, first, last| async move { - let mut conn = database.acquire().await?; let after_id = after .map(|x: OpaqueCursor| { x.extract_for_type(NodeType::UpstreamOAuth2Link) @@ -293,7 +288,7 @@ impl User { }) .transpose()?; - let page = conn + let page = repo .upstream_oauth_link() .list_paginated(&self.0, before_id, after_id, first, last) .await?; @@ -347,8 +342,8 @@ pub struct UserEmailsPagination(mas_data_model::User); impl UserEmailsPagination { /// Identifies the total count of items in the connection. async fn total_count(&self, ctx: &Context<'_>) -> Result { - let mut conn = ctx.data::()?.acquire().await?; - let count = conn.user_email().count(&self.0).await?; + let mut repo = PgRepository::from_pool(ctx.data::()?).await?; + let count = repo.user_email().count(&self.0).await?; Ok(count) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index e7376f72..f344f7e0 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -22,11 +22,11 @@ use mas_storage::{ CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::PgPool; use thiserror::Error; use zeroize::Zeroizing; @@ -199,14 +199,14 @@ pub(crate) async fn post( Json(input): Json, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, password, - } => user_password_login(&password_manager, &mut txn, user, password).await?, + } => user_password_login(&password_manager, &mut repo, user, password).await?, - Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?, + Credentials::Token { token } => token_login(&mut repo, &clock, &token).await?, _ => { return Err(RouteError::Unsupported); @@ -224,14 +224,14 @@ pub(crate) async fn post( }; let access_token = TokenType::CompatAccessToken.generate(&mut rng); - let access_token = txn + let access_token = repo .compat_access_token() .add(&mut rng, &clock, &session, access_token, expires_in) .await?; let refresh_token = if input.refresh_token { let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); - let refresh_token = txn + let refresh_token = repo .compat_refresh_token() .add(&mut rng, &clock, &session, &access_token, refresh_token) .await?; @@ -240,7 +240,7 @@ pub(crate) async fn post( None }; - txn.commit().await?; + repo.save().await?; Ok(Json(ResponseBody { access_token: access_token.token, @@ -252,11 +252,11 @@ pub(crate) async fn post( } async fn token_login( - txn: &mut Transaction<'_, Postgres>, + repo: &mut PgRepository, clock: &Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { - let login = txn + let login = repo .compat_sso_login() .find_by_token(token) .await? @@ -300,40 +300,40 @@ async fn token_login( } }; - let session = txn + let session = repo .compat_session() .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; - let user = txn + let user = repo .user() .lookup(session.user_id) .await? .ok_or(RouteError::UserNotFound)?; - txn.compat_sso_login().exchange(clock, login).await?; + repo.compat_sso_login().exchange(clock, login).await?; Ok((session, user)) } async fn user_password_login( password_manager: &PasswordManager, - txn: &mut Transaction<'_, Postgres>, + repo: &mut PgRepository, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { let (clock, mut rng) = crate::clock_and_rng(); // Find the user - let user = txn + let user = repo .user() .find_by_username(&username) .await? .ok_or(RouteError::UserNotFound)?; // Lookup its password - let user_password = txn + let user_password = repo .user_password() .active(&user) .await? @@ -354,7 +354,7 @@ async fn user_password_login( if let Some((version, hashed_password)) = new_password_hash { // Save the upgraded password if needed - txn.user_password() + repo.user_password() .add( &mut rng, &clock, @@ -368,7 +368,7 @@ async fn user_password_login( // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - let session = txn + let session = repo .compat_session() .add(&mut rng, &clock, &user, device) .await?; diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 33352424..7ca61ab2 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -31,7 +31,7 @@ use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; @@ -60,12 +60,12 @@ pub async fn get( Query(params): Query, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -90,7 +90,7 @@ pub async fn get( return Ok((cookie_jar, destination.go()).into_response()); } - let login = conn + let login = repo .compat_sso_login() .lookup(id) .await? @@ -124,12 +124,12 @@ pub async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(clock.now(), form)?; - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -154,7 +154,7 @@ pub async fn post( return Ok((cookie_jar, destination.go()).into_response()); } - let login = txn + let login = repo .compat_sso_login() .lookup(id) .await? @@ -188,16 +188,16 @@ pub async fn post( }; let device = Device::generate(&mut rng); - let compat_session = txn + let compat_session = repo .compat_session() .add(&mut rng, &clock, &session.user, device) .await?; - txn.compat_sso_login() + repo.compat_sso_login() .fulfill(&clock, login, &compat_session) .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, Redirect::to(redirect_uri.as_str())).into_response()) } diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 9c23b733..befd3e32 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::{compat::CompatSsoLoginRepository, Repository}; +use mas_storage::{compat::CompatSsoLoginRepository, PgRepository, Repository}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -80,8 +80,8 @@ pub async fn get( } let token = Alphanumeric.sample_string(&mut rng, 32); - let mut conn = pool.acquire().await?; - let login = conn + let mut repo = PgRepository::from_pool(&pool).await?; + let login = repo .compat_sso_login() .add(&mut rng, &clock, token, redirect_url) .await?; diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 25125c72..762f77b2 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use sqlx::PgPool; use thiserror::Error; @@ -72,7 +72,7 @@ pub(crate) async fn post( maybe_authorization: Option>>, ) -> Result { let clock = Clock::default(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; @@ -83,23 +83,23 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - let token = txn + let token = repo .compat_access_token() .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::InvalidAuthorization)?; - let session = txn + let session = repo .compat_session() .lookup(token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::InvalidAuthorization)?; - txn.compat_session().finish(&clock, session).await?; + repo.compat_session().finish(&clock, session).await?; - txn.commit().await?; + repo.save().await?; Ok(Json(serde_json::json!({}))) } diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 7bfc940a..ea6d5d23 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -18,7 +18,7 @@ use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, - Repository, + PgRepository, Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; @@ -92,7 +92,7 @@ pub(crate) async fn post( Json(input): Json, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let token_type = TokenType::check(&input.refresh_token)?; @@ -100,7 +100,7 @@ pub(crate) async fn post( return Err(RouteError::InvalidToken); } - let refresh_token = txn + let refresh_token = repo .compat_refresh_token() .find_by_token(&input.refresh_token) .await? @@ -110,7 +110,7 @@ pub(crate) async fn post( return Err(RouteError::RefreshTokenConsumed); } - let session = txn + let session = repo .compat_session() .lookup(refresh_token.session_id) .await? @@ -120,7 +120,7 @@ pub(crate) async fn post( return Err(RouteError::InvalidSession); } - let access_token = txn + let access_token = repo .compat_access_token() .lookup(refresh_token.access_token_id) .await? @@ -130,7 +130,7 @@ pub(crate) async fn post( let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let expires_in = Duration::minutes(5); - let new_access_token = txn + let new_access_token = repo .compat_access_token() .add( &mut rng, @@ -140,7 +140,7 @@ pub(crate) async fn post( Some(expires_in), ) .await?; - let new_refresh_token = txn + let new_refresh_token = repo .compat_refresh_token() .add( &mut rng, @@ -151,17 +151,17 @@ pub(crate) async fn post( ) .await?; - txn.compat_refresh_token() + repo.compat_refresh_token() .consume(&clock, refresh_token) .await?; if let Some(access_token) = access_token { - txn.compat_access_token() + repo.compat_access_token() .expire(&clock, access_token) .await?; } - txn.commit().await?; + repo.save().await?; Ok(Json(ResponseBody { access_token: new_access_token.token, diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index 2177388b..d3a610b6 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -28,6 +28,7 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; +use mas_storage::PgRepository; use sqlx::PgPool; use tracing::{info_span, Instrument}; @@ -67,8 +68,9 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut conn = pool.acquire().await?; - let maybe_session = session_info.load_session(&mut conn).await?; + let mut repo = PgRepository::from_pool(&pool).await?; + let maybe_session = session_info.load_session(&mut repo).await?; + repo.cancel().await?; let mut request = async_graphql::http::receive_batch_body( content_type, @@ -117,8 +119,9 @@ pub async fn get( RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let mut conn = pool.acquire().await?; - let maybe_session = session_info.load_session(&mut conn).await?; + let mut repo = PgRepository::from_pool(&pool).await?; + let maybe_session = session_info.load_session(&mut repo).await?; + repo.cancel().await?; let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 9f462c50..c983e79c 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -27,11 +27,11 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -84,13 +84,13 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = txn + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) .await? @@ -107,7 +107,7 @@ pub(crate) async fn get( return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response()); }; - match complete(grant, session, &policy_factory, txn).await { + match complete(grant, session, &policy_factory, repo).await { Ok(params) => { let res = callback_destination.go(&templates, params).await?; Ok((cookie_jar, res).into_response()) @@ -159,7 +159,7 @@ pub(crate) async fn complete( grant: AuthorizationGrant, browser_session: BrowserSession, policy_factory: &PolicyFactory, - mut txn: Transaction<'_, Postgres>, + mut repo: PgRepository, ) -> Result>, GrantCompletionError> { let (clock, mut rng) = crate::clock_and_rng(); @@ -170,7 +170,7 @@ pub(crate) async fn complete( // Check if the authentication is fresh enough if !browser_session.was_authenticated_after(grant.max_auth_time()) { - txn.commit().await?; + repo.save().await?; return Err(GrantCompletionError::RequiresReauth); } @@ -184,13 +184,13 @@ pub(crate) async fn complete( return Err(GrantCompletionError::PolicyViolation); } - let client = txn + let client = repo .oauth2_client() .lookup(grant.client_id) .await? .ok_or(GrantCompletionError::NoSuchClient)?; - let current_consent = txn + let current_consent = repo .oauth2_client() .get_consent_for_user(&client, &browser_session.user) .await?; @@ -202,17 +202,17 @@ pub(crate) async fn complete( // Check if the client lacks consent *or* if consent was explicitely asked if lacks_consent || grant.requires_consent { - txn.commit().await?; + repo.save().await?; return Err(GrantCompletionError::RequiresConsent); } // All good, let's start the session - let session = txn + let session = repo .oauth2_session() .create_from_grant(&mut rng, &clock, &grant, &browser_session) .await?; - let grant = txn + let grant = repo .oauth2_authorization_grant() .fulfill(&clock, &session, grant) .await?; @@ -233,6 +233,6 @@ pub(crate) async fn complete( )); } - txn.commit().await?; + repo.save().await?; Ok(params) } diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index b33b6912..155f72f7 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -27,7 +27,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::Templates; use oauth2_types::{ @@ -139,10 +139,10 @@ pub(crate) async fn get( Form(params): Form, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; // First, figure out what client it is - let client = txn + let client = repo .oauth2_client() .find_by_client_id(¶ms.auth.client_id) .await? @@ -170,7 +170,7 @@ pub(crate) async fn get( let templates = templates.clone(); let callback_destination = callback_destination.clone(); async move { - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); // Check if the request/request_uri/registration params are used. If so, reply @@ -275,7 +275,7 @@ pub(crate) async fn get( let requires_consent = prompt.contains(&Prompt::Consent); - let grant = txn + let grant = repo .oauth2_authorization_grant() .add( &mut rng, @@ -302,7 +302,7 @@ pub(crate) async fn get( } None if prompt.contains(&Prompt::Create) => { // Client asked for a registration, show the registration prompt - txn.commit().await?; + repo.save().await?; mas_router::Register::and_then(continue_grant) .go() @@ -310,7 +310,7 @@ pub(crate) async fn get( } None => { // Other cases where we don't have a session, ask for a login - txn.commit().await?; + repo.save().await?; mas_router::Login::and_then(continue_grant) .go() @@ -323,7 +323,7 @@ pub(crate) async fn get( || prompt.contains(&Prompt::SelectAccount) => { // TODO: better pages here - txn.commit().await?; + repo.save().await?; mas_router::Reauth::and_then(continue_grant) .go() @@ -333,7 +333,7 @@ pub(crate) async fn get( // Else, we immediately try to complete the authorization grant Some(user_session) if prompt.contains(&Prompt::None) => { // With prompt=none, we should get back to the client immediately - match self::complete::complete(grant, user_session, &policy_factory, txn).await + match self::complete::complete(grant, user_session, &policy_factory, repo).await { Ok(params) => callback_destination.go(&templates, params).await?, Err(GrantCompletionError::RequiresConsent) => { @@ -372,7 +372,7 @@ pub(crate) async fn get( Some(user_session) => { let grant_id = grant.id; // Else, we show the relevant reauth/consent page if necessary - match self::complete::complete(grant, user_session, &policy_factory, txn).await + match self::complete::complete(grant, user_session, &policy_factory, repo).await { Ok(params) => callback_destination.go(&templates, params).await?, Err( diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 94bf1346..f3d4fd46 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -30,7 +30,7 @@ use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -81,13 +81,13 @@ pub(crate) async fn get( Path(grant_id): Path, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = conn + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) .await? @@ -136,15 +136,15 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; - let grant = txn + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) .await? @@ -167,7 +167,7 @@ pub(crate) async fn post( return Err(RouteError::PolicyViolation); } - let client = txn + let client = repo .oauth2_client() .lookup(grant.client_id) .await? @@ -180,7 +180,7 @@ pub(crate) async fn post( .filter(|s| !s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")) .cloned() .collect(); - txn.oauth2_client() + repo.oauth2_client() .give_consent_for_user( &mut rng, &clock, @@ -190,9 +190,11 @@ pub(crate) async fn post( ) .await?; - txn.oauth2_authorization_grant().give_consent(grant).await?; + repo.oauth2_authorization_grant() + .give_consent(grant) + .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, next.go_next()).into_response()) } diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index d032695a..2837928f 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -25,7 +25,7 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{OAuth2AccessTokenRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -130,12 +130,13 @@ pub(crate) async fn post( client_authorization: ClientAuthorization, ) -> Result { let clock = Clock::default(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let client = client_authorization .credentials - .fetch(&mut conn) - .await? + .fetch(&mut repo) + .await + .unwrap() .ok_or(RouteError::ClientNotFound)?; let method = match &client.token_endpoint_auth_method { @@ -166,14 +167,14 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let token = conn + let token = repo .oauth2_access_token() .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .oauth2_session() .lookup(token.session_id) .await? @@ -181,7 +182,7 @@ pub(crate) async fn post( // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; - let browser_session = conn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -205,14 +206,14 @@ pub(crate) async fn post( } TokenType::RefreshToken => { - let token = conn + let token = repo .oauth2_refresh_token() .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .oauth2_session() .lookup(token.session_id) .await? @@ -220,7 +221,7 @@ pub(crate) async fn post( // XXX: is that the right error to bubble up? .ok_or(RouteError::UnknownToken)?; - let browser_session = conn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -244,21 +245,21 @@ pub(crate) async fn post( } TokenType::CompatAccessToken => { - let access_token = conn + let access_token = repo .compat_access_token() .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .compat_session() .lookup(access_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; - let user = conn + let user = repo .user() .lookup(session.user_id) .await? @@ -285,21 +286,21 @@ pub(crate) async fn post( } TokenType::CompatRefreshToken => { - let refresh_token = conn + let refresh_token = repo .compat_refresh_token() .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; - let session = conn + let session = repo .compat_session() .lookup(refresh_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; - let user = conn + let user = repo .user() .lookup(session.user_id) .await? diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index a6ff6158..d6180f9a 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::{oauth2::OAuth2ClientRepository, Repository}; +use mas_storage::{oauth2::OAuth2ClientRepository, PgRepository, Repository}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -124,8 +124,7 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - // Grab a txn - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( @@ -141,7 +140,7 @@ pub(crate) async fn post( _ => (None, None), }; - let client = txn + let client = repo .oauth2_client() .add( &mut rng, @@ -170,7 +169,7 @@ pub(crate) async fn post( ) .await?; - txn.commit().await?; + repo.save().await?; let response = ClientRegistrationResponse { client_id: client.client_id, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 97f249c2..6365a0ad 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -37,7 +37,7 @@ use mas_storage::{ OAuth2RefreshTokenRepository, OAuth2SessionRepository, }, user::BrowserSessionRepository, - Repository, + PgRepository, Repository, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -49,7 +49,7 @@ use oauth2_types::{ }; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::PgPool; use thiserror::Error; use tracing::debug; use url::Url; @@ -166,11 +166,11 @@ pub(crate) async fn post( State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let client = client_authorization .credentials - .fetch(&mut txn) + .fetch(&mut repo) .await? .ok_or(RouteError::ClientNotFound)?; @@ -188,10 +188,10 @@ pub(crate) async fn post( let reply = match form { AccessTokenRequest::AuthorizationCode(grant) => { - authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await? + authorization_code_grant(&grant, &client, &key_store, &url_builder, repo).await? } AccessTokenRequest::RefreshToken(grant) => { - refresh_token_grant(&grant, &client, txn).await? + refresh_token_grant(&grant, &client, repo).await? } _ => { return Err(RouteError::InvalidGrant); @@ -211,11 +211,11 @@ async fn authorization_code_grant( client: &Client, key_store: &Keystore, url_builder: &UrlBuilder, - mut txn: Transaction<'_, Postgres>, + mut repo: PgRepository, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let authz_grant = txn + let authz_grant = repo .oauth2_authorization_grant() .find_by_code(&grant.code) .await? @@ -238,13 +238,13 @@ async fn authorization_code_grant( // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - let session = txn + let session = repo .oauth2_session() .lookup(session_id) .await? .ok_or(RouteError::NoSuchOAuthSession)?; - txn.oauth2_session().finish(&clock, session).await?; - txn.commit().await?; + repo.oauth2_session().finish(&clock, session).await?; + repo.save().await?; } return Err(RouteError::InvalidGrant); @@ -266,7 +266,7 @@ async fn authorization_code_grant( } }; - let session = txn + let session = repo .oauth2_session() .lookup(session_id) .await? @@ -289,7 +289,7 @@ async fn authorization_code_grant( } }; - let browser_session = txn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -299,12 +299,12 @@ async fn authorization_code_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = txn + let access_token = repo .oauth2_access_token() .add(&mut rng, &clock, &session, access_token_str, ttl) .await?; - let refresh_token = txn + let refresh_token = repo .oauth2_refresh_token() .add(&mut rng, &clock, &session, &access_token, refresh_token_str) .await?; @@ -355,11 +355,11 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } - txn.oauth2_authorization_grant() + repo.oauth2_authorization_grant() .exchange(&clock, authz_grant) .await?; - txn.commit().await?; + repo.save().await?; Ok(params) } @@ -367,17 +367,17 @@ async fn authorization_code_grant( async fn refresh_token_grant( grant: &RefreshTokenGrant, client: &Client, - mut txn: Transaction<'_, Postgres>, + mut repo: PgRepository, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let refresh_token = txn + let refresh_token = repo .oauth2_refresh_token() .find_by_token(&grant.refresh_token) .await? .ok_or(RouteError::InvalidGrant)?; - let session = txn + let session = repo .oauth2_session() .lookup(refresh_token.session_id) .await? @@ -396,12 +396,12 @@ async fn refresh_token_grant( let access_token_str = TokenType::AccessToken.generate(&mut rng); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let new_access_token = txn + let new_access_token = repo .oauth2_access_token() .add(&mut rng, &clock, &session, access_token_str.clone(), ttl) .await?; - let new_refresh_token = txn + let new_refresh_token = repo .oauth2_refresh_token() .add( &mut rng, @@ -412,14 +412,14 @@ async fn refresh_token_grant( ) .await?; - let refresh_token = txn + let refresh_token = repo .oauth2_refresh_token() .consume(&clock, refresh_token) .await?; if let Some(access_token_id) = refresh_token.access_token_id { - if let Some(access_token) = txn.oauth2_access_token().lookup(access_token_id).await? { - txn.oauth2_access_token() + if let Some(access_token) = repo.oauth2_access_token().lookup(access_token_id).await? { + repo.oauth2_access_token() .revoke(&clock, access_token) .await?; } @@ -430,7 +430,7 @@ async fn refresh_token_grant( .with_refresh_token(new_refresh_token.refresh_token) .with_scope(session.scope); - txn.commit().await?; + repo.save().await?; Ok(params) } diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 49b6c5f1..a125c5dd 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -31,7 +31,7 @@ use mas_router::UrlBuilder; use mas_storage::{ oauth2::OAuth2ClientRepository, user::{BrowserSessionRepository, UserEmailRepository}, - Repository, + DatabaseError, PgRepository, Repository, }; use oauth2_types::scope; use serde::Serialize; @@ -64,7 +64,7 @@ pub enum RouteError { Internal(Box), #[error("failed to authenticate")] - AuthorizationVerificationError(#[from] AuthorizationVerificationError), + AuthorizationVerificationError(#[from] AuthorizationVerificationError), #[error("no suitable key found for signing")] InvalidSigningKey, @@ -102,11 +102,11 @@ pub async fn get( user_authorization: UserAuthorization, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let session = user_authorization.protected(&mut conn, clock.now()).await?; + let session = user_authorization.protected(&mut repo, clock.now()).await?; - let browser_session = conn + let browser_session = repo .browser_session() .lookup(session.user_session_id) .await? @@ -115,7 +115,7 @@ pub async fn get( let user = browser_session.user; let user_email = if session.scope.contains(&scope::EMAIL) { - conn.user_email().get_primary(&user).await? + repo.user_email().get_primary(&user).await? } else { None }; @@ -127,7 +127,7 @@ pub async fn get( email: user_email.map(|u| u.email), }; - let client = conn + let client = repo .oauth2_client() .lookup(session.client_id) .await? diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 178eba1a..bdd19b7b 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -24,7 +24,7 @@ use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - Repository, + PgRepository, Repository, }; use sqlx::PgPool; use thiserror::Error; @@ -67,9 +67,9 @@ pub(crate) async fn get( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let provider = txn + let provider = repo .upstream_oauth_provider() .lookup(provider_id) .await? @@ -100,7 +100,7 @@ pub(crate) async fn get( &mut rng, )?; - let session = txn + let session = repo .upstream_oauth_session() .add( &mut rng, @@ -116,7 +116,7 @@ pub(crate) async fn get( .add(session.id, provider.id, data.state, query.post_auth_action) .save(cookie_jar, clock.now()); - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, Redirect::temporary(url.as_str()))) } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 8cb9a605..521efd7b 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -26,8 +26,11 @@ use mas_oidc_client::requests::{ }; use mas_router::{Route, UrlBuilder}; use mas_storage::{ - upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository}, - Repository, UpstreamOAuthLinkRepository, + upstream_oauth2::{ + UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, + UpstreamOAuthSessionRepository, + }, + PgRepository, Repository, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; @@ -129,9 +132,9 @@ pub(crate) async fn get( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; - let provider = txn + let provider = repo .upstream_oauth_provider() .lookup(provider_id) .await? @@ -142,7 +145,7 @@ pub(crate) async fn get( .find_session(provider_id, ¶ms.state) .map_err(|_| RouteError::MissingCookie)?; - let session = txn + let session = repo .upstream_oauth_session() .lookup(session_id) .await? @@ -244,7 +247,7 @@ pub(crate) async fn get( let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; // Look for an existing link - let maybe_link = txn + let maybe_link = repo .upstream_oauth_link() .find_by_subject(&provider, &subject) .await?; @@ -252,12 +255,12 @@ pub(crate) async fn get( let link = if let Some(link) = maybe_link { link } else { - txn.upstream_oauth_link() + repo.upstream_oauth_link() .add(&mut rng, &clock, &provider, subject) .await? }; - let session = txn + let session = repo .upstream_oauth_session() .complete_with_link(&clock, session, &link, response.id_token) .await?; @@ -266,7 +269,7 @@ pub(crate) async fn get( .add_link_to_session(session.id, link.id)? .save(cookie_jar, clock.now()); - txn.commit().await?; + repo.save().await?; Ok(( cookie_jar, diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 10e1f80e..18849be8 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,9 +25,9 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::UpstreamOAuthSessionRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserRepository}, - Repository, UpstreamOAuthLinkRepository, + PgRepository, Repository, }; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, @@ -99,7 +99,7 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, Path(link_id): Path, ) -> Result { - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (clock, mut rng) = crate::clock_and_rng(); let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); @@ -107,13 +107,13 @@ pub(crate) async fn get( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let link = txn + let link = repo .upstream_oauth_link() .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; - let upstream_session = txn + let upstream_session = repo .upstream_oauth_session() .lookup(session_id) .await? @@ -131,24 +131,24 @@ pub(crate) async fn get( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); - let maybe_user_session = user_session_info.load_session(&mut txn).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let render = match (maybe_user_session, link.user_id) { (Some(session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. - txn.upstream_oauth_session() + repo.upstream_oauth_session() .consume(&clock, upstream_session) .await?; - let session = txn + let session = repo .browser_session() .authenticate_with_upstream(&mut rng, &clock, session, &link) .await?; cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; let ctx = EmptyContext .with_session(session) @@ -163,7 +163,7 @@ pub(crate) async fn get( // Session already linked, but link doesn't match the currently // logged user. Suggest logging out of the current user // and logging in with the new one - let user = txn + let user = repo .user() .lookup(user_id) .await? @@ -187,7 +187,7 @@ pub(crate) async fn get( (None, Some(user_id)) => { // Session linked, but user not logged in: do the login - let user = txn + let user = repo .user() .lookup(user_id) .await? @@ -216,8 +216,8 @@ pub(crate) async fn post( Path(link_id): Path, Form(form): Form>, ) -> Result { - let mut txn = pool.begin().await?; let (clock, mut rng) = crate::clock_and_rng(); + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); @@ -229,13 +229,13 @@ pub(crate) async fn post( post_auth_action: post_auth_action.cloned(), }; - let link = txn + let link = repo .upstream_oauth_link() .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; - let upstream_session = txn + let upstream_session = repo .upstream_oauth_session() .lookup(session_id) .await? @@ -252,11 +252,11 @@ pub(crate) async fn post( } let (user_session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_user_session = user_session_info.load_session(&mut txn).await?; + let maybe_user_session = user_session_info.load_session(&mut repo).await?; let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { - txn.upstream_oauth_link() + repo.upstream_oauth_link() .associate_to_user(&link, &session.user) .await?; @@ -264,32 +264,32 @@ pub(crate) async fn post( } (None, Some(user_id), FormData::Login) => { - let user = txn + let user = repo .user() .lookup(user_id) .await? .ok_or(RouteError::UserNotFound)?; - txn.browser_session().add(&mut rng, &clock, &user).await? + repo.browser_session().add(&mut rng, &clock, &user).await? } (None, None, FormData::Register { username }) => { - let user = txn.user().add(&mut rng, &clock, username).await?; - txn.upstream_oauth_link() + let user = repo.user().add(&mut rng, &clock, username).await?; + repo.upstream_oauth_link() .associate_to_user(&link, &user) .await?; - txn.browser_session().add(&mut rng, &clock, &user).await? + repo.browser_session().add(&mut rng, &clock, &user).await? } _ => return Err(RouteError::InvalidFormAction), }; - txn.upstream_oauth_session() + repo.upstream_oauth_session() .consume(&clock, upstream_session) .await?; - let session = txn + let session = repo .browser_session() .authenticate_with_upstream(&mut rng, &clock, session, &link) .await?; @@ -299,7 +299,7 @@ pub(crate) async fn post( .save(cookie_jar, clock.now()); let cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, post_auth_action.go_next())) } diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index c7cd2767..e0cc063d 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Repository}; +use mas_storage::{user::UserEmailRepository, PgRepository, Repository}; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; @@ -43,12 +43,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -74,12 +74,12 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -88,7 +88,7 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = txn + let user_email = repo .user_email() .add(&mut rng, &clock, &session.user, form.email) .await?; @@ -101,7 +101,7 @@ pub(crate) async fn post( }; start_email_verification( &mailer, - &mut txn, + &mut repo, &mut rng, &clock, &session.user, @@ -109,7 +109,7 @@ pub(crate) async fn post( ) .await?; - txn.commit().await?; + repo.save().await?; Ok((cookie_jar, next.go()).into_response()) } diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index e6e1e341..3fda398a 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -28,11 +28,11 @@ use mas_data_model::{BrowserSession, User, UserEmail}; use mas_email::Mailer; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; -use sqlx::{PgConnection, PgPool}; +use sqlx::PgPool; use tracing::info; pub mod add; @@ -54,14 +54,14 @@ pub(crate) async fn get( ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - render(&mut rng, &clock, templates, session, cookie_jar, &mut conn).await + render(&mut rng, &clock, templates, session, cookie_jar, &mut repo).await } else { let login = mas_router::Login::default(); Ok((cookie_jar, login.go()).into_response()) @@ -74,11 +74,11 @@ async fn render( templates: Templates, session: BrowserSession, cookie_jar: PrivateCookieJar, - conn: &mut PgConnection, + repo: &mut impl Repository, ) -> Result { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), rng); - let emails = conn.user_email().all(&session.user).await?; + let emails = repo.user_email().all(&session.user).await?; let ctx = AccountEmailsContext::new(emails) .with_session(session) @@ -91,7 +91,7 @@ async fn render( async fn start_email_verification( mailer: &Mailer, - conn: &mut PgConnection, + repo: &mut impl Repository, mut rng: impl Rng + Send, clock: &Clock, user: &User, @@ -103,7 +103,7 @@ async fn start_email_verification( let address: Address = user_email.email.parse()?; - let verification = conn + let verification = repo .user_email() .add_verification_code(&mut rng, clock, &user_email, Duration::hours(8), code) .await?; @@ -130,11 +130,11 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let mut session = if let Some(session) = maybe_session { session @@ -147,21 +147,21 @@ pub(crate) async fn post( match form { ManagementForm::Add { email } => { - let email = txn + let email = repo .user_email() .add(&mut rng, &clock, &session.user, email) .await?; let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) .await?; - txn.commit().await?; + repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::ResendConfirmation { id } => { let id = id.parse()?; - let email = txn + let email = repo .user_email() .lookup(id) .await? @@ -172,15 +172,15 @@ pub(crate) async fn post( } let next = mas_router::AccountVerifyEmail::new(email.id); - start_email_verification(&mailer, &mut txn, &mut rng, &clock, &session.user, email) + start_email_verification(&mailer, &mut repo, &mut rng, &clock, &session.user, email) .await?; - txn.commit().await?; + repo.save().await?; return Ok((cookie_jar, next.go()).into_response()); } ManagementForm::Remove { id } => { let id = id.parse()?; - let email = txn + let email = repo .user_email() .lookup(id) .await? @@ -190,11 +190,11 @@ pub(crate) async fn post( return Err(anyhow!("Email not found").into()); } - txn.user_email().remove(email).await?; + repo.user_email().remove(email).await?; } ManagementForm::SetPrimary { id } => { let id = id.parse()?; - let email = txn + let email = repo .user_email() .lookup(id) .await? @@ -204,7 +204,7 @@ pub(crate) async fn post( return Err(anyhow!("Email not found").into()); } - txn.user_email().set_as_primary(&email).await?; + repo.user_email().set_as_primary(&email).await?; session.user.primary_user_email_id = Some(email.id); } }; @@ -215,11 +215,11 @@ pub(crate) async fn post( templates.clone(), session, cookie_jar, - &mut txn, + &mut repo, ) .await?; - txn.commit().await?; + repo.save().await?; Ok(reply) } diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 1192743e..085b9a33 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::{user::UserEmailRepository, Clock, Repository}; +use mas_storage::{user::UserEmailRepository, Clock, PgRepository, Repository}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; use sqlx::PgPool; @@ -45,12 +45,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -59,7 +59,7 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = conn + let user_email = repo .user_email() .lookup(id) .await? @@ -89,12 +89,12 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let clock = Clock::default(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -103,33 +103,33 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = txn + let user_email = repo .user_email() .lookup(id) .await? .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; - let verification = txn + let verification = repo .user_email() .find_verification_code(&clock, &user_email, &form.code) .await? .context("Invalid code")?; // TODO: display nice errors if the code was already consumed or expired - txn.user_email() + repo.user_email() .consume_verification_code(&clock, verification) .await?; if session.user.primary_user_email_id.is_none() { - txn.user_email().set_as_primary(&user_email).await?; + repo.user_email().set_as_primary(&user_email).await?; } - txn.user_email() + repo.user_email() .mark_as_verified(&clock, user_email) .await?; - txn.commit().await?; + repo.save().await?; let destination = query.go_next_or_default(&mas_router::AccountEmails); Ok((cookie_jar, destination).into_response()) diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 0188aef2..5017db00 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -25,7 +25,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -36,12 +36,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -50,9 +50,9 @@ pub(crate) async fn get( return Ok((cookie_jar, login.go()).into_response()); }; - let active_sessions = conn.browser_session().count_active(&session.user).await?; + let active_sessions = repo.browser_session().count_active(&session.user).await?; - let emails = conn.user_email().all(&session.user).await?; + let emails = repo.user_email().all(&session.user).await?; let ctx = AccountContext::new(active_sessions, emails) .with_session(session) diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 42c0194b..8d496432 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -27,7 +27,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; @@ -50,11 +50,11 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { render(&mut rng, &clock, templates, session, cookie_jar).await @@ -90,13 +90,13 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -105,7 +105,7 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_password = txn + let user_password = repo .user_password() .active(&session.user) .await? @@ -129,7 +129,7 @@ pub(crate) async fn post( } let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; - let user_password = txn + let user_password = repo .user_password() .add( &mut rng, @@ -141,14 +141,14 @@ pub(crate) async fn post( ) .await?; - let session = txn + let session = repo .browser_session() .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; let reply = render(&mut rng, &clock, templates.clone(), session, cookie_jar).await?; - txn.commit().await?; + repo.save().await?; Ok(reply) } diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 2471296d..49668dae 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -20,6 +20,7 @@ use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{csrf::CsrfExt, FancyError, SessionInfoExt}; use mas_keystore::Encrypter; use mas_router::UrlBuilder; +use mas_storage::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; @@ -30,11 +31,11 @@ pub async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let session = session_info.load_session(&mut conn).await?; + let session = session_info.load_session(&mut repo).await?; let ctx = IndexContext::new(url_builder.oidc_discovery()) .maybe_with_session(session) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 1ef5efbb..76ffa455 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -26,14 +26,14 @@ use mas_keystore::Encrypter; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, - Clock, Repository, + Clock, PgRepository, Repository, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::{PgConnection, PgPool}; +use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -56,23 +56,23 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let providers = conn.upstream_oauth_provider().all().await?; + let providers = repo.upstream_oauth_provider().all().await?; let content = render( LoginContext::default().with_upstrem_providers(providers), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -90,7 +90,7 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; @@ -112,14 +112,14 @@ pub(crate) async fn post( }; if !state.is_valid() { - let providers = conn.upstream_oauth_provider().all().await?; + let providers = repo.upstream_oauth_provider().all().await?; let content = render( LoginContext::default() .with_form_state(state) .with_upstrem_providers(providers), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -129,7 +129,7 @@ pub(crate) async fn post( match login( password_manager, - &mut conn, + &mut repo, rng, &clock, &form.username, @@ -138,6 +138,8 @@ pub(crate) async fn post( .await { Ok(session_info) => { + repo.save().await?; + let cookie_jar = cookie_jar.set_session(&session_info); let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) @@ -149,7 +151,7 @@ pub(crate) async fn post( LoginContext::default().with_form_state(state), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -162,7 +164,7 @@ pub(crate) async fn post( // TODO: move that logic elsewhere? async fn login( password_manager: PasswordManager, - conn: &mut PgConnection, + repo: &mut impl Repository, mut rng: impl Rng + CryptoRng + Send, clock: &Clock, username: &str, @@ -170,7 +172,7 @@ async fn login( ) -> Result { // XXX: we're loosing the error context here // First, lookup the user - let user = conn + let user = repo .user() .find_by_username(username) .await @@ -178,7 +180,7 @@ async fn login( .ok_or(FormError::InvalidCredentials)?; // And its password - let user_password = conn + let user_password = repo .user_password() .active(&user) .await @@ -200,7 +202,7 @@ async fn login( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - conn.user_password() + repo.user_password() .add( &mut rng, clock, @@ -216,14 +218,14 @@ async fn login( }; // Start a new session - let user_session = conn + let user_session = repo .browser_session() .add(&mut rng, clock, &user) .await .map_err(|_| FormError::Internal)?; // And mark it as authenticated by the password - let user_session = conn + let user_session = repo .browser_session() .authenticate_with_password(&mut rng, clock, user_session, &user_password) .await @@ -236,10 +238,10 @@ async fn render( ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - conn: &mut PgConnection, + repo: &mut impl Repository, templates: &Templates, ) -> Result { - let next = action.load_context(conn).await?; + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 88c4a9c2..156e6afb 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::{user::BrowserSessionRepository, Clock, Repository}; +use mas_storage::{user::BrowserSessionRepository, Clock, PgRepository, Repository}; use sqlx::PgPool; pub(crate) async fn post( @@ -32,20 +32,20 @@ pub(crate) async fn post( Form(form): Form>>, ) -> Result { let clock = Clock::default(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, mut cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if let Some(session) = maybe_session { - txn.browser_session().finish(&clock, session).await?; + repo.browser_session().finish(&clock, session).await?; cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); } - txn.commit().await?; + repo.save().await?; let destination = if let Some(action) = form { action.go_next() diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 7911a930..aac51abd 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -26,7 +26,7 @@ use mas_keystore::Encrypter; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; @@ -48,12 +48,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -65,7 +65,7 @@ pub(crate) async fn get( }; let ctx = ReauthContext::default(); - let next = query.load_context(&mut conn).await?; + let next = query.load_context(&mut repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { @@ -86,13 +86,13 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut txn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; let session = if let Some(session) = maybe_session { session @@ -104,7 +104,7 @@ pub(crate) async fn post( }; // Load the user password - let user_password = txn + let user_password = repo .user_password() .active(&session.user) .await? @@ -125,7 +125,7 @@ pub(crate) async fn post( let user_password = if let Some((version, new_password_hash)) = new_password_hash { // Save the upgraded password - txn.user_password() + repo.user_password() .add( &mut rng, &clock, @@ -140,13 +140,13 @@ pub(crate) async fn post( }; // Mark the session as authenticated by the password - let session = txn + let session = repo .browser_session() .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; let cookie_jar = cookie_jar.set_session(&session); - txn.commit().await?; + repo.save().await?; let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index b2fe9fe0..a014eb9d 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -33,7 +33,7 @@ use mas_policy::PolicyFactory; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, + PgRepository, Repository, }; use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, @@ -41,7 +41,7 @@ use mas_templates::{ }; use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::{PgConnection, PgPool}; +use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -66,12 +66,12 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); - let maybe_session = session_info.load_session(&mut conn).await?; + let maybe_session = session_info.load_session(&mut repo).await?; if maybe_session.is_some() { let reply = query.go_next(); @@ -81,7 +81,7 @@ pub(crate) async fn get( RegisterContext::default(), query, csrf_token, - &mut conn, + &mut repo, &templates, ) .await?; @@ -102,7 +102,7 @@ pub(crate) async fn post( Form(form): Form>, ) -> Result { let (clock, mut rng) = crate::clock_and_rng(); - let mut txn = pool.begin().await?; + let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(clock.now(), form)?; @@ -114,7 +114,7 @@ pub(crate) async fn post( if form.username.is_empty() { state.add_error_on_field(RegisterFormField::Username, FieldError::Required); - } else if txn.user().exists(&form.username).await? { + } else if repo.user().exists(&form.username).await? { state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); } @@ -177,7 +177,7 @@ pub(crate) async fn post( RegisterContext::default().with_form_state(state), query, csrf_token, - &mut txn, + &mut repo, &templates, ) .await?; @@ -185,15 +185,15 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - let user = txn.user().add(&mut rng, &clock, form.username).await?; + let user = repo.user().add(&mut rng, &clock, form.username).await?; let password = Zeroizing::new(form.password.into_bytes()); let (version, hashed_password) = password_manager.hash(&mut rng, password).await?; - let user_password = txn + let user_password = repo .user_password() .add(&mut rng, &clock, &user, version, hashed_password, None) .await?; - let user_email = txn + let user_email = repo .user_email() .add(&mut rng, &clock, &user, form.email) .await?; @@ -205,7 +205,7 @@ pub(crate) async fn post( let address: Address = user_email.email.parse()?; - let verification = txn + let verification = repo .user_email() .add_verification_code(&mut rng, &clock, &user_email, Duration::hours(8), code) .await?; @@ -219,14 +219,14 @@ pub(crate) async fn post( let next = mas_router::AccountVerifyEmail::new(user_email.id).and_maybe(query.post_auth_action); - let session = txn.browser_session().add(&mut rng, &clock, &user).await?; + let session = repo.browser_session().add(&mut rng, &clock, &user).await?; - let session = txn + let session = repo .browser_session() .authenticate_with_password(&mut rng, &clock, session, &user_password) .await?; - txn.commit().await?; + repo.save().await?; let cookie_jar = cookie_jar.set_session(&session); Ok((cookie_jar, next.go()).into_response()) @@ -236,10 +236,10 @@ async fn render( ctx: RegisterContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, - conn: &mut PgConnection, + repo: &mut impl Repository, templates: &Templates, ) -> Result { - let next = action.load_context(conn).await?; + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 57d53762..db3c3392 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,12 +15,13 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::CompatSsoLoginRepository, oauth2::OAuth2AuthorizationGrantRepository, - upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, + compat::CompatSsoLoginRepository, + oauth2::OAuth2AuthorizationGrantRepository, + upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, + Repository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; -use sqlx::PgConnection; #[derive(Serialize, Deserialize, Default, Debug, Clone)] pub(crate) struct OptionalPostAuthAction { @@ -39,14 +40,14 @@ impl OptionalPostAuthAction { self.go_next_or_default(&mas_router::Index) } - pub async fn load_context( + pub async fn load_context( &self, - conn: &mut PgConnection, + repo: &mut R, ) -> anyhow::Result> { let Some(action) = self.post_auth_action.clone() else { return Ok(None) }; let ctx = match action { PostAuthAction::ContinueAuthorizationGrant { id } => { - let grant = conn + let grant = repo .oauth2_authorization_grant() .lookup(id) .await? @@ -56,7 +57,7 @@ impl OptionalPostAuthAction { } PostAuthAction::ContinueCompatSsoLogin { id } => { - let login = conn + let login = repo .compat_sso_login() .lookup(id) .await? @@ -68,13 +69,13 @@ impl OptionalPostAuthAction { PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::LinkUpstream { id } => { - let link = conn + let link = repo .upstream_oauth_link() .lookup(id) .await? .context("Failed to load upstream OAuth 2.0 link")?; - let provider = conn + let provider = repo .upstream_oauth_provider() .lookup(link.provider_id) .await? diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 09a70023..97aeee24 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -183,7 +183,7 @@ pub(crate) mod tracing; pub mod upstream_oauth2; pub mod user; -pub use self::{repository::Repository, upstream_oauth2::UpstreamOAuthLinkRepository}; +pub use self::repository::{PgRepository, Repository}; /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 91df9313..c57c5dcd 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -32,7 +32,7 @@ use crate::{ }; #[async_trait] -pub trait OAuth2AuthorizationGrantRepository { +pub trait OAuth2AuthorizationGrantRepository: Send + Sync { type Error; #[allow(clippy::too_many_arguments)] diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 9df2f61d..c28bc4ef 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -27,7 +27,7 @@ use crate::{ }; #[async_trait] -pub trait OAuth2SessionRepository { +pub trait OAuth2SessionRepository: Send + Sync { type Error; async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index b9bf5683..1fde4b41 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -12,89 +12,100 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlx::{PgConnection, Postgres, Transaction}; +use sqlx::{PgPool, Postgres, Transaction}; use crate::{ compat::{ - PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, - PgCompatSsoLoginRepository, + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, + PgCompatSessionRepository, PgCompatSsoLoginRepository, }, oauth2::{ - PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository, - PgOAuth2ClientRepository, PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, + OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, + OAuth2RefreshTokenRepository, OAuth2SessionRepository, PgOAuth2AccessTokenRepository, + PgOAuth2AuthorizationGrantRepository, PgOAuth2ClientRepository, + PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository, }, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, - PgUpstreamOAuthSessionRepository, + PgUpstreamOAuthSessionRepository, UpstreamOAuthLinkRepository, + UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::{ - PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, - PgUserRepository, + BrowserSessionRepository, PgBrowserSessionRepository, PgUserEmailRepository, + PgUserPasswordRepository, PgUserRepository, UserEmailRepository, UserPasswordRepository, + UserRepository, }, + DatabaseError, }; -pub trait Repository { - type UpstreamOAuthLinkRepository<'c> +pub trait Repository: Send { + type Error: std::error::Error + Send + Sync + 'static; + + type UpstreamOAuthLinkRepository<'c>: UpstreamOAuthLinkRepository + 'c where Self: 'c; - type UpstreamOAuthProviderRepository<'c> + type UpstreamOAuthProviderRepository<'c>: UpstreamOAuthProviderRepository + + 'c where Self: 'c; - type UpstreamOAuthSessionRepository<'c> + type UpstreamOAuthSessionRepository<'c>: UpstreamOAuthSessionRepository + + 'c where Self: 'c; - type UserRepository<'c> + type UserRepository<'c>: UserRepository + 'c where Self: 'c; - type UserEmailRepository<'c> + type UserEmailRepository<'c>: UserEmailRepository + 'c where Self: 'c; - type UserPasswordRepository<'c> + type UserPasswordRepository<'c>: UserPasswordRepository + 'c where Self: 'c; - type BrowserSessionRepository<'c> + type BrowserSessionRepository<'c>: BrowserSessionRepository + 'c where Self: 'c; - type OAuth2ClientRepository<'c> + type OAuth2ClientRepository<'c>: OAuth2ClientRepository + 'c where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> + type OAuth2AuthorizationGrantRepository<'c>: OAuth2AuthorizationGrantRepository + + 'c where Self: 'c; - type OAuth2SessionRepository<'c> + type OAuth2SessionRepository<'c>: OAuth2SessionRepository + 'c where Self: 'c; - type OAuth2AccessTokenRepository<'c> + type OAuth2AccessTokenRepository<'c>: OAuth2AccessTokenRepository + 'c where Self: 'c; - type OAuth2RefreshTokenRepository<'c> + type OAuth2RefreshTokenRepository<'c>: OAuth2RefreshTokenRepository + 'c where Self: 'c; - type CompatSessionRepository<'c> + type CompatSessionRepository<'c>: CompatSessionRepository + 'c where Self: 'c; - type CompatSsoLoginRepository<'c> + type CompatSsoLoginRepository<'c>: CompatSsoLoginRepository + 'c where Self: 'c; - type CompatAccessTokenRepository<'c> + type CompatAccessTokenRepository<'c>: CompatAccessTokenRepository + 'c where Self: 'c; - type CompatRefreshTokenRepository<'c> + type CompatRefreshTokenRepository<'c>: CompatRefreshTokenRepository + 'c where Self: 'c; @@ -116,7 +127,30 @@ pub trait Repository { fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; } -impl Repository for PgConnection { +pub struct PgRepository { + txn: Transaction<'static, Postgres>, +} + +impl PgRepository { + pub async fn from_pool(pool: &PgPool) -> Result { + let txn = pool.begin().await?; + Ok(PgRepository { txn }) + } + + pub async fn save(self) -> Result<(), DatabaseError> { + self.txn.commit().await?; + Ok(()) + } + + pub async fn cancel(self) -> Result<(), DatabaseError> { + self.txn.rollback().await?; + Ok(()) + } +} + +impl Repository for PgRepository { + type Error = DatabaseError; + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; @@ -135,149 +169,66 @@ impl Repository for PgConnection { type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(self) + PgUpstreamOAuthLinkRepository::new(&mut self.txn) } fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(self) + PgUpstreamOAuthProviderRepository::new(&mut self.txn) } fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(self) + PgUpstreamOAuthSessionRepository::new(&mut self.txn) } fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(self) + PgUserRepository::new(&mut self.txn) } fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(self) + PgUserEmailRepository::new(&mut self.txn) } fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(self) + PgUserPasswordRepository::new(&mut self.txn) } fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(self) + PgBrowserSessionRepository::new(&mut self.txn) } fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(self) + PgOAuth2ClientRepository::new(&mut self.txn) } fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(self) + PgOAuth2AuthorizationGrantRepository::new(&mut self.txn) } fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(self) + PgOAuth2SessionRepository::new(&mut self.txn) } fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(self) + PgOAuth2AccessTokenRepository::new(&mut self.txn) } fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(self) + PgOAuth2RefreshTokenRepository::new(&mut self.txn) } fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(self) + PgCompatSessionRepository::new(&mut self.txn) } fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(self) + PgCompatSsoLoginRepository::new(&mut self.txn) } fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(self) + PgCompatAccessTokenRepository::new(&mut self.txn) } fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(self) - } -} - -impl<'t> Repository for Transaction<'t, Postgres> { - type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; - type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; - type UpstreamOAuthSessionRepository<'c> = PgUpstreamOAuthSessionRepository<'c> where Self: 'c; - type UserRepository<'c> = PgUserRepository<'c> where Self: 'c; - type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; - type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; - type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; - type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; - type OAuth2AuthorizationGrantRepository<'c> = PgOAuth2AuthorizationGrantRepository<'c> where Self: 'c; - type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; - type OAuth2AccessTokenRepository<'c> = PgOAuth2AccessTokenRepository<'c> where Self: 'c; - type OAuth2RefreshTokenRepository<'c> = PgOAuth2RefreshTokenRepository<'c> where Self: 'c; - type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; - type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; - type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; - type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; - - fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { - PgUpstreamOAuthLinkRepository::new(self) - } - - fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { - PgUpstreamOAuthProviderRepository::new(self) - } - - fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_> { - PgUpstreamOAuthSessionRepository::new(self) - } - - fn user(&mut self) -> Self::UserRepository<'_> { - PgUserRepository::new(self) - } - - fn user_email(&mut self) -> Self::UserEmailRepository<'_> { - PgUserEmailRepository::new(self) - } - - fn user_password(&mut self) -> Self::UserPasswordRepository<'_> { - PgUserPasswordRepository::new(self) - } - - fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { - PgBrowserSessionRepository::new(self) - } - - fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { - PgOAuth2ClientRepository::new(self) - } - - fn oauth2_authorization_grant(&mut self) -> Self::OAuth2AuthorizationGrantRepository<'_> { - PgOAuth2AuthorizationGrantRepository::new(self) - } - - fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { - PgOAuth2SessionRepository::new(self) - } - - fn oauth2_access_token(&mut self) -> Self::OAuth2AccessTokenRepository<'_> { - PgOAuth2AccessTokenRepository::new(self) - } - - fn oauth2_refresh_token(&mut self) -> Self::OAuth2RefreshTokenRepository<'_> { - PgOAuth2RefreshTokenRepository::new(self) - } - - fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { - PgCompatSessionRepository::new(self) - } - - fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { - PgCompatSsoLoginRepository::new(self) - } - - fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { - PgCompatAccessTokenRepository::new(self) - } - - fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { - PgCompatRefreshTokenRepository::new(self) + PgCompatRefreshTokenRepository::new(&mut self.txn) } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index e195056c..d2a24731 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -29,20 +29,20 @@ mod tests { use sqlx::PgPool; use super::*; - use crate::{Clock, Repository}; + use crate::{Clock, PgRepository, Repository}; #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_repository(pool: PgPool) -> Result<(), Box> { let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); let clock = Clock::default(); - let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?; // The provider list should be empty at the start - let all_providers = conn.upstream_oauth_provider().all().await?; + let all_providers = repo.upstream_oauth_provider().all().await?; assert!(all_providers.is_empty()); // Let's add a provider - let provider = conn + let provider = repo .upstream_oauth_provider() .add( &mut rng, @@ -57,7 +57,7 @@ mod tests { .await?; // Look it up in the database - let provider = conn + let provider = repo .upstream_oauth_provider() .lookup(provider.id) .await? @@ -66,7 +66,7 @@ mod tests { assert_eq!(provider.client_id, "client-id"); // Start a session - let session = conn + let session = repo .upstream_oauth_session() .add( &mut rng, @@ -79,7 +79,7 @@ mod tests { .await?; // Look it up in the database - let session = conn + let session = repo .upstream_oauth_session() .lookup(session.id) .await? @@ -91,19 +91,19 @@ mod tests { assert!(!session.is_consumed()); // Create a link - let link = conn + let link = repo .upstream_oauth_link() .add(&mut rng, &clock, &provider, "a-subject".to_owned()) .await?; // We can look it up by its ID - conn.upstream_oauth_link() + repo.upstream_oauth_link() .lookup(link.id) .await? .expect("link to be found in database"); // or by its subject - let link = conn + let link = repo .upstream_oauth_link() .find_by_subject(&provider, "a-subject") .await? @@ -111,7 +111,7 @@ mod tests { assert_eq!(link.subject, "a-subject"); assert_eq!(link.provider_id, provider.id); - let session = conn + let session = repo .upstream_oauth_session() .complete_with_link(&clock, session, &link, None) .await?; @@ -119,7 +119,7 @@ mod tests { assert!(!session.is_consumed()); assert_eq!(session.link_id(), Some(link.id)); - let session = conn + let session = repo .upstream_oauth_session() .consume(&clock, session) .await?; diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index f4d11c6a..39a33b8d 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,7 +14,7 @@ //! Database-related tasks -use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, Repository}; +use mas_storage::{oauth2::OAuth2AccessTokenRepository, Clock, PgRepository, Repository}; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; @@ -33,8 +33,8 @@ impl std::fmt::Debug for CleanupExpired { impl Task for CleanupExpired { async fn run(&self) { let res = async move { - let mut conn = self.0.acquire().await?; - conn.oauth2_access_token().cleanup_expired(&self.1).await + let mut repo = PgRepository::from_pool(&self.0).await?; + repo.oauth2_access_token().cleanup_expired(&self.1).await } .await;