1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Simplify the HTTP client building

Also supports loading the WebPKI roots instead of the native ones for
TLS
This commit is contained in:
Quentin Gliech
2022-09-15 16:00:33 +02:00
parent a663deb7e1
commit 7b819ffa8b
10 changed files with 216 additions and 148 deletions

View File

@ -13,21 +13,24 @@ headers = "0.3.8"
http = "0.2.8"
http-body = "0.4.5"
hyper = "0.14.20"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true }
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"], default-features = false, optional = true }
once_cell = "1.15.0"
opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0"
opentelemetry-semantic-conventions = "0.9.0"
rustls = "0.20.6"
rustls = { version = "0.20.6", optional = true }
rustls-native-certs = { version = "0.6.2", optional = true }
serde = "1.0.145"
serde_json = "1.0.85"
serde_urlencoded = "0.7.1"
thiserror = "1.0.36"
tokio = { version = "1.21.1", optional = true }
tokio = { version = "1.21.1", features = ["sync", "parking_lot"], optional = true }
tower = { version = "0.4.13", features = ["timeout", "limit"] }
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
tracing = "0.1.36"
tracing-opentelemetry = "0.17.4"
webpki = { version = "0.22.0", optional = true }
webpki-roots = { version = "0.22.4", optional = true }
[dev-dependencies]
anyhow = "1.0.65"
@ -38,4 +41,12 @@ tower = { version = "0.4.13", features = ["util"] }
[features]
default = []
axum = ["dep:axum"]
client = ["dep:hyper-rustls", "hyper/tcp", "dep:tokio", "tokio?/sync", "tokio?/parking_lot"]
native-roots = ["dep:rustls-native-certs"]
webpki-roots = ["dep:webpki-roots"]
client = [
"dep:rustls",
"hyper/tcp",
"dep:hyper-rustls",
"dep:tokio",
"dep:webpki",
]

View File

@ -12,50 +12,158 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use std::{convert::Infallible, net::SocketAddr};
use bytes::Bytes;
use futures_util::{FutureExt, TryFutureExt};
use http::{Request, Response};
use http_body::{combinators::BoxBody, Body};
use hyper::{
client::{connect::dns::GaiResolver, HttpConnector},
client::{
connect::dns::{GaiResolver, Name},
HttpConnector,
},
Client,
};
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error;
use tokio::{sync::OnceCell, task::JoinError};
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt};
use crate::{
layers::{
client::{ClientLayer, ClientResponse},
otel::{TraceDns, TraceLayer},
},
BoxError, FutureService,
BoxError,
};
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features");
#[cfg(all(feature = "webpki-roots", feature = "native-roots"))]
compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive");
#[cfg(feature = "native-roots")]
static NATIVE_TLS_ROOTS: tokio::sync::OnceCell<rustls::RootCertStore> =
tokio::sync::OnceCell::const_new();
#[cfg(feature = "native-roots")]
fn load_tls_roots_blocking() -> Result<rustls::RootCertStore, NativeRootsLoadError> {
let mut roots = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs()?;
for cert in certs {
let cert = rustls::Certificate(cert.0);
roots.add(&cert)?;
}
if roots.is_empty() {
return Err(NativeRootsLoadError::Empty);
}
Ok(roots)
}
#[cfg(feature = "native-roots")]
async fn tls_roots() -> Result<rustls::RootCertStore, NativeRootsInitError> {
NATIVE_TLS_ROOTS
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_tls_roots");
let roots = tokio::task::spawn_blocking(|| {
let _span = span.entered();
load_tls_roots_blocking()
})
.await??;
Ok(roots)
})
.await
.cloned()
}
#[cfg(feature = "webpki-roots")]
async fn tls_roots() -> Result<rustls::RootCertStore, Infallible> {
let mut roots = rustls::RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
Ok(roots)
}
#[cfg(feature = "native-roots")]
#[derive(Error, Debug)]
#[error(transparent)]
pub enum NativeRootsInitError {
RootsLoadError(#[from] NativeRootsLoadError),
JoinError(#[from] tokio::task::JoinError),
}
/// A wrapper over a boxed error that implements ``std::error::Error``.
/// This is helps converting to ``anyhow::Error`` with the `?` operator
#[derive(Error, Debug)]
pub enum ClientError {
#[error("failed to initialize HTTPS client")]
Init(#[from] ClientInitError),
#[error(transparent)]
Call(#[from] BoxError),
#[error(transparent)]
pub struct ClientError {
#[from]
inner: BoxError,
}
#[derive(Error, Debug, Clone)]
pub enum ClientInitError {
#[error("failed to load system certificates")]
CertificateLoad {
#[from]
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
},
#[cfg(feature = "native-roots")]
#[error(transparent)]
TlsRootsInit(std::sync::Arc<NativeRootsInitError>),
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
#[cfg(feature = "native-roots")]
impl From<NativeRootsInitError> for ClientInitError {
fn from(inner: NativeRootsInitError) -> Self {
Self::TlsRootsInit(std::sync::Arc::new(inner))
}
}
impl From<Infallible> for ClientInitError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
#[cfg(feature = "native-roots")]
#[derive(Error, Debug)]
pub enum NativeRootsLoadError {
#[error("could not load root certificates")]
Io(#[from] std::io::Error),
#[error("invalid root certificate")]
Webpki(#[from] webpki::Error),
#[error("no root certificate loaded")]
Empty,
}
/// Create a basic Hyper HTTP & HTTPS client without any tracing
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_untraced_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
{
let resolver = GaiResolver::new();
let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(make_client(resolver, tls_config))
}
async fn make_base_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
@ -68,54 +176,57 @@ where
.layer(TraceLayer::dns())
.service(GaiResolver::new());
let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(make_client(resolver, tls_config))
}
fn make_client<R, B, E>(
resolver: R,
tls_config: rustls::ClientConfig,
) -> hyper::Client<HttpsConnector<HttpConnector<R>>, B>
where
R: Service<Name> + Send + Sync + Clone + 'static,
R::Error: std::error::Error + Send + Sync,
R::Future: Send,
R::Response: Iterator<Item = SocketAddr>,
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
{
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
let tls_config = TLS_CONFIG
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_certificates");
tokio::task::spawn_blocking(|| {
let _span = span.entered();
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth()
})
.await
})
.await
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.with_tls_config(tls_config)
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http);
// TODO: we should get the remote address here
let client = Client::builder().build(https);
Ok::<_, ClientInitError>(client)
Client::builder().build(https)
}
#[must_use]
pub fn client<B, E>(
/// Create a traced HTTP client, with a default timeout, which follows redirects
/// and handles compression
///
/// # Errors
///
/// Returns an error if it failed to initialize
pub async fn client<B, E>(
operation: &'static str,
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
) -> Result<
BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>,
ClientInitError,
>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError> + 'static,
{
let fut = make_base_client()
// Map the error to a ClientError
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
// Wrap it in an Shared (Arc) to be able to Clone it
.shared();
let client: FutureService<_, _> = FutureService::new(fut);
let client = make_base_client().await?;
let client = ServiceBuilder::new()
// Convert the errors to ClientError to help dealing with them
@ -124,7 +235,8 @@ where
r.map(|body| body.map_err(ClientError::from).boxed())
})
.layer(ClientLayer::new(operation))
.service(client);
.service(client)
.boxed_clone();
client.boxed_clone()
Ok(client)
}

View File

@ -1,77 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A copy of [`tower::util::FutureService`] that also maps the future error to
//! help implementing [`Clone`] on the service
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures_util::ready;
use tower::Service;
#[derive(Clone, Debug)]
pub struct FutureService<F, S> {
state: State<F, S>,
}
impl<F, S> FutureService<F, S> {
#[must_use]
pub fn new(future: F) -> Self {
Self {
state: State::Future(future),
}
}
}
#[derive(Clone, Debug)]
enum State<F, S> {
Future(F),
Service(S),
}
impl<F, S, R, FE, E> Service<R> for FutureService<F, S>
where
F: Future<Output = Result<S, FE>> + Unpin,
S: Service<R, Error = E>,
E: From<FE>,
{
type Response = S::Response;
type Error = E;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
self.state = match &mut self.state {
State::Future(fut) => {
let fut = Pin::new(fut);
let svc = ready!(fut.poll(cx)?);
State::Service(svc)
}
State::Service(svc) => return svc.poll_ready(cx),
};
}
}
fn call(&mut self, req: R) -> Self::Future {
if let State::Service(svc) = &mut self.state {
svc.call(req)
} else {
panic!("FutureService::call was called before FutureService::poll_ready")
}
}
}

View File

@ -27,17 +27,15 @@
#[cfg(feature = "client")]
mod client;
mod ext;
mod future_service;
mod layers;
#[cfg(feature = "client")]
pub use self::client::client;
pub use self::client::{client, make_untraced_client};
pub use self::{
ext::{
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
ServiceExt as HttpServiceExt,
},
future_service::FutureService,
layers::{
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},