// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::convert::Infallible; use hyper::{ client::{ connect::dns::{GaiResolver, Name}, HttpConnector, }, Client, }; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use mas_tower::{ DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer, InFlightCounterService, TraceLayer, TraceService, }; use thiserror::Error; use tower::Layer; use tracing::Span; #[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))] compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features"); #[cfg(all(feature = "webpki-roots", feature = "native-roots"))] compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive"); #[cfg(feature = "native-roots")] static NATIVE_TLS_ROOTS: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); #[cfg(feature = "native-roots")] fn load_tls_roots_blocking() -> Result { let mut roots = rustls::RootCertStore::empty(); let certs = rustls_native_certs::load_native_certs()?; for cert in certs { let cert = rustls::Certificate(cert.0); roots.add(&cert)?; } if roots.is_empty() { return Err(NativeRootsLoadError::Empty); } Ok(roots) } #[cfg(feature = "native-roots")] async fn tls_roots() -> Result { NATIVE_TLS_ROOTS .get_or_try_init(|| async move { // Load the TLS config once in a blocking task because loading the system // certificates can take a long time (~200ms) on macOS let span = tracing::info_span!("load_tls_roots"); let roots = tokio::task::spawn_blocking(|| { let _span = span.entered(); load_tls_roots_blocking() }) .await??; Ok(roots) }) .await .cloned() } #[cfg(feature = "webpki-roots")] #[allow(clippy::unused_async)] async fn tls_roots() -> Result { let mut roots = rustls::RootCertStore::empty(); roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( ta.subject, ta.spki, ta.name_constraints, ) })); Ok(roots) } #[cfg(feature = "native-roots")] #[derive(Error, Debug)] #[error(transparent)] pub enum NativeRootsInitError { RootsLoadError(#[from] NativeRootsLoadError), JoinError(#[from] tokio::task::JoinError), } #[derive(Error, Debug, Clone)] pub enum ClientInitError { #[cfg(feature = "native-roots")] #[error(transparent)] TlsRootsInit(std::sync::Arc), } #[cfg(feature = "native-roots")] impl From for ClientInitError { fn from(inner: NativeRootsInitError) -> Self { Self::TlsRootsInit(std::sync::Arc::new(inner)) } } impl From for ClientInitError { fn from(e: Infallible) -> Self { match e {} } } #[cfg(feature = "native-roots")] #[derive(Error, Debug)] pub enum NativeRootsLoadError { #[error("could not load root certificates")] Io(#[from] std::io::Error), #[error("invalid root certificate")] Rustls(#[from] rustls::Error), #[error("no root certificate loaded")] Empty, } async fn make_tls_config() -> Result { let roots = tls_roots().await?; let tls_config = rustls::ClientConfig::builder() .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth(); Ok(tls_config) } pub type UntracedClient = hyper::Client; pub type TracedClient = hyper::Client; /// Create a basic Hyper HTTP & HTTPS client without any tracing /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates pub async fn make_untraced_client() -> Result, ClientInitError> where B: http_body::Body + Send + 'static, B::Data: Send, { let https = make_untraced_connector().await?; Ok(Client::builder().build(https)) } /// Create a basic Hyper HTTP & HTTPS client which traces DNS requests /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates pub async fn make_traced_client() -> Result, ClientInitError> where B: http_body::Body + Send, B::Data: Send, { let https = make_traced_connector().await?; Ok(Client::builder().build(https)) } pub type TraceResolver = InFlightCounterService Span>>>>; pub type UntracedConnector = HttpsConnector>; pub type TracedConnector = HttpsConnector>>; /// Create a traced HTTP and HTTPS connector /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates pub async fn make_traced_connector() -> Result where { let in_flight_counter = InFlightCounterLayer::new("dns.resolve.active_requests"); let duration_recorder = DurationRecorderLayer::new("dns.resolve.duration"); let trace_layer = TraceLayer::from_fn( (|request: &Name| { tracing::info_span!( "dns.resolve", "otel.kind" = "client", "net.host.name" = %request, ) }) as fn(&Name) -> Span, ); let resolver = (in_flight_counter, duration_recorder, trace_layer).layer(GaiResolver::new()); let tls_config = make_tls_config().await?; Ok(make_connector(resolver, tls_config)) } async fn make_untraced_connector() -> Result where { let resolver = GaiResolver::new(); let tls_config = make_tls_config().await?; Ok(make_connector(resolver, tls_config)) } fn make_connector( resolver: R, tls_config: rustls::ClientConfig, ) -> HttpsConnector> { let mut http = HttpConnector::new_with_resolver(resolver); http.enforce_http(false); HttpsConnectorBuilder::new() .with_tls_config(tls_config) .https_or_http() .enable_http1() .enable_http2() .wrap_connector(http) }