// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::{convert::Infallible, net::SocketAddr}; use bytes::Bytes; use http::{Request, Response}; use http_body::{combinators::BoxBody, Body}; use hyper::{ client::{ connect::dns::{GaiResolver, Name}, HttpConnector, }, Client, }; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use thiserror::Error; use tower::{ util::{BoxCloneService, MapErrLayer, MapResponseLayer}, Layer, Service, ServiceExt, }; use crate::{ layers::{ client::{ClientLayer, ClientResponse}, otel::{TraceDns, TraceLayer}, }, BoxError, }; #[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))] compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features"); #[cfg(all(feature = "webpki-roots", feature = "native-roots"))] compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive"); #[cfg(feature = "native-roots")] static NATIVE_TLS_ROOTS: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); #[cfg(feature = "native-roots")] fn load_tls_roots_blocking() -> Result { let mut roots = rustls::RootCertStore::empty(); let certs = rustls_native_certs::load_native_certs()?; for cert in certs { let cert = rustls::Certificate(cert.0); roots.add(&cert)?; } if roots.is_empty() { return Err(NativeRootsLoadError::Empty); } Ok(roots) } #[cfg(feature = "native-roots")] async fn tls_roots() -> Result { NATIVE_TLS_ROOTS .get_or_try_init(|| async move { // Load the TLS config once in a blocking task because loading the system // certificates can take a long time (~200ms) on macOS let span = tracing::info_span!("load_tls_roots"); let roots = tokio::task::spawn_blocking(|| { let _span = span.entered(); load_tls_roots_blocking() }) .await??; Ok(roots) }) .await .cloned() } #[cfg(feature = "webpki-roots")] #[allow(clippy::unused_async)] async fn tls_roots() -> Result { let mut roots = rustls::RootCertStore::empty(); roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( ta.subject, ta.spki, ta.name_constraints, ) })); Ok(roots) } #[cfg(feature = "native-roots")] #[derive(Error, Debug)] #[error(transparent)] pub enum NativeRootsInitError { RootsLoadError(#[from] NativeRootsLoadError), JoinError(#[from] tokio::task::JoinError), } /// A wrapper over a boxed error that implements ``std::error::Error``. /// This is helps converting to ``anyhow::Error`` with the `?` operator #[derive(Error, Debug)] #[error(transparent)] pub struct ClientError { #[from] inner: BoxError, } #[derive(Error, Debug, Clone)] pub enum ClientInitError { #[cfg(feature = "native-roots")] #[error(transparent)] TlsRootsInit(std::sync::Arc), } #[cfg(feature = "native-roots")] impl From for ClientInitError { fn from(inner: NativeRootsInitError) -> Self { Self::TlsRootsInit(std::sync::Arc::new(inner)) } } impl From for ClientInitError { fn from(_: Infallible) -> Self { unreachable!() } } #[cfg(feature = "native-roots")] #[derive(Error, Debug)] pub enum NativeRootsLoadError { #[error("could not load root certificates")] Io(#[from] std::io::Error), #[error("invalid root certificate")] Webpki(#[from] webpki::Error), #[error("no root certificate loaded")] Empty, } async fn make_tls_config() -> Result { let roots = tls_roots().await?; let tls_config = rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth(); Ok(tls_config) } /// Create a basic Hyper HTTP & HTTPS client without any tracing /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates pub async fn make_untraced_client( ) -> Result>, B>, ClientInitError> where B: http_body::Body + Send + 'static, E: Into, { let https = make_untraced_connector().await?; Ok(Client::builder().build(https)) } async fn make_traced_client( ) -> Result>>, B>, ClientInitError> where B: http_body::Body + Send + 'static, E: Into, { let https = make_traced_connector().await?; Ok(Client::builder().build(https)) } /// Create a traced HTTP and HTTPS connector /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates pub async fn make_traced_connector( ) -> Result>>, ClientInitError> where { // Trace DNS requests let resolver = TraceLayer::dns().layer(GaiResolver::new()); let tls_config = make_tls_config().await?; Ok(make_connector(resolver, tls_config)) } async fn make_untraced_connector( ) -> Result>, ClientInitError> where { let resolver = GaiResolver::new(); let tls_config = make_tls_config().await?; Ok(make_connector(resolver, tls_config)) } fn make_connector( resolver: R, tls_config: rustls::ClientConfig, ) -> HttpsConnector> where R: Service + Send + Sync + Clone + 'static, R::Error: std::error::Error + Send + Sync, R::Future: Send, R::Response: Iterator, { let mut http = HttpConnector::new_with_resolver(resolver); http.enforce_http(false); HttpsConnectorBuilder::new() .with_tls_config(tls_config) .https_or_http() .enable_http1() .enable_http2() .wrap_connector(http) } /// Create a traced HTTP client, with a default timeout, which follows redirects /// and handles compression /// /// # Errors /// /// Returns an error if it failed to initialize pub async fn client( operation: &'static str, ) -> Result< BoxCloneService, Response>, ClientError>, ClientInitError, > where B: http_body::Body + Default + Send + 'static, E: Into + 'static, { let client = make_traced_client().await?; let layer = ( // Convert the errors to ClientError to help dealing with them MapErrLayer::new(ClientError::from), MapResponseLayer::new(|r: ClientResponse| { r.map(|body| body.map_err(ClientError::from).boxed()) }), ClientLayer::new(operation), ); let client = layer.layer(client).boxed_clone(); Ok(client) }