diff --git a/Cargo.lock b/Cargo.lock index 131a3b6f..52dc2ca3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -678,9 +678,9 @@ dependencies = [ [[package]] name = "axum-extra" -version = "0.4.0-rc.2" +version = "0.4.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4236821ba5932aa38a08a45e4068ff2318c2882672c30ff26a433bc84b14e6" +checksum = "16a35dfc7e1c432f55bc4f5665926651cc34d169ed7db7b6c01a26a20abc47af" dependencies = [ "axum 0.6.0-rc.5", "bytes 1.2.1", @@ -698,9 +698,9 @@ dependencies = [ [[package]] name = "axum-macros" -version = "0.3.0-rc.2" +version = "0.3.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414ca6cd8cbe767411488373833ebd9d07a6b470d841250d8d2b8bd69769ca0e" +checksum = "f2185fff4d6f14de84dcc01b0ff8eee2ac5331a962cf85e60e080ce7db724cc9" dependencies = [ "heck", "proc-macro2", diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index cd3ffccd..c6ec87e6 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -8,7 +8,7 @@ license = "Apache-2.0" [dependencies] async-trait = "0.1.58" axum = { version = "0.6.0-rc.5", features = ["headers"] } -axum-extra = { version = "0.4.0-rc.2", features = ["cookie-private"] } +axum-extra = { version = "0.4.0-rc.3", features = ["cookie-private"] } bincode = "1.3.3" chrono = "0.4.23" data-encoding = "2.3.2" diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index cbfc14ac..bc042de5 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -29,6 +29,7 @@ use mas_storage::MIGRATOR; use mas_tasks::TaskQueue; use mas_templates::Templates; use tokio::signal::unix::SignalKind; +use tower::Layer; use tracing::{error, info, log::warn}; #[derive(Parser, Debug, Default)] @@ -215,9 +216,8 @@ impl Options { }; // and build the router - let router = crate::server::build_router(state.clone(), &config.resources) - .layer(ServerLayer::new(config.name.clone())) - .into_service(); + let router = crate::server::build_router(state.clone(), &config.resources); + let router = ServerLayer::new(config.name.clone()).layer(router); // Display some informations about where we'll be serving connections let is_tls = config.tls.is_some(); diff --git a/crates/cli/src/server.rs b/crates/cli/src/server.rs index 98e12d35..4f135e78 100644 --- a/crates/cli/src/server.rs +++ b/crates/cli/src/server.rs @@ -19,7 +19,10 @@ use std::{ }; use anyhow::Context; -use axum::{body::HttpBody, error_handling::HandleErrorLayer, extract::FromRef, Extension, Router}; +use axum::{ + body::HttpBody, error_handling::HandleErrorLayer, extract::FromRef, Extension, Router, + RouterService, +}; use hyper::StatusCode; use listenfd::ListenFd; use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp}; @@ -33,14 +36,14 @@ use tower::Layer; use tower_http::services::ServeDir; #[allow(clippy::trait_duplication_in_bounds)] -pub fn build_router(state: AppState, resources: &[HttpResource]) -> Router +pub fn build_router(state: AppState, resources: &[HttpResource]) -> RouterService where B: HttpBody + Send + 'static, ::Data: Into + Send, ::Error: std::error::Error + Send + Sync, { let templates = Templates::from_ref(&state); - let mut router = Router::with_state(state); + let mut router = Router::new(); for resource in resources { router = match resource { @@ -106,7 +109,7 @@ where } } - router + router.with_state(state) } pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result { diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 094e9735..68166da3 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -22,8 +22,8 @@ hyper = { version = "0.14.23", features = ["full"] } tower = "0.4.13" tower-http = { version = "0.3.4", features = ["cors"] } axum = { version = "0.6.0-rc.5", features = ["ws"] } -axum-macros = "0.3.0-rc.2" -axum-extra = { version = "0.4.0-rc.2", features = ["cookie-private"] } +axum-macros = "0.3.0-rc.3" +axum-extra = { version = "0.4.0-rc.3", features = ["cookie-private"] } async-graphql = { version = "4.0.16", features = ["tracing", "apollo_tracing"] } diff --git a/crates/handlers/src/health.rs b/crates/handlers/src/health.rs index aec5822e..6322dffd 100644 --- a/crates/handlers/src/health.rs +++ b/crates/handlers/src/health.rs @@ -39,7 +39,7 @@ mod tests { #[sqlx::test(migrator = "mas_storage::MIGRATOR")] async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> { let state = crate::test_state(pool).await?; - let app = crate::router(state).into_service(); + let app = crate::healthcheck_router().with_state(state); let request = Request::builder().uri("/health").body(Body::empty())?; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index aebac9c0..8c6d2e5d 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -58,15 +58,6 @@ pub use compat::MatrixHomeserver; pub use self::{app_state::AppState, graphql::schema as graphql_schema}; -#[must_use] -pub fn empty_router(state: S) -> Router -where - B: HttpBody + Send + 'static, - S: Clone + Send + Sync + 'static, -{ - Router::with_state(state) -} - #[must_use] pub fn healthcheck_router() -> Router where @@ -74,7 +65,7 @@ where S: Clone + Send + Sync + 'static, PgPool: FromRef, { - Router::inherit_state().route(mas_router::Healthcheck::route(), get(self::health::get)) + Router::new().route(mas_router::Healthcheck::route(), get(self::health::get)) } #[must_use] @@ -87,7 +78,7 @@ where mas_graphql::Schema: FromRef, Encrypter: FromRef, { - let mut router = Router::inherit_state() + let mut router = Router::new() .route( "/graphql", get(self::graphql::get).post(self::graphql::post), @@ -109,7 +100,7 @@ where Keystore: FromRef, UrlBuilder: FromRef, { - Router::inherit_state() + Router::new() .route( mas_router::OidcConfiguration::route(), get(self::oauth2::discovery::get), @@ -148,7 +139,7 @@ where Encrypter: FromRef, { // All those routes are API-like, with a common CORS layer - Router::inherit_state() + Router::new() .route( mas_router::OAuth2Keys::route(), get(self::oauth2::keys::get), @@ -199,7 +190,7 @@ where PgPool: FromRef, MatrixHomeserver: FromRef, { - Router::inherit_state() + Router::new() .route( mas_router::CompatLogin::route(), get(self::compat::login::get).post(self::compat::login::post), @@ -243,7 +234,7 @@ where Templates: FromRef, Mailer: FromRef, { - Router::inherit_state() + Router::new() .route( mas_router::ChangePasswordDiscovery::route(), get(|| async { mas_router::AccountPassword.go() }), @@ -324,9 +315,10 @@ where )) } +/* #[must_use] #[allow(clippy::trait_duplication_in_bounds)] -pub fn router(state: S) -> Router +pub fn router(state: S) -> RouterService where B: HttpBody + Send + 'static, ::Data: Into + Send, @@ -349,14 +341,16 @@ where let compat_router = compat_router(); let human_router = human_router(Templates::from_ref(&state)); - Router::with_state(state) + Router::new() .merge(healthcheck_router) .merge(discovery_router) .merge(human_router) .merge(api_router) .merge(graphql_router) .merge(compat_router) + .with_state(state) } +*/ #[cfg(test)] async fn test_state(pool: PgPool) -> Result {