You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Cleanup HTTP client building
This commit is contained in:
@ -12,27 +12,22 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{convert::Infallible, net::SocketAddr};
|
||||
use std::convert::Infallible;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::{Request, Response};
|
||||
use hyper::{
|
||||
client::{
|
||||
connect::dns::{GaiResolver, Name},
|
||||
HttpConnector,
|
||||
},
|
||||
client::{connect::dns::GaiResolver, HttpConnector},
|
||||
Client,
|
||||
};
|
||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||
use thiserror::Error;
|
||||
use tower::{Layer, Service};
|
||||
use tower::Layer;
|
||||
|
||||
use crate::{
|
||||
layers::{
|
||||
client::ClientLayer,
|
||||
client::{ClientLayer, ClientService},
|
||||
otel::{TraceDns, TraceLayer},
|
||||
},
|
||||
BoxCloneSyncService, BoxError,
|
||||
BoxError,
|
||||
};
|
||||
|
||||
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
|
||||
@ -154,38 +149,41 @@ async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
|
||||
Ok(tls_config)
|
||||
}
|
||||
|
||||
type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
|
||||
type TracedClient<B> = hyper::Client<TracedConnector, B>;
|
||||
|
||||
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to load the TLS certificates
|
||||
pub async fn make_untraced_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
|
||||
pub async fn make_untraced_client<B>() -> Result<UntracedClient<B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
B: http_body::Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
{
|
||||
let https = make_untraced_connector().await?;
|
||||
Ok(Client::builder().build(https))
|
||||
}
|
||||
|
||||
async fn make_traced_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||
async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
B: http_body::Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
{
|
||||
let https = make_traced_connector().await?;
|
||||
Ok(Client::builder().build(https))
|
||||
}
|
||||
|
||||
type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
|
||||
type TracedConnector = HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>;
|
||||
|
||||
/// Create a traced HTTP and HTTPS connector
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to load the TLS certificates
|
||||
pub async fn make_traced_connector(
|
||||
) -> Result<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, ClientInitError>
|
||||
pub async fn make_traced_connector() -> Result<TracedConnector, ClientInitError>
|
||||
where
|
||||
{
|
||||
// Trace DNS requests
|
||||
@ -194,8 +192,7 @@ where
|
||||
Ok(make_connector(resolver, tls_config))
|
||||
}
|
||||
|
||||
async fn make_untraced_connector(
|
||||
) -> Result<HttpsConnector<HttpConnector<GaiResolver>>, ClientInitError>
|
||||
async fn make_untraced_connector() -> Result<UntracedConnector, ClientInitError>
|
||||
where
|
||||
{
|
||||
let resolver = GaiResolver::new();
|
||||
@ -206,13 +203,7 @@ where
|
||||
fn make_connector<R>(
|
||||
resolver: R,
|
||||
tls_config: rustls::ClientConfig,
|
||||
) -> HttpsConnector<HttpConnector<R>>
|
||||
where
|
||||
R: Service<Name> + Send + Sync + Clone + 'static,
|
||||
R::Error: std::error::Error + Send + Sync,
|
||||
R::Future: Send,
|
||||
R::Response: Iterator<Item = SocketAddr>,
|
||||
{
|
||||
) -> HttpsConnector<HttpConnector<R>> {
|
||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||
http.enforce_http(false);
|
||||
|
||||
@ -229,16 +220,17 @@ where
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to initialize
|
||||
pub async fn client<B, E>(
|
||||
pub async fn client<B>(
|
||||
operation: &'static str,
|
||||
) -> Result<BoxCloneSyncService<Request<B>, Response<hyper::Body>, hyper::Error>, ClientInitError>
|
||||
) -> Result<ClientService<TracedClient<B>>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||
E: Into<BoxError> + 'static,
|
||||
B: http_body::Body + Default + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
{
|
||||
let client = make_traced_client().await?;
|
||||
|
||||
let client = ClientLayer::new(operation).layer(client);
|
||||
|
||||
Ok(BoxCloneSyncService::new(client))
|
||||
Ok(client)
|
||||
}
|
||||
|
@ -12,12 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{marker::PhantomData, time::Duration};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
||||
use http::{header::USER_AGENT, HeaderValue};
|
||||
use tokio::sync::Semaphore;
|
||||
use tower::{
|
||||
limit::{ConcurrencyLimit, ConcurrencyLimitLayer},
|
||||
Layer, Service,
|
||||
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
|
||||
Layer,
|
||||
};
|
||||
use tower_http::{
|
||||
follow_redirect::{FollowRedirect, FollowRedirectLayer},
|
||||
@ -26,56 +27,60 @@ use tower_http::{
|
||||
};
|
||||
|
||||
use super::otel::TraceLayer;
|
||||
use crate::{otel::TraceHttpClient, BoxError};
|
||||
use crate::otel::{TraceHttpClient, TraceHttpClientLayer};
|
||||
|
||||
static MAS_USER_AGENT: HeaderValue =
|
||||
HeaderValue::from_static("matrix-authentication-service/0.0.1");
|
||||
pub type ClientService<S> = SetRequestHeader<
|
||||
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
|
||||
HeaderValue,
|
||||
>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientLayer<ReqBody> {
|
||||
operation: &'static str,
|
||||
_t: PhantomData<ReqBody>,
|
||||
pub struct ClientLayer {
|
||||
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
|
||||
outer_trace_layer: TraceHttpClientLayer,
|
||||
concurrency_limit_layer: GlobalConcurrencyLimitLayer,
|
||||
follow_redirect_layer: FollowRedirectLayer,
|
||||
inner_trace_layer: TraceHttpClientLayer,
|
||||
timeout_layer: TimeoutLayer,
|
||||
}
|
||||
|
||||
impl<B> ClientLayer<B> {
|
||||
impl ClientLayer {
|
||||
#[must_use]
|
||||
pub fn new(operation: &'static str) -> Self {
|
||||
let semaphore = Arc::new(Semaphore::new(10));
|
||||
Self::with_semaphore(operation, semaphore)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_semaphore(operation: &'static str, semaphore: Arc<Semaphore>) -> Self {
|
||||
Self {
|
||||
operation,
|
||||
_t: PhantomData,
|
||||
user_agent_layer: SetRequestHeaderLayer::overriding(
|
||||
USER_AGENT,
|
||||
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
|
||||
),
|
||||
outer_trace_layer: TraceLayer::http_client(operation),
|
||||
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
|
||||
follow_redirect_layer: FollowRedirectLayer::new(),
|
||||
inner_trace_layer: TraceLayer::inner_http_client(),
|
||||
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
|
||||
impl<S> Layer<S> for ClientLayer
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
ReqBody: http_body::Body + Default + Send + 'static,
|
||||
ResBody: http_body::Body + Sync + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
S: Clone,
|
||||
{
|
||||
type Service = SetRequestHeader<
|
||||
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
|
||||
HeaderValue,
|
||||
>;
|
||||
type Service = ClientService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
// Note that all layers here just forward the error type.
|
||||
(
|
||||
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
|
||||
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
||||
// limits in it
|
||||
TraceLayer::http_client(self.operation),
|
||||
ConcurrencyLimitLayer::new(10),
|
||||
FollowRedirectLayer::new(),
|
||||
// A trace for each "real" http request
|
||||
TraceLayer::inner_http_client(),
|
||||
TimeoutLayer::new(Duration::from_secs(10)),
|
||||
&self.user_agent_layer,
|
||||
&self.outer_trace_layer,
|
||||
&self.concurrency_limit_layer,
|
||||
&self.follow_redirect_layer,
|
||||
&self.inner_trace_layer,
|
||||
&self.timeout_layer,
|
||||
)
|
||||
.layer(inner)
|
||||
}
|
||||
|
@ -49,3 +49,5 @@ pub use self::{
|
||||
};
|
||||
|
||||
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
pub type EmptyBody = http_body::Empty<bytes::Bytes>;
|
||||
|
Reference in New Issue
Block a user