diff --git a/Cargo.lock b/Cargo.lock index c4d81ffb..a1c80d6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -516,9 +516,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.5.15" +version = "0.6.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de18bc5f2e9df8f52da03856bf40e29b747de5a84e43aefff90e3dc4a21529b" +checksum = "d49958d54e0bab71947eb00a33175eb9164ccc0ea4c262d2139c5f8899a3616e" dependencies = [ "async-trait", "axum-core", @@ -548,9 +548,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.2.7" +version = "0.3.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4f44a0e6200e9d11a1cdc989e4b358f6e3d354fbf48478f345a17f4e43f8635" +checksum = "5e52ebadfce2f1e7fec9b2dd920952477ffeac9f07a5c492c0cbba2bb22cd294" dependencies = [ "async-trait", "bytes 1.2.1", @@ -562,9 +562,9 @@ dependencies = [ [[package]] name = "axum-extra" -version = "0.3.7" +version = "0.4.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69034b3b0fd97923eee2ce8a47540edb21e07f48f87f67d44bb4271cec622bdb" +checksum = "090ae29ae83a40882fb99bce421d8dce4a819325a013bb301dad4cc4b74ab40c" dependencies = [ "axum", "bytes 1.2.1", @@ -2730,9 +2730,9 @@ checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" [[package]] name = "matchit" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +checksum = "3dfc802da7b1cf80aefffa0c7b2f77247c8b32206cc83c270b61264f5b360a80" [[package]] name = "md-5" diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 04e86476..8e7295aa 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -7,8 +7,8 @@ license = "Apache-2.0" [dependencies] async-trait = "0.1.57" -axum = { version = "0.5.15", features = ["headers"] } -axum-extra = { version = "0.3.7", features = ["cookie-private"] } +axum = { version = "0.6.0-rc.1", features = ["headers"] } +axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] } bincode = "1.3.3" chrono = "0.4.22" data-encoding = "2.3.2" diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index e2bafdc1..94ea1079 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -19,13 +19,13 @@ use axum::{ body::HttpBody, extract::{ rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason}, - Form, FromRequest, RequestParts, TypedHeader, + Form, FromRequest, FromRequestParts, TypedHeader, }, response::IntoResponse, BoxError, }; use headers::{authorization::Basic, Authorization}; -use http::StatusCode; +use http::{Request, StatusCode}; use mas_data_model::{Client, JwksOrJwksUri, StorageBackend}; use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; @@ -234,18 +234,23 @@ impl IntoResponse for ClientAuthorizationError { } #[async_trait] -impl FromRequest for ClientAuthorization +impl FromRequest for ClientAuthorization where - B: Send + HttpBody, - B::Data: Send, - B::Error: std::error::Error + Send + Sync + 'static, F: DeserializeOwned, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, + S: Send + Sync, { type Rejection = ClientAuthorizationError; #[allow(clippy::too_many_lines)] - async fn from_request(req: &mut RequestParts) -> Result { - let header = TypedHeader::>::from_request(req).await; + async fn from_request(req: Request, state: &S) -> Result { + // Split the request into parts so we can extract some headers + let (mut parts, body) = req.into_parts(); + + let header = + TypedHeader::>::from_request_parts(&mut parts, state).await; // Take the Authorization header let credentials_from_header = match header { @@ -258,6 +263,9 @@ where }, }; + // Reconstruct the request from the parts + let req = Request::from_parts(parts, body); + // Take the form value let ( client_id_from_form, @@ -265,7 +273,7 @@ where client_assertion_type, client_assertion, form, - ) = match Form::>::from_request(req).await { + ) = match Form::>::from_request(req, state).await { Ok(Form(form)) => ( form.client_id, form.client_secret, @@ -385,19 +393,17 @@ mod tests { #[tokio::test] async fn none_test() { - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .body(Full::::new("client_id=client-id&foo=bar".into())) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .body(Full::::new("client_id=client-id&foo=bar".into())) + .unwrap(); assert_eq!( - ClientAuthorization::::from_request(&mut req) + ClientAuthorization::::from_request(req, &()) .await .unwrap(), ClientAuthorization { @@ -411,23 +417,21 @@ mod tests { #[tokio::test] async fn client_secret_basic_test() { - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .header( - http::header::AUTHORIZATION, - "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", - ) - .body(Full::::new("foo=bar".into())) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .header( + http::header::AUTHORIZATION, + "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", + ) + .body(Full::::new("foo=bar".into())) + .unwrap(); assert_eq!( - ClientAuthorization::::from_request(&mut req) + ClientAuthorization::::from_request(req, &()) .await .unwrap(), ClientAuthorization { @@ -440,23 +444,21 @@ mod tests { ); // client_id in both header and body - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .header( - http::header::AUTHORIZATION, - "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", - ) - .body(Full::::new("client_id=client-id&foo=bar".into())) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .header( + http::header::AUTHORIZATION, + "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", + ) + .body(Full::::new("client_id=client-id&foo=bar".into())) + .unwrap(); assert_eq!( - ClientAuthorization::::from_request(&mut req) + ClientAuthorization::::from_request(req, &()) .await .unwrap(), ClientAuthorization { @@ -469,62 +471,56 @@ mod tests { ); // client_id in both header and body mismatch - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .header( - http::header::AUTHORIZATION, - "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", - ) - .body(Full::::new("client_id=mismatch-id&foo=bar".into())) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .header( + http::header::AUTHORIZATION, + "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", + ) + .body(Full::::new("client_id=mismatch-id&foo=bar".into())) + .unwrap(); assert!(matches!( - ClientAuthorization::::from_request(&mut req).await, + ClientAuthorization::::from_request(req, &()).await, Err(ClientAuthorizationError::ClientIdMismatch { .. }), )); // Invalid header - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .header(http::header::AUTHORIZATION, "Basic invalid") - .body(Full::::new("foo=bar".into())) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .header(http::header::AUTHORIZATION, "Basic invalid") + .body(Full::::new("foo=bar".into())) + .unwrap(); assert!(matches!( - ClientAuthorization::::from_request(&mut req).await, + ClientAuthorization::::from_request(req, &()).await, Err(ClientAuthorizationError::InvalidHeader), )); } #[tokio::test] async fn client_secret_post_test() { - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .body(Full::::new( - "client_id=client-id&client_secret=client-secret&foo=bar".into(), - )) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .body(Full::::new( + "client_id=client-id&client_secret=client-secret&foo=bar".into(), + )) + .unwrap(); assert_eq!( - ClientAuthorization::::from_request(&mut req) + ClientAuthorization::::from_request(req, &()) .await .unwrap(), ClientAuthorization { @@ -546,18 +542,16 @@ mod tests { JWT_BEARER_CLIENT_ASSERTION, jwt, )); - let mut req = RequestParts::new( - Request::builder() - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .body(Full::new(body)) - .unwrap(), - ); + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .body(Full::new(body)) + .unwrap(); - let authz = ClientAuthorization::::from_request(&mut req) + let authz = ClientAuthorization::::from_request(req, &()) .await .unwrap(); assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"}))); diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 1947e8b5..9e253f6f 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -19,12 +19,13 @@ use axum::{ body::HttpBody, extract::{ rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason}, - Form, FromRequest, TypedHeader, + Form, FromRequest, FromRequestParts, TypedHeader, }, response::{IntoResponse, Response}, + BoxError, }; use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; -use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, StatusCode}; +use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; use mas_storage::{ oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError}, @@ -275,19 +276,20 @@ impl IntoResponse for AuthorizationVerificationError { } #[async_trait] -impl FromRequest for UserAuthorization +impl FromRequest for UserAuthorization where - B: Send + HttpBody, - B::Data: Send, - B::Error: Error + Send + Sync + 'static, F: DeserializeOwned, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, + S: Send + Sync, { type Rejection = UserAuthorizationError; - async fn from_request( - req: &mut axum::extract::RequestParts, - ) -> Result { - let header = TypedHeader::>::from_request(req).await; + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, body) = req.into_parts(); + let header = + TypedHeader::>::from_request_parts(&mut parts, state).await; // Take the Authorization header let token_from_header = match header { @@ -300,18 +302,21 @@ where }, }; + let req = Request::from_parts(parts, body); + // Take the form value - let (token_from_form, form) = match Form::>::from_request(req).await { - Ok(Form(form)) => (form.access_token, Some(form.inner)), - // If it is not a form, continue - Err(FormRejection::InvalidFormContentType(_err)) => (None, None), - // If the form could not be read, return a Bad Request error - Err(FormRejection::FailedToDeserializeQueryString(err)) => { - return Err(UserAuthorizationError::BadForm(err)) - } - // Other errors (body read twice, byte stream broke) return an internal error - Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))), - }; + let (token_from_form, form) = + match Form::>::from_request(req, state).await { + Ok(Form(form)) => (form.access_token, Some(form.inner)), + // If it is not a form, continue + Err(FormRejection::InvalidFormContentType(_err)) => (None, None), + // If the form could not be read, return a Bad Request error + Err(FormRejection::FailedToDeserializeQueryString(err)) => { + return Err(UserAuthorizationError::BadForm(err)) + } + // Other errors (body read twice, byte stream broke) return an internal error + Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))), + }; let access_token = match (token_from_header, token_from_form) { // Ensure the token should not be in both the form and the access token diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index b336456c..e1fee88c 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -24,7 +24,7 @@ use futures::stream::{StreamExt, TryStreamExt}; use hyper::Server; use mas_config::RootConfig; use mas_email::Mailer; -use mas_handlers::MatrixHomeserver; +use mas_handlers::{AppState, MatrixHomeserver}; use mas_http::ServerLayer; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; @@ -174,8 +174,6 @@ impl Options { .key_store() .await .context("could not import keys from config")?; - // Wrap the key store in an Arc - let key_store = Arc::new(key_store); let encrypter = config.secrets.encrypter(); @@ -236,18 +234,20 @@ impl Options { .context("could not watch for templates changes")?; } - let router = mas_handlers::router( - &pool, - &templates, - &key_store, - &encrypter, - &mailer, - &url_builder, - &homeserver, - &policy_factory, - ) - .fallback(static_files) - .layer(ServerLayer::default()); + let state = AppState { + pool, + templates, + key_store, + encrypter, + url_builder, + mailer, + homeserver, + policy_factory, + }; + + let router = mas_handlers::router(state) + .fallback_service(static_files) + .layer(ServerLayer::default()); info!("Listening on http://{}", listener.local_addr().unwrap()); diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 6c1b2d42..4d206867 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -20,9 +20,9 @@ anyhow = "1.0.64" hyper = { version = "0.14.20", features = ["full"] } tower = "0.4.13" tower-http = { version = "0.3.4", features = ["cors"] } -axum = "0.5.15" +axum = "0.6.0-rc.1" axum-macros = "0.2.3" -axum-extra = { version = "0.3.7", features = ["cookie-private"] } +axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] } # Emails lettre = { version = "0.10.1", default-features = false, features = ["builder"] } diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs new file mode 100644 index 00000000..1b5c8019 --- /dev/null +++ b/crates/handlers/src/app_state.rs @@ -0,0 +1,85 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use axum::extract::FromRef; +use mas_email::Mailer; +use mas_keystore::{Encrypter, Keystore}; +use mas_policy::PolicyFactory; +use mas_router::UrlBuilder; +use mas_templates::Templates; +use sqlx::PgPool; + +use crate::MatrixHomeserver; + +#[derive(Clone)] +pub struct AppState { + pub pool: PgPool, + pub templates: Templates, + pub key_store: Keystore, + pub encrypter: Encrypter, + pub url_builder: UrlBuilder, + pub mailer: Mailer, + pub homeserver: MatrixHomeserver, + pub policy_factory: Arc, +} + +impl FromRef for PgPool { + fn from_ref(input: &AppState) -> Self { + input.pool.clone() + } +} + +impl FromRef for Templates { + fn from_ref(input: &AppState) -> Self { + input.templates.clone() + } +} + +impl FromRef for Keystore { + fn from_ref(input: &AppState) -> Self { + input.key_store.clone() + } +} + +impl FromRef for Encrypter { + fn from_ref(input: &AppState) -> Self { + input.encrypter.clone() + } +} + +impl FromRef for UrlBuilder { + fn from_ref(input: &AppState) -> Self { + input.url_builder.clone() + } +} + +impl FromRef for Mailer { + fn from_ref(input: &AppState) -> Self { + input.mailer.clone() + } +} + +impl FromRef for MatrixHomeserver { + fn from_ref(input: &AppState) -> Self { + input.homeserver.clone() + } +} + +impl FromRef for Arc { + fn from_ref(input: &AppState) -> Self { + input.policy_factory.clone() + } +} diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index c827dcee..f456ab7a 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{response::IntoResponse, Extension, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use chrono::{Duration, Utc}; use hyper::StatusCode; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; @@ -197,8 +197,8 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( - Extension(pool): Extension, - Extension(homeserver): Extension, + State(pool): State, + State(homeserver): State, Json(input): Json, ) -> Result { let mut txn = pool.begin().await?; diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 7ae1f314..3f38cc1a 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -16,9 +16,8 @@ use std::collections::HashMap; use axum::{ - extract::{Form, Path, Query}, + extract::{Form, Path, Query, State}, response::{Html, IntoResponse, Redirect, Response}, - Extension, }; use axum_extra::extract::PrivateCookieJar; use chrono::{Duration, Utc}; @@ -50,8 +49,8 @@ pub struct Params { } pub async fn get( - Extension(pool): Extension, - Extension(templates): Extension, + State(pool): State, + State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, Query(params): Query, @@ -114,12 +113,12 @@ pub async fn get( } pub async fn post( - Extension(pool): Extension, - Extension(templates): Extension, + State(pool): State, + State(templates): State, cookie_jar: PrivateCookieJar, Path(id): Path, - Form(form): Form>, Query(params): Query, + Form(form): Form>, ) -> Result { let mut txn = pool.begin().await?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 3ce8b8a0..2a4c4676 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -13,7 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::Query, response::IntoResponse, Extension}; +use axum::{ + extract::{Query, State}, + response::IntoResponse, +}; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_storage::compat::insert_compat_sso_login; @@ -63,8 +66,8 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip(pool, url_builder), err)] pub async fn get( - Extension(pool): Extension, - Extension(url_builder): Extension, + State(pool): State, + State(url_builder): State, Query(params): Query, ) -> Result { // Check the redirectUrl parameter diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 613389e8..36e64c47 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{response::IntoResponse, Extension, Json, TypedHeader}; +use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; @@ -64,7 +64,7 @@ impl From for RouteError { } pub(crate) async fn post( - Extension(pool): Extension, + State(pool): State, maybe_authorization: Option>>, ) -> Result { let mut conn = pool.acquire().await?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 085a2a50..3b2636a3 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{response::IntoResponse, Extension, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; @@ -96,7 +96,7 @@ pub struct ResponseBody { } pub(crate) async fn post( - Extension(pool): Extension, + State(pool): State, Json(input): Json, ) -> Result { let mut txn = pool.begin().await?; diff --git a/crates/handlers/src/health.rs b/crates/handlers/src/health.rs index 578b6ecd..da13c286 100644 --- a/crates/handlers/src/health.rs +++ b/crates/handlers/src/health.rs @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::Extension, response::IntoResponse}; +use axum::{extract::State, response::IntoResponse}; use mas_axum_utils::FancyError; use sqlx::PgPool; use tracing::{info_span, Instrument}; -pub async fn get(Extension(pool): Extension) -> Result { +pub async fn get(State(pool): State) -> Result { let mut conn = pool.acquire().await?; sqlx::query("SELECT $1") @@ -38,7 +38,9 @@ mod tests { #[sqlx::test(migrator = "mas_storage::MIGRATOR")] async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> { - let app = crate::test_router(&pool).await?; + let state = crate::test_state(pool).await?; + let app = crate::api_router(state); + let request = Request::builder().uri("/health").body(Body::empty())?; let response = app.oneshot(request).await?; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index ee19643a..7913d268 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -23,7 +23,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration}; use axum::{ body::HttpBody, - extract::Extension, + extract::FromRef, response::{Html, IntoResponse}, routing::{get, on, post, MethodFilter}, Router, @@ -37,9 +37,10 @@ use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; use mas_templates::{ErrorContext, Templates}; use sqlx::PgPool; -use tower::util::ThenLayer; +use tower::util::AndThenLayer; use tower_http::cors::{Any, CorsLayer}; +mod app_state; mod compat; mod health; mod oauth2; @@ -47,30 +48,24 @@ mod views; pub use compat::MatrixHomeserver; +pub use self::app_state::AppState; + #[must_use] -#[allow( - clippy::too_many_lines, - clippy::missing_panics_doc, - clippy::too_many_arguments, - clippy::trait_duplication_in_bounds -)] -pub fn router( - pool: &PgPool, - templates: &Templates, - key_store: &Keystore, - encrypter: &Encrypter, - mailer: &Mailer, - url_builder: &UrlBuilder, - homeserver: &MatrixHomeserver, - policy_factory: &Arc, -) -> Router +#[allow(clippy::trait_duplication_in_bounds)] +pub fn api_router(state: Arc) -> Router where B: HttpBody + Send + 'static, ::Data: Send, ::Error: std::error::Error + Send + Sync, + S: Send + Sync + 'static, + Keystore: FromRef, + UrlBuilder: FromRef, + Arc: FromRef, + PgPool: FromRef, + Encrypter: FromRef, { // All those routes are API-like, with a common CORS layer - let api_router = Router::new() + Router::with_state_arc(state) .route( mas_router::ChangePasswordDiscovery::route(), get(|| async { mas_router::AccountPassword.go() }), @@ -118,9 +113,21 @@ where CONTENT_TYPE, ]) .max_age(Duration::from_secs(60 * 60)), - ); - - let compat_router = Router::new() + ) +} +#[must_use] +#[allow(clippy::trait_duplication_in_bounds)] +pub fn compat_router(state: Arc) -> Router +where + B: HttpBody + Send + 'static, + ::Data: Send, + ::Error: std::error::Error + Send + Sync, + S: Send + Sync + 'static, + UrlBuilder: FromRef, + PgPool: FromRef, + MatrixHomeserver: FromRef, +{ + Router::with_state_arc(state) .route( mas_router::CompatLogin::route(), get(self::compat::login::get).post(self::compat::login::post), @@ -146,106 +153,131 @@ where HeaderName::from_static("x-requested-with"), ]) .max_age(Duration::from_secs(60 * 60)), - ); + ) +} - let human_router = { - let templates = templates.clone(); - Router::new() - .route(mas_router::Index::route(), get(self::views::index::get)) - .route(mas_router::Healthcheck::route(), get(self::health::get)) - .route( - mas_router::Login::route(), - get(self::views::login::get).post(self::views::login::post), - ) - .route(mas_router::Logout::route(), post(self::views::logout::post)) - .route( - mas_router::Reauth::route(), - get(self::views::reauth::get).post(self::views::reauth::post), - ) - .route( - mas_router::Register::route(), - get(self::views::register::get).post(self::views::register::post), - ) - .route(mas_router::Account::route(), get(self::views::account::get)) - .route( - mas_router::AccountPassword::route(), - get(self::views::account::password::get).post(self::views::account::password::post), - ) - .route( - mas_router::AccountEmails::route(), - get(self::views::account::emails::get).post(self::views::account::emails::post), - ) - .route( - mas_router::AccountVerifyEmail::route(), - get(self::views::account::emails::verify::get) - .post(self::views::account::emails::verify::post), - ) - .route( - mas_router::AccountAddEmail::route(), - get(self::views::account::emails::add::get) - .post(self::views::account::emails::add::post), - ) - .route( - mas_router::OAuth2AuthorizationEndpoint::route(), - get(self::oauth2::authorization::get), - ) - .route( - mas_router::ContinueAuthorizationGrant::route(), - get(self::oauth2::authorization::complete::get), - ) - .route( - mas_router::Consent::route(), - get(self::oauth2::consent::get).post(self::oauth2::consent::post), - ) - .route( - mas_router::CompatLoginSsoRedirect::route(), - get(self::compat::login_sso_redirect::get), - ) - .route( - mas_router::CompatLoginSsoRedirectIdp::route(), - get(self::compat::login_sso_redirect::get), - ) - .route( - mas_router::CompatLoginSsoComplete::route(), - get(self::compat::login_sso_complete::get) - .post(self::compat::login_sso_complete::post), - ) - .layer(ThenLayer::new( - move |result: Result| async move { - let response = result.unwrap(); - - if response.status().is_server_error() { - // Error responses should have an ErrorContext attached to them - let ext = response.extensions().get::(); - if let Some(ctx) = ext { - if let Ok(res) = templates.render_error(ctx).await { - let (mut parts, _original_body) = response.into_parts(); - parts.headers.remove(CONTENT_TYPE); - return Ok((parts, Html(res)).into_response()); - } +#[must_use] +#[allow(clippy::trait_duplication_in_bounds)] +pub fn human_router(state: Arc) -> Router +where + B: HttpBody + Send + 'static, + ::Data: Send, + ::Error: std::error::Error + Send + Sync, + S: Send + Sync + 'static, + UrlBuilder: FromRef, + Arc: FromRef, + PgPool: FromRef, + Encrypter: FromRef, + Templates: FromRef, + Mailer: FromRef, +{ + let templates = Templates::from_ref(&state); + Router::with_state_arc(state) + .route(mas_router::Index::route(), get(self::views::index::get)) + .route(mas_router::Healthcheck::route(), get(self::health::get)) + .route( + mas_router::Login::route(), + get(self::views::login::get).post(self::views::login::post), + ) + .route(mas_router::Logout::route(), post(self::views::logout::post)) + .route( + mas_router::Reauth::route(), + get(self::views::reauth::get).post(self::views::reauth::post), + ) + .route( + mas_router::Register::route(), + get(self::views::register::get).post(self::views::register::post), + ) + .route(mas_router::Account::route(), get(self::views::account::get)) + .route( + mas_router::AccountPassword::route(), + get(self::views::account::password::get).post(self::views::account::password::post), + ) + .route( + mas_router::AccountEmails::route(), + get(self::views::account::emails::get).post(self::views::account::emails::post), + ) + .route( + mas_router::AccountVerifyEmail::route(), + get(self::views::account::emails::verify::get) + .post(self::views::account::emails::verify::post), + ) + .route( + mas_router::AccountAddEmail::route(), + get(self::views::account::emails::add::get) + .post(self::views::account::emails::add::post), + ) + .route( + mas_router::OAuth2AuthorizationEndpoint::route(), + get(self::oauth2::authorization::get), + ) + .route( + mas_router::ContinueAuthorizationGrant::route(), + get(self::oauth2::authorization::complete::get), + ) + .route( + mas_router::Consent::route(), + get(self::oauth2::consent::get).post(self::oauth2::consent::post), + ) + .route( + mas_router::CompatLoginSsoRedirect::route(), + get(self::compat::login_sso_redirect::get), + ) + .route( + mas_router::CompatLoginSsoRedirectIdp::route(), + get(self::compat::login_sso_redirect::get), + ) + .route( + mas_router::CompatLoginSsoComplete::route(), + get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post), + ) + .layer(AndThenLayer::new( + move |response: axum::response::Response| async move { + if response.status().is_server_error() { + // Error responses should have an ErrorContext attached to them + let ext = response.extensions().get::(); + if let Some(ctx) = ext { + if let Ok(res) = templates.render_error(ctx).await { + let (mut parts, _original_body) = response.into_parts(); + parts.headers.remove(CONTENT_TYPE); + return Ok((parts, Html(res)).into_response()); } } + } - Ok(response) - }, - )) - }; + Ok::<_, Infallible>(response) + }, + )) +} - human_router - .merge(api_router) - .merge(compat_router) - .layer(Extension(pool.clone())) - .layer(Extension(templates.clone())) - .layer(Extension(key_store.clone())) - .layer(Extension(encrypter.clone())) - .layer(Extension(url_builder.clone())) - .layer(Extension(mailer.clone())) - .layer(Extension(homeserver.clone())) - .layer(Extension(policy_factory.clone())) +#[must_use] +#[allow(clippy::trait_duplication_in_bounds)] +pub fn router(state: S) -> Router +where + B: HttpBody + Send + 'static, + ::Data: Send, + ::Error: std::error::Error + Send + Sync, + S: Send + Sync + 'static, + Keystore: FromRef, + UrlBuilder: FromRef, + Arc: FromRef, + PgPool: FromRef, + Encrypter: FromRef, + Templates: FromRef, + Mailer: FromRef, + MatrixHomeserver: FromRef, +{ + let state = Arc::new(state); + + let api_router = api_router(state.clone()); + let compat_router = compat_router(state.clone()); + let human_router = human_router(state); + + human_router.merge(api_router).merge(compat_router) } #[cfg(test)] -async fn test_router(pool: &PgPool) -> Result { +async fn test_state(pool: PgPool) -> Result, anyhow::Error> { use mas_email::MailTransport; let templates = Templates::load(None, true).await?; @@ -265,14 +297,14 @@ async fn test_router(pool: &PgPool) -> Result { let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?; let policy_factory = Arc::new(policy_factory); - Ok(router( + Ok(Arc::new(AppState { pool, - &templates, - &key_store, - &encrypter, - &mailer, - &url_builder, - &homeserver, - &policy_factory, - )) + templates, + key_store, + encrypter, + url_builder, + mailer, + homeserver, + policy_factory, + })) } diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 3385f371..8f416ec7 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -16,9 +16,8 @@ use std::sync::Arc; use anyhow::anyhow; use axum::{ - extract::Path, + extract::{Path, State}, response::{IntoResponse, Response}, - Extension, }; use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; @@ -104,9 +103,9 @@ impl From for RouteError { } pub(crate) async fn get( - Extension(policy_factory): Extension>, - Extension(templates): Extension, - Extension(pool): Extension, + State(policy_factory): State>, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index c327e3b8..f8bf4182 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use anyhow::{anyhow, Context}; use axum::{ - extract::{Extension, Form}, + extract::{Form, State}, response::{IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -156,9 +156,9 @@ fn resolve_response_mode( #[allow(clippy::too_many_lines)] pub(crate) async fn get( - Extension(policy_factory): Extension>, - Extension(templates): Extension, - Extension(pool): Extension, + State(policy_factory): State>, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index dd7cf0f9..e1e5b05b 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use anyhow::Context; use axum::{ - extract::{Extension, Form, Path}, + extract::{Form, Path, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -50,9 +50,9 @@ impl IntoResponse for RouteError { } pub(crate) async fn get( - Extension(policy_factory): Extension>, - Extension(templates): Extension, - Extension(pool): Extension, + State(policy_factory): State>, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, Path(grant_id): Path, ) -> Result { @@ -112,8 +112,8 @@ pub(crate) async fn get( } pub(crate) async fn post( - Extension(policy_factory): Extension>, - Extension(pool): Extension, + State(policy_factory): State>, + State(pool): State, cookie_jar: PrivateCookieJar, Path(grant_id): Path, Form(form): Form>, diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 0d333de8..f2983ab9 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::Extension, response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use mas_iana::{ jose::JsonWebSignatureAlg, oauth::{ @@ -30,8 +30,8 @@ use oauth2_types::{ #[allow(clippy::too_many_lines)] pub(crate) async fn get( - Extension(key_store): Extension, - Extension(url_builder): Extension, + State(key_store): State, + State(url_builder): State, ) -> impl IntoResponse { // This is how clients can authenticate let client_auth_methods_supported = Some(vec![ diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index d60989c0..9d6b11db 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::Extension, response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError}; use mas_data_model::{TokenFormatError, TokenType}; @@ -154,8 +154,8 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( - Extension(pool): Extension, - Extension(encrypter): Extension, + State(pool): State, + State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { let mut conn = pool.acquire().await?; diff --git a/crates/handlers/src/oauth2/keys.rs b/crates/handlers/src/oauth2/keys.rs index 3c766786..68d70f97 100644 --- a/crates/handlers/src/oauth2/keys.rs +++ b/crates/handlers/src/oauth2/keys.rs @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::Extension, response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use mas_keystore::Keystore; -pub(crate) async fn get(Extension(key_store): Extension) -> impl IntoResponse { +pub(crate) async fn get(State(key_store): State) -> impl IntoResponse { let jwks = key_store.public_jwks(); Json(jwks) } diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 302dfcbc..9c48922e 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -14,7 +14,7 @@ use std::sync::Arc; -use axum::{response::IntoResponse, Extension, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; use mas_policy::{PolicyFactory, Violation}; use mas_storage::oauth2::client::insert_client; @@ -105,8 +105,8 @@ impl IntoResponse for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( - Extension(pool): Extension, - Extension(policy_factory): Extension>, + State(pool): State, + State(policy_factory): State>, Json(body): Json, ) -> Result { info!(?body, "Client registration"); diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 9ff32070..e03b5d8e 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use anyhow::Context; -use axum::{extract::Extension, response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use chrono::{DateTime, Duration, Utc}; use data_encoding::BASE64URL_NOPAD; use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; @@ -188,11 +188,11 @@ impl From for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + State(key_store): State, + State(url_builder): State, + State(pool): State, + State(encrypter): State, client_authorization: ClientAuthorization, - Extension(key_store): Extension, - Extension(url_builder): Extension, - Extension(pool): Extension, - Extension(encrypter): Extension, ) -> Result { let mut txn = pool.begin().await?; diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 6ec88daa..aa1130f1 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -14,7 +14,7 @@ use anyhow::Context; use axum::{ - extract::Extension, + extract::State, response::{IntoResponse, Response}, Json, }; @@ -48,9 +48,9 @@ struct SignedUserInfo { } pub async fn get( - Extension(url_builder): Extension, - Extension(pool): Extension, - Extension(key_store): Extension, + State(url_builder): State, + State(pool): State, + State(key_store): State, user_authorization: UserAuthorization, ) -> Result { // TODO: error handling diff --git a/crates/handlers/src/oauth2/webfinger.rs b/crates/handlers/src/oauth2/webfinger.rs index 833223c5..fb6b8dee 100644 --- a/crates/handlers/src/oauth2/webfinger.rs +++ b/crates/handlers/src/oauth2/webfinger.rs @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{extract::Query, response::IntoResponse, Extension, Json, TypedHeader}; +use axum::{ + extract::{Query, State}, + response::IntoResponse, + Json, TypedHeader, +}; use headers::ContentType; use mas_router::UrlBuilder; use oauth2_types::webfinger::WebFingerResponse; @@ -33,7 +37,7 @@ fn jrd() -> mime::Mime { pub(crate) async fn get( Query(params): Query, - Extension(url_builder): Extension, + State(url_builder): State, ) -> impl IntoResponse { // TODO: should we validate the subject? let subject = params.resource; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 760b6f4b..fe11ff3f 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::{Extension, Form, Query}, + extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -38,8 +38,8 @@ pub struct EmailForm { } pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { let mut conn = pool.begin().await?; @@ -66,8 +66,8 @@ pub(crate) async fn get( } pub(crate) async fn post( - Extension(pool): Extension, - Extension(mailer): Extension, + State(pool): State, + State(mailer): State, cookie_jar: PrivateCookieJar, Query(query): Query, Form(form): Form>, diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 89223e18..5ac9d862 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::{Extension, Form}, + extract::{Form, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -52,8 +52,8 @@ pub enum ManagementForm { } pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { let mut conn = pool.acquire().await?; @@ -118,9 +118,9 @@ async fn start_email_verification( } pub(crate) async fn post( - Extension(templates): Extension, - Extension(pool): Extension, - Extension(mailer): Extension, + State(templates): State, + State(pool): State, + State(mailer): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 8d6f1f59..cebcc772 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::{Extension, Form, Path, Query}, + extract::{Form, Path, Query, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -40,8 +40,8 @@ pub struct CodeForm { } pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, Query(query): Query, Path(id): Path, cookie_jar: PrivateCookieJar, @@ -78,7 +78,7 @@ pub(crate) async fn get( } pub(crate) async fn post( - Extension(pool): Extension, + State(pool): State, cookie_jar: PrivateCookieJar, Query(query): Query, Path(id): Path, diff --git a/crates/handlers/src/views/account/mod.rs b/crates/handlers/src/views/account/mod.rs index f8e09a6f..b35e6b04 100644 --- a/crates/handlers/src/views/account/mod.rs +++ b/crates/handlers/src/views/account/mod.rs @@ -16,7 +16,7 @@ pub mod emails; pub mod password; use axum::{ - extract::Extension, + extract::State, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -28,8 +28,8 @@ use mas_templates::{AccountContext, TemplateContext, Templates}; use sqlx::PgPool; pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { let mut conn = pool.acquire().await?; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 5afc66ba..55fcecd4 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -14,7 +14,7 @@ use argon2::Argon2; use axum::{ - extract::{Extension, Form}, + extract::{Form, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -41,8 +41,8 @@ pub struct ChangeForm { } pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { let mut conn = pool.acquire().await?; @@ -76,8 +76,8 @@ async fn render( } pub(crate) async fn post( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index d2b09fbd..66ef34bc 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::Extension, + extract::State, response::{Html, IntoResponse}, }; use axum_extra::extract::PrivateCookieJar; @@ -24,9 +24,9 @@ use mas_templates::{IndexContext, TemplateContext, Templates}; use sqlx::PgPool; pub async fn get( - Extension(templates): Extension, - Extension(url_builder): Extension, - Extension(pool): Extension, + State(templates): State, + State(url_builder): State, + State(pool): State, cookie_jar: PrivateCookieJar, ) -> Result { let mut conn = pool.acquire().await?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 6d1493c2..b43932b4 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::{Extension, Form, Query}, + extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -44,8 +44,8 @@ impl ToFormState for LoginForm { #[tracing::instrument(skip(templates, pool, cookie_jar))] pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { @@ -74,8 +74,8 @@ pub(crate) async fn get( } pub(crate) async fn post( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 9318f159..34d4a69c 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::{Extension, Form}, + extract::{Form, State}, response::IntoResponse, }; use axum_extra::extract::PrivateCookieJar; @@ -27,7 +27,7 @@ use mas_storage::user::end_session; use sqlx::PgPool; pub(crate) async fn post( - Extension(pool): Extension, + State(pool): State, cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 15939009..8cb2463f 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{ - extract::{Extension, Form, Query}, + extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -36,8 +36,8 @@ pub(crate) struct ReauthForm { } pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { @@ -75,7 +75,7 @@ pub(crate) async fn get( } pub(crate) async fn post( - Extension(pool): Extension, + State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index e09c6ebe..1417abc6 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -18,7 +18,7 @@ use std::{str::FromStr, sync::Arc}; use argon2::Argon2; use axum::{ - extract::{Extension, Form, Query}, + extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, }; use axum_extra::extract::PrivateCookieJar; @@ -57,8 +57,8 @@ impl ToFormState for RegisterForm { } pub(crate) async fn get( - Extension(templates): Extension, - Extension(pool): Extension, + State(templates): State, + State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, ) -> Result { @@ -87,10 +87,10 @@ pub(crate) async fn get( } pub(crate) async fn post( - Extension(mailer): Extension, - Extension(policy_factory): Extension>, - Extension(templates): Extension, - Extension(pool): Extension, + State(mailer): State, + State(policy_factory): State>, + State(templates): State, + State(pool): State, Query(query): Query, cookie_jar: PrivateCookieJar, Form(form): Form>, diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 357bd55c..d8083d31 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] -axum = { version = "0.5.15", optional = true } +axum = { version = "0.6.0-rc.1", optional = true } bytes = "1.2.1" futures-util = "0.3.24" headers = "0.3.8" diff --git a/crates/router/Cargo.toml b/crates/router/Cargo.toml index 798e864e..f6867573 100644 --- a/crates/router/Cargo.toml +++ b/crates/router/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] -axum = { version = "0.5.15", default-features = false } +axum = { version = "0.6.0-rc.1", default-features = false } serde = { version = "1.0.144", features = ["derive"] } serde_urlencoded = "0.7.1" serde_with = "2.0.0" diff --git a/crates/static-files/Cargo.toml b/crates/static-files/Cargo.toml index 53d4a3eb..df887afa 100644 --- a/crates/static-files/Cargo.toml +++ b/crates/static-files/Cargo.toml @@ -9,8 +9,8 @@ license = "Apache-2.0" dev = [] [dependencies] -axum = "0.5.15" -headers = "0.3.8" +axum = { version = "0.6.0-rc.1", features = ["headers"] } +headers = "0.3.7" http = "0.2.8" http-body = "0.4.5" mime_guess = "2.0.4"