You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Upgrade axum to 0.6.0-rc.1
This commit is contained in:
@@ -23,7 +23,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
body::HttpBody,
|
||||
extract::Extension,
|
||||
extract::FromRef,
|
||||
response::{Html, IntoResponse},
|
||||
routing::{get, on, post, MethodFilter},
|
||||
Router,
|
||||
@@ -37,9 +37,10 @@ use mas_policy::PolicyFactory;
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_templates::{ErrorContext, Templates};
|
||||
use sqlx::PgPool;
|
||||
use tower::util::ThenLayer;
|
||||
use tower::util::AndThenLayer;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
|
||||
mod app_state;
|
||||
mod compat;
|
||||
mod health;
|
||||
mod oauth2;
|
||||
@@ -47,30 +48,24 @@ mod views;
|
||||
|
||||
pub use compat::MatrixHomeserver;
|
||||
|
||||
pub use self::app_state::AppState;
|
||||
|
||||
#[must_use]
|
||||
#[allow(
|
||||
clippy::too_many_lines,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::too_many_arguments,
|
||||
clippy::trait_duplication_in_bounds
|
||||
)]
|
||||
pub fn router<B>(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
key_store: &Keystore,
|
||||
encrypter: &Encrypter,
|
||||
mailer: &Mailer,
|
||||
url_builder: &UrlBuilder,
|
||||
homeserver: &MatrixHomeserver,
|
||||
policy_factory: &Arc<PolicyFactory>,
|
||||
) -> Router<B>
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn api_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Send + Sync + 'static,
|
||||
Keystore: FromRef<S>,
|
||||
UrlBuilder: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
{
|
||||
// All those routes are API-like, with a common CORS layer
|
||||
let api_router = Router::new()
|
||||
Router::with_state_arc(state)
|
||||
.route(
|
||||
mas_router::ChangePasswordDiscovery::route(),
|
||||
get(|| async { mas_router::AccountPassword.go() }),
|
||||
@@ -118,9 +113,21 @@ where
|
||||
CONTENT_TYPE,
|
||||
])
|
||||
.max_age(Duration::from_secs(60 * 60)),
|
||||
);
|
||||
|
||||
let compat_router = Router::new()
|
||||
)
|
||||
}
|
||||
#[must_use]
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn compat_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Send + Sync + 'static,
|
||||
UrlBuilder: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
MatrixHomeserver: FromRef<S>,
|
||||
{
|
||||
Router::with_state_arc(state)
|
||||
.route(
|
||||
mas_router::CompatLogin::route(),
|
||||
get(self::compat::login::get).post(self::compat::login::post),
|
||||
@@ -146,106 +153,131 @@ where
|
||||
HeaderName::from_static("x-requested-with"),
|
||||
])
|
||||
.max_age(Duration::from_secs(60 * 60)),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let human_router = {
|
||||
let templates = templates.clone();
|
||||
Router::new()
|
||||
.route(mas_router::Index::route(), get(self::views::index::get))
|
||||
.route(mas_router::Healthcheck::route(), get(self::health::get))
|
||||
.route(
|
||||
mas_router::Login::route(),
|
||||
get(self::views::login::get).post(self::views::login::post),
|
||||
)
|
||||
.route(mas_router::Logout::route(), post(self::views::logout::post))
|
||||
.route(
|
||||
mas_router::Reauth::route(),
|
||||
get(self::views::reauth::get).post(self::views::reauth::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::Register::route(),
|
||||
get(self::views::register::get).post(self::views::register::post),
|
||||
)
|
||||
.route(mas_router::Account::route(), get(self::views::account::get))
|
||||
.route(
|
||||
mas_router::AccountPassword::route(),
|
||||
get(self::views::account::password::get).post(self::views::account::password::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::AccountEmails::route(),
|
||||
get(self::views::account::emails::get).post(self::views::account::emails::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::AccountVerifyEmail::route(),
|
||||
get(self::views::account::emails::verify::get)
|
||||
.post(self::views::account::emails::verify::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::AccountAddEmail::route(),
|
||||
get(self::views::account::emails::add::get)
|
||||
.post(self::views::account::emails::add::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::OAuth2AuthorizationEndpoint::route(),
|
||||
get(self::oauth2::authorization::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::ContinueAuthorizationGrant::route(),
|
||||
get(self::oauth2::authorization::complete::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::Consent::route(),
|
||||
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::CompatLoginSsoRedirect::route(),
|
||||
get(self::compat::login_sso_redirect::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::CompatLoginSsoRedirectIdp::route(),
|
||||
get(self::compat::login_sso_redirect::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::CompatLoginSsoComplete::route(),
|
||||
get(self::compat::login_sso_complete::get)
|
||||
.post(self::compat::login_sso_complete::post),
|
||||
)
|
||||
.layer(ThenLayer::new(
|
||||
move |result: Result<axum::response::Response, Infallible>| async move {
|
||||
let response = result.unwrap();
|
||||
|
||||
if response.status().is_server_error() {
|
||||
// Error responses should have an ErrorContext attached to them
|
||||
let ext = response.extensions().get::<ErrorContext>();
|
||||
if let Some(ctx) = ext {
|
||||
if let Ok(res) = templates.render_error(ctx).await {
|
||||
let (mut parts, _original_body) = response.into_parts();
|
||||
parts.headers.remove(CONTENT_TYPE);
|
||||
return Ok((parts, Html(res)).into_response());
|
||||
}
|
||||
#[must_use]
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn human_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Send + Sync + 'static,
|
||||
UrlBuilder: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
Templates: FromRef<S>,
|
||||
Mailer: FromRef<S>,
|
||||
{
|
||||
let templates = Templates::from_ref(&state);
|
||||
Router::with_state_arc(state)
|
||||
.route(mas_router::Index::route(), get(self::views::index::get))
|
||||
.route(mas_router::Healthcheck::route(), get(self::health::get))
|
||||
.route(
|
||||
mas_router::Login::route(),
|
||||
get(self::views::login::get).post(self::views::login::post),
|
||||
)
|
||||
.route(mas_router::Logout::route(), post(self::views::logout::post))
|
||||
.route(
|
||||
mas_router::Reauth::route(),
|
||||
get(self::views::reauth::get).post(self::views::reauth::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::Register::route(),
|
||||
get(self::views::register::get).post(self::views::register::post),
|
||||
)
|
||||
.route(mas_router::Account::route(), get(self::views::account::get))
|
||||
.route(
|
||||
mas_router::AccountPassword::route(),
|
||||
get(self::views::account::password::get).post(self::views::account::password::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::AccountEmails::route(),
|
||||
get(self::views::account::emails::get).post(self::views::account::emails::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::AccountVerifyEmail::route(),
|
||||
get(self::views::account::emails::verify::get)
|
||||
.post(self::views::account::emails::verify::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::AccountAddEmail::route(),
|
||||
get(self::views::account::emails::add::get)
|
||||
.post(self::views::account::emails::add::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::OAuth2AuthorizationEndpoint::route(),
|
||||
get(self::oauth2::authorization::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::ContinueAuthorizationGrant::route(),
|
||||
get(self::oauth2::authorization::complete::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::Consent::route(),
|
||||
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
|
||||
)
|
||||
.route(
|
||||
mas_router::CompatLoginSsoRedirect::route(),
|
||||
get(self::compat::login_sso_redirect::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::CompatLoginSsoRedirectIdp::route(),
|
||||
get(self::compat::login_sso_redirect::get),
|
||||
)
|
||||
.route(
|
||||
mas_router::CompatLoginSsoComplete::route(),
|
||||
get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
|
||||
)
|
||||
.layer(AndThenLayer::new(
|
||||
move |response: axum::response::Response| async move {
|
||||
if response.status().is_server_error() {
|
||||
// Error responses should have an ErrorContext attached to them
|
||||
let ext = response.extensions().get::<ErrorContext>();
|
||||
if let Some(ctx) = ext {
|
||||
if let Ok(res) = templates.render_error(ctx).await {
|
||||
let (mut parts, _original_body) = response.into_parts();
|
||||
parts.headers.remove(CONTENT_TYPE);
|
||||
return Ok((parts, Html(res)).into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
},
|
||||
))
|
||||
};
|
||||
Ok::<_, Infallible>(response)
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
human_router
|
||||
.merge(api_router)
|
||||
.merge(compat_router)
|
||||
.layer(Extension(pool.clone()))
|
||||
.layer(Extension(templates.clone()))
|
||||
.layer(Extension(key_store.clone()))
|
||||
.layer(Extension(encrypter.clone()))
|
||||
.layer(Extension(url_builder.clone()))
|
||||
.layer(Extension(mailer.clone()))
|
||||
.layer(Extension(homeserver.clone()))
|
||||
.layer(Extension(policy_factory.clone()))
|
||||
#[must_use]
|
||||
#[allow(clippy::trait_duplication_in_bounds)]
|
||||
pub fn router<S, B>(state: S) -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
<B as HttpBody>::Data: Send,
|
||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||
S: Send + Sync + 'static,
|
||||
Keystore: FromRef<S>,
|
||||
UrlBuilder: FromRef<S>,
|
||||
Arc<PolicyFactory>: FromRef<S>,
|
||||
PgPool: FromRef<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
Templates: FromRef<S>,
|
||||
Mailer: FromRef<S>,
|
||||
MatrixHomeserver: FromRef<S>,
|
||||
{
|
||||
let state = Arc::new(state);
|
||||
|
||||
let api_router = api_router(state.clone());
|
||||
let compat_router = compat_router(state.clone());
|
||||
let human_router = human_router(state);
|
||||
|
||||
human_router.merge(api_router).merge(compat_router)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
|
||||
async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
|
||||
use mas_email::MailTransport;
|
||||
|
||||
let templates = Templates::load(None, true).await?;
|
||||
@@ -265,14 +297,14 @@ async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
|
||||
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
|
||||
let policy_factory = Arc::new(policy_factory);
|
||||
|
||||
Ok(router(
|
||||
Ok(Arc::new(AppState {
|
||||
pool,
|
||||
&templates,
|
||||
&key_store,
|
||||
&encrypter,
|
||||
&mailer,
|
||||
&url_builder,
|
||||
&homeserver,
|
||||
&policy_factory,
|
||||
))
|
||||
templates,
|
||||
key_store,
|
||||
encrypter,
|
||||
url_builder,
|
||||
mailer,
|
||||
homeserver,
|
||||
policy_factory,
|
||||
}))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user