You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Cleanup HTTP client building
This commit is contained in:
@@ -174,7 +174,7 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result<PublicJsonWebKeySet, BoxErro
|
|||||||
|
|
||||||
let request = http::Request::builder()
|
let request = http::Request::builder()
|
||||||
.uri(uri.as_str())
|
.uri(uri.as_str())
|
||||||
.body(http_body::Empty::new())
|
.body(mas_http::EmptyBody::new())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut client = mas_http::client("fetch-jwks")
|
let mut client = mas_http::client("fetch-jwks")
|
||||||
|
@@ -12,27 +12,22 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{convert::Infallible, net::SocketAddr};
|
use std::convert::Infallible;
|
||||||
|
|
||||||
use bytes::Bytes;
|
|
||||||
use http::{Request, Response};
|
|
||||||
use hyper::{
|
use hyper::{
|
||||||
client::{
|
client::{connect::dns::GaiResolver, HttpConnector},
|
||||||
connect::dns::{GaiResolver, Name},
|
|
||||||
HttpConnector,
|
|
||||||
},
|
|
||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tower::{Layer, Service};
|
use tower::Layer;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
layers::{
|
layers::{
|
||||||
client::ClientLayer,
|
client::{ClientLayer, ClientService},
|
||||||
otel::{TraceDns, TraceLayer},
|
otel::{TraceDns, TraceLayer},
|
||||||
},
|
},
|
||||||
BoxCloneSyncService, BoxError,
|
BoxError,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
|
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
|
||||||
@@ -154,38 +149,41 @@ async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
|
|||||||
Ok(tls_config)
|
Ok(tls_config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
|
||||||
|
type TracedClient<B> = hyper::Client<TracedConnector, B>;
|
||||||
|
|
||||||
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns an error if it failed to load the TLS certificates
|
/// Returns an error if it failed to load the TLS certificates
|
||||||
pub async fn make_untraced_client<B, E>(
|
pub async fn make_untraced_client<B>() -> Result<UntracedClient<B>, ClientInitError>
|
||||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
|
|
||||||
where
|
where
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
B: http_body::Body + Send + 'static,
|
||||||
E: Into<BoxError>,
|
B::Data: Send,
|
||||||
{
|
{
|
||||||
let https = make_untraced_connector().await?;
|
let https = make_untraced_connector().await?;
|
||||||
Ok(Client::builder().build(https))
|
Ok(Client::builder().build(https))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn make_traced_client<B, E>(
|
async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
|
||||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
|
||||||
where
|
where
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
B: http_body::Body + Send + 'static,
|
||||||
E: Into<BoxError>,
|
B::Data: Send,
|
||||||
{
|
{
|
||||||
let https = make_traced_connector().await?;
|
let https = make_traced_connector().await?;
|
||||||
Ok(Client::builder().build(https))
|
Ok(Client::builder().build(https))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
|
||||||
|
type TracedConnector = HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>;
|
||||||
|
|
||||||
/// Create a traced HTTP and HTTPS connector
|
/// Create a traced HTTP and HTTPS connector
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns an error if it failed to load the TLS certificates
|
/// Returns an error if it failed to load the TLS certificates
|
||||||
pub async fn make_traced_connector(
|
pub async fn make_traced_connector() -> Result<TracedConnector, ClientInitError>
|
||||||
) -> Result<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, ClientInitError>
|
|
||||||
where
|
where
|
||||||
{
|
{
|
||||||
// Trace DNS requests
|
// Trace DNS requests
|
||||||
@@ -194,8 +192,7 @@ where
|
|||||||
Ok(make_connector(resolver, tls_config))
|
Ok(make_connector(resolver, tls_config))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn make_untraced_connector(
|
async fn make_untraced_connector() -> Result<UntracedConnector, ClientInitError>
|
||||||
) -> Result<HttpsConnector<HttpConnector<GaiResolver>>, ClientInitError>
|
|
||||||
where
|
where
|
||||||
{
|
{
|
||||||
let resolver = GaiResolver::new();
|
let resolver = GaiResolver::new();
|
||||||
@@ -206,13 +203,7 @@ where
|
|||||||
fn make_connector<R>(
|
fn make_connector<R>(
|
||||||
resolver: R,
|
resolver: R,
|
||||||
tls_config: rustls::ClientConfig,
|
tls_config: rustls::ClientConfig,
|
||||||
) -> HttpsConnector<HttpConnector<R>>
|
) -> HttpsConnector<HttpConnector<R>> {
|
||||||
where
|
|
||||||
R: Service<Name> + Send + Sync + Clone + 'static,
|
|
||||||
R::Error: std::error::Error + Send + Sync,
|
|
||||||
R::Future: Send,
|
|
||||||
R::Response: Iterator<Item = SocketAddr>,
|
|
||||||
{
|
|
||||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||||
http.enforce_http(false);
|
http.enforce_http(false);
|
||||||
|
|
||||||
@@ -229,16 +220,17 @@ where
|
|||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns an error if it failed to initialize
|
/// Returns an error if it failed to initialize
|
||||||
pub async fn client<B, E>(
|
pub async fn client<B>(
|
||||||
operation: &'static str,
|
operation: &'static str,
|
||||||
) -> Result<BoxCloneSyncService<Request<B>, Response<hyper::Body>, hyper::Error>, ClientInitError>
|
) -> Result<ClientService<TracedClient<B>>, ClientInitError>
|
||||||
where
|
where
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
B: http_body::Body + Default + Send + 'static,
|
||||||
E: Into<BoxError> + 'static,
|
B::Data: Send,
|
||||||
|
B::Error: Into<BoxError>,
|
||||||
{
|
{
|
||||||
let client = make_traced_client().await?;
|
let client = make_traced_client().await?;
|
||||||
|
|
||||||
let client = ClientLayer::new(operation).layer(client);
|
let client = ClientLayer::new(operation).layer(client);
|
||||||
|
|
||||||
Ok(BoxCloneSyncService::new(client))
|
Ok(client)
|
||||||
}
|
}
|
||||||
|
@@ -12,12 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{marker::PhantomData, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
use http::{header::USER_AGENT, HeaderValue};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
use tower::{
|
use tower::{
|
||||||
limit::{ConcurrencyLimit, ConcurrencyLimitLayer},
|
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
|
||||||
Layer, Service,
|
Layer,
|
||||||
};
|
};
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
follow_redirect::{FollowRedirect, FollowRedirectLayer},
|
follow_redirect::{FollowRedirect, FollowRedirectLayer},
|
||||||
@@ -26,56 +27,60 @@ use tower_http::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use super::otel::TraceLayer;
|
use super::otel::TraceLayer;
|
||||||
use crate::{otel::TraceHttpClient, BoxError};
|
use crate::otel::{TraceHttpClient, TraceHttpClientLayer};
|
||||||
|
|
||||||
static MAS_USER_AGENT: HeaderValue =
|
pub type ClientService<S> = SetRequestHeader<
|
||||||
HeaderValue::from_static("matrix-authentication-service/0.0.1");
|
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
|
||||||
|
HeaderValue,
|
||||||
|
>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ClientLayer<ReqBody> {
|
pub struct ClientLayer {
|
||||||
operation: &'static str,
|
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
|
||||||
_t: PhantomData<ReqBody>,
|
outer_trace_layer: TraceHttpClientLayer,
|
||||||
|
concurrency_limit_layer: GlobalConcurrencyLimitLayer,
|
||||||
|
follow_redirect_layer: FollowRedirectLayer,
|
||||||
|
inner_trace_layer: TraceHttpClientLayer,
|
||||||
|
timeout_layer: TimeoutLayer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> ClientLayer<B> {
|
impl ClientLayer {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(operation: &'static str) -> Self {
|
pub fn new(operation: &'static str) -> Self {
|
||||||
|
let semaphore = Arc::new(Semaphore::new(10));
|
||||||
|
Self::with_semaphore(operation, semaphore)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_semaphore(operation: &'static str, semaphore: Arc<Semaphore>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
operation,
|
user_agent_layer: SetRequestHeaderLayer::overriding(
|
||||||
_t: PhantomData,
|
USER_AGENT,
|
||||||
|
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
|
||||||
|
),
|
||||||
|
outer_trace_layer: TraceLayer::http_client(operation),
|
||||||
|
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
|
||||||
|
follow_redirect_layer: FollowRedirectLayer::new(),
|
||||||
|
inner_trace_layer: TraceLayer::inner_http_client(),
|
||||||
|
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
|
impl<S> Layer<S> for ClientLayer
|
||||||
where
|
where
|
||||||
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E>
|
S: Clone,
|
||||||
+ Clone
|
|
||||||
+ Send
|
|
||||||
+ Sync
|
|
||||||
+ 'static,
|
|
||||||
ReqBody: http_body::Body + Default + Send + 'static,
|
|
||||||
ResBody: http_body::Body + Sync + Send + 'static,
|
|
||||||
S::Future: Send + 'static,
|
|
||||||
E: Into<BoxError>,
|
|
||||||
{
|
{
|
||||||
type Service = SetRequestHeader<
|
type Service = ClientService<S>;
|
||||||
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
|
|
||||||
HeaderValue,
|
|
||||||
>;
|
|
||||||
|
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
// Note that all layers here just forward the error type.
|
|
||||||
(
|
(
|
||||||
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
|
&self.user_agent_layer,
|
||||||
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
&self.outer_trace_layer,
|
||||||
// limits in it
|
&self.concurrency_limit_layer,
|
||||||
TraceLayer::http_client(self.operation),
|
&self.follow_redirect_layer,
|
||||||
ConcurrencyLimitLayer::new(10),
|
&self.inner_trace_layer,
|
||||||
FollowRedirectLayer::new(),
|
&self.timeout_layer,
|
||||||
// A trace for each "real" http request
|
|
||||||
TraceLayer::inner_http_client(),
|
|
||||||
TimeoutLayer::new(Duration::from_secs(10)),
|
|
||||||
)
|
)
|
||||||
.layer(inner)
|
.layer(inner)
|
||||||
}
|
}
|
||||||
|
@@ -49,3 +49,5 @@ pub use self::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||||
|
|
||||||
|
pub type EmptyBody = http_body::Empty<bytes::Bytes>;
|
||||||
|
Reference in New Issue
Block a user