1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Cleanup HTTP client building

This commit is contained in:
Quentin Gliech
2022-11-23 12:26:58 +01:00
parent 16088fc11c
commit d514a8922c
4 changed files with 71 additions and 72 deletions

View File

@ -12,27 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{convert::Infallible, net::SocketAddr};
use std::convert::Infallible;
use bytes::Bytes;
use http::{Request, Response};
use hyper::{
client::{
connect::dns::{GaiResolver, Name},
HttpConnector,
},
client::{connect::dns::GaiResolver, HttpConnector},
Client,
};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error;
use tower::{Layer, Service};
use tower::Layer;
use crate::{
layers::{
client::ClientLayer,
client::{ClientLayer, ClientService},
otel::{TraceDns, TraceLayer},
},
BoxCloneSyncService, BoxError,
BoxError,
};
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
@ -154,38 +149,41 @@ async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
Ok(tls_config)
}
type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
type TracedClient<B> = hyper::Client<TracedConnector, B>;
/// Create a basic Hyper HTTP & HTTPS client without any tracing
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_untraced_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
pub async fn make_untraced_client<B>() -> Result<UntracedClient<B>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
B: http_body::Body + Send + 'static,
B::Data: Send,
{
let https = make_untraced_connector().await?;
Ok(Client::builder().build(https))
}
async fn make_traced_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
B: http_body::Body + Send + 'static,
B::Data: Send,
{
let https = make_traced_connector().await?;
Ok(Client::builder().build(https))
}
type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
type TracedConnector = HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>;
/// Create a traced HTTP and HTTPS connector
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_traced_connector(
) -> Result<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, ClientInitError>
pub async fn make_traced_connector() -> Result<TracedConnector, ClientInitError>
where
{
// Trace DNS requests
@ -194,8 +192,7 @@ where
Ok(make_connector(resolver, tls_config))
}
async fn make_untraced_connector(
) -> Result<HttpsConnector<HttpConnector<GaiResolver>>, ClientInitError>
async fn make_untraced_connector() -> Result<UntracedConnector, ClientInitError>
where
{
let resolver = GaiResolver::new();
@ -206,13 +203,7 @@ where
fn make_connector<R>(
resolver: R,
tls_config: rustls::ClientConfig,
) -> HttpsConnector<HttpConnector<R>>
where
R: Service<Name> + Send + Sync + Clone + 'static,
R::Error: std::error::Error + Send + Sync,
R::Future: Send,
R::Response: Iterator<Item = SocketAddr>,
{
) -> HttpsConnector<HttpConnector<R>> {
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
@ -229,16 +220,17 @@ where
/// # Errors
///
/// Returns an error if it failed to initialize
pub async fn client<B, E>(
pub async fn client<B>(
operation: &'static str,
) -> Result<BoxCloneSyncService<Request<B>, Response<hyper::Body>, hyper::Error>, ClientInitError>
) -> Result<ClientService<TracedClient<B>>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError> + 'static,
B: http_body::Body + Default + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
let client = make_traced_client().await?;
let client = ClientLayer::new(operation).layer(client);
Ok(BoxCloneSyncService::new(client))
Ok(client)
}

View File

@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{marker::PhantomData, time::Duration};
use std::{sync::Arc, time::Duration};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use http::{header::USER_AGENT, HeaderValue};
use tokio::sync::Semaphore;
use tower::{
limit::{ConcurrencyLimit, ConcurrencyLimitLayer},
Layer, Service,
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
Layer,
};
use tower_http::{
follow_redirect::{FollowRedirect, FollowRedirectLayer},
@ -26,56 +27,60 @@ use tower_http::{
};
use super::otel::TraceLayer;
use crate::{otel::TraceHttpClient, BoxError};
use crate::otel::{TraceHttpClient, TraceHttpClientLayer};
static MAS_USER_AGENT: HeaderValue =
HeaderValue::from_static("matrix-authentication-service/0.0.1");
pub type ClientService<S> = SetRequestHeader<
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
HeaderValue,
>;
#[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
operation: &'static str,
_t: PhantomData<ReqBody>,
pub struct ClientLayer {
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
outer_trace_layer: TraceHttpClientLayer,
concurrency_limit_layer: GlobalConcurrencyLimitLayer,
follow_redirect_layer: FollowRedirectLayer,
inner_trace_layer: TraceHttpClientLayer,
timeout_layer: TimeoutLayer,
}
impl<B> ClientLayer<B> {
impl ClientLayer {
#[must_use]
pub fn new(operation: &'static str) -> Self {
let semaphore = Arc::new(Semaphore::new(10));
Self::with_semaphore(operation, semaphore)
}
#[must_use]
pub fn with_semaphore(operation: &'static str, semaphore: Arc<Semaphore>) -> Self {
Self {
operation,
_t: PhantomData,
user_agent_layer: SetRequestHeaderLayer::overriding(
USER_AGENT,
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
),
outer_trace_layer: TraceLayer::http_client(operation),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
follow_redirect_layer: FollowRedirectLayer::new(),
inner_trace_layer: TraceLayer::inner_http_client(),
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
}
}
}
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
impl<S> Layer<S> for ClientLayer
where
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E>
+ Clone
+ Send
+ Sync
+ 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
S::Future: Send + 'static,
E: Into<BoxError>,
S: Clone,
{
type Service = SetRequestHeader<
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
HeaderValue,
>;
type Service = ClientService<S>;
fn layer(&self, inner: S) -> Self::Service {
// Note that all layers here just forward the error type.
(
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
// A trace that has the whole operation, with all the redirects, timeouts and rate
// limits in it
TraceLayer::http_client(self.operation),
ConcurrencyLimitLayer::new(10),
FollowRedirectLayer::new(),
// A trace for each "real" http request
TraceLayer::inner_http_client(),
TimeoutLayer::new(Duration::from_secs(10)),
&self.user_agent_layer,
&self.outer_trace_layer,
&self.concurrency_limit_layer,
&self.follow_redirect_layer,
&self.inner_trace_layer,
&self.timeout_layer,
)
.layer(inner)
}

View File

@ -49,3 +49,5 @@ pub use self::{
};
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type EmptyBody = http_body::Empty<bytes::Bytes>;