You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Gate some crates behind features in mas-http
This commit is contained in:
@ -36,4 +36,4 @@ mas-storage = { path = "../storage" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-http = { path = "../http" }
|
||||
mas-http = { path = "../http", features = ["client"] }
|
||||
|
@ -35,7 +35,7 @@ opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqw
|
||||
mas-config = { path = "../config" }
|
||||
mas-email = { path = "../email" }
|
||||
mas-handlers = { path = "../handlers" }
|
||||
mas-http = { path = "../http" }
|
||||
mas-http = { path = "../http", features = ["axum"] }
|
||||
mas-policy = { path = "../policy" }
|
||||
mas-router = { path = "../router" }
|
||||
mas-static-files = { path = "../static-files" }
|
||||
|
@ -6,14 +6,14 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.5.13"
|
||||
axum = { version = "0.5.13", optional = true }
|
||||
bytes = "1.2.1"
|
||||
futures-util = "0.3.21"
|
||||
headers = "0.3.7"
|
||||
http = "0.2.8"
|
||||
http-body = "0.4.5"
|
||||
hyper = "0.14.20"
|
||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false }
|
||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true }
|
||||
once_cell = "1.13.0"
|
||||
opentelemetry = "0.17.0"
|
||||
opentelemetry-http = "0.6.0"
|
||||
@ -23,14 +23,19 @@ serde = "1.0.142"
|
||||
serde_json = "1.0.83"
|
||||
serde_urlencoded = "0.7.1"
|
||||
thiserror = "1.0.32"
|
||||
tokio = { version = "1.20.1", features = ["sync", "parking_lot"] }
|
||||
tokio = { version = "1.20.1", optional = true }
|
||||
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
||||
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
|
||||
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
||||
tracing = "0.1.36"
|
||||
tracing-opentelemetry = "0.17.4"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0.62"
|
||||
serde = { version = "1.0.142", features = ["derive"] }
|
||||
tokio = { version = "1.20.1", features = ["macros"] }
|
||||
tokio = { version = "1.20.1", features = ["macros", "rt"] }
|
||||
tower = { version = "0.4.13", features = ["util"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
axum = ["dep:axum"]
|
||||
client = ["dep:hyper-rustls", "hyper/tcp", "tokio", "tokio/sync", "tokio/parking_lot"]
|
||||
|
130
crates/http/src/client.rs
Normal file
130
crates/http/src/client.rs
Normal file
@ -0,0 +1,130 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::{FutureExt, TryFutureExt};
|
||||
use http::{Request, Response};
|
||||
use http_body::{combinators::BoxBody, Body};
|
||||
use hyper::{
|
||||
client::{connect::dns::GaiResolver, HttpConnector},
|
||||
Client,
|
||||
};
|
||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
||||
use thiserror::Error;
|
||||
use tokio::{sync::OnceCell, task::JoinError};
|
||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
||||
|
||||
use crate::{
|
||||
layers::{
|
||||
client::{ClientLayer, ClientResponse},
|
||||
otel::{TraceDns, TraceLayer},
|
||||
},
|
||||
BoxError, FutureService,
|
||||
};
|
||||
|
||||
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
||||
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("failed to initialize HTTPS client")]
|
||||
Init(#[from] ClientInitError),
|
||||
|
||||
#[error(transparent)]
|
||||
Call(#[from] BoxError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientInitError {
|
||||
#[error("failed to load system certificates")]
|
||||
CertificateLoad {
|
||||
#[from]
|
||||
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
||||
},
|
||||
}
|
||||
|
||||
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
||||
|
||||
async fn make_base_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
// Trace DNS requests
|
||||
let resolver = ServiceBuilder::new()
|
||||
.layer(TraceLayer::dns())
|
||||
.service(GaiResolver::new());
|
||||
|
||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||
http.enforce_http(false);
|
||||
|
||||
let tls_config = TLS_CONFIG
|
||||
.get_or_try_init(|| async move {
|
||||
// Load the TLS config once in a blocking task because loading the system
|
||||
// certificates can take a long time (~200ms) on macOS
|
||||
let span = tracing::info_span!("load_certificates");
|
||||
tokio::task::spawn_blocking(|| {
|
||||
let _span = span.entered();
|
||||
rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_native_roots()
|
||||
.with_no_client_auth()
|
||||
})
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
||||
|
||||
let https = HttpsConnectorBuilder::new()
|
||||
.with_tls_config(tls_config.clone())
|
||||
.https_or_http()
|
||||
.enable_http1()
|
||||
.enable_http2()
|
||||
.wrap_connector(http);
|
||||
|
||||
// TODO: we should get the remote address here
|
||||
let client = Client::builder().build(https);
|
||||
|
||||
Ok::<_, ClientInitError>(client)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn client<B, E>(
|
||||
operation: &'static str,
|
||||
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||
E: Into<BoxError> + 'static,
|
||||
{
|
||||
let fut = make_base_client()
|
||||
// Map the error to a ClientError
|
||||
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
||||
// Wrap it in an Shared (Arc) to be able to Clone it
|
||||
.shared();
|
||||
|
||||
let client: FutureService<_, _> = FutureService::new(fut);
|
||||
|
||||
let client = ServiceBuilder::new()
|
||||
// Convert the errors to ClientError to help dealing with them
|
||||
.map_err(ClientError::from)
|
||||
.map_response(|r: ClientResponse<hyper::Body>| {
|
||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||
})
|
||||
.layer(ClientLayer::new(operation))
|
||||
.service(client);
|
||||
|
||||
client.boxed_clone()
|
||||
}
|
@ -12,17 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
use axum::extract::{ConnectInfo, MatchedPath};
|
||||
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
|
||||
use http::{Method, Request, Version};
|
||||
#[cfg(feature = "client")]
|
||||
use hyper::client::connect::dns::Name;
|
||||
use opentelemetry::trace::{SpanBuilder, SpanKind};
|
||||
use opentelemetry_semantic_conventions::trace::{
|
||||
HTTP_FLAVOR, HTTP_HOST, HTTP_METHOD, HTTP_REQUEST_CONTENT_LENGTH, HTTP_ROUTE, HTTP_TARGET,
|
||||
HTTP_USER_AGENT, NET_HOST_NAME, NET_PEER_IP, NET_PEER_PORT, NET_TRANSPORT,
|
||||
};
|
||||
use opentelemetry_semantic_conventions::trace as SC;
|
||||
|
||||
pub trait MakeSpanBuilder<R> {
|
||||
fn make_span_builder(&self, request: &R) -> SpanBuilder;
|
||||
@ -117,24 +116,24 @@ impl SpanFromHttpRequest {
|
||||
impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
|
||||
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
|
||||
let mut attributes = vec![
|
||||
HTTP_METHOD.string(http_method_str(request.method())),
|
||||
HTTP_FLAVOR.string(http_flavor(request.version())),
|
||||
HTTP_TARGET.string(request.uri().to_string()),
|
||||
SC::HTTP_METHOD.string(http_method_str(request.method())),
|
||||
SC::HTTP_FLAVOR.string(http_flavor(request.version())),
|
||||
SC::HTTP_TARGET.string(request.uri().to_string()),
|
||||
];
|
||||
|
||||
let headers = request.headers();
|
||||
|
||||
if let Some(host) = headers.typed_get::<Host>() {
|
||||
attributes.push(HTTP_HOST.string(host.to_string()));
|
||||
attributes.push(SC::HTTP_HOST.string(host.to_string()));
|
||||
}
|
||||
|
||||
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
|
||||
attributes.push(HTTP_USER_AGENT.string(user_agent.to_string()));
|
||||
attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string()));
|
||||
}
|
||||
|
||||
if let Some(ContentLength(content_length)) = headers.typed_get() {
|
||||
if let Ok(content_length) = content_length.try_into() {
|
||||
attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
||||
attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,42 +143,47 @@ impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SpanFromAxumRequest;
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
|
||||
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
|
||||
let mut attributes = vec![
|
||||
HTTP_METHOD.string(http_method_str(request.method())),
|
||||
HTTP_FLAVOR.string(http_flavor(request.version())),
|
||||
HTTP_TARGET.string(request.uri().to_string()),
|
||||
SC::HTTP_METHOD.string(http_method_str(request.method())),
|
||||
SC::HTTP_FLAVOR.string(http_flavor(request.version())),
|
||||
SC::HTTP_TARGET.string(request.uri().to_string()),
|
||||
];
|
||||
|
||||
let headers = request.headers();
|
||||
|
||||
if let Some(host) = headers.typed_get::<Host>() {
|
||||
attributes.push(HTTP_HOST.string(host.to_string()));
|
||||
attributes.push(SC::HTTP_HOST.string(host.to_string()));
|
||||
}
|
||||
|
||||
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
|
||||
attributes.push(HTTP_USER_AGENT.string(user_agent.to_string()));
|
||||
attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string()));
|
||||
}
|
||||
|
||||
if let Some(ContentLength(content_length)) = headers.typed_get() {
|
||||
if let Ok(content_length) = content_length.try_into() {
|
||||
attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
||||
attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
|
||||
attributes.push(NET_TRANSPORT.string("ip_tcp"));
|
||||
attributes.push(NET_PEER_IP.string(addr.ip().to_string()));
|
||||
attributes.push(NET_PEER_PORT.i64(addr.port().into()));
|
||||
if let Some(ConnectInfo(addr)) = request
|
||||
.extensions()
|
||||
.get::<ConnectInfo<std::net::SocketAddr>>()
|
||||
{
|
||||
attributes.push(SC::NET_TRANSPORT.string("ip_tcp"));
|
||||
attributes.push(SC::NET_PEER_IP.string(addr.ip().to_string()));
|
||||
attributes.push(SC::NET_PEER_PORT.i64(addr.port().into()));
|
||||
}
|
||||
|
||||
let name = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
||||
let path = path.as_str().to_owned();
|
||||
attributes.push(HTTP_ROUTE.string(path.clone()));
|
||||
attributes.push(SC::HTTP_ROUTE.string(path.clone()));
|
||||
path
|
||||
} else {
|
||||
request.uri().path().to_owned()
|
||||
@ -191,12 +195,14 @@ impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct SpanFromDnsRequest;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
impl MakeSpanBuilder<Name> for SpanFromDnsRequest {
|
||||
fn make_span_builder(&self, request: &Name) -> SpanBuilder {
|
||||
let attributes = vec![NET_HOST_NAME.string(request.as_str().to_owned())];
|
||||
let attributes = vec![SC::NET_HOST_NAME.string(request.as_str().to_owned())];
|
||||
|
||||
SpanBuilder::from_name("resolve")
|
||||
.with_kind(SpanKind::Client)
|
||||
|
@ -37,6 +37,7 @@ pub type TraceHttpServer<S> = Trace<
|
||||
S,
|
||||
>;
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
pub type TraceAxumServerLayer = TraceLayer<
|
||||
ExtractFromHttpRequest,
|
||||
DefaultInjectContext,
|
||||
@ -45,6 +46,7 @@ pub type TraceAxumServerLayer = TraceLayer<
|
||||
DefaultOnError,
|
||||
>;
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
pub type TraceAxumServer<S> = Trace<
|
||||
ExtractFromHttpRequest,
|
||||
DefaultInjectContext,
|
||||
@ -71,6 +73,7 @@ pub type TraceHttpClient<S> = Trace<
|
||||
S,
|
||||
>;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
pub type TraceDnsLayer = TraceLayer<
|
||||
DefaultExtractContext,
|
||||
DefaultInjectContext,
|
||||
@ -79,6 +82,7 @@ pub type TraceDnsLayer = TraceLayer<
|
||||
DefaultOnError,
|
||||
>;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
pub type TraceDns<S> = Trace<
|
||||
DefaultExtractContext,
|
||||
DefaultInjectContext,
|
||||
@ -98,6 +102,7 @@ impl TraceHttpServerLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
impl TraceAxumServerLayer {
|
||||
#[must_use]
|
||||
pub fn axum() -> Self {
|
||||
@ -126,6 +131,7 @@ impl TraceHttpClientLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
impl TraceDnsLayer {
|
||||
#[must_use]
|
||||
pub fn dns() -> Self {
|
||||
|
@ -14,12 +14,10 @@
|
||||
|
||||
use headers::{ContentLength, HeaderMapExt};
|
||||
use http::Response;
|
||||
#[cfg(feature = "client")]
|
||||
use hyper::client::connect::HttpInfo;
|
||||
use opentelemetry::trace::SpanRef;
|
||||
use opentelemetry_semantic_conventions::trace::{
|
||||
HTTP_RESPONSE_CONTENT_LENGTH, HTTP_STATUS_CODE, NET_HOST_IP, NET_HOST_PORT, NET_PEER_IP,
|
||||
NET_PEER_PORT,
|
||||
};
|
||||
use opentelemetry_semantic_conventions::trace as SC;
|
||||
|
||||
pub trait OnResponse<R> {
|
||||
fn on_response(&self, span: &SpanRef<'_>, response: &R);
|
||||
@ -37,21 +35,22 @@ pub struct OnHttpResponse;
|
||||
|
||||
impl<B> OnResponse<Response<B>> for OnHttpResponse {
|
||||
fn on_response(&self, span: &SpanRef<'_>, response: &Response<B>) {
|
||||
span.set_attribute(HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16())));
|
||||
span.set_attribute(SC::HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16())));
|
||||
|
||||
if let Some(ContentLength(content_length)) = response.headers().typed_get() {
|
||||
if let Ok(content_length) = content_length.try_into() {
|
||||
span.set_attribute(HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length));
|
||||
span.set_attribute(SC::HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
// Get local and remote address from hyper's HttpInfo injected by the
|
||||
// HttpConnector
|
||||
if let Some(info) = response.extensions().get::<HttpInfo>() {
|
||||
span.set_attribute(NET_PEER_IP.string(info.remote_addr().ip().to_string()));
|
||||
span.set_attribute(NET_PEER_PORT.i64(info.remote_addr().port().into()));
|
||||
span.set_attribute(NET_HOST_IP.string(info.local_addr().ip().to_string()));
|
||||
span.set_attribute(NET_HOST_PORT.i64(info.local_addr().port().into()));
|
||||
span.set_attribute(SC::NET_PEER_IP.string(info.remote_addr().ip().to_string()));
|
||||
span.set_attribute(SC::NET_PEER_PORT.i64(info.remote_addr().port().into()));
|
||||
span.set_attribute(SC::NET_HOST_IP.string(info.local_addr().ip().to_string()));
|
||||
span.set_attribute(SC::NET_HOST_PORT.i64(info.local_addr().port().into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -37,10 +37,14 @@ where
|
||||
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
ServiceBuilder::new()
|
||||
.compression()
|
||||
.layer(TraceLayer::axum())
|
||||
.service(inner)
|
||||
.boxed_clone()
|
||||
let builder = ServiceBuilder::new().compression();
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
let builder = builder.layer(TraceLayer::axum());
|
||||
|
||||
#[cfg(not(feature = "axum"))]
|
||||
let builder = builder.layer(TraceLayer::http_server());
|
||||
|
||||
builder.service(inner).boxed_clone()
|
||||
}
|
||||
}
|
||||
|
@ -24,30 +24,14 @@
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::{FutureExt, TryFutureExt};
|
||||
use http::{Request, Response};
|
||||
use http_body::{combinators::BoxBody, Body};
|
||||
use hyper::{
|
||||
client::{connect::dns::GaiResolver, HttpConnector},
|
||||
Client,
|
||||
};
|
||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
||||
use thiserror::Error;
|
||||
use tokio::{sync::OnceCell, task::JoinError};
|
||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
||||
|
||||
use self::layers::{
|
||||
client::ClientResponse,
|
||||
otel::{TraceDns, TraceLayer},
|
||||
};
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
mod client;
|
||||
mod ext;
|
||||
mod future_service;
|
||||
mod layers;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
pub use self::client::client;
|
||||
pub use self::{
|
||||
ext::{
|
||||
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
||||
@ -67,97 +51,3 @@ pub use self::{
|
||||
};
|
||||
|
||||
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
||||
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("failed to initialize HTTPS client")]
|
||||
Init(#[from] ClientInitError),
|
||||
|
||||
#[error(transparent)]
|
||||
Call(#[from] BoxError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientInitError {
|
||||
#[error("failed to load system certificates")]
|
||||
CertificateLoad {
|
||||
#[from]
|
||||
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
||||
},
|
||||
}
|
||||
|
||||
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
||||
|
||||
async fn make_base_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
// Trace DNS requests
|
||||
let resolver = ServiceBuilder::new()
|
||||
.layer(TraceLayer::dns())
|
||||
.service(GaiResolver::new());
|
||||
|
||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||
http.enforce_http(false);
|
||||
|
||||
let tls_config = TLS_CONFIG
|
||||
.get_or_try_init(|| async move {
|
||||
// Load the TLS config once in a blocking task because loading the system
|
||||
// certificates can take a long time (~200ms) on macOS
|
||||
let span = tracing::info_span!("load_certificates");
|
||||
tokio::task::spawn_blocking(|| {
|
||||
let _span = span.entered();
|
||||
rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_native_roots()
|
||||
.with_no_client_auth()
|
||||
})
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
||||
|
||||
let https = HttpsConnectorBuilder::new()
|
||||
.with_tls_config(tls_config.clone())
|
||||
.https_or_http()
|
||||
.enable_http1()
|
||||
.enable_http2()
|
||||
.wrap_connector(http);
|
||||
|
||||
// TODO: we should get the remote address here
|
||||
let client = Client::builder().build(https);
|
||||
|
||||
Ok::<_, ClientInitError>(client)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn client<B, E>(
|
||||
operation: &'static str,
|
||||
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||
E: Into<BoxError> + 'static,
|
||||
{
|
||||
let fut = make_base_client()
|
||||
// Map the error to a ClientError
|
||||
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
||||
// Wrap it in an Shared (Arc) to be able to Clone it
|
||||
.shared();
|
||||
|
||||
let client: FutureService<_, _> = FutureService::new(fut);
|
||||
|
||||
let client = ServiceBuilder::new()
|
||||
// Convert the errors to ClientError to help dealing with them
|
||||
.map_err(ClientError::from)
|
||||
.map_response(|r: ClientResponse<hyper::Body>| {
|
||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||
})
|
||||
.layer(ClientLayer::new(operation))
|
||||
.service(client);
|
||||
|
||||
client.boxed_clone()
|
||||
}
|
||||
|
Reference in New Issue
Block a user