1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Better CORS filter to allow OTEL propagator headers

This commit is contained in:
Quentin Gliech
2021-10-14 18:29:32 +02:00
parent e630279b54
commit 29f3edd833
11 changed files with 96 additions and 239 deletions

2
Cargo.lock generated
View File

@ -1464,6 +1464,8 @@ dependencies = [
"mas-data-model",
"mime",
"oauth2-types",
"once_cell",
"opentelemetry",
"password-hash",
"pkcs8",
"rand 0.8.4",

View File

@ -30,7 +30,12 @@ use opentelemetry_semantic_conventions as semcov;
pub fn setup(config: &TelemetryConfig) -> anyhow::Result<Option<Tracer>> {
global::set_error_handler(|e| tracing::error!("{}", e))?;
global::set_text_map_propagator(propagator());
let propagator = propagator();
// The CORS filter needs to know what headers it should whitelist for
// CORS-protected requests.
mas_core::filters::cors::set_propagator(&propagator);
global::set_text_map_propagator(propagator);
let tracer = tracer(&config.tracing)?;
meter(&config.metrics)?;

View File

@ -67,6 +67,8 @@ cookie = "0.15.1"
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
mas-config = { path = "../config" }
mas-data-model = { path = "../data-model" }
opentelemetry = "0.16.0"
once_cell = "1.8.0"
[dependencies.jwt-compact]
# Waiting on the next release because of the bump of the `rsa` dependency

View File

@ -0,0 +1,37 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Wrapper around [`warp::filters::cors`]
use std::string::ToString;
use once_cell::sync::OnceCell;
static PROPAGATOR_HEADERS: OnceCell<Vec<String>> = OnceCell::new();
/// Notify the CORS filter what opentelemetry propagators are being used. This
/// helps whitelisting headers in CORS requests.
pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
PROPAGATOR_HEADERS
.set(propagator.fields().map(ToString::to_string).collect())
.expect(concat!(module_path!(), "::set_propagator was called twice"));
}
/// Create a wrapping filter that exposes CORS behavior for a wrapped filter.
#[must_use]
pub fn cors() -> warp::filters::cors::Builder {
warp::filters::cors::cors()
.allow_any_origin()
.allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new()))
}

View File

@ -1,200 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::Reverse, future::Future, pin::Pin};
use mime::{Mime, STAR};
use serde::Serialize;
use tera::Context;
use tide::{
http::headers::{ACCEPT, LOCATION},
Body, Request, StatusCode,
};
use tracing::debug;
use crate::{state::State, templates::common_context};
/// Get the weight parameter for a mime type from 0 to 1000
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
fn get_weight(mime: &Mime) -> usize {
let q = mime
.get_param("q")
.map_or(1.0_f64, |q| q.as_str().parse().unwrap_or(0.0))
.min(1.0)
.max(0.0);
// Weight have a 3 digit precision so we can multiply by 1000 and cast to
// int. Sign loss should not happen here because of the min/max up there and
// truncation does not matter here.
(q * 1000.0) as _
}
/// Find what content type should be used for a given request
fn preferred_mime_type<'a>(
request: &Request<State>,
supported_types: &'a [Mime],
) -> Option<&'a Mime> {
let accept = request.header(ACCEPT)?;
// Parse the Accept header as a list of mime types with their associated
// weight
let accepted_types: Vec<(Mime, usize)> = {
let v: Option<Vec<_>> = accept
.into_iter()
.flat_map(|value| value.as_str().split(','))
.map(|mime| {
mime.trim().parse().ok().map(|mime| {
let q = get_weight(&mime);
(mime, q)
})
})
.collect();
let mut v = v?;
v.sort_by_key(|(_, weight)| Reverse(*weight));
v
};
// For each supported content type, find out if it is accepted with what
// weight and specificity
let mut types: Vec<_> = supported_types
.iter()
.enumerate()
.filter_map(|(index, supported)| {
accepted_types.iter().find_map(|(accepted, weight)| {
if accepted.type_() == supported.type_()
&& accepted.subtype() == supported.subtype()
{
// Accept: text/html
Some((supported, *weight, 2_usize, index))
} else if accepted.type_() == supported.type_() && accepted.subtype() == STAR {
// Accept: text/*
Some((supported, *weight, 1, index))
} else if accepted.type_() == STAR && accepted.subtype() == STAR {
// Accept: */*
Some((supported, *weight, 0, index))
} else {
None
}
})
})
.collect();
types.sort_by_key(|(_, weight, specificity, index)| {
(Reverse(*weight), Reverse(*specificity), *index)
});
types.first().map(|(mime, _, _, _)| *mime)
}
#[derive(Serialize)]
struct ErrorContext {
#[serde(skip_serializing_if = "Option::is_none")]
code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<String>,
}
impl ErrorContext {
fn should_render(&self) -> bool {
self.code.is_some() || self.description.is_some() || self.details.is_some()
}
}
pub fn middleware<'a>(
request: tide::Request<State>,
next: tide::Next<'a, State>,
) -> Pin<Box<dyn Future<Output = tide::Result> + Send + 'a>> {
Box::pin(async {
let content_type = preferred_mime_type(
&request,
&[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON],
);
debug!("Content-Type from Accept: {:?}", content_type);
// TODO: We should not clone here
let templates = request.state().templates().clone();
// TODO: This context should probably be comptuted somewhere else
let pctx = common_context(&request).await?.clone();
let mut response = next.run(request).await;
// Find out what message should be displayed from the response status
// code
let (code, description) = match response.status() {
StatusCode::NotFound => (Some("Not found".to_string()), None),
StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None),
StatusCode::Found
| StatusCode::PermanentRedirect
| StatusCode::TemporaryRedirect
| StatusCode::SeeOther => {
let description = response.header(LOCATION).map(|loc| format!("To {}", loc));
(Some("Redirecting".to_string()), description)
}
StatusCode::InternalServerError => (Some("Internal server error".to_string()), None),
_ => (None, None),
};
// If there is an error associated to the response, format it in a nice
// way with a backtrace if we have one
let details = response.take_error().map(|err| {
format!(
"{:?}{}",
err,
err.backtrace()
.map(|bt| format!("\nBacktrace:\n{}", bt.to_string()))
.unwrap_or_default()
)
});
let error_context = ErrorContext {
code,
description,
details,
};
// This is the case if one of the code, description or details is not
// None
if error_context.should_render() {
match content_type {
Some(c) if c == &mime::APPLICATION_JSON => {
response.set_body(Body::from_json(&error_context)?);
response.set_content_type("application/json");
}
Some(c) if c == &mime::TEXT_HTML => {
let mut ctx = Context::from_serialize(&error_context)?;
ctx.extend(pctx);
response.set_body(templates.render("error.html", &ctx)?);
response.set_content_type("text/html");
}
Some(c) if c == &mime::TEXT_PLAIN => {
let mut ctx = Context::from_serialize(&error_context)?;
ctx.extend(pctx);
response.set_body(templates.render("error.txt", &ctx)?);
response.set_content_type("text/plain");
}
_ => {
response.set_body("Unsupported Content-Type in Accept header");
response.set_content_type("text/plain");
response.set_status(StatusCode::NotAcceptable);
}
}
}
Ok(response)
})
}

View File

@ -17,11 +17,11 @@
#![allow(clippy::unused_async)] // Some warp filters need that
#![deny(missing_docs)]
pub mod csrf;
// mod errors;
pub mod authenticate;
pub mod client;
pub mod cookies;
pub mod cors;
pub mod csrf;
pub mod database;
pub mod headers;
pub mod session;

View File

@ -15,6 +15,7 @@
use std::collections::HashSet;
use hyper::Method;
use mas_config::OAuth2Config;
use oauth2_types::{
oidc::Metadata,
pkce::CodeChallengeMethod,
@ -22,7 +23,7 @@ use oauth2_types::{
};
use warp::{Filter, Rejection, Reply};
use crate::config::OAuth2Config;
use crate::filters::cors::cors;
pub(super) fn filter(
config: &OAuth2Config,
@ -87,15 +88,9 @@ pub(super) fn filter(
code_challenge_methods_supported,
};
// TODO: get the headers list from the global opentelemetry propagators
let cors = warp::cors()
.allow_method(Method::GET)
.allow_any_origin()
.allow_headers(["traceparent"]);
warp::path!(".well-known" / "openid-configuration").and(
warp::get()
.map(move || warp::reply::json(&metadata))
.with(cors),
.with(cors().allow_method(Method::GET)),
)
}

View File

@ -13,6 +13,7 @@
// limitations under the License.
use chrono::Utc;
use hyper::Method;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info, warn};
@ -23,6 +24,7 @@ use crate::{
errors::WrapError,
filters::{
client::{client_authentication, ClientAuthentication},
cors::cors,
database::connection,
},
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
@ -33,12 +35,14 @@ pub fn filter(
pool: &PgPool,
oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::path!("oauth2" / "introspect")
.and(warp::post())
warp::path!("oauth2" / "introspect").and(
warp::post()
.and(connection(pool))
.and(client_authentication(oauth2_config))
.and_then(introspect)
.recover(recover)
.with(cors().allow_method(Method::POST)),
)
}
const INACTIVE: IntrospectionResponse = IntrospectionResponse {

View File

@ -12,19 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use hyper::Method;
use mas_config::OAuth2Config;
use warp::{Filter, Rejection, Reply};
use crate::config::OAuth2Config;
use crate::filters::cors::cors;
pub(super) fn filter(
config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let jwks = config.keys.to_public_jwks();
let cors = warp::cors().allow_any_origin();
warp::path!("oauth2" / "keys.json")
.and(warp::get())
warp::path!("oauth2" / "keys.json").and(
warp::get()
.map(move || warp::reply::json(&jwks))
.with(cors)
.with(cors().allow_method(Method::GET)),
)
}

View File

@ -16,7 +16,7 @@ use anyhow::Context;
use chrono::Duration;
use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, Pragma};
use hyper::StatusCode;
use hyper::{Method, StatusCode};
use jwt_compact::{Claims, Header, TimeOptions};
use oauth2_types::{
errors::{
@ -44,6 +44,7 @@ use crate::{
errors::WrapError,
filters::{
client::{client_authentication, ClientAuthentication},
cors::cors,
database::connection,
with_keys,
},
@ -92,14 +93,16 @@ pub fn filter(
oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let issuer = oauth2_config.issuer.clone();
warp::path!("oauth2" / "token")
.and(warp::post())
warp::path!("oauth2" / "token").and(
warp::post()
.and(client_authentication(oauth2_config))
.and(with_keys(oauth2_config))
.and(warp::any().map(move || issuer.clone()))
.and(connection(pool))
.and_then(token)
.recover(recover)
.with(cors().allow_method(Method::POST)),
)
}
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {

View File

@ -12,13 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use hyper::Method;
use serde::Serialize;
use sqlx::PgPool;
use warp::{Filter, Rejection, Reply};
use crate::{
config::OAuth2Config,
filters::authenticate::{authentication, recover_unauthorized},
filters::{
authenticate::{authentication, recover_unauthorized},
cors::cors,
},
storage::oauth2::access_token::OAuth2AccessTokenLookup,
};
@ -31,11 +35,15 @@ pub(super) fn filter(
pool: &PgPool,
_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::path!("oauth2" / "userinfo")
.and(warp::get().or(warp::post()).unify())
warp::path!("oauth2" / "userinfo").and(
warp::get()
.or(warp::post())
.unify()
.and(authentication(pool))
.and_then(userinfo)
.recover(recover_unauthorized)
.with(cors().allow_methods([Method::GET, Method::POST])),
)
}
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {