You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Add CORS headers to API-like routes
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2128,6 +2128,7 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
@ -2143,6 +2144,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-rustls 0.23.0",
|
"hyper-rustls 0.23.0",
|
||||||
|
"once_cell",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-http",
|
"opentelemetry-http",
|
||||||
"opentelemetry-semantic-conventions",
|
"opentelemetry-semantic-conventions",
|
||||||
|
@ -34,3 +34,5 @@ mas-storage = { path = "../storage" }
|
|||||||
mas-data-model = { path = "../data-model" }
|
mas-data-model = { path = "../data-model" }
|
||||||
mas-jose = { path = "../jose" }
|
mas-jose = { path = "../jose" }
|
||||||
mas-iana = { path = "../iana" }
|
mas-iana = { path = "../iana" }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
@ -40,7 +40,7 @@ pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
|
|||||||
|
|
||||||
// The CORS filter needs to know what headers it should whitelist for
|
// The CORS filter needs to know what headers it should whitelist for
|
||||||
// CORS-protected requests.
|
// CORS-protected requests.
|
||||||
// TODO mas_warp_utils::filters::cors::set_propagator(&propagator);
|
mas_http::set_propagator(&propagator);
|
||||||
global::set_text_map_propagator(propagator);
|
global::set_text_map_propagator(propagator);
|
||||||
|
|
||||||
let tracer = tracer(&config.tracing.exporter)?;
|
let tracer = tracer(&config.tracing.exporter)?;
|
||||||
|
@ -22,6 +22,7 @@ anyhow = "1.0.56"
|
|||||||
# Web server
|
# Web server
|
||||||
hyper = { version = "0.14.18", features = ["full"] }
|
hyper = { version = "0.14.18", features = ["full"] }
|
||||||
tower = "0.4.12"
|
tower = "0.4.12"
|
||||||
|
tower-http = { version = "0.2.5", features = ["cors"] }
|
||||||
axum = "0.5.1"
|
axum = "0.5.1"
|
||||||
axum-macros = "0.2.0"
|
axum-macros = "0.2.0"
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
clippy::unused_async // Some axum handlers need that
|
clippy::unused_async // Some axum handlers need that
|
||||||
)]
|
)]
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::HttpBody,
|
body::HttpBody,
|
||||||
@ -27,12 +27,15 @@ use axum::{
|
|||||||
routing::{get, on, post, MethodFilter},
|
routing::{get, on, post, MethodFilter},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use hyper::header::AUTHORIZATION;
|
||||||
use mas_axum_utils::UrlBuilder;
|
use mas_axum_utils::UrlBuilder;
|
||||||
use mas_config::Encrypter;
|
use mas_config::Encrypter;
|
||||||
use mas_email::Mailer;
|
use mas_email::Mailer;
|
||||||
|
use mas_http::CorsLayerExt;
|
||||||
use mas_jose::StaticKeystore;
|
use mas_jose::StaticKeystore;
|
||||||
use mas_templates::Templates;
|
use mas_templates::Templates;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
|
||||||
mod health;
|
mod health;
|
||||||
mod oauth2;
|
mod oauth2;
|
||||||
@ -52,6 +55,33 @@ where
|
|||||||
<B as HttpBody>::Data: Send,
|
<B as HttpBody>::Data: Send,
|
||||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||||
{
|
{
|
||||||
|
// All those routes are API-like, with a common CORS layer
|
||||||
|
let api_router = Router::new()
|
||||||
|
.route(
|
||||||
|
"/.well-known/openid-configuration",
|
||||||
|
get(self::oauth2::discovery::get),
|
||||||
|
)
|
||||||
|
.route("/oauth2/keys.json", get(self::oauth2::keys::get))
|
||||||
|
.route(
|
||||||
|
"/oauth2/userinfo",
|
||||||
|
on(
|
||||||
|
MethodFilter::POST | MethodFilter::GET,
|
||||||
|
self::oauth2::userinfo::get,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/oauth2/introspect",
|
||||||
|
post(self::oauth2::introspection::post),
|
||||||
|
)
|
||||||
|
.route("/oauth2/token", post(self::oauth2::token::post))
|
||||||
|
.layer(
|
||||||
|
CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_otel_headers([AUTHORIZATION])
|
||||||
|
.max_age(Duration::from_secs(60 * 60)),
|
||||||
|
);
|
||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/", get(self::views::index::get))
|
.route("/", get(self::views::index::get))
|
||||||
.route("/health", get(self::health::get))
|
.route("/health", get(self::health::get))
|
||||||
@ -78,28 +108,12 @@ where
|
|||||||
"/account/emails",
|
"/account/emails",
|
||||||
get(self::views::account::emails::get).post(self::views::account::emails::post),
|
get(self::views::account::emails::get).post(self::views::account::emails::post),
|
||||||
)
|
)
|
||||||
.route(
|
|
||||||
"/.well-known/openid-configuration",
|
|
||||||
get(self::oauth2::discovery::get),
|
|
||||||
)
|
|
||||||
.route("/oauth2/keys.json", get(self::oauth2::keys::get))
|
|
||||||
.route(
|
|
||||||
"/oauth2/userinfo",
|
|
||||||
on(
|
|
||||||
MethodFilter::POST | MethodFilter::GET,
|
|
||||||
self::oauth2::userinfo::get,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/oauth2/introspect",
|
|
||||||
post(self::oauth2::introspection::post),
|
|
||||||
)
|
|
||||||
.route("/oauth2/token", post(self::oauth2::token::post))
|
|
||||||
.route("/oauth2/authorize", get(self::oauth2::authorization::get))
|
.route("/oauth2/authorize", get(self::oauth2::authorization::get))
|
||||||
.route(
|
.route(
|
||||||
"/oauth2/authorize/step",
|
"/oauth2/authorize/step",
|
||||||
get(self::oauth2::authorization::step_get),
|
get(self::oauth2::authorization::step_get),
|
||||||
)
|
)
|
||||||
|
.merge(api_router)
|
||||||
.fallback(mas_static_files::Assets)
|
.fallback(mas_static_files::Assets)
|
||||||
.layer(Extension(pool.clone()))
|
.layer(Extension(pool.clone()))
|
||||||
.layer(Extension(templates.clone()))
|
.layer(Extension(templates.clone()))
|
||||||
|
@ -13,6 +13,7 @@ http = "0.2.6"
|
|||||||
http-body = "0.4.4"
|
http-body = "0.4.4"
|
||||||
hyper = "0.14.18"
|
hyper = "0.14.18"
|
||||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
|
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
|
||||||
|
once_cell = "1.10.0"
|
||||||
opentelemetry = "0.17.0"
|
opentelemetry = "0.17.0"
|
||||||
opentelemetry-http = "0.6.0"
|
opentelemetry-http = "0.6.0"
|
||||||
opentelemetry-semantic-conventions = "0.9.0"
|
opentelemetry-semantic-conventions = "0.9.0"
|
||||||
@ -23,6 +24,6 @@ serde_json = "1.0.79"
|
|||||||
thiserror = "1.0.30"
|
thiserror = "1.0.30"
|
||||||
tokio = { version = "1.17.0", features = ["sync", "parking_lot"] }
|
tokio = { version = "1.17.0", features = ["sync", "parking_lot"] }
|
||||||
tower = { version = "0.4.12", features = ["timeout", "limit"] }
|
tower = { version = "0.4.12", features = ["timeout", "limit"] }
|
||||||
tower-http = { version = "0.2.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full"] }
|
tower-http = { version = "0.2.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
|
||||||
tracing = "0.1.32"
|
tracing = "0.1.32"
|
||||||
tracing-opentelemetry = "0.17.2"
|
tracing-opentelemetry = "0.17.2"
|
||||||
|
@ -12,8 +12,53 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use http::header::HeaderName;
|
||||||
|
use once_cell::sync::OnceCell;
|
||||||
|
use tower_http::cors::CorsLayer;
|
||||||
|
|
||||||
use crate::layers::json::Json;
|
use crate::layers::json::Json;
|
||||||
|
|
||||||
|
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
|
||||||
|
|
||||||
|
/// Notify the CORS layer what opentelemetry propagators are being used. This
|
||||||
|
/// helps whitelisting headers in CORS requests.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// When called twice
|
||||||
|
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
|
||||||
|
let headers = propagator
|
||||||
|
.fields()
|
||||||
|
.map(|h| HeaderName::try_from(h).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
?headers,
|
||||||
|
"Headers allowed in CORS requests for trace propagators set"
|
||||||
|
);
|
||||||
|
PROPAGATOR_HEADERS
|
||||||
|
.set(headers)
|
||||||
|
.expect(concat!(module_path!(), "::set_propagator was called twice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CorsLayerExt {
|
||||||
|
#[must_use]
|
||||||
|
fn allow_otel_headers<H>(self, headers: H) -> Self
|
||||||
|
where
|
||||||
|
H: IntoIterator<Item = HeaderName>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CorsLayerExt for CorsLayer {
|
||||||
|
fn allow_otel_headers<H>(self, headers: H) -> Self
|
||||||
|
where
|
||||||
|
H: IntoIterator<Item = HeaderName>,
|
||||||
|
{
|
||||||
|
let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default();
|
||||||
|
let headers: Vec<_> = headers.into_iter().chain(base.into_iter()).collect();
|
||||||
|
self.allow_headers(headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait ServiceExt: Sized {
|
pub trait ServiceExt: Sized {
|
||||||
fn json<T>(self) -> Json<Self, T>;
|
fn json<T>(self) -> Json<Self, T>;
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ mod future_service;
|
|||||||
mod layers;
|
mod layers;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
ext::ServiceExt as HttpServiceExt,
|
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
|
||||||
future_service::FutureService,
|
future_service::FutureService,
|
||||||
layers::{client::ClientLayer, json::JsonResponseLayer, otel, server::ServerLayer},
|
layers::{client::ClientLayer, json::JsonResponseLayer, otel, server::ServerLayer},
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user