// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #![deny(clippy::future_not_send)] #![allow( // Some axum handlers need that clippy::unused_async, // Because of how axum handlers work, we sometime have take many arguments clippy::too_many_arguments, // Code generated by tracing::instrument trigger this when returning an `impl Trait` // See https://github.com/tokio-rs/tracing/issues/2613 clippy::let_with_type_underscore, )] use std::{convert::Infallible, time::Duration}; use axum::{ body::{Bytes, HttpBody}, extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State}, http::Method, response::{Html, IntoResponse}, routing::{get, on, post, MethodFilter}, Router, }; use headers::HeaderName; use hyper::{ header::{ ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, }, StatusCode, Version, }; use mas_axum_utils::{cookies::CookieJar, FancyError}; use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; use mas_matrix::BoxHomeserverConnection; use mas_policy::Policy; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates}; use passwords::PasswordManager; use sqlx::PgPool; use tower::util::AndThenLayer; use tower_http::cors::{Any, CorsLayer}; mod compat; mod graphql; mod health; mod oauth2; pub mod passwords; pub mod upstream_oauth2; mod views; mod activity_tracker; mod preferred_language; mod site_config; #[cfg(test)] mod test_utils; /// Implement `From` for `RouteError`, for "internal server error" kind of /// errors. #[macro_export] macro_rules! impl_from_error_for_route { ($route_error:ty : $error:ty) => { impl From<$error> for $route_error { fn from(e: $error) -> Self { Self::Internal(Box::new(e)) } } }; ($error:ty) => { impl_from_error_for_route!(self::RouteError: $error); }; } pub use mas_axum_utils::{ cookies::CookieManager, http_client_factory::HttpClientFactory, ErrorWrapper, }; pub use self::{ activity_tracker::{ActivityTracker, Bound as BoundActivityTracker}, graphql::schema as graphql_schema, preferred_language::PreferredLanguage, site_config::SiteConfig, upstream_oauth2::cache::MetadataCache, }; pub fn healthcheck_router() -> Router where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, PgPool: FromRef, { Router::new().route(mas_router::Healthcheck::route(), get(self::health::get)) } pub fn graphql_router(playground: bool) -> Router where B: HttpBody + Send + 'static, ::Data: Into, ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, mas_graphql::Schema: FromRef, BoundActivityTracker: FromRequestParts, BoxRepository: FromRequestParts, BoxClock: FromRequestParts, Encrypter: FromRef, CookieJar: FromRequestParts, { let mut router = Router::new() .route( mas_router::GraphQL::route(), get(self::graphql::get).post(self::graphql::post), ) .layer( CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_otel_headers([ AUTHORIZATION, ACCEPT, ACCEPT_LANGUAGE, CONTENT_LANGUAGE, CONTENT_TYPE, ]), ); if playground { router = router.route( mas_router::GraphQLPlayground::route(), get(self::graphql::playground), ); } router } pub fn discovery_router() -> Router where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, Keystore: FromRef, UrlBuilder: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, { Router::new() .route( mas_router::OidcConfiguration::route(), get(self::oauth2::discovery::get), ) .route( mas_router::Webfinger::route(), get(self::oauth2::webfinger::get), ) .layer( CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_otel_headers([ AUTHORIZATION, ACCEPT, ACCEPT_LANGUAGE, CONTENT_LANGUAGE, CONTENT_TYPE, ]) .max_age(Duration::from_secs(60 * 60)), ) } pub fn api_router() -> Router where B: HttpBody + Send + 'static, ::Data: Send, ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, Keystore: FromRef, UrlBuilder: FromRef, BoxRepository: FromRequestParts, ActivityTracker: FromRequestParts, BoundActivityTracker: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, SiteConfig: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, Policy: FromRequestParts, { // All those routes are API-like, with a common CORS layer Router::new() .route( mas_router::OAuth2Keys::route(), get(self::oauth2::keys::get), ) .route( mas_router::OidcUserinfo::route(), on( MethodFilter::POST | MethodFilter::GET, self::oauth2::userinfo::get, ), ) .route( mas_router::OAuth2Introspection::route(), post(self::oauth2::introspection::post), ) .route( mas_router::OAuth2Revocation::route(), post(self::oauth2::revoke::post), ) .route( mas_router::OAuth2TokenEndpoint::route(), post(self::oauth2::token::post), ) .route( mas_router::OAuth2RegistrationEndpoint::route(), post(self::oauth2::registration::post), ) .route( mas_router::OAuth2DeviceAuthorizationEndpoint::route(), post(self::oauth2::device::authorize::post), ) .layer( CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_otel_headers([ AUTHORIZATION, ACCEPT, ACCEPT_LANGUAGE, CONTENT_LANGUAGE, CONTENT_TYPE, ]) .max_age(Duration::from_secs(60 * 60)), ) } #[allow(clippy::trait_duplication_in_bounds)] pub fn compat_router() -> Router where B: HttpBody + Send + 'static, ::Data: Send, ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, SiteConfig: FromRef, BoxHomeserverConnection: FromRef, PasswordManager: FromRef, BoundActivityTracker: FromRequestParts, BoxRepository: FromRequestParts, BoxClock: FromRequestParts, BoxRng: FromRequestParts, { Router::new() .route( mas_router::CompatLogin::route(), get(self::compat::login::get).post(self::compat::login::post), ) .route( mas_router::CompatLogout::route(), post(self::compat::logout::post), ) .route( mas_router::CompatRefresh::route(), post(self::compat::refresh::post), ) .route( mas_router::CompatLoginSsoRedirect::route(), get(self::compat::login_sso_redirect::get), ) .route( mas_router::CompatLoginSsoRedirectIdp::route(), get(self::compat::login_sso_redirect::get), ) .route( mas_router::CompatLoginSsoRedirectSlash::route(), get(self::compat::login_sso_redirect::get), ) .layer( CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_otel_headers([ AUTHORIZATION, ACCEPT, ACCEPT_LANGUAGE, CONTENT_LANGUAGE, CONTENT_TYPE, HeaderName::from_static("x-requested-with"), ]) .max_age(Duration::from_secs(60 * 60)), ) } #[allow(clippy::too_many_lines)] pub fn human_router(templates: Templates) -> Router where B: HttpBody + Send + 'static, ::Data: Send, ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, PreferredLanguage: FromRequestParts, BoxRepository: FromRequestParts, CookieJar: FromRequestParts, BoundActivityTracker: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Keystore: FromRef, HttpClientFactory: FromRef, PasswordManager: FromRef, MetadataCache: FromRef, SiteConfig: FromRef, BoxHomeserverConnection: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, Policy: FromRequestParts, { Router::new() // XXX: hard-coded redirect from /account to /account/ .route( "/account", get( |State(url_builder): State, RawQuery(query): RawQuery| async move { let prefix = url_builder.prefix().unwrap_or_default(); let route = mas_router::Account::route(); let destination = if let Some(query) = query { format!("{prefix}{route}?{query}") } else { format!("{prefix}{route}") }; axum::response::Redirect::to(&destination) }, ), ) .route(mas_router::Account::route(), get(self::views::app::get)) .route( mas_router::AccountWildcard::route(), get(self::views::app::get), ) .route( mas_router::ChangePasswordDiscovery::route(), get(|State(url_builder): State| async move { url_builder.redirect(&mas_router::AccountPassword) }), ) .route(mas_router::Index::route(), get(self::views::index::get)) .route( mas_router::Login::route(), get(self::views::login::get).post(self::views::login::post), ) .route(mas_router::Logout::route(), post(self::views::logout::post)) .route( mas_router::Reauth::route(), get(self::views::reauth::get).post(self::views::reauth::post), ) .route( mas_router::Register::route(), get(self::views::register::get).post(self::views::register::post), ) .route( mas_router::AccountPassword::route(), get(self::views::account::password::get).post(self::views::account::password::post), ) .route( mas_router::AccountVerifyEmail::route(), get(self::views::account::emails::verify::get) .post(self::views::account::emails::verify::post), ) .route( mas_router::AccountAddEmail::route(), get(self::views::account::emails::add::get) .post(self::views::account::emails::add::post), ) .route( mas_router::OAuth2AuthorizationEndpoint::route(), get(self::oauth2::authorization::get), ) .route( mas_router::ContinueAuthorizationGrant::route(), get(self::oauth2::authorization::complete::get), ) .route( mas_router::Consent::route(), get(self::oauth2::consent::get).post(self::oauth2::consent::post), ) .route( mas_router::CompatLoginSsoComplete::route(), get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post), ) .route( mas_router::UpstreamOAuth2Authorize::route(), get(self::upstream_oauth2::authorize::get), ) .route( mas_router::UpstreamOAuth2Callback::route(), get(self::upstream_oauth2::callback::get), ) .route( mas_router::UpstreamOAuth2Link::route(), get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post), ) .route( mas_router::DeviceCodeLink::route(), get(self::oauth2::device::link::get).post(self::oauth2::device::link::post), ) .route( mas_router::DeviceCodeConsent::route(), get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post), ) .layer(AndThenLayer::new( move |response: axum::response::Response| async move { if response.status().is_server_error() { // Error responses should have an ErrorContext attached to them let ext = response.extensions().get::(); if let Some(ctx) = ext { if let Ok(res) = templates.render_error(ctx) { let (mut parts, _original_body) = response.into_parts(); parts.headers.remove(CONTENT_TYPE); parts.headers.remove(CONTENT_LENGTH); return Ok((parts, Html(res)).into_response()); } } } Ok::<_, Infallible>(response) }, )) } /// The fallback handler for all routes that don't match anything else. /// /// # Errors /// /// Returns an error if the template rendering fails. pub async fn fallback( State(templates): State, OriginalUri(uri): OriginalUri, method: Method, version: Version, PreferredLanguage(locale): PreferredLanguage, ) -> Result { let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale); // XXX: this should look at the Accept header and return JSON if requested let res = templates.render_not_found(&ctx)?; Ok((StatusCode::NOT_FOUND, Html(res))) }