1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Add CORS headers to API-like routes

This commit is contained in:
Quentin Gliech
2022-04-07 16:25:42 +02:00
parent b43817e66c
commit bc24e30867
8 changed files with 86 additions and 21 deletions

2
Cargo.lock generated
View File

@ -2128,6 +2128,7 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"tower", "tower",
"tower-http",
"tracing", "tracing",
"url", "url",
] ]
@ -2143,6 +2144,7 @@ dependencies = [
"http-body", "http-body",
"hyper", "hyper",
"hyper-rustls 0.23.0", "hyper-rustls 0.23.0",
"once_cell",
"opentelemetry", "opentelemetry",
"opentelemetry-http", "opentelemetry-http",
"opentelemetry-semantic-conventions", "opentelemetry-semantic-conventions",

View File

@ -34,3 +34,5 @@ mas-storage = { path = "../storage" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
[features]

View File

@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
// The CORS filter needs to know what headers it should whitelist for // The CORS filter needs to know what headers it should whitelist for
// CORS-protected requests. // CORS-protected requests.
// TODO mas_warp_utils::filters::cors::set_propagator(&propagator); mas_http::set_propagator(&propagator);
global::set_text_map_propagator(propagator); global::set_text_map_propagator(propagator);
let tracer = tracer(&config.tracing.exporter)?; let tracer = tracer(&config.tracing.exporter)?;

View File

@ -22,6 +22,7 @@ anyhow = "1.0.56"
# Web server # Web server
hyper = { version = "0.14.18", features = ["full"] } hyper = { version = "0.14.18", features = ["full"] }
tower = "0.4.12" tower = "0.4.12"
tower-http = { version = "0.2.5", features = ["cors"] }
axum = "0.5.1" axum = "0.5.1"
axum-macros = "0.2.0" axum-macros = "0.2.0"

View File

@ -19,7 +19,7 @@
clippy::unused_async // Some axum handlers need that clippy::unused_async // Some axum handlers need that
)] )]
use std::sync::Arc; use std::{sync::Arc, time::Duration};
use axum::{ use axum::{
body::HttpBody, body::HttpBody,
@ -27,12 +27,15 @@ use axum::{
routing::{get, on, post, MethodFilter}, routing::{get, on, post, MethodFilter},
Router, Router,
}; };
use hyper::header::AUTHORIZATION;
use mas_axum_utils::UrlBuilder; use mas_axum_utils::UrlBuilder;
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_email::Mailer; use mas_email::Mailer;
use mas_http::CorsLayerExt;
use mas_jose::StaticKeystore; use mas_jose::StaticKeystore;
use mas_templates::Templates; use mas_templates::Templates;
use sqlx::PgPool; use sqlx::PgPool;
use tower_http::cors::{Any, CorsLayer};
mod health; mod health;
mod oauth2; mod oauth2;
@ -52,6 +55,33 @@ where
<B as HttpBody>::Data: Send, <B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
{ {
// All those routes are API-like, with a common CORS layer
let api_router = Router::new()
.route(
"/.well-known/openid-configuration",
get(self::oauth2::discovery::get),
)
.route("/oauth2/keys.json", get(self::oauth2::keys::get))
.route(
"/oauth2/userinfo",
on(
MethodFilter::POST | MethodFilter::GET,
self::oauth2::userinfo::get,
),
)
.route(
"/oauth2/introspect",
post(self::oauth2::introspection::post),
)
.route("/oauth2/token", post(self::oauth2::token::post))
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_otel_headers([AUTHORIZATION])
.max_age(Duration::from_secs(60 * 60)),
);
Router::new() Router::new()
.route("/", get(self::views::index::get)) .route("/", get(self::views::index::get))
.route("/health", get(self::health::get)) .route("/health", get(self::health::get))
@ -78,28 +108,12 @@ where
"/account/emails", "/account/emails",
get(self::views::account::emails::get).post(self::views::account::emails::post), get(self::views::account::emails::get).post(self::views::account::emails::post),
) )
.route(
"/.well-known/openid-configuration",
get(self::oauth2::discovery::get),
)
.route("/oauth2/keys.json", get(self::oauth2::keys::get))
.route(
"/oauth2/userinfo",
on(
MethodFilter::POST | MethodFilter::GET,
self::oauth2::userinfo::get,
),
)
.route(
"/oauth2/introspect",
post(self::oauth2::introspection::post),
)
.route("/oauth2/token", post(self::oauth2::token::post))
.route("/oauth2/authorize", get(self::oauth2::authorization::get)) .route("/oauth2/authorize", get(self::oauth2::authorization::get))
.route( .route(
"/oauth2/authorize/step", "/oauth2/authorize/step",
get(self::oauth2::authorization::step_get), get(self::oauth2::authorization::step_get),
) )
.merge(api_router)
.fallback(mas_static_files::Assets) .fallback(mas_static_files::Assets)
.layer(Extension(pool.clone())) .layer(Extension(pool.clone()))
.layer(Extension(templates.clone())) .layer(Extension(templates.clone()))

View File

@ -13,6 +13,7 @@ http = "0.2.6"
http-body = "0.4.4" http-body = "0.4.4"
hyper = "0.14.18" hyper = "0.14.18"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] } hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
once_cell = "1.10.0"
opentelemetry = "0.17.0" opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0" opentelemetry-http = "0.6.0"
opentelemetry-semantic-conventions = "0.9.0" opentelemetry-semantic-conventions = "0.9.0"
@ -23,6 +24,6 @@ serde_json = "1.0.79"
thiserror = "1.0.30" thiserror = "1.0.30"
tokio = { version = "1.17.0", features = ["sync", "parking_lot"] } tokio = { version = "1.17.0", features = ["sync", "parking_lot"] }
tower = { version = "0.4.12", features = ["timeout", "limit"] } tower = { version = "0.4.12", features = ["timeout", "limit"] }
tower-http = { version = "0.2.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full"] } tower-http = { version = "0.2.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
tracing = "0.1.32" tracing = "0.1.32"
tracing-opentelemetry = "0.17.2" tracing-opentelemetry = "0.17.2"

View File

@ -12,8 +12,53 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use http::header::HeaderName;
use once_cell::sync::OnceCell;
use tower_http::cors::CorsLayer;
use crate::layers::json::Json; use crate::layers::json::Json;
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
/// Notify the CORS layer what opentelemetry propagators are being used. This
/// helps whitelisting headers in CORS requests.
///
/// # Panics
///
/// When called twice
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
let headers = propagator
.fields()
.map(|h| HeaderName::try_from(h).unwrap())
.collect();
tracing::debug!(
?headers,
"Headers allowed in CORS requests for trace propagators set"
);
PROPAGATOR_HEADERS
.set(headers)
.expect(concat!(module_path!(), "::set_propagator was called twice"));
}
pub trait CorsLayerExt {
#[must_use]
fn allow_otel_headers<H>(self, headers: H) -> Self
where
H: IntoIterator<Item = HeaderName>;
}
impl CorsLayerExt for CorsLayer {
fn allow_otel_headers<H>(self, headers: H) -> Self
where
H: IntoIterator<Item = HeaderName>,
{
let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default();
let headers: Vec<_> = headers.into_iter().chain(base.into_iter()).collect();
self.allow_headers(headers)
}
}
pub trait ServiceExt: Sized { pub trait ServiceExt: Sized {
fn json<T>(self) -> Json<Self, T>; fn json<T>(self) -> Json<Self, T>;
} }

View File

@ -47,7 +47,7 @@ mod future_service;
mod layers; mod layers;
pub use self::{ pub use self::{
ext::ServiceExt as HttpServiceExt, ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
future_service::FutureService, future_service::FutureService,
layers::{client::ClientLayer, json::JsonResponseLayer, otel, server::ServerLayer}, layers::{client::ClientLayer, json::JsonResponseLayer, otel, server::ServerLayer},
}; };