diff --git a/Cargo.lock b/Cargo.lock index bd5060f5..bf1f68c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1464,6 +1464,8 @@ dependencies = [ "mas-data-model", "mime", "oauth2-types", + "once_cell", + "opentelemetry", "password-hash", "pkcs8", "rand 0.8.4", diff --git a/crates/cli/src/telemetry.rs b/crates/cli/src/telemetry.rs index 8067fc79..e6a1f23e 100644 --- a/crates/cli/src/telemetry.rs +++ b/crates/cli/src/telemetry.rs @@ -30,7 +30,12 @@ use opentelemetry_semantic_conventions as semcov; pub fn setup(config: &TelemetryConfig) -> anyhow::Result> { global::set_error_handler(|e| tracing::error!("{}", e))?; - global::set_text_map_propagator(propagator()); + let propagator = propagator(); + + // The CORS filter needs to know what headers it should whitelist for + // CORS-protected requests. + mas_core::filters::cors::set_propagator(&propagator); + global::set_text_map_propagator(propagator); let tracer = tracer(&config.tracing)?; meter(&config.metrics)?; diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e0c61faa..b823ef3e 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -67,6 +67,8 @@ cookie = "0.15.1" oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] } mas-config = { path = "../config" } mas-data-model = { path = "../data-model" } +opentelemetry = "0.16.0" +once_cell = "1.8.0" [dependencies.jwt-compact] # Waiting on the next release because of the bump of the `rsa` dependency diff --git a/crates/core/src/filters/cors.rs b/crates/core/src/filters/cors.rs new file mode 100644 index 00000000..f1afd86e --- /dev/null +++ b/crates/core/src/filters/cors.rs @@ -0,0 +1,37 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Wrapper around [`warp::filters::cors`] + +use std::string::ToString; + +use once_cell::sync::OnceCell; + +static PROPAGATOR_HEADERS: OnceCell> = OnceCell::new(); + +/// Notify the CORS filter what opentelemetry propagators are being used. This +/// helps whitelisting headers in CORS requests. +pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) { + PROPAGATOR_HEADERS + .set(propagator.fields().map(ToString::to_string).collect()) + .expect(concat!(module_path!(), "::set_propagator was called twice")); +} + +/// Create a wrapping filter that exposes CORS behavior for a wrapped filter. +#[must_use] +pub fn cors() -> warp::filters::cors::Builder { + warp::filters::cors::cors() + .allow_any_origin() + .allow_headers(PROPAGATOR_HEADERS.get().unwrap_or(&Vec::new())) +} diff --git a/crates/core/src/filters/errors.rs b/crates/core/src/filters/errors.rs deleted file mode 100644 index 728f700c..00000000 --- a/crates/core/src/filters/errors.rs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{cmp::Reverse, future::Future, pin::Pin}; - -use mime::{Mime, STAR}; -use serde::Serialize; -use tera::Context; -use tide::{ - http::headers::{ACCEPT, LOCATION}, - Body, Request, StatusCode, -}; -use tracing::debug; - -use crate::{state::State, templates::common_context}; - -/// Get the weight parameter for a mime type from 0 to 1000 -#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] -fn get_weight(mime: &Mime) -> usize { - let q = mime - .get_param("q") - .map_or(1.0_f64, |q| q.as_str().parse().unwrap_or(0.0)) - .min(1.0) - .max(0.0); - - // Weight have a 3 digit precision so we can multiply by 1000 and cast to - // int. Sign loss should not happen here because of the min/max up there and - // truncation does not matter here. - (q * 1000.0) as _ -} - -/// Find what content type should be used for a given request -fn preferred_mime_type<'a>( - request: &Request, - supported_types: &'a [Mime], -) -> Option<&'a Mime> { - let accept = request.header(ACCEPT)?; - // Parse the Accept header as a list of mime types with their associated - // weight - let accepted_types: Vec<(Mime, usize)> = { - let v: Option> = accept - .into_iter() - .flat_map(|value| value.as_str().split(',')) - .map(|mime| { - mime.trim().parse().ok().map(|mime| { - let q = get_weight(&mime); - (mime, q) - }) - }) - .collect(); - let mut v = v?; - v.sort_by_key(|(_, weight)| Reverse(*weight)); - v - }; - - // For each supported content type, find out if it is accepted with what - // weight and specificity - let mut types: Vec<_> = supported_types - .iter() - .enumerate() - .filter_map(|(index, supported)| { - accepted_types.iter().find_map(|(accepted, weight)| { - if accepted.type_() == supported.type_() - && accepted.subtype() == supported.subtype() - { - // Accept: text/html - Some((supported, *weight, 2_usize, index)) - } else if accepted.type_() == supported.type_() && accepted.subtype() == STAR { - // Accept: text/* - Some((supported, *weight, 1, index)) - } else if accepted.type_() == STAR && accepted.subtype() == STAR { - // Accept: */* - Some((supported, *weight, 0, index)) - } else { - None - } - }) - }) - .collect(); - - types.sort_by_key(|(_, weight, specificity, index)| { - (Reverse(*weight), Reverse(*specificity), *index) - }); - - types.first().map(|(mime, _, _, _)| *mime) -} - -#[derive(Serialize)] -struct ErrorContext { - #[serde(skip_serializing_if = "Option::is_none")] - code: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - details: Option, -} - -impl ErrorContext { - fn should_render(&self) -> bool { - self.code.is_some() || self.description.is_some() || self.details.is_some() - } -} - -pub fn middleware<'a>( - request: tide::Request, - next: tide::Next<'a, State>, -) -> Pin + Send + 'a>> { - Box::pin(async { - let content_type = preferred_mime_type( - &request, - &[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON], - ); - debug!("Content-Type from Accept: {:?}", content_type); - - // TODO: We should not clone here - let templates = request.state().templates().clone(); - - // TODO: This context should probably be comptuted somewhere else - let pctx = common_context(&request).await?.clone(); - - let mut response = next.run(request).await; - - // Find out what message should be displayed from the response status - // code - let (code, description) = match response.status() { - StatusCode::NotFound => (Some("Not found".to_string()), None), - StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None), - StatusCode::Found - | StatusCode::PermanentRedirect - | StatusCode::TemporaryRedirect - | StatusCode::SeeOther => { - let description = response.header(LOCATION).map(|loc| format!("To {}", loc)); - (Some("Redirecting".to_string()), description) - } - StatusCode::InternalServerError => (Some("Internal server error".to_string()), None), - _ => (None, None), - }; - - // If there is an error associated to the response, format it in a nice - // way with a backtrace if we have one - let details = response.take_error().map(|err| { - format!( - "{:?}{}", - err, - err.backtrace() - .map(|bt| format!("\nBacktrace:\n{}", bt.to_string())) - .unwrap_or_default() - ) - }); - - let error_context = ErrorContext { - code, - description, - details, - }; - - // This is the case if one of the code, description or details is not - // None - if error_context.should_render() { - match content_type { - Some(c) if c == &mime::APPLICATION_JSON => { - response.set_body(Body::from_json(&error_context)?); - response.set_content_type("application/json"); - } - Some(c) if c == &mime::TEXT_HTML => { - let mut ctx = Context::from_serialize(&error_context)?; - ctx.extend(pctx); - response.set_body(templates.render("error.html", &ctx)?); - response.set_content_type("text/html"); - } - Some(c) if c == &mime::TEXT_PLAIN => { - let mut ctx = Context::from_serialize(&error_context)?; - ctx.extend(pctx); - response.set_body(templates.render("error.txt", &ctx)?); - response.set_content_type("text/plain"); - } - _ => { - response.set_body("Unsupported Content-Type in Accept header"); - response.set_content_type("text/plain"); - response.set_status(StatusCode::NotAcceptable); - } - } - } - - Ok(response) - }) -} diff --git a/crates/core/src/filters/mod.rs b/crates/core/src/filters/mod.rs index 97340ce3..f75694fa 100644 --- a/crates/core/src/filters/mod.rs +++ b/crates/core/src/filters/mod.rs @@ -17,11 +17,11 @@ #![allow(clippy::unused_async)] // Some warp filters need that #![deny(missing_docs)] -pub mod csrf; -// mod errors; pub mod authenticate; pub mod client; pub mod cookies; +pub mod cors; +pub mod csrf; pub mod database; pub mod headers; pub mod session; diff --git a/crates/core/src/handlers/oauth2/discovery.rs b/crates/core/src/handlers/oauth2/discovery.rs index 766216ae..64d227fe 100644 --- a/crates/core/src/handlers/oauth2/discovery.rs +++ b/crates/core/src/handlers/oauth2/discovery.rs @@ -15,6 +15,7 @@ use std::collections::HashSet; use hyper::Method; +use mas_config::OAuth2Config; use oauth2_types::{ oidc::Metadata, pkce::CodeChallengeMethod, @@ -22,7 +23,7 @@ use oauth2_types::{ }; use warp::{Filter, Rejection, Reply}; -use crate::config::OAuth2Config; +use crate::filters::cors::cors; pub(super) fn filter( config: &OAuth2Config, @@ -87,15 +88,9 @@ pub(super) fn filter( code_challenge_methods_supported, }; - // TODO: get the headers list from the global opentelemetry propagators - let cors = warp::cors() - .allow_method(Method::GET) - .allow_any_origin() - .allow_headers(["traceparent"]); - warp::path!(".well-known" / "openid-configuration").and( warp::get() .map(move || warp::reply::json(&metadata)) - .with(cors), + .with(cors().allow_method(Method::GET)), ) } diff --git a/crates/core/src/handlers/oauth2/introspection.rs b/crates/core/src/handlers/oauth2/introspection.rs index 9806774e..8b9320a2 100644 --- a/crates/core/src/handlers/oauth2/introspection.rs +++ b/crates/core/src/handlers/oauth2/introspection.rs @@ -13,6 +13,7 @@ // limitations under the License. use chrono::Utc; +use hyper::Method; use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint}; use sqlx::{pool::PoolConnection, PgPool, Postgres}; use tracing::{info, warn}; @@ -23,6 +24,7 @@ use crate::{ errors::WrapError, filters::{ client::{client_authentication, ClientAuthentication}, + cors::cors, database::connection, }, storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token}, @@ -33,12 +35,14 @@ pub fn filter( pool: &PgPool, oauth2_config: &OAuth2Config, ) -> impl Filter + Clone + Send + Sync + 'static { - warp::path!("oauth2" / "introspect") - .and(warp::post()) - .and(connection(pool)) - .and(client_authentication(oauth2_config)) - .and_then(introspect) - .recover(recover) + warp::path!("oauth2" / "introspect").and( + warp::post() + .and(connection(pool)) + .and(client_authentication(oauth2_config)) + .and_then(introspect) + .recover(recover) + .with(cors().allow_method(Method::POST)), + ) } const INACTIVE: IntrospectionResponse = IntrospectionResponse { diff --git a/crates/core/src/handlers/oauth2/keys.rs b/crates/core/src/handlers/oauth2/keys.rs index 06cd9867..08fe8748 100644 --- a/crates/core/src/handlers/oauth2/keys.rs +++ b/crates/core/src/handlers/oauth2/keys.rs @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +use hyper::Method; +use mas_config::OAuth2Config; use warp::{Filter, Rejection, Reply}; -use crate::config::OAuth2Config; +use crate::filters::cors::cors; pub(super) fn filter( config: &OAuth2Config, ) -> impl Filter + Clone + Send + Sync + 'static { let jwks = config.keys.to_public_jwks(); - let cors = warp::cors().allow_any_origin(); - - warp::path!("oauth2" / "keys.json") - .and(warp::get()) - .map(move || warp::reply::json(&jwks)) - .with(cors) + warp::path!("oauth2" / "keys.json").and( + warp::get() + .map(move || warp::reply::json(&jwks)) + .with(cors().allow_method(Method::GET)), + ) } diff --git a/crates/core/src/handlers/oauth2/token.rs b/crates/core/src/handlers/oauth2/token.rs index 2b401275..d51d043c 100644 --- a/crates/core/src/handlers/oauth2/token.rs +++ b/crates/core/src/handlers/oauth2/token.rs @@ -16,7 +16,7 @@ use anyhow::Context; use chrono::Duration; use data_encoding::BASE64URL_NOPAD; use headers::{CacheControl, Pragma}; -use hyper::StatusCode; +use hyper::{Method, StatusCode}; use jwt_compact::{Claims, Header, TimeOptions}; use oauth2_types::{ errors::{ @@ -44,6 +44,7 @@ use crate::{ errors::WrapError, filters::{ client::{client_authentication, ClientAuthentication}, + cors::cors, database::connection, with_keys, }, @@ -92,14 +93,16 @@ pub fn filter( oauth2_config: &OAuth2Config, ) -> impl Filter + Clone + Send + Sync + 'static { let issuer = oauth2_config.issuer.clone(); - warp::path!("oauth2" / "token") - .and(warp::post()) - .and(client_authentication(oauth2_config)) - .and(with_keys(oauth2_config)) - .and(warp::any().map(move || issuer.clone())) - .and(connection(pool)) - .and_then(token) - .recover(recover) + warp::path!("oauth2" / "token").and( + warp::post() + .and(client_authentication(oauth2_config)) + .and(with_keys(oauth2_config)) + .and(warp::any().map(move || issuer.clone())) + .and(connection(pool)) + .and_then(token) + .recover(recover) + .with(cors().allow_method(Method::POST)), + ) } async fn recover(rejection: Rejection) -> Result { diff --git a/crates/core/src/handlers/oauth2/userinfo.rs b/crates/core/src/handlers/oauth2/userinfo.rs index fdac9ed4..9d0cafb9 100644 --- a/crates/core/src/handlers/oauth2/userinfo.rs +++ b/crates/core/src/handlers/oauth2/userinfo.rs @@ -12,13 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +use hyper::Method; use serde::Serialize; use sqlx::PgPool; use warp::{Filter, Rejection, Reply}; use crate::{ config::OAuth2Config, - filters::authenticate::{authentication, recover_unauthorized}, + filters::{ + authenticate::{authentication, recover_unauthorized}, + cors::cors, + }, storage::oauth2::access_token::OAuth2AccessTokenLookup, }; @@ -31,11 +35,15 @@ pub(super) fn filter( pool: &PgPool, _config: &OAuth2Config, ) -> impl Filter + Clone + Send + Sync + 'static { - warp::path!("oauth2" / "userinfo") - .and(warp::get().or(warp::post()).unify()) - .and(authentication(pool)) - .and_then(userinfo) - .recover(recover_unauthorized) + warp::path!("oauth2" / "userinfo").and( + warp::get() + .or(warp::post()) + .unify() + .and(authentication(pool)) + .and_then(userinfo) + .recover(recover_unauthorized) + .with(cors().allow_methods([Method::GET, Method::POST])), + ) } async fn userinfo(token: OAuth2AccessTokenLookup) -> Result {