You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
460 lines
15 KiB
Rust
460 lines
15 KiB
Rust
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#![deny(clippy::future_not_send)]
|
|
#![allow(
|
|
// Some axum handlers need that
|
|
clippy::unused_async,
|
|
// Because of how axum handlers work, we sometime have take many arguments
|
|
clippy::too_many_arguments,
|
|
// Code generated by tracing::instrument trigger this when returning an `impl Trait`
|
|
// See https://github.com/tokio-rs/tracing/issues/2613
|
|
clippy::let_with_type_underscore,
|
|
)]
|
|
|
|
use std::{convert::Infallible, time::Duration};
|
|
|
|
use axum::{
|
|
body::{Bytes, HttpBody},
|
|
extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
|
|
http::Method,
|
|
response::{Html, IntoResponse},
|
|
routing::{get, on, post, MethodFilter},
|
|
Router,
|
|
};
|
|
use headers::HeaderName;
|
|
use hyper::{
|
|
header::{
|
|
ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
|
|
},
|
|
StatusCode, Version,
|
|
};
|
|
use mas_axum_utils::{cookies::CookieJar, FancyError};
|
|
use mas_http::CorsLayerExt;
|
|
use mas_keystore::{Encrypter, Keystore};
|
|
use mas_matrix::BoxHomeserverConnection;
|
|
use mas_policy::Policy;
|
|
use mas_router::{Route, UrlBuilder};
|
|
use mas_storage::{BoxClock, BoxRepository, BoxRng};
|
|
use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
|
|
use passwords::PasswordManager;
|
|
use sqlx::PgPool;
|
|
use tower::util::AndThenLayer;
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
|
|
mod compat;
|
|
mod graphql;
|
|
mod health;
|
|
mod oauth2;
|
|
pub mod passwords;
|
|
pub mod upstream_oauth2;
|
|
mod views;
|
|
|
|
mod activity_tracker;
|
|
mod preferred_language;
|
|
mod site_config;
|
|
#[cfg(test)]
|
|
mod test_utils;
|
|
|
|
/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
|
|
/// errors.
|
|
#[macro_export]
|
|
macro_rules! impl_from_error_for_route {
|
|
($route_error:ty : $error:ty) => {
|
|
impl From<$error> for $route_error {
|
|
fn from(e: $error) -> Self {
|
|
Self::Internal(Box::new(e))
|
|
}
|
|
}
|
|
};
|
|
($error:ty) => {
|
|
impl_from_error_for_route!(self::RouteError: $error);
|
|
};
|
|
}
|
|
|
|
pub use mas_axum_utils::{
|
|
cookies::CookieManager, http_client_factory::HttpClientFactory, ErrorWrapper,
|
|
};
|
|
|
|
pub use self::{
|
|
activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
|
|
graphql::schema as graphql_schema,
|
|
preferred_language::PreferredLanguage,
|
|
site_config::SiteConfig,
|
|
upstream_oauth2::cache::MetadataCache,
|
|
};
|
|
|
|
pub fn healthcheck_router<S, B>() -> Router<S, B>
|
|
where
|
|
B: HttpBody + Send + 'static,
|
|
S: Clone + Send + Sync + 'static,
|
|
PgPool: FromRef<S>,
|
|
{
|
|
Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
|
|
}
|
|
|
|
pub fn graphql_router<S, B>(playground: bool) -> Router<S, B>
|
|
where
|
|
B: HttpBody + Send + 'static,
|
|
<B as HttpBody>::Data: Into<Bytes>,
|
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
|
S: Clone + Send + Sync + 'static,
|
|
mas_graphql::Schema: FromRef<S>,
|
|
BoundActivityTracker: FromRequestParts<S>,
|
|
BoxRepository: FromRequestParts<S>,
|
|
BoxClock: FromRequestParts<S>,
|
|
Encrypter: FromRef<S>,
|
|
CookieJar: FromRequestParts<S>,
|
|
{
|
|
let mut router = Router::new()
|
|
.route(
|
|
mas_router::GraphQL::route(),
|
|
get(self::graphql::get).post(self::graphql::post),
|
|
)
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_otel_headers([
|
|
AUTHORIZATION,
|
|
ACCEPT,
|
|
ACCEPT_LANGUAGE,
|
|
CONTENT_LANGUAGE,
|
|
CONTENT_TYPE,
|
|
]),
|
|
);
|
|
|
|
if playground {
|
|
router = router.route(
|
|
mas_router::GraphQLPlayground::route(),
|
|
get(self::graphql::playground),
|
|
);
|
|
}
|
|
|
|
router
|
|
}
|
|
|
|
pub fn discovery_router<S, B>() -> Router<S, B>
|
|
where
|
|
B: HttpBody + Send + 'static,
|
|
S: Clone + Send + Sync + 'static,
|
|
Keystore: FromRef<S>,
|
|
UrlBuilder: FromRef<S>,
|
|
BoxClock: FromRequestParts<S>,
|
|
BoxRng: FromRequestParts<S>,
|
|
{
|
|
Router::new()
|
|
.route(
|
|
mas_router::OidcConfiguration::route(),
|
|
get(self::oauth2::discovery::get),
|
|
)
|
|
.route(
|
|
mas_router::Webfinger::route(),
|
|
get(self::oauth2::webfinger::get),
|
|
)
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_otel_headers([
|
|
AUTHORIZATION,
|
|
ACCEPT,
|
|
ACCEPT_LANGUAGE,
|
|
CONTENT_LANGUAGE,
|
|
CONTENT_TYPE,
|
|
])
|
|
.max_age(Duration::from_secs(60 * 60)),
|
|
)
|
|
}
|
|
|
|
pub fn api_router<S, B>() -> Router<S, B>
|
|
where
|
|
B: HttpBody + Send + 'static,
|
|
<B as HttpBody>::Data: Send,
|
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
|
S: Clone + Send + Sync + 'static,
|
|
Keystore: FromRef<S>,
|
|
UrlBuilder: FromRef<S>,
|
|
BoxRepository: FromRequestParts<S>,
|
|
ActivityTracker: FromRequestParts<S>,
|
|
BoundActivityTracker: FromRequestParts<S>,
|
|
Encrypter: FromRef<S>,
|
|
HttpClientFactory: FromRef<S>,
|
|
SiteConfig: FromRef<S>,
|
|
BoxClock: FromRequestParts<S>,
|
|
BoxRng: FromRequestParts<S>,
|
|
Policy: FromRequestParts<S>,
|
|
{
|
|
// All those routes are API-like, with a common CORS layer
|
|
Router::new()
|
|
.route(
|
|
mas_router::OAuth2Keys::route(),
|
|
get(self::oauth2::keys::get),
|
|
)
|
|
.route(
|
|
mas_router::OidcUserinfo::route(),
|
|
on(
|
|
MethodFilter::POST | MethodFilter::GET,
|
|
self::oauth2::userinfo::get,
|
|
),
|
|
)
|
|
.route(
|
|
mas_router::OAuth2Introspection::route(),
|
|
post(self::oauth2::introspection::post),
|
|
)
|
|
.route(
|
|
mas_router::OAuth2Revocation::route(),
|
|
post(self::oauth2::revoke::post),
|
|
)
|
|
.route(
|
|
mas_router::OAuth2TokenEndpoint::route(),
|
|
post(self::oauth2::token::post),
|
|
)
|
|
.route(
|
|
mas_router::OAuth2RegistrationEndpoint::route(),
|
|
post(self::oauth2::registration::post),
|
|
)
|
|
.route(
|
|
mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
|
|
post(self::oauth2::device::authorize::post),
|
|
)
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_otel_headers([
|
|
AUTHORIZATION,
|
|
ACCEPT,
|
|
ACCEPT_LANGUAGE,
|
|
CONTENT_LANGUAGE,
|
|
CONTENT_TYPE,
|
|
])
|
|
.max_age(Duration::from_secs(60 * 60)),
|
|
)
|
|
}
|
|
|
|
#[allow(clippy::trait_duplication_in_bounds)]
|
|
pub fn compat_router<S, B>() -> Router<S, B>
|
|
where
|
|
B: HttpBody + Send + 'static,
|
|
<B as HttpBody>::Data: Send,
|
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
|
S: Clone + Send + Sync + 'static,
|
|
UrlBuilder: FromRef<S>,
|
|
SiteConfig: FromRef<S>,
|
|
BoxHomeserverConnection: FromRef<S>,
|
|
PasswordManager: FromRef<S>,
|
|
BoundActivityTracker: FromRequestParts<S>,
|
|
BoxRepository: FromRequestParts<S>,
|
|
BoxClock: FromRequestParts<S>,
|
|
BoxRng: FromRequestParts<S>,
|
|
{
|
|
Router::new()
|
|
.route(
|
|
mas_router::CompatLogin::route(),
|
|
get(self::compat::login::get).post(self::compat::login::post),
|
|
)
|
|
.route(
|
|
mas_router::CompatLogout::route(),
|
|
post(self::compat::logout::post),
|
|
)
|
|
.route(
|
|
mas_router::CompatRefresh::route(),
|
|
post(self::compat::refresh::post),
|
|
)
|
|
.route(
|
|
mas_router::CompatLoginSsoRedirect::route(),
|
|
get(self::compat::login_sso_redirect::get),
|
|
)
|
|
.route(
|
|
mas_router::CompatLoginSsoRedirectIdp::route(),
|
|
get(self::compat::login_sso_redirect::get),
|
|
)
|
|
.route(
|
|
mas_router::CompatLoginSsoRedirectSlash::route(),
|
|
get(self::compat::login_sso_redirect::get),
|
|
)
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_otel_headers([
|
|
AUTHORIZATION,
|
|
ACCEPT,
|
|
ACCEPT_LANGUAGE,
|
|
CONTENT_LANGUAGE,
|
|
CONTENT_TYPE,
|
|
HeaderName::from_static("x-requested-with"),
|
|
])
|
|
.max_age(Duration::from_secs(60 * 60)),
|
|
)
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
pub fn human_router<S, B>(templates: Templates) -> Router<S, B>
|
|
where
|
|
B: HttpBody + Send + 'static,
|
|
<B as HttpBody>::Data: Send,
|
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
|
S: Clone + Send + Sync + 'static,
|
|
UrlBuilder: FromRef<S>,
|
|
PreferredLanguage: FromRequestParts<S>,
|
|
BoxRepository: FromRequestParts<S>,
|
|
CookieJar: FromRequestParts<S>,
|
|
BoundActivityTracker: FromRequestParts<S>,
|
|
Encrypter: FromRef<S>,
|
|
Templates: FromRef<S>,
|
|
Keystore: FromRef<S>,
|
|
HttpClientFactory: FromRef<S>,
|
|
PasswordManager: FromRef<S>,
|
|
MetadataCache: FromRef<S>,
|
|
SiteConfig: FromRef<S>,
|
|
BoxHomeserverConnection: FromRef<S>,
|
|
BoxClock: FromRequestParts<S>,
|
|
BoxRng: FromRequestParts<S>,
|
|
Policy: FromRequestParts<S>,
|
|
{
|
|
Router::new()
|
|
// XXX: hard-coded redirect from /account to /account/
|
|
.route(
|
|
"/account",
|
|
get(
|
|
|State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| async move {
|
|
let prefix = url_builder.prefix().unwrap_or_default();
|
|
let route = mas_router::Account::route();
|
|
let destination = if let Some(query) = query {
|
|
format!("{prefix}{route}?{query}")
|
|
} else {
|
|
format!("{prefix}{route}")
|
|
};
|
|
|
|
axum::response::Redirect::to(&destination)
|
|
},
|
|
),
|
|
)
|
|
.route(mas_router::Account::route(), get(self::views::app::get))
|
|
.route(
|
|
mas_router::AccountWildcard::route(),
|
|
get(self::views::app::get),
|
|
)
|
|
.route(
|
|
mas_router::ChangePasswordDiscovery::route(),
|
|
get(|State(url_builder): State<UrlBuilder>| async move {
|
|
url_builder.redirect(&mas_router::AccountPassword)
|
|
}),
|
|
)
|
|
.route(mas_router::Index::route(), get(self::views::index::get))
|
|
.route(
|
|
mas_router::Login::route(),
|
|
get(self::views::login::get).post(self::views::login::post),
|
|
)
|
|
.route(mas_router::Logout::route(), post(self::views::logout::post))
|
|
.route(
|
|
mas_router::Reauth::route(),
|
|
get(self::views::reauth::get).post(self::views::reauth::post),
|
|
)
|
|
.route(
|
|
mas_router::Register::route(),
|
|
get(self::views::register::get).post(self::views::register::post),
|
|
)
|
|
.route(
|
|
mas_router::AccountPassword::route(),
|
|
get(self::views::account::password::get).post(self::views::account::password::post),
|
|
)
|
|
.route(
|
|
mas_router::AccountVerifyEmail::route(),
|
|
get(self::views::account::emails::verify::get)
|
|
.post(self::views::account::emails::verify::post),
|
|
)
|
|
.route(
|
|
mas_router::AccountAddEmail::route(),
|
|
get(self::views::account::emails::add::get)
|
|
.post(self::views::account::emails::add::post),
|
|
)
|
|
.route(
|
|
mas_router::OAuth2AuthorizationEndpoint::route(),
|
|
get(self::oauth2::authorization::get),
|
|
)
|
|
.route(
|
|
mas_router::ContinueAuthorizationGrant::route(),
|
|
get(self::oauth2::authorization::complete::get),
|
|
)
|
|
.route(
|
|
mas_router::Consent::route(),
|
|
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
|
|
)
|
|
.route(
|
|
mas_router::CompatLoginSsoComplete::route(),
|
|
get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
|
|
)
|
|
.route(
|
|
mas_router::UpstreamOAuth2Authorize::route(),
|
|
get(self::upstream_oauth2::authorize::get),
|
|
)
|
|
.route(
|
|
mas_router::UpstreamOAuth2Callback::route(),
|
|
get(self::upstream_oauth2::callback::get),
|
|
)
|
|
.route(
|
|
mas_router::UpstreamOAuth2Link::route(),
|
|
get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
|
|
)
|
|
.route(
|
|
mas_router::DeviceCodeLink::route(),
|
|
get(self::oauth2::device::link::get).post(self::oauth2::device::link::post),
|
|
)
|
|
.route(
|
|
mas_router::DeviceCodeConsent::route(),
|
|
get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
|
|
)
|
|
.layer(AndThenLayer::new(
|
|
move |response: axum::response::Response| async move {
|
|
if response.status().is_server_error() {
|
|
// Error responses should have an ErrorContext attached to them
|
|
let ext = response.extensions().get::<ErrorContext>();
|
|
if let Some(ctx) = ext {
|
|
if let Ok(res) = templates.render_error(ctx) {
|
|
let (mut parts, _original_body) = response.into_parts();
|
|
parts.headers.remove(CONTENT_TYPE);
|
|
parts.headers.remove(CONTENT_LENGTH);
|
|
return Ok((parts, Html(res)).into_response());
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok::<_, Infallible>(response)
|
|
},
|
|
))
|
|
}
|
|
|
|
/// The fallback handler for all routes that don't match anything else.
|
|
///
|
|
/// # Errors
|
|
///
|
|
/// Returns an error if the template rendering fails.
|
|
pub async fn fallback(
|
|
State(templates): State<Templates>,
|
|
OriginalUri(uri): OriginalUri,
|
|
method: Method,
|
|
version: Version,
|
|
PreferredLanguage(locale): PreferredLanguage,
|
|
) -> Result<impl IntoResponse, FancyError> {
|
|
let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
|
|
// XXX: this should look at the Accept header and return JSON if requested
|
|
|
|
let res = templates.render_not_found(&ctx)?;
|
|
|
|
Ok((StatusCode::NOT_FOUND, Html(res)))
|
|
}
|