1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-23 11:02:35 +03:00

Make the HTTP client factory reuse the underlying client

This avoids duplicating clients, and makes it so that they all share the same connection pool.
This commit is contained in:
Quentin Gliech
2023-09-14 14:22:49 +02:00
parent f29e4adcfa
commit 54071c4969
15 changed files with 146 additions and 77 deletions

View File

@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{sync::Arc, time::Duration};
use std::time::Duration;
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use hyper::client::connect::HttpInfo;
use mas_tower::{
EnrichSpan, MakeSpan, TraceContextLayer, TraceContextService, TraceLayer, TraceService,
DurationRecorderLayer, DurationRecorderService, EnrichSpan, InFlightCounterLayer,
InFlightCounterService, MakeSpan, MetricsAttributes, TraceContextLayer, TraceContextService,
TraceLayer, TraceService,
};
use tokio::sync::Semaphore;
use opentelemetry::KeyValue;
use tower::{
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
Layer,
@@ -33,21 +35,31 @@ use tower_http::{
use tracing::Span;
pub type ClientService<S> = SetRequestHeader<
ConcurrencyLimit<
FollowRedirect<
TraceService<
TraceContextService<Timeout<S>>,
MakeSpanForRequest,
EnrichSpanOnResponse,
EnrichSpanOnError,
DurationRecorderService<
InFlightCounterService<
ConcurrencyLimit<
FollowRedirect<
TraceService<
TraceContextService<Timeout<S>>,
MakeSpanForRequest,
EnrichSpanOnResponse,
EnrichSpanOnError,
>,
>,
>,
OnRequestLabels,
>,
OnRequestLabels,
OnResponseLabels,
KeyValue,
>,
HeaderValue,
>;
#[derive(Debug, Clone)]
pub struct MakeSpanForRequest;
#[derive(Debug, Clone, Default)]
pub struct MakeSpanForRequest {
category: Option<&'static str>,
}
impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
fn make_span(&self, request: &Request<B>) -> Span {
@@ -58,6 +70,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
.map(tracing::field::display);
let content_length = headers.typed_get().map(|ContentLength(len)| len);
let net_sock_peer_name = request.uri().host();
let category = self.category.unwrap_or("UNSET");
tracing::info_span!(
"http.client.request",
@@ -78,6 +91,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
"net.sock.host.port" = tracing::field::Empty,
"user_agent.original" = user_agent,
"rust.error" = tracing::field::Empty,
"mas.category" = category,
)
}
}
@@ -123,6 +137,42 @@ where
}
}
#[derive(Debug, Clone, Default)]
pub struct OnRequestLabels {
category: Option<&'static str>,
}
impl<B> MetricsAttributes<Request<B>> for OnRequestLabels
where
B: 'static,
{
type Iter<'a> = std::array::IntoIter<KeyValue, 3>;
fn attributes<'a>(&'a self, t: &'a Request<B>) -> Self::Iter<'a> {
[
KeyValue::new("http.request.method", t.method().as_str().to_owned()),
KeyValue::new("network.protocol.name", "http"),
KeyValue::new("mas.category", self.category.unwrap_or("UNSET")),
]
.into_iter()
}
}
#[derive(Debug, Clone, Default)]
pub struct OnResponseLabels;
impl<B> MetricsAttributes<Response<B>> for OnResponseLabels
where
B: 'static,
{
type Iter<'a> = std::iter::Once<KeyValue>;
fn attributes<'a>(&'a self, t: &'a Response<B>) -> Self::Iter<'a> {
std::iter::once(KeyValue::new(
"http.response.status_code",
i64::from(t.status().as_u16()),
))
}
}
#[derive(Debug, Clone)]
pub struct ClientLayer {
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
@@ -131,6 +181,8 @@ pub struct ClientLayer {
trace_layer: TraceLayer<MakeSpanForRequest, EnrichSpanOnResponse, EnrichSpanOnError>,
trace_context_layer: TraceContextLayer,
timeout_layer: TimeoutLayer,
duration_recorder_layer: DurationRecorderLayer<OnRequestLabels, OnResponseLabels, KeyValue>,
in_flight_counter_layer: InFlightCounterLayer<OnRequestLabels>,
}
impl Default for ClientLayer {
@@ -142,26 +194,45 @@ impl Default for ClientLayer {
impl ClientLayer {
#[must_use]
pub fn new() -> Self {
let semaphore = Arc::new(Semaphore::new(10));
Self::with_semaphore(semaphore)
}
#[must_use]
pub fn with_semaphore(semaphore: Arc<Semaphore>) -> Self {
Self {
user_agent_layer: SetRequestHeaderLayer::overriding(
USER_AGENT,
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::new(10),
follow_redirect_layer: FollowRedirectLayer::new(),
trace_layer: TraceLayer::new(MakeSpanForRequest)
trace_layer: TraceLayer::new(MakeSpanForRequest::default())
.on_response(EnrichSpanOnResponse)
.on_error(EnrichSpanOnError),
trace_context_layer: TraceContextLayer::new(),
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
duration_recorder_layer: DurationRecorderLayer::new("http.client.duration")
.on_request(OnRequestLabels::default())
.on_response(OnResponseLabels)
.on_error(KeyValue::new("http.error", true)),
in_flight_counter_layer: InFlightCounterLayer::new("http.client.active_requests")
.on_request(OnRequestLabels::default()),
}
}
#[must_use]
pub fn with_category(mut self, category: &'static str) -> Self {
self.trace_layer = TraceLayer::new(MakeSpanForRequest {
category: Some(category),
})
.on_response(EnrichSpanOnResponse)
.on_error(EnrichSpanOnError);
self.duration_recorder_layer = self.duration_recorder_layer.on_request(OnRequestLabels {
category: Some(category),
});
self.in_flight_counter_layer = self.in_flight_counter_layer.on_request(OnRequestLabels {
category: Some(category),
});
self
}
}
impl<S> Layer<S> for ClientLayer
@@ -173,6 +244,8 @@ where
fn layer(&self, inner: S) -> Self::Service {
(
&self.user_agent_layer,
&self.duration_recorder_layer,
&self.in_flight_counter_layer,
&self.concurrency_limit_layer,
&self.follow_redirect_layer,
&self.trace_layer,