1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Simplify error handling in user-facing routes

This commit is contained in:
Quentin Gliech
2022-05-10 16:51:12 +02:00
parent 2cba5e7ad2
commit ca7b26cf18
17 changed files with 192 additions and 383 deletions

View File

@@ -19,11 +19,12 @@
clippy::unused_async // Some axum handlers need that
)]
use std::{sync::Arc, time::Duration};
use std::{convert::Infallible, sync::Arc, time::Duration};
use axum::{
body::HttpBody,
extract::Extension,
response::{Html, IntoResponse},
routing::{get, on, post, MethodFilter},
Router,
};
@@ -33,8 +34,9 @@ use mas_email::Mailer;
use mas_http::CorsLayerExt;
use mas_jose::StaticKeystore;
use mas_router::{Route, UrlBuilder};
use mas_templates::Templates;
use mas_templates::{ErrorContext, Templates};
use sqlx::PgPool;
use tower::util::ThenLayer;
use tower_http::cors::{Any, CorsLayer};
mod health;
@@ -42,6 +44,7 @@ mod oauth2;
mod views;
#[must_use]
#[allow(clippy::too_many_lines, clippy::missing_panics_doc)]
pub fn router<B>(
pool: &PgPool,
templates: &Templates,
@@ -102,47 +105,71 @@ where
.max_age(Duration::from_secs(60 * 60)),
);
Router::new()
.route(mas_router::Index::route(), get(self::views::index::get))
.route(mas_router::Healthcheck::route(), get(self::health::get))
.route(
mas_router::Login::route(),
get(self::views::login::get).post(self::views::login::post),
)
.route(mas_router::Logout::route(), post(self::views::logout::post))
.route(
mas_router::Reauth::route(),
get(self::views::reauth::get).post(self::views::reauth::post),
)
.route(
mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post),
)
.route(
mas_router::VerifyEmail::route(),
get(self::views::verify::get),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post),
)
.route(
mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post),
)
.route(
mas_router::OAuth2AuthorizationEndpoint::route(),
get(self::oauth2::authorization::get),
)
.route(
mas_router::ContinueAuthorizationGrant::route(),
get(self::oauth2::authorization::complete::get),
)
.route(
mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
)
let human_router = {
let templates = templates.clone();
Router::new()
.route(mas_router::Index::route(), get(self::views::index::get))
.route(mas_router::Healthcheck::route(), get(self::health::get))
.route(
mas_router::Login::route(),
get(self::views::login::get).post(self::views::login::post),
)
.route(mas_router::Logout::route(), post(self::views::logout::post))
.route(
mas_router::Reauth::route(),
get(self::views::reauth::get).post(self::views::reauth::post),
)
.route(
mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post),
)
.route(
mas_router::VerifyEmail::route(),
get(self::views::verify::get),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post),
)
.route(
mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post),
)
.route(
mas_router::OAuth2AuthorizationEndpoint::route(),
get(self::oauth2::authorization::get),
)
.route(
mas_router::ContinueAuthorizationGrant::route(),
get(self::oauth2::authorization::complete::get),
)
.route(
mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
)
.layer(ThenLayer::new(
move |result: Result<axum::response::Response, Infallible>| async move {
let response = result.unwrap();
if response.status().is_server_error() {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx).await {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
return Ok((parts, Html(res)).into_response());
}
}
}
Ok(response)
},
))
};
human_router
.merge(api_router)
.layer(Extension(pool.clone()))
.layer(Extension(templates.clone()))