1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Cleanup HTTP client building

This commit is contained in:
Quentin Gliech
2022-11-23 12:26:58 +01:00
parent 16088fc11c
commit d514a8922c
4 changed files with 71 additions and 72 deletions

View File

@@ -174,7 +174,7 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result<PublicJsonWebKeySet, BoxErro
let request = http::Request::builder() let request = http::Request::builder()
.uri(uri.as_str()) .uri(uri.as_str())
.body(http_body::Empty::new()) .body(mas_http::EmptyBody::new())
.unwrap(); .unwrap();
let mut client = mas_http::client("fetch-jwks") let mut client = mas_http::client("fetch-jwks")

View File

@@ -12,27 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{convert::Infallible, net::SocketAddr}; use std::convert::Infallible;
use bytes::Bytes;
use http::{Request, Response};
use hyper::{ use hyper::{
client::{ client::{connect::dns::GaiResolver, HttpConnector},
connect::dns::{GaiResolver, Name},
HttpConnector,
},
Client, Client,
}; };
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error; use thiserror::Error;
use tower::{Layer, Service}; use tower::Layer;
use crate::{ use crate::{
layers::{ layers::{
client::ClientLayer, client::{ClientLayer, ClientService},
otel::{TraceDns, TraceLayer}, otel::{TraceDns, TraceLayer},
}, },
BoxCloneSyncService, BoxError, BoxError,
}; };
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))] #[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
@@ -154,38 +149,41 @@ async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
Ok(tls_config) Ok(tls_config)
} }
type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
type TracedClient<B> = hyper::Client<TracedConnector, B>;
/// Create a basic Hyper HTTP & HTTPS client without any tracing /// Create a basic Hyper HTTP & HTTPS client without any tracing
/// ///
/// # Errors /// # Errors
/// ///
/// Returns an error if it failed to load the TLS certificates /// Returns an error if it failed to load the TLS certificates
pub async fn make_untraced_client<B, E>( pub async fn make_untraced_client<B>() -> Result<UntracedClient<B>, ClientInitError>
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
where where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static, B: http_body::Body + Send + 'static,
E: Into<BoxError>, B::Data: Send,
{ {
let https = make_untraced_connector().await?; let https = make_untraced_connector().await?;
Ok(Client::builder().build(https)) Ok(Client::builder().build(https))
} }
async fn make_traced_client<B, E>( async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
where where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static, B: http_body::Body + Send + 'static,
E: Into<BoxError>, B::Data: Send,
{ {
let https = make_traced_connector().await?; let https = make_traced_connector().await?;
Ok(Client::builder().build(https)) Ok(Client::builder().build(https))
} }
type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
type TracedConnector = HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>;
/// Create a traced HTTP and HTTPS connector /// Create a traced HTTP and HTTPS connector
/// ///
/// # Errors /// # Errors
/// ///
/// Returns an error if it failed to load the TLS certificates /// Returns an error if it failed to load the TLS certificates
pub async fn make_traced_connector( pub async fn make_traced_connector() -> Result<TracedConnector, ClientInitError>
) -> Result<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, ClientInitError>
where where
{ {
// Trace DNS requests // Trace DNS requests
@@ -194,8 +192,7 @@ where
Ok(make_connector(resolver, tls_config)) Ok(make_connector(resolver, tls_config))
} }
async fn make_untraced_connector( async fn make_untraced_connector() -> Result<UntracedConnector, ClientInitError>
) -> Result<HttpsConnector<HttpConnector<GaiResolver>>, ClientInitError>
where where
{ {
let resolver = GaiResolver::new(); let resolver = GaiResolver::new();
@@ -206,13 +203,7 @@ where
fn make_connector<R>( fn make_connector<R>(
resolver: R, resolver: R,
tls_config: rustls::ClientConfig, tls_config: rustls::ClientConfig,
) -> HttpsConnector<HttpConnector<R>> ) -> HttpsConnector<HttpConnector<R>> {
where
R: Service<Name> + Send + Sync + Clone + 'static,
R::Error: std::error::Error + Send + Sync,
R::Future: Send,
R::Response: Iterator<Item = SocketAddr>,
{
let mut http = HttpConnector::new_with_resolver(resolver); let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false); http.enforce_http(false);
@@ -229,16 +220,17 @@ where
/// # Errors /// # Errors
/// ///
/// Returns an error if it failed to initialize /// Returns an error if it failed to initialize
pub async fn client<B, E>( pub async fn client<B>(
operation: &'static str, operation: &'static str,
) -> Result<BoxCloneSyncService<Request<B>, Response<hyper::Body>, hyper::Error>, ClientInitError> ) -> Result<ClientService<TracedClient<B>>, ClientInitError>
where where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static, B: http_body::Body + Default + Send + 'static,
E: Into<BoxError> + 'static, B::Data: Send,
B::Error: Into<BoxError>,
{ {
let client = make_traced_client().await?; let client = make_traced_client().await?;
let client = ClientLayer::new(operation).layer(client); let client = ClientLayer::new(operation).layer(client);
Ok(BoxCloneSyncService::new(client)) Ok(client)
} }

View File

@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{marker::PhantomData, time::Duration}; use std::{sync::Arc, time::Duration};
use http::{header::USER_AGENT, HeaderValue, Request, Response}; use http::{header::USER_AGENT, HeaderValue};
use tokio::sync::Semaphore;
use tower::{ use tower::{
limit::{ConcurrencyLimit, ConcurrencyLimitLayer}, limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
Layer, Service, Layer,
}; };
use tower_http::{ use tower_http::{
follow_redirect::{FollowRedirect, FollowRedirectLayer}, follow_redirect::{FollowRedirect, FollowRedirectLayer},
@@ -26,56 +27,60 @@ use tower_http::{
}; };
use super::otel::TraceLayer; use super::otel::TraceLayer;
use crate::{otel::TraceHttpClient, BoxError}; use crate::otel::{TraceHttpClient, TraceHttpClientLayer};
static MAS_USER_AGENT: HeaderValue = pub type ClientService<S> = SetRequestHeader<
HeaderValue::from_static("matrix-authentication-service/0.0.1"); TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
HeaderValue,
>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> { pub struct ClientLayer {
operation: &'static str, user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
_t: PhantomData<ReqBody>, outer_trace_layer: TraceHttpClientLayer,
concurrency_limit_layer: GlobalConcurrencyLimitLayer,
follow_redirect_layer: FollowRedirectLayer,
inner_trace_layer: TraceHttpClientLayer,
timeout_layer: TimeoutLayer,
} }
impl<B> ClientLayer<B> { impl ClientLayer {
#[must_use] #[must_use]
pub fn new(operation: &'static str) -> Self { pub fn new(operation: &'static str) -> Self {
let semaphore = Arc::new(Semaphore::new(10));
Self::with_semaphore(operation, semaphore)
}
#[must_use]
pub fn with_semaphore(operation: &'static str, semaphore: Arc<Semaphore>) -> Self {
Self { Self {
operation, user_agent_layer: SetRequestHeaderLayer::overriding(
_t: PhantomData, USER_AGENT,
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
),
outer_trace_layer: TraceLayer::http_client(operation),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
follow_redirect_layer: FollowRedirectLayer::new(),
inner_trace_layer: TraceLayer::inner_http_client(),
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
} }
} }
} }
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody> impl<S> Layer<S> for ClientLayer
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E> S: Clone,
+ Clone
+ Send
+ Sync
+ 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
S::Future: Send + 'static,
E: Into<BoxError>,
{ {
type Service = SetRequestHeader< type Service = ClientService<S>;
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
HeaderValue,
>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
// Note that all layers here just forward the error type.
( (
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()), &self.user_agent_layer,
// A trace that has the whole operation, with all the redirects, timeouts and rate &self.outer_trace_layer,
// limits in it &self.concurrency_limit_layer,
TraceLayer::http_client(self.operation), &self.follow_redirect_layer,
ConcurrencyLimitLayer::new(10), &self.inner_trace_layer,
FollowRedirectLayer::new(), &self.timeout_layer,
// A trace for each "real" http request
TraceLayer::inner_http_client(),
TimeoutLayer::new(Duration::from_secs(10)),
) )
.layer(inner) .layer(inner)
} }

View File

@@ -49,3 +49,5 @@ pub use self::{
}; };
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>; pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type EmptyBody = http_body::Empty<bytes::Bytes>;