1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Make the HTTP client factory reuse the underlying client

This avoids duplicating clients, and makes it so that they all share the same connection pool.
This commit is contained in:
Quentin Gliech
2023-09-14 14:22:49 +02:00
parent f29e4adcfa
commit 54071c4969
15 changed files with 146 additions and 77 deletions

View File

@@ -191,8 +191,7 @@ async fn fetch_jwks(
.unwrap(); .unwrap();
let mut client = http_client_factory let mut client = http_client_factory
.client() .client("client.fetch_jwks")
.await?
.response_body_to_bytes() .response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>(); .json_response::<PublicJsonWebKeySet>();

View File

@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use axum::body::Full; use axum::body::Full;
use mas_http::{ use mas_http::{
BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService, make_traced_connector, BodyToBytesResponseLayer, Client, ClientInitError, ClientLayer,
TracedClient, ClientService, HttpService, TracedClient, TracedConnector,
}; };
use tokio::sync::Semaphore;
use tower::{ use tower::{
util::{MapErrLayer, MapRequestLayer}, util::{MapErrLayer, MapRequestLayer},
BoxError, Layer, BoxError, Layer,
@@ -27,15 +24,16 @@ use tower::{
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HttpClientFactory { pub struct HttpClientFactory {
semaphore: Arc<Semaphore>, traced_connector: TracedConnector,
client_layer: ClientLayer,
} }
impl HttpClientFactory { impl HttpClientFactory {
#[must_use] pub async fn new() -> Result<Self, ClientInitError> {
pub fn new(concurrency_limit: usize) -> Self { Ok(Self {
Self { traced_connector: make_traced_connector().await?,
semaphore: Arc::new(Semaphore::new(concurrency_limit)), client_layer: ClientLayer::new(),
} })
} }
/// Constructs a new HTTP client /// Constructs a new HTTP client
@@ -43,14 +41,16 @@ impl HttpClientFactory {
/// # Errors /// # Errors
/// ///
/// Returns an error if the client failed to initialise /// Returns an error if the client failed to initialise
pub async fn client<B>(&self) -> Result<ClientService<TracedClient<B>>, ClientInitError> pub fn client<B>(&self, category: &'static str) -> ClientService<TracedClient<B>>
where where
B: axum::body::HttpBody + Send, B: axum::body::HttpBody + Send,
B::Data: Send, B::Data: Send,
{ {
let client = mas_http::make_traced_client::<B>().await?; let client = Client::builder().build(self.traced_connector.clone());
let layer = ClientLayer::with_semaphore(self.semaphore.clone()); self.client_layer
Ok(layer.layer(client)) .clone()
.with_category(category)
.layer(client)
} }
/// Constructs a new [`HttpService`], suitable for `mas-oidc-client` /// Constructs a new [`HttpService`], suitable for `mas-oidc-client`
@@ -58,8 +58,8 @@ impl HttpClientFactory {
/// # Errors /// # Errors
/// ///
/// Returns an error if the client failed to initialise /// Returns an error if the client failed to initialise
pub async fn http_service(&self) -> Result<HttpService, ClientInitError> { pub fn http_service(&self, category: &'static str) -> HttpService {
let client = self.client().await?; let client = self.client(category);
let client = ( let client = (
MapErrLayer::new(BoxError::from), MapErrLayer::new(BoxError::from),
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
@@ -67,6 +67,6 @@ impl HttpClientFactory {
) )
.layer(client); .layer(client);
Ok(HttpService::new(client)) HttpService::new(client)
} }
} }

View File

@@ -67,7 +67,7 @@ impl Options {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> { pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC; use Subcommand as SC;
let http_client_factory = HttpClientFactory::new(10); let http_client_factory = HttpClientFactory::new().await?;
match self.subcommand { match self.subcommand {
SC::Http { SC::Http {
show_headers, show_headers,
@@ -75,7 +75,7 @@ impl Options {
url, url,
} => { } => {
let _span = info_span!("cli.debug.http").entered(); let _span = info_span!("cli.debug.http").entered();
let mut client = http_client_factory.client().await?; let mut client = http_client_factory.client("debug");
let request = hyper::Request::builder() let request = hyper::Request::builder()
.uri(url) .uri(url)
.body(hyper::Body::empty())?; .body(hyper::Body::empty())?;
@@ -99,8 +99,7 @@ impl Options {
} => { } => {
let _span = info_span!("cli.debug.http").entered(); let _span = info_span!("cli.debug.http").entered();
let mut client = http_client_factory let mut client = http_client_factory
.client() .client("debug")
.await?
.response_body_to_bytes() .response_body_to_bytes()
.json_response(); .json_response();
let request = hyper::Request::builder() let request = hyper::Request::builder()

View File

@@ -97,6 +97,8 @@ impl Options {
// Load and compile the templates // Load and compile the templates
let templates = templates_from_config(&config.templates, &url_builder).await?; let templates = templates_from_config(&config.templates, &url_builder).await?;
let http_client_factory = HttpClientFactory::new().await?;
if !self.no_worker { if !self.no_worker {
let mailer = mailer_from_config(&config.email, &templates)?; let mailer = mailer_from_config(&config.email, &templates)?;
mailer.test_connection().await?; mailer.test_connection().await?;
@@ -105,15 +107,12 @@ impl Options {
let mut rng = thread_rng(); let mut rng = thread_rng();
let worker_name = Alphanumeric.sample_string(&mut rng, 10); let worker_name = Alphanumeric.sample_string(&mut rng, 10);
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
info!(worker_name, "Starting task worker"); info!(worker_name, "Starting task worker");
let conn = SynapseConnection::new( let conn = SynapseConnection::new(
config.matrix.homeserver.clone(), config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(), config.matrix.endpoint.clone(),
config.matrix.secret.clone(), config.matrix.secret.clone(),
http_client_factory, http_client_factory.clone(),
); );
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?; let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
// TODO: grab the handle // TODO: grab the handle
@@ -126,9 +125,6 @@ impl Options {
let password_manager = password_manager_from_config(&config.passwords).await?; let password_manager = password_manager_from_config(&config.passwords).await?;
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
// The upstream OIDC metadata cache // The upstream OIDC metadata cache
let metadata_cache = MetadataCache::new(); let metadata_cache = MetadataCache::new();

View File

@@ -49,7 +49,7 @@ impl Options {
let mailer = mailer_from_config(&config.email, &templates)?; let mailer = mailer_from_config(&config.email, &templates)?;
mailer.test_connection().await?; mailer.test_connection().await?;
let http_client_factory = HttpClientFactory::new(50); let http_client_factory = HttpClientFactory::new().await?;
let conn = SynapseConnection::new( let conn = SynapseConnection::new(
config.matrix.homeserver.clone(), config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(), config.matrix.endpoint.clone(),

View File

@@ -122,9 +122,7 @@ impl AppState {
let http_service = self let http_service = self
.http_client_factory .http_client_factory
.http_service() .http_service("upstream_oauth2.metadata");
.await
.expect("Failed to create the HTTP service");
self.metadata_cache self.metadata_cache
.warm_up_and_run( .warm_up_and_run(

View File

@@ -142,7 +142,7 @@ impl TestState {
let homeserver_connection = MockHomeserverConnection::new("example.com"); let homeserver_connection = MockHomeserverConnection::new("example.com");
let http_client_factory = HttpClientFactory::new(10); let http_client_factory = HttpClientFactory::new().await?;
let site_config = SiteConfig::default(); let site_config = SiteConfig::default();

View File

@@ -84,7 +84,7 @@ pub(crate) async fn get(
.await? .await?
.ok_or(RouteError::ProviderNotFound)?; .ok_or(RouteError::ProviderNotFound)?;
let http_service = http_client_factory.http_service().await?; let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
// First, discover the provider // First, discover the provider
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;

View File

@@ -188,7 +188,7 @@ pub(crate) async fn get(
CodeOrError::Code { code } => code, CodeOrError::Code { code } => code,
}; };
let http_service = http_client_factory.http_service().await?; let http_service = http_client_factory.http_service("upstream_oauth2.callback");
// Discover the provider // Discover the provider
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;

View File

@@ -14,13 +14,11 @@
use std::convert::Infallible; use std::convert::Infallible;
use hyper::{ use hyper::client::{
client::{
connect::dns::{GaiResolver, Name}, connect::dns::{GaiResolver, Name},
HttpConnector, HttpConnector,
},
Client,
}; };
pub use hyper::Client;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use mas_tower::{ use mas_tower::{
DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer, DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer,

View File

@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{sync::Arc, time::Duration}; use std::time::Duration;
use headers::{ContentLength, HeaderMapExt, Host, UserAgent}; use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
use http::{header::USER_AGENT, HeaderValue, Request, Response}; use http::{header::USER_AGENT, HeaderValue, Request, Response};
use hyper::client::connect::HttpInfo; use hyper::client::connect::HttpInfo;
use mas_tower::{ use mas_tower::{
EnrichSpan, MakeSpan, TraceContextLayer, TraceContextService, TraceLayer, TraceService, DurationRecorderLayer, DurationRecorderService, EnrichSpan, InFlightCounterLayer,
InFlightCounterService, MakeSpan, MetricsAttributes, TraceContextLayer, TraceContextService,
TraceLayer, TraceService,
}; };
use tokio::sync::Semaphore; use opentelemetry::KeyValue;
use tower::{ use tower::{
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer}, limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
Layer, Layer,
@@ -33,6 +35,8 @@ use tower_http::{
use tracing::Span; use tracing::Span;
pub type ClientService<S> = SetRequestHeader< pub type ClientService<S> = SetRequestHeader<
DurationRecorderService<
InFlightCounterService<
ConcurrencyLimit< ConcurrencyLimit<
FollowRedirect< FollowRedirect<
TraceService< TraceService<
@@ -43,11 +47,19 @@ pub type ClientService<S> = SetRequestHeader<
>, >,
>, >,
>, >,
OnRequestLabels,
>,
OnRequestLabels,
OnResponseLabels,
KeyValue,
>,
HeaderValue, HeaderValue,
>; >;
#[derive(Debug, Clone)] #[derive(Debug, Clone, Default)]
pub struct MakeSpanForRequest; pub struct MakeSpanForRequest {
category: Option<&'static str>,
}
impl<B> MakeSpan<Request<B>> for MakeSpanForRequest { impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
fn make_span(&self, request: &Request<B>) -> Span { fn make_span(&self, request: &Request<B>) -> Span {
@@ -58,6 +70,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
.map(tracing::field::display); .map(tracing::field::display);
let content_length = headers.typed_get().map(|ContentLength(len)| len); let content_length = headers.typed_get().map(|ContentLength(len)| len);
let net_sock_peer_name = request.uri().host(); let net_sock_peer_name = request.uri().host();
let category = self.category.unwrap_or("UNSET");
tracing::info_span!( tracing::info_span!(
"http.client.request", "http.client.request",
@@ -78,6 +91,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
"net.sock.host.port" = tracing::field::Empty, "net.sock.host.port" = tracing::field::Empty,
"user_agent.original" = user_agent, "user_agent.original" = user_agent,
"rust.error" = tracing::field::Empty, "rust.error" = tracing::field::Empty,
"mas.category" = category,
) )
} }
} }
@@ -123,6 +137,42 @@ where
} }
} }
#[derive(Debug, Clone, Default)]
pub struct OnRequestLabels {
category: Option<&'static str>,
}
impl<B> MetricsAttributes<Request<B>> for OnRequestLabels
where
B: 'static,
{
type Iter<'a> = std::array::IntoIter<KeyValue, 3>;
fn attributes<'a>(&'a self, t: &'a Request<B>) -> Self::Iter<'a> {
[
KeyValue::new("http.request.method", t.method().as_str().to_owned()),
KeyValue::new("network.protocol.name", "http"),
KeyValue::new("mas.category", self.category.unwrap_or("UNSET")),
]
.into_iter()
}
}
#[derive(Debug, Clone, Default)]
pub struct OnResponseLabels;
impl<B> MetricsAttributes<Response<B>> for OnResponseLabels
where
B: 'static,
{
type Iter<'a> = std::iter::Once<KeyValue>;
fn attributes<'a>(&'a self, t: &'a Response<B>) -> Self::Iter<'a> {
std::iter::once(KeyValue::new(
"http.response.status_code",
i64::from(t.status().as_u16()),
))
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ClientLayer { pub struct ClientLayer {
user_agent_layer: SetRequestHeaderLayer<HeaderValue>, user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
@@ -131,6 +181,8 @@ pub struct ClientLayer {
trace_layer: TraceLayer<MakeSpanForRequest, EnrichSpanOnResponse, EnrichSpanOnError>, trace_layer: TraceLayer<MakeSpanForRequest, EnrichSpanOnResponse, EnrichSpanOnError>,
trace_context_layer: TraceContextLayer, trace_context_layer: TraceContextLayer,
timeout_layer: TimeoutLayer, timeout_layer: TimeoutLayer,
duration_recorder_layer: DurationRecorderLayer<OnRequestLabels, OnResponseLabels, KeyValue>,
in_flight_counter_layer: InFlightCounterLayer<OnRequestLabels>,
} }
impl Default for ClientLayer { impl Default for ClientLayer {
@@ -142,26 +194,45 @@ impl Default for ClientLayer {
impl ClientLayer { impl ClientLayer {
#[must_use] #[must_use]
pub fn new() -> Self { pub fn new() -> Self {
let semaphore = Arc::new(Semaphore::new(10));
Self::with_semaphore(semaphore)
}
#[must_use]
pub fn with_semaphore(semaphore: Arc<Semaphore>) -> Self {
Self { Self {
user_agent_layer: SetRequestHeaderLayer::overriding( user_agent_layer: SetRequestHeaderLayer::overriding(
USER_AGENT, USER_AGENT,
HeaderValue::from_static("matrix-authentication-service/0.0.1"), HeaderValue::from_static("matrix-authentication-service/0.0.1"),
), ),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore), concurrency_limit_layer: GlobalConcurrencyLimitLayer::new(10),
follow_redirect_layer: FollowRedirectLayer::new(), follow_redirect_layer: FollowRedirectLayer::new(),
trace_layer: TraceLayer::new(MakeSpanForRequest) trace_layer: TraceLayer::new(MakeSpanForRequest::default())
.on_response(EnrichSpanOnResponse) .on_response(EnrichSpanOnResponse)
.on_error(EnrichSpanOnError), .on_error(EnrichSpanOnError),
trace_context_layer: TraceContextLayer::new(), trace_context_layer: TraceContextLayer::new(),
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)), timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
duration_recorder_layer: DurationRecorderLayer::new("http.client.duration")
.on_request(OnRequestLabels::default())
.on_response(OnResponseLabels)
.on_error(KeyValue::new("http.error", true)),
in_flight_counter_layer: InFlightCounterLayer::new("http.client.active_requests")
.on_request(OnRequestLabels::default()),
} }
} }
#[must_use]
pub fn with_category(mut self, category: &'static str) -> Self {
self.trace_layer = TraceLayer::new(MakeSpanForRequest {
category: Some(category),
})
.on_response(EnrichSpanOnResponse)
.on_error(EnrichSpanOnError);
self.duration_recorder_layer = self.duration_recorder_layer.on_request(OnRequestLabels {
category: Some(category),
});
self.in_flight_counter_layer = self.in_flight_counter_layer.on_request(OnRequestLabels {
category: Some(category),
});
self
}
} }
impl<S> Layer<S> for ClientLayer impl<S> Layer<S> for ClientLayer
@@ -173,6 +244,8 @@ where
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
( (
&self.user_agent_layer, &self.user_agent_layer,
&self.duration_recorder_layer,
&self.in_flight_counter_layer,
&self.concurrency_limit_layer, &self.concurrency_limit_layer,
&self.follow_redirect_layer, &self.follow_redirect_layer,
&self.trace_layer, &self.trace_layer,

View File

@@ -33,7 +33,7 @@ mod service;
#[cfg(feature = "client")] #[cfg(feature = "client")]
pub use self::{ pub use self::{
client::{ client::{
make_traced_client, make_traced_connector, make_untraced_client, ClientInitError, make_traced_client, make_traced_connector, make_untraced_client, Client, ClientInitError,
TracedClient, TracedConnector, UntracedClient, UntracedConnector, TracedClient, TracedConnector, UntracedClient, UntracedConnector,
}, },
layers::client::{ClientLayer, ClientService}, layers::client::{ClientLayer, ClientService},

View File

@@ -154,8 +154,7 @@ impl HomeserverConnection for SynapseConnection {
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> { async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
let mut client = self let mut client = self
.http_client_factory .http_client_factory
.client() .client("homeserver.query_user")
.await?
.response_body_to_bytes() .response_body_to_bytes()
.json_response(); .json_response();
@@ -218,8 +217,7 @@ impl HomeserverConnection for SynapseConnection {
let mut client = self let mut client = self
.http_client_factory .http_client_factory
.client() .client("homeserver.provision_user")
.await?
.request_bytes_to_body() .request_bytes_to_body()
.json_request(); .json_request();
@@ -255,8 +253,7 @@ impl HomeserverConnection for SynapseConnection {
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self let mut client = self
.http_client_factory .http_client_factory
.client() .client("homeserver.create_device")
.await?
.request_bytes_to_body() .request_bytes_to_body()
.json_request(); .json_request();
@@ -284,7 +281,7 @@ impl HomeserverConnection for SynapseConnection {
err(Display), err(Display),
)] )]
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self.http_client_factory.client().await?; let mut client = self.http_client_factory.client("homeserver.delete_device");
let request = self let request = self
.delete(&format!( .delete(&format!(
@@ -314,8 +311,7 @@ impl HomeserverConnection for SynapseConnection {
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
let mut client = self let mut client = self
.http_client_factory .http_client_factory
.client() .client("homeserver.delete_user")
.await?
.request_bytes_to_body() .request_bytes_to_body()
.json_request(); .json_request();
@@ -345,8 +341,7 @@ impl HomeserverConnection for SynapseConnection {
async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> {
let mut client = self let mut client = self
.http_client_factory .http_client_factory
.client() .client("homeserver.set_displayname")
.await?
.request_bytes_to_body() .request_bytes_to_body()
.json_request(); .json_request();

View File

@@ -40,7 +40,7 @@ impl InFlightCounterLayer {
pub fn new(name: &'static str) -> Self { pub fn new(name: &'static str) -> Self {
let counter = crate::meter() let counter = crate::meter()
.i64_up_down_counter(name) .i64_up_down_counter(name)
.with_unit(Unit::new("ms")) .with_unit(Unit::new("{request}"))
.with_description("The number of in-flight requests") .with_description("The number of in-flight requests")
.init(); .init();

View File

@@ -69,6 +69,17 @@ where
} }
} }
impl<V, T, const N: usize> MetricsAttributes<T> for [V; N]
where
V: MetricsAttributes<T> + 'static,
T: 'static,
{
type Iter<'a> = Box<dyn Iterator<Item = KeyValue> + 'a>;
fn attributes<'a>(&'a self, t: &'a T) -> Self::Iter<'_> {
Box::new(self.iter().flat_map(|v| v.attributes(t)))
}
}
impl<V, T> MetricsAttributes<T> for KV<V> impl<V, T> MetricsAttributes<T> for KV<V>
where where
V: Into<Value> + Clone + 'static, V: Into<Value> + Clone + 'static,