From d514a8922cab5111a2e0353f2517414824bc493a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 23 Nov 2022 12:26:58 +0100 Subject: [PATCH] Cleanup HTTP client building --- crates/axum-utils/src/client_authorization.rs | 2 +- crates/http/src/client.rs | 60 ++++++-------- crates/http/src/layers/client.rs | 79 ++++++++++--------- crates/http/src/lib.rs | 2 + 4 files changed, 71 insertions(+), 72 deletions(-) diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 6835658f..6710dd35 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -174,7 +174,7 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result Result { Ok(tls_config) } +type UntracedClient = hyper::Client; +type TracedClient = hyper::Client; + /// Create a basic Hyper HTTP & HTTPS client without any tracing /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates -pub async fn make_untraced_client( -) -> Result>, B>, ClientInitError> +pub async fn make_untraced_client() -> Result, ClientInitError> where - B: http_body::Body + Send + 'static, - E: Into, + B: http_body::Body + Send + 'static, + B::Data: Send, { let https = make_untraced_connector().await?; Ok(Client::builder().build(https)) } -async fn make_traced_client( -) -> Result>>, B>, ClientInitError> +async fn make_traced_client() -> Result, ClientInitError> where - B: http_body::Body + Send + 'static, - E: Into, + B: http_body::Body + Send + 'static, + B::Data: Send, { let https = make_traced_connector().await?; Ok(Client::builder().build(https)) } +type UntracedConnector = HttpsConnector>; +type TracedConnector = HttpsConnector>>; + /// Create a traced HTTP and HTTPS connector /// /// # Errors /// /// Returns an error if it failed to load the TLS certificates -pub async fn make_traced_connector( -) -> Result>>, ClientInitError> +pub async fn make_traced_connector() -> Result where { // Trace DNS requests @@ -194,8 +192,7 @@ where Ok(make_connector(resolver, tls_config)) } -async fn make_untraced_connector( -) -> Result>, ClientInitError> +async fn make_untraced_connector() -> Result where { let resolver = GaiResolver::new(); @@ -206,13 +203,7 @@ where fn make_connector( resolver: R, tls_config: rustls::ClientConfig, -) -> HttpsConnector> -where - R: Service + Send + Sync + Clone + 'static, - R::Error: std::error::Error + Send + Sync, - R::Future: Send, - R::Response: Iterator, -{ +) -> HttpsConnector> { let mut http = HttpConnector::new_with_resolver(resolver); http.enforce_http(false); @@ -229,16 +220,17 @@ where /// # Errors /// /// Returns an error if it failed to initialize -pub async fn client( +pub async fn client( operation: &'static str, -) -> Result, Response, hyper::Error>, ClientInitError> +) -> Result>, ClientInitError> where - B: http_body::Body + Default + Send + 'static, - E: Into + 'static, + B: http_body::Body + Default + Send + 'static, + B::Data: Send, + B::Error: Into, { let client = make_traced_client().await?; let client = ClientLayer::new(operation).layer(client); - Ok(BoxCloneSyncService::new(client)) + Ok(client) } diff --git a/crates/http/src/layers/client.rs b/crates/http/src/layers/client.rs index eb10dd5a..2e112178 100644 --- a/crates/http/src/layers/client.rs +++ b/crates/http/src/layers/client.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{marker::PhantomData, time::Duration}; +use std::{sync::Arc, time::Duration}; -use http::{header::USER_AGENT, HeaderValue, Request, Response}; +use http::{header::USER_AGENT, HeaderValue}; +use tokio::sync::Semaphore; use tower::{ - limit::{ConcurrencyLimit, ConcurrencyLimitLayer}, - Layer, Service, + limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer}, + Layer, }; use tower_http::{ follow_redirect::{FollowRedirect, FollowRedirectLayer}, @@ -26,56 +27,60 @@ use tower_http::{ }; use super::otel::TraceLayer; -use crate::{otel::TraceHttpClient, BoxError}; +use crate::otel::{TraceHttpClient, TraceHttpClientLayer}; -static MAS_USER_AGENT: HeaderValue = - HeaderValue::from_static("matrix-authentication-service/0.0.1"); +pub type ClientService = SetRequestHeader< + TraceHttpClient>>>>, + HeaderValue, +>; #[derive(Debug, Clone)] -pub struct ClientLayer { - operation: &'static str, - _t: PhantomData, +pub struct ClientLayer { + user_agent_layer: SetRequestHeaderLayer, + outer_trace_layer: TraceHttpClientLayer, + concurrency_limit_layer: GlobalConcurrencyLimitLayer, + follow_redirect_layer: FollowRedirectLayer, + inner_trace_layer: TraceHttpClientLayer, + timeout_layer: TimeoutLayer, } -impl ClientLayer { +impl ClientLayer { #[must_use] pub fn new(operation: &'static str) -> Self { + let semaphore = Arc::new(Semaphore::new(10)); + Self::with_semaphore(operation, semaphore) + } + + #[must_use] + pub fn with_semaphore(operation: &'static str, semaphore: Arc) -> Self { Self { - operation, - _t: PhantomData, + user_agent_layer: SetRequestHeaderLayer::overriding( + USER_AGENT, + HeaderValue::from_static("matrix-authentication-service/0.0.1"), + ), + outer_trace_layer: TraceLayer::http_client(operation), + concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore), + follow_redirect_layer: FollowRedirectLayer::new(), + inner_trace_layer: TraceLayer::inner_http_client(), + timeout_layer: TimeoutLayer::new(Duration::from_secs(10)), } } } -impl Layer for ClientLayer +impl Layer for ClientLayer where - S: Service, Response = Response, Error = E> - + Clone - + Send - + Sync - + 'static, - ReqBody: http_body::Body + Default + Send + 'static, - ResBody: http_body::Body + Sync + Send + 'static, - S::Future: Send + 'static, - E: Into, + S: Clone, { - type Service = SetRequestHeader< - TraceHttpClient>>>>, - HeaderValue, - >; + type Service = ClientService; fn layer(&self, inner: S) -> Self::Service { - // Note that all layers here just forward the error type. ( - SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()), - // A trace that has the whole operation, with all the redirects, timeouts and rate - // limits in it - TraceLayer::http_client(self.operation), - ConcurrencyLimitLayer::new(10), - FollowRedirectLayer::new(), - // A trace for each "real" http request - TraceLayer::inner_http_client(), - TimeoutLayer::new(Duration::from_secs(10)), + &self.user_agent_layer, + &self.outer_trace_layer, + &self.concurrency_limit_layer, + &self.follow_redirect_layer, + &self.inner_trace_layer, + &self.timeout_layer, ) .layer(inner) } diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 3b251725..74165ced 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -49,3 +49,5 @@ pub use self::{ }; pub(crate) type BoxError = Box; + +pub type EmptyBody = http_body::Empty;