From 7b819ffa8bf53c371f27ab9e07a96d395228c33b Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 15 Sep 2022 16:00:33 +0200 Subject: [PATCH] Simplify the HTTP client building Also supports loading the WebPKI roots instead of the native ones for TLS --- Cargo.lock | 4 +- crates/axum-utils/Cargo.toml | 5 + crates/axum-utils/src/client_authorization.rs | 10 +- crates/cli/Cargo.toml | 12 +- crates/cli/src/commands/debug.rs | 3 +- crates/handlers/Cargo.toml | 12 +- crates/http/Cargo.toml | 19 +- crates/http/src/client.rs | 218 +++++++++++++----- crates/http/src/future_service.rs | 77 ------- crates/http/src/lib.rs | 4 +- 10 files changed, 216 insertions(+), 148 deletions(-) delete mode 100644 crates/http/src/future_service.rs diff --git a/Cargo.lock b/Cargo.lock index 0f1bb02d..1bdd9d8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1926,7 +1926,6 @@ dependencies = [ "http", "hyper", "rustls 0.20.6", - "rustls-native-certs 0.6.2", "tokio", "tokio-rustls 0.23.4", ] @@ -2459,6 +2458,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-semantic-conventions", "rustls 0.20.6", + "rustls-native-certs 0.6.2", "serde", "serde_json", "serde_urlencoded", @@ -2468,6 +2468,8 @@ dependencies = [ "tower-http", "tracing", "tracing-opentelemetry", + "webpki 0.22.0", + "webpki-roots", ] [[package]] diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index b82ada1d..eb1045b0 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -36,3 +36,8 @@ mas-jose = { path = "../jose" } mas-keystore = { path = "../keystore" } mas-storage = { path = "../storage" } mas-templates = { path = "../templates" } + +[features] +default = ["native-roots"] +native-roots = ["mas-http/native-roots"] +webpki-roots = ["mas-http/webpki-roots"] diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 94ea1079..e6379688 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -39,7 +39,7 @@ use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use sqlx::PgExecutor; use thiserror::Error; -use tower::ServiceExt; +use tower::{Service, ServiceExt}; static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; @@ -177,12 +177,12 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result() - .map_err(Box::new); + .json_response::(); - let response = client.oneshot(request).await?; + let response = client.ready().await?.call(request).await?; Ok(response.into_body()) } diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 2f00a349..7da7ed4c 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -34,7 +34,7 @@ opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqw mas-config = { path = "../config" } mas-email = { path = "../email" } -mas-handlers = { path = "../handlers" } +mas-handlers = { path = "../handlers", default-features = false } mas-http = { path = "../http", features = ["axum"] } mas-policy = { path = "../policy" } mas-router = { path = "../router" } @@ -47,8 +47,16 @@ mas-templates = { path = "../templates" } indoc = "1.0.7" [features] -default = ["otlp", "jaeger", "zipkin"] +default = ["otlp", "jaeger", "zipkin", "native-roots"] + +# Use the native root certificates +native-roots = ["mas-http/native-roots", "mas-handlers/native-roots"] +# Use the webpki root certificates +webpki-roots = ["mas-http/webpki-roots", "mas-handlers/webpki-roots"] + +# Read the builtin static files and templates from the source directory dev = ["mas-templates/dev", "mas-static-files/dev"] + # Enable OpenTelemetry OTLP exporter. Requires "protoc" otlp = ["opentelemetry-otlp"] # Enable OpenTelemetry Jaeger exporter and propagator. diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 06cf5c05..3325f175 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -72,7 +72,7 @@ impl Options { json: false, url, } => { - let mut client = mas_http::client("cli-debug-http"); + let mut client = mas_http::client("cli-debug-http").await?; let request = hyper::Request::builder() .uri(url) .body(hyper::Body::empty())?; @@ -97,6 +97,7 @@ impl Options { url, } => { let mut client = mas_http::client("cli-debug-http") + .await? .response_body_to_bytes() .json_response(); let request = hyper::Request::builder() diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index d9fcaf94..3cac5ae8 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -47,10 +47,10 @@ rand = "0.8.5" headers = "0.3.8" oauth2-types = { path = "../oauth2-types" } -mas-axum-utils = { path = "../axum-utils" } +mas-axum-utils = { path = "../axum-utils", default-features = false } mas-data-model = { path = "../data-model" } mas-email = { path = "../email" } -mas-http = { path = "../http" } +mas-http = { path = "../http" } mas-iana = { path = "../iana" } mas-jose = { path = "../jose" } mas-keystore = { path = "../keystore" } @@ -61,3 +61,11 @@ mas-templates = { path = "../templates" } [dev-dependencies] indoc = "1.0.7" + +[features] +default = ["native-roots"] + +# Use the native root certificates +native-roots = ["mas-axum-utils/native-roots", "mas-http/native-roots"] +# Use the webpki root certificates +webpki-roots = ["mas-axum-utils/webpki-roots", "mas-http/webpki-roots"] diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 2bf90932..64772bd8 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -13,21 +13,24 @@ headers = "0.3.8" http = "0.2.8" http-body = "0.4.5" hyper = "0.14.20" -hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true } +hyper-rustls = { version = "0.23.0", features = ["http1", "http2"], default-features = false, optional = true } once_cell = "1.15.0" opentelemetry = "0.17.0" opentelemetry-http = "0.6.0" opentelemetry-semantic-conventions = "0.9.0" -rustls = "0.20.6" +rustls = { version = "0.20.6", optional = true } +rustls-native-certs = { version = "0.6.2", optional = true } serde = "1.0.145" serde_json = "1.0.85" serde_urlencoded = "0.7.1" thiserror = "1.0.36" -tokio = { version = "1.21.1", optional = true } +tokio = { version = "1.21.1", features = ["sync", "parking_lot"], optional = true } tower = { version = "0.4.13", features = ["timeout", "limit"] } tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] } tracing = "0.1.36" tracing-opentelemetry = "0.17.4" +webpki = { version = "0.22.0", optional = true } +webpki-roots = { version = "0.22.4", optional = true } [dev-dependencies] anyhow = "1.0.65" @@ -38,4 +41,12 @@ tower = { version = "0.4.13", features = ["util"] } [features] default = [] axum = ["dep:axum"] -client = ["dep:hyper-rustls", "hyper/tcp", "dep:tokio", "tokio?/sync", "tokio?/parking_lot"] +native-roots = ["dep:rustls-native-certs"] +webpki-roots = ["dep:webpki-roots"] +client = [ + "dep:rustls", + "hyper/tcp", + "dep:hyper-rustls", + "dep:tokio", + "dep:webpki", +] diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs index d181e052..c280566e 100644 --- a/crates/http/src/client.rs +++ b/crates/http/src/client.rs @@ -12,50 +12,158 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::{convert::Infallible, net::SocketAddr}; use bytes::Bytes; -use futures_util::{FutureExt, TryFutureExt}; use http::{Request, Response}; use http_body::{combinators::BoxBody, Body}; use hyper::{ - client::{connect::dns::GaiResolver, HttpConnector}, + client::{ + connect::dns::{GaiResolver, Name}, + HttpConnector, + }, Client, }; -use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; +use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use thiserror::Error; -use tokio::{sync::OnceCell, task::JoinError}; -use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt}; +use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt}; use crate::{ layers::{ client::{ClientLayer, ClientResponse}, otel::{TraceDns, TraceLayer}, }, - BoxError, FutureService, + BoxError, }; +#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))] +compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features"); + +#[cfg(all(feature = "webpki-roots", feature = "native-roots"))] +compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive"); + +#[cfg(feature = "native-roots")] +static NATIVE_TLS_ROOTS: tokio::sync::OnceCell = + tokio::sync::OnceCell::const_new(); + +#[cfg(feature = "native-roots")] +fn load_tls_roots_blocking() -> Result { + let mut roots = rustls::RootCertStore::empty(); + let certs = rustls_native_certs::load_native_certs()?; + for cert in certs { + let cert = rustls::Certificate(cert.0); + roots.add(&cert)?; + } + + if roots.is_empty() { + return Err(NativeRootsLoadError::Empty); + } + + Ok(roots) +} + +#[cfg(feature = "native-roots")] +async fn tls_roots() -> Result { + NATIVE_TLS_ROOTS + .get_or_try_init(|| async move { + // Load the TLS config once in a blocking task because loading the system + // certificates can take a long time (~200ms) on macOS + let span = tracing::info_span!("load_tls_roots"); + let roots = tokio::task::spawn_blocking(|| { + let _span = span.entered(); + load_tls_roots_blocking() + }) + .await??; + Ok(roots) + }) + .await + .cloned() +} + +#[cfg(feature = "webpki-roots")] +async fn tls_roots() -> Result { + let mut roots = rustls::RootCertStore::empty(); + roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + Ok(roots) +} + +#[cfg(feature = "native-roots")] +#[derive(Error, Debug)] +#[error(transparent)] +pub enum NativeRootsInitError { + RootsLoadError(#[from] NativeRootsLoadError), + + JoinError(#[from] tokio::task::JoinError), +} + /// A wrapper over a boxed error that implements ``std::error::Error``. /// This is helps converting to ``anyhow::Error`` with the `?` operator #[derive(Error, Debug)] -pub enum ClientError { - #[error("failed to initialize HTTPS client")] - Init(#[from] ClientInitError), - - #[error(transparent)] - Call(#[from] BoxError), +#[error(transparent)] +pub struct ClientError { + #[from] + inner: BoxError, } #[derive(Error, Debug, Clone)] pub enum ClientInitError { - #[error("failed to load system certificates")] - CertificateLoad { - #[from] - inner: Arc, // That error is in an Arc to have the error implement Clone - }, + #[cfg(feature = "native-roots")] + #[error(transparent)] + TlsRootsInit(std::sync::Arc), } -static TLS_CONFIG: OnceCell = OnceCell::const_new(); +#[cfg(feature = "native-roots")] +impl From for ClientInitError { + fn from(inner: NativeRootsInitError) -> Self { + Self::TlsRootsInit(std::sync::Arc::new(inner)) + } +} + +impl From for ClientInitError { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +#[cfg(feature = "native-roots")] +#[derive(Error, Debug)] +pub enum NativeRootsLoadError { + #[error("could not load root certificates")] + Io(#[from] std::io::Error), + + #[error("invalid root certificate")] + Webpki(#[from] webpki::Error), + + #[error("no root certificate loaded")] + Empty, +} + +/// Create a basic Hyper HTTP & HTTPS client without any tracing +/// +/// # Errors +/// +/// Returns an error if it failed to load the TLS certificates +pub async fn make_untraced_client( +) -> Result>, B>, ClientInitError> +where + B: http_body::Body + Send + 'static, + E: Into, +{ + let resolver = GaiResolver::new(); + let roots = tls_roots().await?; + let tls_config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth(); + + Ok(make_client(resolver, tls_config)) +} async fn make_base_client( ) -> Result>>, B>, ClientInitError> @@ -68,54 +176,57 @@ where .layer(TraceLayer::dns()) .service(GaiResolver::new()); + let roots = tls_roots().await?; + let tls_config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth(); + + Ok(make_client(resolver, tls_config)) +} + +fn make_client( + resolver: R, + tls_config: rustls::ClientConfig, +) -> hyper::Client>, B> +where + R: Service + Send + Sync + Clone + 'static, + R::Error: std::error::Error + Send + Sync, + R::Future: Send, + R::Response: Iterator, + B: http_body::Body + Send + 'static, + E: Into, +{ let mut http = HttpConnector::new_with_resolver(resolver); http.enforce_http(false); - let tls_config = TLS_CONFIG - .get_or_try_init(|| async move { - // Load the TLS config once in a blocking task because loading the system - // certificates can take a long time (~200ms) on macOS - let span = tracing::info_span!("load_certificates"); - tokio::task::spawn_blocking(|| { - let _span = span.entered(); - rustls::ClientConfig::builder() - .with_safe_defaults() - .with_native_roots() - .with_no_client_auth() - }) - .await - }) - .await - .map_err(|e| ClientInitError::from(Arc::new(e)))?; - let https = HttpsConnectorBuilder::new() - .with_tls_config(tls_config.clone()) + .with_tls_config(tls_config) .https_or_http() .enable_http1() .enable_http2() .wrap_connector(http); - // TODO: we should get the remote address here - let client = Client::builder().build(https); - - Ok::<_, ClientInitError>(client) + Client::builder().build(https) } -#[must_use] -pub fn client( +/// Create a traced HTTP client, with a default timeout, which follows redirects +/// and handles compression +/// +/// # Errors +/// +/// Returns an error if it failed to initialize +pub async fn client( operation: &'static str, -) -> BoxCloneService, Response>, ClientError> +) -> Result< + BoxCloneService, Response>, ClientError>, + ClientInitError, +> where B: http_body::Body + Default + Send + 'static, E: Into + 'static, { - let fut = make_base_client() - // Map the error to a ClientError - .map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e)))) - // Wrap it in an Shared (Arc) to be able to Clone it - .shared(); - - let client: FutureService<_, _> = FutureService::new(fut); + let client = make_base_client().await?; let client = ServiceBuilder::new() // Convert the errors to ClientError to help dealing with them @@ -124,7 +235,8 @@ where r.map(|body| body.map_err(ClientError::from).boxed()) }) .layer(ClientLayer::new(operation)) - .service(client); + .service(client) + .boxed_clone(); - client.boxed_clone() + Ok(client) } diff --git a/crates/http/src/future_service.rs b/crates/http/src/future_service.rs deleted file mode 100644 index 214022ba..00000000 --- a/crates/http/src/future_service.rs +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! A copy of [`tower::util::FutureService`] that also maps the future error to -//! help implementing [`Clone`] on the service - -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use futures_util::ready; -use tower::Service; - -#[derive(Clone, Debug)] -pub struct FutureService { - state: State, -} - -impl FutureService { - #[must_use] - pub fn new(future: F) -> Self { - Self { - state: State::Future(future), - } - } -} - -#[derive(Clone, Debug)] -enum State { - Future(F), - Service(S), -} - -impl Service for FutureService -where - F: Future> + Unpin, - S: Service, - E: From, -{ - type Response = S::Response; - type Error = E; - type Future = S::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - self.state = match &mut self.state { - State::Future(fut) => { - let fut = Pin::new(fut); - let svc = ready!(fut.poll(cx)?); - State::Service(svc) - } - State::Service(svc) => return svc.poll_ready(cx), - }; - } - } - - fn call(&mut self, req: R) -> Self::Future { - if let State::Service(svc) = &mut self.state { - svc.call(req) - } else { - panic!("FutureService::call was called before FutureService::poll_ready") - } - } -} diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index ce00c50a..dd0b71fb 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -27,17 +27,15 @@ #[cfg(feature = "client")] mod client; mod ext; -mod future_service; mod layers; #[cfg(feature = "client")] -pub use self::client::client; +pub use self::client::{client, make_untraced_client}; pub use self::{ ext::{ set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt, ServiceExt as HttpServiceExt, }, - future_service::FutureService, layers::{ body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},