diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index ad8c407e..5c7acf34 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -190,8 +190,8 @@ impl TokenType { /// use rand::thread_rng; /// use mas_data_model::TokenType::{AccessToken, RefreshToken}; /// - /// AccessToken.generate(thread_rng()); - /// RefreshToken.generate(thread_rng()); + /// AccessToken.generate(&mut thread_rng()); + /// RefreshToken.generate(&mut thread_rng()); /// ``` pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String { let random_part: String = rng diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 45446e10..4271f896 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -17,16 +17,20 @@ use std::{convert::Infallible, sync::Arc}; use axum::{ async_trait, extract::{FromRef, FromRequestParts}, + response::IntoResponse, }; +use hyper::StatusCode; use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRng, SystemClock}; +use mas_storage_pg::PgRepository; use mas_templates::Templates; use rand::SeedableRng; use sqlx::PgPool; +use thiserror::Error; use crate::{passwords::PasswordManager, MatrixHomeserver}; @@ -140,3 +144,26 @@ impl FromRequestParts for BoxRng { Ok(Box::new(rng)) } } + +#[derive(Debug, Error)] +#[error(transparent)] +pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError); + +impl IntoResponse for RepositoryError { + fn into_response(self) -> axum::response::Response { + (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() + } +} + +#[async_trait] +impl FromRequestParts for PgRepository { + type Rejection = RepositoryError; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + let repo = PgRepository::from_pool(&state.pool).await?; + Ok(repo) + } +} diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 279b575d..f76cda71 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -28,7 +28,6 @@ use mas_storage_pg::PgRepository; use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; -use sqlx::PgPool; use thiserror::Error; use zeroize::Zeroizing; @@ -197,11 +196,10 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(pool): State, + mut repo: PgRepository, State(homeserver): State, Json(input): Json, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; let (session, user) = match input.credentials { Credentials::Password { identifier: Identifier::User { user }, diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 6201b0c6..602b4d80 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -36,7 +36,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use ulid::Ulid; #[derive(Serialize)] @@ -56,14 +55,12 @@ pub struct Params { pub async fn get( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); @@ -120,15 +117,13 @@ pub async fn get( pub async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); cookie_jar.verify_form(&clock, form)?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index a8063141..d8ef0fb2 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -24,7 +24,6 @@ use mas_storage_pg::PgRepository; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; -use sqlx::PgPool; use thiserror::Error; use url::Url; @@ -60,7 +59,7 @@ impl IntoResponse for RouteError { pub async fn get( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(url_builder): State, Query(params): Query, ) -> Result { @@ -79,7 +78,6 @@ pub async fn get( } let token = Alphanumeric.sample_string(&mut rng, 32); - let mut repo = PgRepository::from_pool(&pool).await?; let login = repo .compat_sso_login() .add(&mut rng, &clock, token, redirect_url) diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index bfc767fa..e1ef02be 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; +use axum::{response::IntoResponse, Json, TypedHeader}; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::TokenType; @@ -21,7 +21,6 @@ use mas_storage::{ BoxClock, Clock, Repository, }; use mas_storage_pg::PgRepository; -use sqlx::PgPool; use thiserror::Error; use super::MatrixError; @@ -69,11 +68,9 @@ impl IntoResponse for RouteError { pub(crate) async fn post( clock: BoxClock, - State(pool): State, + mut repo: PgRepository, maybe_authorization: Option>>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; let token = authorization.token(); diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 868be9db..6b90464e 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::State, response::IntoResponse, Json}; +use axum::{response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; @@ -23,7 +23,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; -use sqlx::PgPool; use thiserror::Error; use super::MatrixError; @@ -90,11 +89,9 @@ pub struct ResponseBody { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, Json(input): Json, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let token_type = TokenType::check(&input.refresh_token)?; if token_type != TokenType::CompatRefreshToken { diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 4d9dcbcd..48ca5560 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -21,7 +21,10 @@ )] #![warn(clippy::pedantic)] #![allow( - clippy::unused_async // Some axum handlers need that + // Some axum handlers need that + clippy::unused_async, + // Because of how axum handlers work, we sometime have take many arguments + clippy::too_many_arguments, )] use std::{convert::Infallible, sync::Arc, time::Duration}; @@ -41,6 +44,7 @@ use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRng}; +use mas_storage_pg::PgRepository; use mas_templates::{ErrorContext, Templates}; use passwords::PasswordManager; use sqlx::PgPool; @@ -154,7 +158,7 @@ where Keystore: FromRef, UrlBuilder: FromRef, Arc: FromRef, - PgPool: FromRef, + PgRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, BoxClock: FromRequestParts, @@ -209,7 +213,7 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - PgPool: FromRef, + PgRepository: FromRequestParts, MatrixHomeserver: FromRef, PasswordManager: FromRef, BoxClock: FromRequestParts, @@ -254,7 +258,7 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, Arc: FromRef, - PgPool: FromRef, + PgRepository: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Mailer: FromRef, @@ -358,7 +362,7 @@ where } #[cfg(test)] -async fn test_state(pool: PgPool) -> Result { +async fn test_state(pool: sqlx::PgPool) -> Result { use mas_email::MailTransport; use crate::passwords::Hasher; diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 934ba088..c17fb9f1 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -32,7 +32,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use mas_templates::Templates; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -82,12 +81,10 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 60753482..30efcaa3 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -39,7 +39,6 @@ use oauth2_types::{ }; use rand::{distributions::Alphanumeric, Rng}; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use self::{callback::CallbackDestination, complete::GrantCompletionError}; @@ -136,12 +135,10 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - // First, figure out what client it is let client = repo .oauth2_client() diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 86c832fb..c83dca03 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -34,7 +34,6 @@ use mas_storage::{ }; use mas_storage_pg::PgRepository; use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates}; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -78,12 +77,10 @@ pub(crate) async fn get( clock: BoxClock, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -133,13 +130,11 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(policy_factory): State>, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index d8e64fa0..65e48e06 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -33,7 +33,6 @@ use oauth2_types::{ requests::{IntrospectionRequest, IntrospectionResponse}, scope::ScopeToken, }; -use sqlx::PgPool; use thiserror::Error; use crate::impl_from_error_for_route; @@ -126,12 +125,10 @@ const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc pub(crate) async fn post( clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: PgRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let client = client_authorization .credentials .fetch(&mut repo) diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index da043b8b..129f636f 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -28,7 +28,6 @@ use oauth2_types::{ }, }; use rand::distributions::{Alphanumeric, DistString}; -use sqlx::PgPool; use thiserror::Error; use tracing::info; @@ -109,7 +108,7 @@ impl IntoResponse for RouteError { pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(policy_factory): State>, State(encrypter): State, Json(body): Json, @@ -125,8 +124,6 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - let mut repo = PgRepository::from_pool(&pool).await?; - let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( OAuthClientAuthenticationMethod::ClientSecretJwt diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 3fe916d8..ed566261 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -50,7 +50,6 @@ use oauth2_types::{ }; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; -use sqlx::PgPool; use thiserror::Error; use tracing::debug; use url::Url; @@ -164,12 +163,10 @@ pub(crate) async fn post( State(http_client_factory): State, State(key_store): State, State(url_builder): State, - State(pool): State, + mut repo: PgRepository, State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let client = client_authorization .credentials .fetch(&mut repo) diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 9d60ac1f..eb9e1cc2 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -37,7 +37,6 @@ use mas_storage_pg::PgRepository; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; -use sqlx::PgPool; use thiserror::Error; use crate::impl_from_error_for_route; @@ -101,12 +100,10 @@ pub async fn get( mut rng: BoxRng, clock: BoxClock, State(url_builder): State, - State(pool): State, + mut repo: PgRepository, State(key_store): State, user_authorization: UserAuthorization, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let session = user_authorization.protected(&mut repo, &clock).await?; let browser_session = repo diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index d6649317..ff47084b 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -27,7 +27,6 @@ use mas_storage::{ BoxClock, BoxRng, Repository, }; use mas_storage_pg::PgRepository; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -61,14 +60,12 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: PgRepository, State(url_builder): State, cookie_jar: PrivateCookieJar, Path(provider_id): Path, Query(query): Query, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let provider = repo .upstream_oauth_provider() .lookup(provider_id) diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index fd66af09..b324cfb2 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -35,7 +35,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -124,7 +123,7 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(http_client_factory): State, - State(pool): State, + mut repo: PgRepository, State(url_builder): State, State(encrypter): State, State(keystore): State, @@ -132,8 +131,6 @@ pub(crate) async fn get( Path(provider_id): Path, Query(params): Query, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let provider = repo .upstream_oauth_provider() .lookup(provider_id) diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index d318fc3e..bdd5df1f 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -35,7 +35,6 @@ use mas_templates::{ UpstreamSuggestLink, }; use serde::Deserialize; -use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -96,12 +95,11 @@ pub(crate) enum FormData { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(templates): State, cookie_jar: PrivateCookieJar, Path(link_id): Path, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); let (session_id, _post_auth_action) = sessions_cookie .lookup_link(link_id) @@ -213,12 +211,11 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Path(link_id): Path, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; let form = cookie_jar.verify_form(&clock, form)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 1c8c6665..64218e3a 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -28,7 +28,6 @@ use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailAddContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use super::start_email_verification; use crate::views::shared::OptionalPostAuthAction; @@ -42,11 +41,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -71,14 +68,12 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(pool): State, + mut repo: PgRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 10772b87..fd2f2981 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -33,7 +33,6 @@ use mas_storage_pg::PgRepository; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use rand::{distributions::Uniform, Rng}; use serde::Deserialize; -use sqlx::PgPool; use tracing::info; pub mod add; @@ -52,11 +51,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -127,13 +124,11 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 644810e5..e330c944 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -28,7 +28,6 @@ use mas_storage::{user::UserEmailRepository, BoxClock, BoxRng, Repository}; use mas_storage_pg::PgRepository; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use ulid::Ulid; use crate::views::shared::OptionalPostAuthAction; @@ -42,13 +41,11 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -85,14 +82,12 @@ pub(crate) async fn get( pub(crate) async fn post( clock: BoxClock, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index 660c1416..76ea5667 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -29,17 +29,14 @@ use mas_storage::{ }; use mas_storage_pg::PgRepository; use mas_templates::{AccountContext, TemplateContext, Templates}; -use sqlx::PgPool; pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index a9f17123..4fa86eae 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -33,7 +33,6 @@ use mas_storage_pg::PgRepository; use mas_templates::{EmptyContext, TemplateContext, Templates}; use rand::Rng; use serde::Deserialize; -use sqlx::PgPool; use zeroize::Zeroizing; use crate::passwords::PasswordManager; @@ -49,11 +48,9 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (session_info, cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; @@ -89,12 +86,10 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index ffe500e7..d4322eef 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -23,18 +23,15 @@ use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRng}; use mas_storage_pg::PgRepository; use mas_templates::{IndexContext, TemplateContext, Templates}; -use sqlx::PgPool; pub async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, State(url_builder): State, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let session = session_info.load_session(&mut repo).await?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index d8abcb46..b245b597 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -34,7 +34,6 @@ use mas_templates::{ }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -54,12 +53,10 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -88,13 +85,11 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index f8491cb9..9cdc93f0 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{ - extract::{Form, State}, - response::IntoResponse, -}; +use axum::{extract::Form, response::IntoResponse}; use axum_extra::extract::PrivateCookieJar; use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, @@ -25,16 +22,13 @@ use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; use mas_storage::{user::BrowserSessionRepository, BoxClock, Repository}; use mas_storage_pg::PgRepository; -use sqlx::PgPool; pub(crate) async fn post( clock: BoxClock, - State(pool): State, + mut repo: PgRepository, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, mut cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 9c2330a3..ced97902 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -31,7 +31,6 @@ use mas_storage::{ use mas_storage_pg::PgRepository; use mas_templates::{ReauthContext, TemplateContext, Templates}; use serde::Deserialize; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -46,12 +45,10 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -84,13 +81,11 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (session_info, cookie_jar) = cookie_jar.session_info(); diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index a8fc7bae..68cf5c49 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -42,7 +42,6 @@ use mas_templates::{ }; use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; @@ -64,12 +63,10 @@ pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -100,13 +97,11 @@ pub(crate) async fn post( State(mailer): State, State(policy_factory): State>, State(templates): State, - State(pool): State, + mut repo: PgRepository, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { - let mut repo = PgRepository::from_pool(&pool).await?; - let form = cookie_jar.verify_form(&clock, form)?; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 097bca74..b3b88232 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -90,7 +90,7 @@ async fn test_user_email_repo(pool: PgPool) { // The user email should not exist yet assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_none()); @@ -111,7 +111,7 @@ async fn test_user_email_repo(pool: PgPool) { assert!(repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .is_some()); @@ -181,7 +181,7 @@ async fn test_user_email_repo(pool: PgPool) { // Reload the user_email let user_email = repo .user_email() - .find(&user, &EMAIL) + .find(&user, EMAIL) .await .unwrap() .expect("user email was not found");