You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Better CORS filter to allow OTEL propagator headers
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1464,6 +1464,8 @@ dependencies = [
|
||||
"mas-data-model",
|
||||
"mime",
|
||||
"oauth2-types",
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"password-hash",
|
||||
"pkcs8",
|
||||
"rand 0.8.4",
|
||||
|
@ -30,7 +30,12 @@ use opentelemetry_semantic_conventions as semcov;
|
||||
|
||||
pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
|
||||
global::set_error_handler(|e| tracing::error!("{}", e))?;
|
||||
global::set_text_map_propagator(propagator());
|
||||
let propagator = propagator();
|
||||
|
||||
// The CORS filter needs to know what headers it should whitelist for
|
||||
// CORS-protected requests.
|
||||
mas_core::filters::cors::set_propagator(&propagator);
|
||||
global::set_text_map_propagator(propagator);
|
||||
|
||||
let tracer = tracer(&config.tracing)?;
|
||||
meter(&config.metrics)?;
|
||||
|
@ -67,6 +67,8 @@ cookie = "0.15.1"
|
||||
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
|
||||
mas-config = { path = "../config" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
opentelemetry = "0.16.0"
|
||||
once_cell = "1.8.0"
|
||||
|
||||
[dependencies.jwt-compact]
|
||||
# Waiting on the next release because of the bump of the `rsa` dependency
|
||||
|
37
crates/core/src/filters/cors.rs
Normal file
37
crates/core/src/filters/cors.rs
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Wrapper around [`warp::filters::cors`]
|
||||
|
||||
use std::string::ToString;
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
|
||||
|
||||
/// Notify the CORS filter what opentelemetry propagators are being used. This
|
||||
/// helps whitelisting headers in CORS requests.
|
||||
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
|
||||
PROPAGATOR_HEADERS
|
||||
.set(propagator.fields().map(ToString::to_string).collect())
|
||||
.expect(concat!(module_path!(), "::set_propagator was called twice"));
|
||||
}
|
||||
|
||||
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
|
||||
#[must_use]
|
||||
pub fn cors() -> warp::filters::cors::Builder {
|
||||
warp::filters::cors::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
|
||||
}
|
@ -1,200 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{cmp::Reverse, future::Future, pin::Pin};
|
||||
|
||||
use mime::{Mime, STAR};
|
||||
use serde::Serialize;
|
||||
use tera::Context;
|
||||
use tide::{
|
||||
http::headers::{ACCEPT, LOCATION},
|
||||
Body, Request, StatusCode,
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{state::State, templates::common_context};
|
||||
|
||||
/// Get the weight parameter for a mime type from 0 to 1000
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
fn get_weight(mime: &Mime) -> usize {
|
||||
let q = mime
|
||||
.get_param("q")
|
||||
.map_or(1.0_f64, |q| q.as_str().parse().unwrap_or(0.0))
|
||||
.min(1.0)
|
||||
.max(0.0);
|
||||
|
||||
// Weight have a 3 digit precision so we can multiply by 1000 and cast to
|
||||
// int. Sign loss should not happen here because of the min/max up there and
|
||||
// truncation does not matter here.
|
||||
(q * 1000.0) as _
|
||||
}
|
||||
|
||||
/// Find what content type should be used for a given request
|
||||
fn preferred_mime_type<'a>(
|
||||
request: &Request<State>,
|
||||
supported_types: &'a [Mime],
|
||||
) -> Option<&'a Mime> {
|
||||
let accept = request.header(ACCEPT)?;
|
||||
// Parse the Accept header as a list of mime types with their associated
|
||||
// weight
|
||||
let accepted_types: Vec<(Mime, usize)> = {
|
||||
let v: Option<Vec<_>> = accept
|
||||
.into_iter()
|
||||
.flat_map(|value| value.as_str().split(','))
|
||||
.map(|mime| {
|
||||
mime.trim().parse().ok().map(|mime| {
|
||||
let q = get_weight(&mime);
|
||||
(mime, q)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut v = v?;
|
||||
v.sort_by_key(|(_, weight)| Reverse(*weight));
|
||||
v
|
||||
};
|
||||
|
||||
// For each supported content type, find out if it is accepted with what
|
||||
// weight and specificity
|
||||
let mut types: Vec<_> = supported_types
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, supported)| {
|
||||
accepted_types.iter().find_map(|(accepted, weight)| {
|
||||
if accepted.type_() == supported.type_()
|
||||
&& accepted.subtype() == supported.subtype()
|
||||
{
|
||||
// Accept: text/html
|
||||
Some((supported, *weight, 2_usize, index))
|
||||
} else if accepted.type_() == supported.type_() && accepted.subtype() == STAR {
|
||||
// Accept: text/*
|
||||
Some((supported, *weight, 1, index))
|
||||
} else if accepted.type_() == STAR && accepted.subtype() == STAR {
|
||||
// Accept: */*
|
||||
Some((supported, *weight, 0, index))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
types.sort_by_key(|(_, weight, specificity, index)| {
|
||||
(Reverse(*weight), Reverse(*specificity), *index)
|
||||
});
|
||||
|
||||
types.first().map(|(mime, _, _, _)| *mime)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorContext {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
code: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
details: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorContext {
|
||||
fn should_render(&self) -> bool {
|
||||
self.code.is_some() || self.description.is_some() || self.details.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn middleware<'a>(
|
||||
request: tide::Request<State>,
|
||||
next: tide::Next<'a, State>,
|
||||
) -> Pin<Box<dyn Future<Output = tide::Result> + Send + 'a>> {
|
||||
Box::pin(async {
|
||||
let content_type = preferred_mime_type(
|
||||
&request,
|
||||
&[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON],
|
||||
);
|
||||
debug!("Content-Type from Accept: {:?}", content_type);
|
||||
|
||||
// TODO: We should not clone here
|
||||
let templates = request.state().templates().clone();
|
||||
|
||||
// TODO: This context should probably be comptuted somewhere else
|
||||
let pctx = common_context(&request).await?.clone();
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
// Find out what message should be displayed from the response status
|
||||
// code
|
||||
let (code, description) = match response.status() {
|
||||
StatusCode::NotFound => (Some("Not found".to_string()), None),
|
||||
StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None),
|
||||
StatusCode::Found
|
||||
| StatusCode::PermanentRedirect
|
||||
| StatusCode::TemporaryRedirect
|
||||
| StatusCode::SeeOther => {
|
||||
let description = response.header(LOCATION).map(|loc| format!("To {}", loc));
|
||||
(Some("Redirecting".to_string()), description)
|
||||
}
|
||||
StatusCode::InternalServerError => (Some("Internal server error".to_string()), None),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
// If there is an error associated to the response, format it in a nice
|
||||
// way with a backtrace if we have one
|
||||
let details = response.take_error().map(|err| {
|
||||
format!(
|
||||
"{:?}{}",
|
||||
err,
|
||||
err.backtrace()
|
||||
.map(|bt| format!("\nBacktrace:\n{}", bt.to_string()))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
});
|
||||
|
||||
let error_context = ErrorContext {
|
||||
code,
|
||||
description,
|
||||
details,
|
||||
};
|
||||
|
||||
// This is the case if one of the code, description or details is not
|
||||
// None
|
||||
if error_context.should_render() {
|
||||
match content_type {
|
||||
Some(c) if c == &mime::APPLICATION_JSON => {
|
||||
response.set_body(Body::from_json(&error_context)?);
|
||||
response.set_content_type("application/json");
|
||||
}
|
||||
Some(c) if c == &mime::TEXT_HTML => {
|
||||
let mut ctx = Context::from_serialize(&error_context)?;
|
||||
ctx.extend(pctx);
|
||||
response.set_body(templates.render("error.html", &ctx)?);
|
||||
response.set_content_type("text/html");
|
||||
}
|
||||
Some(c) if c == &mime::TEXT_PLAIN => {
|
||||
let mut ctx = Context::from_serialize(&error_context)?;
|
||||
ctx.extend(pctx);
|
||||
response.set_body(templates.render("error.txt", &ctx)?);
|
||||
response.set_content_type("text/plain");
|
||||
}
|
||||
_ => {
|
||||
response.set_body("Unsupported Content-Type in Accept header");
|
||||
response.set_content_type("text/plain");
|
||||
response.set_status(StatusCode::NotAcceptable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
@ -17,11 +17,11 @@
|
||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub mod csrf;
|
||||
// mod errors;
|
||||
pub mod authenticate;
|
||||
pub mod client;
|
||||
pub mod cookies;
|
||||
pub mod cors;
|
||||
pub mod csrf;
|
||||
pub mod database;
|
||||
pub mod headers;
|
||||
pub mod session;
|
||||
|
@ -15,6 +15,7 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use hyper::Method;
|
||||
use mas_config::OAuth2Config;
|
||||
use oauth2_types::{
|
||||
oidc::Metadata,
|
||||
pkce::CodeChallengeMethod,
|
||||
@ -22,7 +23,7 @@ use oauth2_types::{
|
||||
};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
use crate::filters::cors::cors;
|
||||
|
||||
pub(super) fn filter(
|
||||
config: &OAuth2Config,
|
||||
@ -87,15 +88,9 @@ pub(super) fn filter(
|
||||
code_challenge_methods_supported,
|
||||
};
|
||||
|
||||
// TODO: get the headers list from the global opentelemetry propagators
|
||||
let cors = warp::cors()
|
||||
.allow_method(Method::GET)
|
||||
.allow_any_origin()
|
||||
.allow_headers(["traceparent"]);
|
||||
|
||||
warp::path!(".well-known" / "openid-configuration").and(
|
||||
warp::get()
|
||||
.map(move || warp::reply::json(&metadata))
|
||||
.with(cors),
|
||||
.with(cors().allow_method(Method::GET)),
|
||||
)
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::Utc;
|
||||
use hyper::Method;
|
||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use tracing::{info, warn};
|
||||
@ -23,6 +24,7 @@ use crate::{
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
client::{client_authentication, ClientAuthentication},
|
||||
cors::cors,
|
||||
database::connection,
|
||||
},
|
||||
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
||||
@ -33,12 +35,14 @@ pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::path!("oauth2" / "introspect")
|
||||
.and(warp::post())
|
||||
.and(connection(pool))
|
||||
.and(client_authentication(oauth2_config))
|
||||
.and_then(introspect)
|
||||
.recover(recover)
|
||||
warp::path!("oauth2" / "introspect").and(
|
||||
warp::post()
|
||||
.and(connection(pool))
|
||||
.and(client_authentication(oauth2_config))
|
||||
.and_then(introspect)
|
||||
.recover(recover)
|
||||
.with(cors().allow_method(Method::POST)),
|
||||
)
|
||||
}
|
||||
|
||||
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
||||
|
@ -12,19 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use hyper::Method;
|
||||
use mas_config::OAuth2Config;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
use crate::filters::cors::cors;
|
||||
|
||||
pub(super) fn filter(
|
||||
config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let jwks = config.keys.to_public_jwks();
|
||||
|
||||
let cors = warp::cors().allow_any_origin();
|
||||
|
||||
warp::path!("oauth2" / "keys.json")
|
||||
.and(warp::get())
|
||||
.map(move || warp::reply::json(&jwks))
|
||||
.with(cors)
|
||||
warp::path!("oauth2" / "keys.json").and(
|
||||
warp::get()
|
||||
.map(move || warp::reply::json(&jwks))
|
||||
.with(cors().allow_method(Method::GET)),
|
||||
)
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ use anyhow::Context;
|
||||
use chrono::Duration;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{CacheControl, Pragma};
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Method, StatusCode};
|
||||
use jwt_compact::{Claims, Header, TimeOptions};
|
||||
use oauth2_types::{
|
||||
errors::{
|
||||
@ -44,6 +44,7 @@ use crate::{
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
client::{client_authentication, ClientAuthentication},
|
||||
cors::cors,
|
||||
database::connection,
|
||||
with_keys,
|
||||
},
|
||||
@ -92,14 +93,16 @@ pub fn filter(
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let issuer = oauth2_config.issuer.clone();
|
||||
warp::path!("oauth2" / "token")
|
||||
.and(warp::post())
|
||||
.and(client_authentication(oauth2_config))
|
||||
.and(with_keys(oauth2_config))
|
||||
.and(warp::any().map(move || issuer.clone()))
|
||||
.and(connection(pool))
|
||||
.and_then(token)
|
||||
.recover(recover)
|
||||
warp::path!("oauth2" / "token").and(
|
||||
warp::post()
|
||||
.and(client_authentication(oauth2_config))
|
||||
.and(with_keys(oauth2_config))
|
||||
.and(warp::any().map(move || issuer.clone()))
|
||||
.and(connection(pool))
|
||||
.and_then(token)
|
||||
.recover(recover)
|
||||
.with(cors().allow_method(Method::POST)),
|
||||
)
|
||||
}
|
||||
|
||||
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||
|
@ -12,13 +12,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use hyper::Method;
|
||||
use serde::Serialize;
|
||||
use sqlx::PgPool;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::OAuth2Config,
|
||||
filters::authenticate::{authentication, recover_unauthorized},
|
||||
filters::{
|
||||
authenticate::{authentication, recover_unauthorized},
|
||||
cors::cors,
|
||||
},
|
||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||
};
|
||||
|
||||
@ -31,11 +35,15 @@ pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::path!("oauth2" / "userinfo")
|
||||
.and(warp::get().or(warp::post()).unify())
|
||||
.and(authentication(pool))
|
||||
.and_then(userinfo)
|
||||
.recover(recover_unauthorized)
|
||||
warp::path!("oauth2" / "userinfo").and(
|
||||
warp::get()
|
||||
.or(warp::post())
|
||||
.unify()
|
||||
.and(authentication(pool))
|
||||
.and_then(userinfo)
|
||||
.recover(recover_unauthorized)
|
||||
.with(cors().allow_methods([Method::GET, Method::POST])),
|
||||
)
|
||||
}
|
||||
|
||||
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
||||
|
Reference in New Issue
Block a user