diff --git a/Cargo.lock b/Cargo.lock index c1cb4313..68f30a4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2128,6 +2128,7 @@ dependencies = [ "thiserror", "tokio", "tower", + "tower-http", "tracing", "url", ] @@ -2143,6 +2144,7 @@ dependencies = [ "http-body", "hyper", "hyper-rustls 0.23.0", + "once_cell", "opentelemetry", "opentelemetry-http", "opentelemetry-semantic-conventions", diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 7d70da38..96aa5277 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -34,3 +34,5 @@ mas-storage = { path = "../storage" } mas-data-model = { path = "../data-model" } mas-jose = { path = "../jose" } mas-iana = { path = "../iana" } + +[features] diff --git a/crates/cli/src/telemetry.rs b/crates/cli/src/telemetry.rs index fe67b3b9..77a7baa5 100644 --- a/crates/cli/src/telemetry.rs +++ b/crates/cli/src/telemetry.rs @@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result> { // The CORS filter needs to know what headers it should whitelist for // CORS-protected requests. - // TODO mas_warp_utils::filters::cors::set_propagator(&propagator); + mas_http::set_propagator(&propagator); global::set_text_map_propagator(propagator); let tracer = tracer(&config.tracing.exporter)?; diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 7eaca346..bd49d5b2 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -22,6 +22,7 @@ anyhow = "1.0.56" # Web server hyper = { version = "0.14.18", features = ["full"] } tower = "0.4.12" +tower-http = { version = "0.2.5", features = ["cors"] } axum = "0.5.1" axum-macros = "0.2.0" diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 8eada93d..f75006a0 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -19,7 +19,7 @@ clippy::unused_async // Some axum handlers need that )] -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use axum::{ body::HttpBody, @@ -27,12 +27,15 @@ use axum::{ routing::{get, on, post, MethodFilter}, Router, }; +use hyper::header::AUTHORIZATION; use mas_axum_utils::UrlBuilder; use mas_config::Encrypter; use mas_email::Mailer; +use mas_http::CorsLayerExt; use mas_jose::StaticKeystore; use mas_templates::Templates; use sqlx::PgPool; +use tower_http::cors::{Any, CorsLayer}; mod health; mod oauth2; @@ -52,6 +55,33 @@ where ::Data: Send, ::Error: std::error::Error + Send + Sync, { + // All those routes are API-like, with a common CORS layer + let api_router = Router::new() + .route( + "/.well-known/openid-configuration", + get(self::oauth2::discovery::get), + ) + .route("/oauth2/keys.json", get(self::oauth2::keys::get)) + .route( + "/oauth2/userinfo", + on( + MethodFilter::POST | MethodFilter::GET, + self::oauth2::userinfo::get, + ), + ) + .route( + "/oauth2/introspect", + post(self::oauth2::introspection::post), + ) + .route("/oauth2/token", post(self::oauth2::token::post)) + .layer( + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_otel_headers([AUTHORIZATION]) + .max_age(Duration::from_secs(60 * 60)), + ); + Router::new() .route("/", get(self::views::index::get)) .route("/health", get(self::health::get)) @@ -78,28 +108,12 @@ where "/account/emails", get(self::views::account::emails::get).post(self::views::account::emails::post), ) - .route( - "/.well-known/openid-configuration", - get(self::oauth2::discovery::get), - ) - .route("/oauth2/keys.json", get(self::oauth2::keys::get)) - .route( - "/oauth2/userinfo", - on( - MethodFilter::POST | MethodFilter::GET, - self::oauth2::userinfo::get, - ), - ) - .route( - "/oauth2/introspect", - post(self::oauth2::introspection::post), - ) - .route("/oauth2/token", post(self::oauth2::token::post)) .route("/oauth2/authorize", get(self::oauth2::authorization::get)) .route( "/oauth2/authorize/step", get(self::oauth2::authorization::step_get), ) + .merge(api_router) .fallback(mas_static_files::Assets) .layer(Extension(pool.clone())) .layer(Extension(templates.clone())) diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 2d832f1f..77f5039a 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -13,6 +13,7 @@ http = "0.2.6" http-body = "0.4.4" hyper = "0.14.18" hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] } +once_cell = "1.10.0" opentelemetry = "0.17.0" opentelemetry-http = "0.6.0" opentelemetry-semantic-conventions = "0.9.0" @@ -23,6 +24,6 @@ serde_json = "1.0.79" thiserror = "1.0.30" tokio = { version = "1.17.0", features = ["sync", "parking_lot"] } tower = { version = "0.4.12", features = ["timeout", "limit"] } -tower-http = { version = "0.2.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full"] } +tower-http = { version = "0.2.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] } tracing = "0.1.32" tracing-opentelemetry = "0.17.2" diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs index d0efe4a9..cd155d2b 100644 --- a/crates/http/src/ext.rs +++ b/crates/http/src/ext.rs @@ -12,8 +12,53 @@ // See the License for the specific language governing permissions and // limitations under the License. +use http::header::HeaderName; +use once_cell::sync::OnceCell; +use tower_http::cors::CorsLayer; + use crate::layers::json::Json; +static PROPAGATOR_HEADERS: OnceCell> = OnceCell::new(); + +/// Notify the CORS layer what opentelemetry propagators are being used. This +/// helps whitelisting headers in CORS requests. +/// +/// # Panics +/// +/// When called twice +pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) { + let headers = propagator + .fields() + .map(|h| HeaderName::try_from(h).unwrap()) + .collect(); + + tracing::debug!( + ?headers, + "Headers allowed in CORS requests for trace propagators set" + ); + PROPAGATOR_HEADERS + .set(headers) + .expect(concat!(module_path!(), "::set_propagator was called twice")); +} + +pub trait CorsLayerExt { + #[must_use] + fn allow_otel_headers(self, headers: H) -> Self + where + H: IntoIterator; +} + +impl CorsLayerExt for CorsLayer { + fn allow_otel_headers(self, headers: H) -> Self + where + H: IntoIterator, + { + let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default(); + let headers: Vec<_> = headers.into_iter().chain(base.into_iter()).collect(); + self.allow_headers(headers) + } +} + pub trait ServiceExt: Sized { fn json(self) -> Json; } diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index f39a4a9e..6a9a9cb0 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -47,7 +47,7 @@ mod future_service; mod layers; pub use self::{ - ext::ServiceExt as HttpServiceExt, + ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, future_service::FutureService, layers::{client::ClientLayer, json::JsonResponseLayer, otel, server::ServerLayer}, };