1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Upgrade axum to 0.6.0-rc.1

This commit is contained in:
Quentin Gliech
2022-09-05 12:15:51 +02:00
parent b15b2d0c21
commit fa47f6e150
37 changed files with 501 additions and 378 deletions

View File

@@ -23,7 +23,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration};
use axum::{
body::HttpBody,
extract::Extension,
extract::FromRef,
response::{Html, IntoResponse},
routing::{get, on, post, MethodFilter},
Router,
@@ -37,9 +37,10 @@ use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_templates::{ErrorContext, Templates};
use sqlx::PgPool;
use tower::util::ThenLayer;
use tower::util::AndThenLayer;
use tower_http::cors::{Any, CorsLayer};
mod app_state;
mod compat;
mod health;
mod oauth2;
@@ -47,30 +48,24 @@ mod views;
pub use compat::MatrixHomeserver;
pub use self::app_state::AppState;
#[must_use]
#[allow(
clippy::too_many_lines,
clippy::missing_panics_doc,
clippy::too_many_arguments,
clippy::trait_duplication_in_bounds
)]
pub fn router<B>(
pool: &PgPool,
templates: &Templates,
key_store: &Keystore,
encrypter: &Encrypter,
mailer: &Mailer,
url_builder: &UrlBuilder,
homeserver: &MatrixHomeserver,
policy_factory: &Arc<PolicyFactory>,
) -> Router<B>
#[allow(clippy::trait_duplication_in_bounds)]
pub fn api_router<S, B>(state: Arc<S>) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>,
Encrypter: FromRef<S>,
{
// All those routes are API-like, with a common CORS layer
let api_router = Router::new()
Router::with_state_arc(state)
.route(
mas_router::ChangePasswordDiscovery::route(),
get(|| async { mas_router::AccountPassword.go() }),
@@ -118,9 +113,21 @@ where
CONTENT_TYPE,
])
.max_age(Duration::from_secs(60 * 60)),
);
let compat_router = Router::new()
)
}
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn compat_router<S, B>(state: Arc<S>) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
UrlBuilder: FromRef<S>,
PgPool: FromRef<S>,
MatrixHomeserver: FromRef<S>,
{
Router::with_state_arc(state)
.route(
mas_router::CompatLogin::route(),
get(self::compat::login::get).post(self::compat::login::post),
@@ -146,106 +153,131 @@ where
HeaderName::from_static("x-requested-with"),
])
.max_age(Duration::from_secs(60 * 60)),
);
)
}
let human_router = {
let templates = templates.clone();
Router::new()
.route(mas_router::Index::route(), get(self::views::index::get))
.route(mas_router::Healthcheck::route(), get(self::health::get))
.route(
mas_router::Login::route(),
get(self::views::login::get).post(self::views::login::post),
)
.route(mas_router::Logout::route(), post(self::views::logout::post))
.route(
mas_router::Reauth::route(),
get(self::views::reauth::get).post(self::views::reauth::post),
)
.route(
mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post),
)
.route(
mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post),
)
.route(
mas_router::AccountVerifyEmail::route(),
get(self::views::account::emails::verify::get)
.post(self::views::account::emails::verify::post),
)
.route(
mas_router::AccountAddEmail::route(),
get(self::views::account::emails::add::get)
.post(self::views::account::emails::add::post),
)
.route(
mas_router::OAuth2AuthorizationEndpoint::route(),
get(self::oauth2::authorization::get),
)
.route(
mas_router::ContinueAuthorizationGrant::route(),
get(self::oauth2::authorization::complete::get),
)
.route(
mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
)
.route(
mas_router::CompatLoginSsoRedirect::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoRedirectIdp::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoComplete::route(),
get(self::compat::login_sso_complete::get)
.post(self::compat::login_sso_complete::post),
)
.layer(ThenLayer::new(
move |result: Result<axum::response::Response, Infallible>| async move {
let response = result.unwrap();
if response.status().is_server_error() {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx).await {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
return Ok((parts, Html(res)).into_response());
}
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn human_router<S, B>(state: Arc<S>) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>,
Encrypter: FromRef<S>,
Templates: FromRef<S>,
Mailer: FromRef<S>,
{
let templates = Templates::from_ref(&state);
Router::with_state_arc(state)
.route(mas_router::Index::route(), get(self::views::index::get))
.route(mas_router::Healthcheck::route(), get(self::health::get))
.route(
mas_router::Login::route(),
get(self::views::login::get).post(self::views::login::post),
)
.route(mas_router::Logout::route(), post(self::views::logout::post))
.route(
mas_router::Reauth::route(),
get(self::views::reauth::get).post(self::views::reauth::post),
)
.route(
mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post),
)
.route(
mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post),
)
.route(
mas_router::AccountVerifyEmail::route(),
get(self::views::account::emails::verify::get)
.post(self::views::account::emails::verify::post),
)
.route(
mas_router::AccountAddEmail::route(),
get(self::views::account::emails::add::get)
.post(self::views::account::emails::add::post),
)
.route(
mas_router::OAuth2AuthorizationEndpoint::route(),
get(self::oauth2::authorization::get),
)
.route(
mas_router::ContinueAuthorizationGrant::route(),
get(self::oauth2::authorization::complete::get),
)
.route(
mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
)
.route(
mas_router::CompatLoginSsoRedirect::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoRedirectIdp::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoComplete::route(),
get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
)
.layer(AndThenLayer::new(
move |response: axum::response::Response| async move {
if response.status().is_server_error() {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx).await {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
return Ok((parts, Html(res)).into_response());
}
}
}
Ok(response)
},
))
};
Ok::<_, Infallible>(response)
},
))
}
human_router
.merge(api_router)
.merge(compat_router)
.layer(Extension(pool.clone()))
.layer(Extension(templates.clone()))
.layer(Extension(key_store.clone()))
.layer(Extension(encrypter.clone()))
.layer(Extension(url_builder.clone()))
.layer(Extension(mailer.clone()))
.layer(Extension(homeserver.clone()))
.layer(Extension(policy_factory.clone()))
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn router<S, B>(state: S) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>,
Encrypter: FromRef<S>,
Templates: FromRef<S>,
Mailer: FromRef<S>,
MatrixHomeserver: FromRef<S>,
{
let state = Arc::new(state);
let api_router = api_router(state.clone());
let compat_router = compat_router(state.clone());
let human_router = human_router(state);
human_router.merge(api_router).merge(compat_router)
}
#[cfg(test)]
async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
use mas_email::MailTransport;
let templates = Templates::load(None, true).await?;
@@ -265,14 +297,14 @@ async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
let policy_factory = Arc::new(policy_factory);
Ok(router(
Ok(Arc::new(AppState {
pool,
&templates,
&key_store,
&encrypter,
&mailer,
&url_builder,
&homeserver,
&policy_factory,
))
templates,
key_store,
encrypter,
url_builder,
mailer,
homeserver,
policy_factory,
}))
}