You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Simplify the HTTP client building
Also supports loading the WebPKI roots instead of the native ones for TLS
This commit is contained in:
@ -13,21 +13,24 @@ headers = "0.3.8"
|
||||
http = "0.2.8"
|
||||
http-body = "0.4.5"
|
||||
hyper = "0.14.20"
|
||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true }
|
||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"], default-features = false, optional = true }
|
||||
once_cell = "1.15.0"
|
||||
opentelemetry = "0.17.0"
|
||||
opentelemetry-http = "0.6.0"
|
||||
opentelemetry-semantic-conventions = "0.9.0"
|
||||
rustls = "0.20.6"
|
||||
rustls = { version = "0.20.6", optional = true }
|
||||
rustls-native-certs = { version = "0.6.2", optional = true }
|
||||
serde = "1.0.145"
|
||||
serde_json = "1.0.85"
|
||||
serde_urlencoded = "0.7.1"
|
||||
thiserror = "1.0.36"
|
||||
tokio = { version = "1.21.1", optional = true }
|
||||
tokio = { version = "1.21.1", features = ["sync", "parking_lot"], optional = true }
|
||||
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
||||
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
||||
tracing = "0.1.36"
|
||||
tracing-opentelemetry = "0.17.4"
|
||||
webpki = { version = "0.22.0", optional = true }
|
||||
webpki-roots = { version = "0.22.4", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0.65"
|
||||
@ -38,4 +41,12 @@ tower = { version = "0.4.13", features = ["util"] }
|
||||
[features]
|
||||
default = []
|
||||
axum = ["dep:axum"]
|
||||
client = ["dep:hyper-rustls", "hyper/tcp", "dep:tokio", "tokio?/sync", "tokio?/parking_lot"]
|
||||
native-roots = ["dep:rustls-native-certs"]
|
||||
webpki-roots = ["dep:webpki-roots"]
|
||||
client = [
|
||||
"dep:rustls",
|
||||
"hyper/tcp",
|
||||
"dep:hyper-rustls",
|
||||
"dep:tokio",
|
||||
"dep:webpki",
|
||||
]
|
||||
|
@ -12,50 +12,158 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{convert::Infallible, net::SocketAddr};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::{FutureExt, TryFutureExt};
|
||||
use http::{Request, Response};
|
||||
use http_body::{combinators::BoxBody, Body};
|
||||
use hyper::{
|
||||
client::{connect::dns::GaiResolver, HttpConnector},
|
||||
client::{
|
||||
connect::dns::{GaiResolver, Name},
|
||||
HttpConnector,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||
use thiserror::Error;
|
||||
use tokio::{sync::OnceCell, task::JoinError};
|
||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
||||
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt};
|
||||
|
||||
use crate::{
|
||||
layers::{
|
||||
client::{ClientLayer, ClientResponse},
|
||||
otel::{TraceDns, TraceLayer},
|
||||
},
|
||||
BoxError, FutureService,
|
||||
BoxError,
|
||||
};
|
||||
|
||||
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
|
||||
compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features");
|
||||
|
||||
#[cfg(all(feature = "webpki-roots", feature = "native-roots"))]
|
||||
compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive");
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
static NATIVE_TLS_ROOTS: tokio::sync::OnceCell<rustls::RootCertStore> =
|
||||
tokio::sync::OnceCell::const_new();
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
fn load_tls_roots_blocking() -> Result<rustls::RootCertStore, NativeRootsLoadError> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
let certs = rustls_native_certs::load_native_certs()?;
|
||||
for cert in certs {
|
||||
let cert = rustls::Certificate(cert.0);
|
||||
roots.add(&cert)?;
|
||||
}
|
||||
|
||||
if roots.is_empty() {
|
||||
return Err(NativeRootsLoadError::Empty);
|
||||
}
|
||||
|
||||
Ok(roots)
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
async fn tls_roots() -> Result<rustls::RootCertStore, NativeRootsInitError> {
|
||||
NATIVE_TLS_ROOTS
|
||||
.get_or_try_init(|| async move {
|
||||
// Load the TLS config once in a blocking task because loading the system
|
||||
// certificates can take a long time (~200ms) on macOS
|
||||
let span = tracing::info_span!("load_tls_roots");
|
||||
let roots = tokio::task::spawn_blocking(|| {
|
||||
let _span = span.entered();
|
||||
load_tls_roots_blocking()
|
||||
})
|
||||
.await??;
|
||||
Ok(roots)
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(feature = "webpki-roots")]
|
||||
async fn tls_roots() -> Result<rustls::RootCertStore, Infallible> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}));
|
||||
Ok(roots)
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
#[derive(Error, Debug)]
|
||||
#[error(transparent)]
|
||||
pub enum NativeRootsInitError {
|
||||
RootsLoadError(#[from] NativeRootsLoadError),
|
||||
|
||||
JoinError(#[from] tokio::task::JoinError),
|
||||
}
|
||||
|
||||
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
||||
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("failed to initialize HTTPS client")]
|
||||
Init(#[from] ClientInitError),
|
||||
|
||||
#[error(transparent)]
|
||||
Call(#[from] BoxError),
|
||||
#[error(transparent)]
|
||||
pub struct ClientError {
|
||||
#[from]
|
||||
inner: BoxError,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientInitError {
|
||||
#[error("failed to load system certificates")]
|
||||
CertificateLoad {
|
||||
#[from]
|
||||
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
||||
},
|
||||
#[cfg(feature = "native-roots")]
|
||||
#[error(transparent)]
|
||||
TlsRootsInit(std::sync::Arc<NativeRootsInitError>),
|
||||
}
|
||||
|
||||
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
||||
#[cfg(feature = "native-roots")]
|
||||
impl From<NativeRootsInitError> for ClientInitError {
|
||||
fn from(inner: NativeRootsInitError) -> Self {
|
||||
Self::TlsRootsInit(std::sync::Arc::new(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for ClientInitError {
|
||||
fn from(_: Infallible) -> Self {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum NativeRootsLoadError {
|
||||
#[error("could not load root certificates")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("invalid root certificate")]
|
||||
Webpki(#[from] webpki::Error),
|
||||
|
||||
#[error("no root certificate loaded")]
|
||||
Empty,
|
||||
}
|
||||
|
||||
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to load the TLS certificates
|
||||
pub async fn make_untraced_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
let resolver = GaiResolver::new();
|
||||
let roots = tls_roots().await?;
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
|
||||
Ok(make_client(resolver, tls_config))
|
||||
}
|
||||
|
||||
async fn make_base_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||
@ -68,54 +176,57 @@ where
|
||||
.layer(TraceLayer::dns())
|
||||
.service(GaiResolver::new());
|
||||
|
||||
let roots = tls_roots().await?;
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
|
||||
Ok(make_client(resolver, tls_config))
|
||||
}
|
||||
|
||||
fn make_client<R, B, E>(
|
||||
resolver: R,
|
||||
tls_config: rustls::ClientConfig,
|
||||
) -> hyper::Client<HttpsConnector<HttpConnector<R>>, B>
|
||||
where
|
||||
R: Service<Name> + Send + Sync + Clone + 'static,
|
||||
R::Error: std::error::Error + Send + Sync,
|
||||
R::Future: Send,
|
||||
R::Response: Iterator<Item = SocketAddr>,
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||
http.enforce_http(false);
|
||||
|
||||
let tls_config = TLS_CONFIG
|
||||
.get_or_try_init(|| async move {
|
||||
// Load the TLS config once in a blocking task because loading the system
|
||||
// certificates can take a long time (~200ms) on macOS
|
||||
let span = tracing::info_span!("load_certificates");
|
||||
tokio::task::spawn_blocking(|| {
|
||||
let _span = span.entered();
|
||||
rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_native_roots()
|
||||
.with_no_client_auth()
|
||||
})
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
||||
|
||||
let https = HttpsConnectorBuilder::new()
|
||||
.with_tls_config(tls_config.clone())
|
||||
.with_tls_config(tls_config)
|
||||
.https_or_http()
|
||||
.enable_http1()
|
||||
.enable_http2()
|
||||
.wrap_connector(http);
|
||||
|
||||
// TODO: we should get the remote address here
|
||||
let client = Client::builder().build(https);
|
||||
|
||||
Ok::<_, ClientInitError>(client)
|
||||
Client::builder().build(https)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn client<B, E>(
|
||||
/// Create a traced HTTP client, with a default timeout, which follows redirects
|
||||
/// and handles compression
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to initialize
|
||||
pub async fn client<B, E>(
|
||||
operation: &'static str,
|
||||
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
||||
) -> Result<
|
||||
BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>,
|
||||
ClientInitError,
|
||||
>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||
E: Into<BoxError> + 'static,
|
||||
{
|
||||
let fut = make_base_client()
|
||||
// Map the error to a ClientError
|
||||
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
||||
// Wrap it in an Shared (Arc) to be able to Clone it
|
||||
.shared();
|
||||
|
||||
let client: FutureService<_, _> = FutureService::new(fut);
|
||||
let client = make_base_client().await?;
|
||||
|
||||
let client = ServiceBuilder::new()
|
||||
// Convert the errors to ClientError to help dealing with them
|
||||
@ -124,7 +235,8 @@ where
|
||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||
})
|
||||
.layer(ClientLayer::new(operation))
|
||||
.service(client);
|
||||
.service(client)
|
||||
.boxed_clone();
|
||||
|
||||
client.boxed_clone()
|
||||
Ok(client)
|
||||
}
|
||||
|
@ -1,77 +0,0 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! A copy of [`tower::util::FutureService`] that also maps the future error to
|
||||
//! help implementing [`Clone`] on the service
|
||||
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use tower::Service;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct FutureService<F, S> {
|
||||
state: State<F, S>,
|
||||
}
|
||||
|
||||
impl<F, S> FutureService<F, S> {
|
||||
#[must_use]
|
||||
pub fn new(future: F) -> Self {
|
||||
Self {
|
||||
state: State::Future(future),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum State<F, S> {
|
||||
Future(F),
|
||||
Service(S),
|
||||
}
|
||||
|
||||
impl<F, S, R, FE, E> Service<R> for FutureService<F, S>
|
||||
where
|
||||
F: Future<Output = Result<S, FE>> + Unpin,
|
||||
S: Service<R, Error = E>,
|
||||
E: From<FE>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = E;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
loop {
|
||||
self.state = match &mut self.state {
|
||||
State::Future(fut) => {
|
||||
let fut = Pin::new(fut);
|
||||
let svc = ready!(fut.poll(cx)?);
|
||||
State::Service(svc)
|
||||
}
|
||||
State::Service(svc) => return svc.poll_ready(cx),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn call(&mut self, req: R) -> Self::Future {
|
||||
if let State::Service(svc) = &mut self.state {
|
||||
svc.call(req)
|
||||
} else {
|
||||
panic!("FutureService::call was called before FutureService::poll_ready")
|
||||
}
|
||||
}
|
||||
}
|
@ -27,17 +27,15 @@
|
||||
#[cfg(feature = "client")]
|
||||
mod client;
|
||||
mod ext;
|
||||
mod future_service;
|
||||
mod layers;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
pub use self::client::client;
|
||||
pub use self::client::{client, make_untraced_client};
|
||||
pub use self::{
|
||||
ext::{
|
||||
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
||||
ServiceExt as HttpServiceExt,
|
||||
},
|
||||
future_service::FutureService,
|
||||
layers::{
|
||||
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
|
||||
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
|
||||
|
Reference in New Issue
Block a user