1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Make the HTTP client factory reuse the underlying client

This avoids duplicating clients, and makes it so that they all share the same connection pool.
This commit is contained in:
Quentin Gliech
2023-09-14 14:22:49 +02:00
parent f29e4adcfa
commit 54071c4969
15 changed files with 146 additions and 77 deletions

View File

@@ -191,8 +191,7 @@ async fn fetch_jwks(
.unwrap();
let mut client = http_client_factory
.client()
.await?
.client("client.fetch_jwks")
.response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>();

View File

@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::body::Full;
use mas_http::{
BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService,
TracedClient,
make_traced_connector, BodyToBytesResponseLayer, Client, ClientInitError, ClientLayer,
ClientService, HttpService, TracedClient, TracedConnector,
};
use tokio::sync::Semaphore;
use tower::{
util::{MapErrLayer, MapRequestLayer},
BoxError, Layer,
@@ -27,15 +24,16 @@ use tower::{
#[derive(Debug, Clone)]
pub struct HttpClientFactory {
semaphore: Arc<Semaphore>,
traced_connector: TracedConnector,
client_layer: ClientLayer,
}
impl HttpClientFactory {
#[must_use]
pub fn new(concurrency_limit: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(concurrency_limit)),
}
pub async fn new() -> Result<Self, ClientInitError> {
Ok(Self {
traced_connector: make_traced_connector().await?,
client_layer: ClientLayer::new(),
})
}
/// Constructs a new HTTP client
@@ -43,14 +41,16 @@ impl HttpClientFactory {
/// # Errors
///
/// Returns an error if the client failed to initialise
pub async fn client<B>(&self) -> Result<ClientService<TracedClient<B>>, ClientInitError>
pub fn client<B>(&self, category: &'static str) -> ClientService<TracedClient<B>>
where
B: axum::body::HttpBody + Send,
B::Data: Send,
{
let client = mas_http::make_traced_client::<B>().await?;
let layer = ClientLayer::with_semaphore(self.semaphore.clone());
Ok(layer.layer(client))
let client = Client::builder().build(self.traced_connector.clone());
self.client_layer
.clone()
.with_category(category)
.layer(client)
}
/// Constructs a new [`HttpService`], suitable for `mas-oidc-client`
@@ -58,8 +58,8 @@ impl HttpClientFactory {
/// # Errors
///
/// Returns an error if the client failed to initialise
pub async fn http_service(&self) -> Result<HttpService, ClientInitError> {
let client = self.client().await?;
pub fn http_service(&self, category: &'static str) -> HttpService {
let client = self.client(category);
let client = (
MapErrLayer::new(BoxError::from),
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
@@ -67,6 +67,6 @@ impl HttpClientFactory {
)
.layer(client);
Ok(HttpService::new(client))
HttpService::new(client)
}
}

View File

@@ -67,7 +67,7 @@ impl Options {
#[tracing::instrument(skip_all)]
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC;
let http_client_factory = HttpClientFactory::new(10);
let http_client_factory = HttpClientFactory::new().await?;
match self.subcommand {
SC::Http {
show_headers,
@@ -75,7 +75,7 @@ impl Options {
url,
} => {
let _span = info_span!("cli.debug.http").entered();
let mut client = http_client_factory.client().await?;
let mut client = http_client_factory.client("debug");
let request = hyper::Request::builder()
.uri(url)
.body(hyper::Body::empty())?;
@@ -99,8 +99,7 @@ impl Options {
} => {
let _span = info_span!("cli.debug.http").entered();
let mut client = http_client_factory
.client()
.await?
.client("debug")
.response_body_to_bytes()
.json_response();
let request = hyper::Request::builder()

View File

@@ -97,6 +97,8 @@ impl Options {
// Load and compile the templates
let templates = templates_from_config(&config.templates, &url_builder).await?;
let http_client_factory = HttpClientFactory::new().await?;
if !self.no_worker {
let mailer = mailer_from_config(&config.email, &templates)?;
mailer.test_connection().await?;
@@ -105,15 +107,12 @@ impl Options {
let mut rng = thread_rng();
let worker_name = Alphanumeric.sample_string(&mut rng, 10);
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
info!(worker_name, "Starting task worker");
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client_factory,
http_client_factory.clone(),
);
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
// TODO: grab the handle
@@ -126,9 +125,6 @@ impl Options {
let password_manager = password_manager_from_config(&config.passwords).await?;
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
// The upstream OIDC metadata cache
let metadata_cache = MetadataCache::new();

View File

@@ -49,7 +49,7 @@ impl Options {
let mailer = mailer_from_config(&config.email, &templates)?;
mailer.test_connection().await?;
let http_client_factory = HttpClientFactory::new(50);
let http_client_factory = HttpClientFactory::new().await?;
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),

View File

@@ -122,9 +122,7 @@ impl AppState {
let http_service = self
.http_client_factory
.http_service()
.await
.expect("Failed to create the HTTP service");
.http_service("upstream_oauth2.metadata");
self.metadata_cache
.warm_up_and_run(

View File

@@ -142,7 +142,7 @@ impl TestState {
let homeserver_connection = MockHomeserverConnection::new("example.com");
let http_client_factory = HttpClientFactory::new(10);
let http_client_factory = HttpClientFactory::new().await?;
let site_config = SiteConfig::default();

View File

@@ -84,7 +84,7 @@ pub(crate) async fn get(
.await?
.ok_or(RouteError::ProviderNotFound)?;
let http_service = http_client_factory.http_service().await?;
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
// First, discover the provider
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;

View File

@@ -188,7 +188,7 @@ pub(crate) async fn get(
CodeOrError::Code { code } => code,
};
let http_service = http_client_factory.http_service().await?;
let http_service = http_client_factory.http_service("upstream_oauth2.callback");
// Discover the provider
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;

View File

@@ -14,13 +14,11 @@
use std::convert::Infallible;
use hyper::{
client::{
connect::dns::{GaiResolver, Name},
HttpConnector,
},
Client,
use hyper::client::{
connect::dns::{GaiResolver, Name},
HttpConnector,
};
pub use hyper::Client;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use mas_tower::{
DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer,

View File

@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{sync::Arc, time::Duration};
use std::time::Duration;
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use hyper::client::connect::HttpInfo;
use mas_tower::{
EnrichSpan, MakeSpan, TraceContextLayer, TraceContextService, TraceLayer, TraceService,
DurationRecorderLayer, DurationRecorderService, EnrichSpan, InFlightCounterLayer,
InFlightCounterService, MakeSpan, MetricsAttributes, TraceContextLayer, TraceContextService,
TraceLayer, TraceService,
};
use tokio::sync::Semaphore;
use opentelemetry::KeyValue;
use tower::{
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
Layer,
@@ -33,21 +35,31 @@ use tower_http::{
use tracing::Span;
pub type ClientService<S> = SetRequestHeader<
ConcurrencyLimit<
FollowRedirect<
TraceService<
TraceContextService<Timeout<S>>,
MakeSpanForRequest,
EnrichSpanOnResponse,
EnrichSpanOnError,
DurationRecorderService<
InFlightCounterService<
ConcurrencyLimit<
FollowRedirect<
TraceService<
TraceContextService<Timeout<S>>,
MakeSpanForRequest,
EnrichSpanOnResponse,
EnrichSpanOnError,
>,
>,
>,
OnRequestLabels,
>,
OnRequestLabels,
OnResponseLabels,
KeyValue,
>,
HeaderValue,
>;
#[derive(Debug, Clone)]
pub struct MakeSpanForRequest;
#[derive(Debug, Clone, Default)]
pub struct MakeSpanForRequest {
category: Option<&'static str>,
}
impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
fn make_span(&self, request: &Request<B>) -> Span {
@@ -58,6 +70,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
.map(tracing::field::display);
let content_length = headers.typed_get().map(|ContentLength(len)| len);
let net_sock_peer_name = request.uri().host();
let category = self.category.unwrap_or("UNSET");
tracing::info_span!(
"http.client.request",
@@ -78,6 +91,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
"net.sock.host.port" = tracing::field::Empty,
"user_agent.original" = user_agent,
"rust.error" = tracing::field::Empty,
"mas.category" = category,
)
}
}
@@ -123,6 +137,42 @@ where
}
}
#[derive(Debug, Clone, Default)]
pub struct OnRequestLabels {
category: Option<&'static str>,
}
impl<B> MetricsAttributes<Request<B>> for OnRequestLabels
where
B: 'static,
{
type Iter<'a> = std::array::IntoIter<KeyValue, 3>;
fn attributes<'a>(&'a self, t: &'a Request<B>) -> Self::Iter<'a> {
[
KeyValue::new("http.request.method", t.method().as_str().to_owned()),
KeyValue::new("network.protocol.name", "http"),
KeyValue::new("mas.category", self.category.unwrap_or("UNSET")),
]
.into_iter()
}
}
#[derive(Debug, Clone, Default)]
pub struct OnResponseLabels;
impl<B> MetricsAttributes<Response<B>> for OnResponseLabels
where
B: 'static,
{
type Iter<'a> = std::iter::Once<KeyValue>;
fn attributes<'a>(&'a self, t: &'a Response<B>) -> Self::Iter<'a> {
std::iter::once(KeyValue::new(
"http.response.status_code",
i64::from(t.status().as_u16()),
))
}
}
#[derive(Debug, Clone)]
pub struct ClientLayer {
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
@@ -131,6 +181,8 @@ pub struct ClientLayer {
trace_layer: TraceLayer<MakeSpanForRequest, EnrichSpanOnResponse, EnrichSpanOnError>,
trace_context_layer: TraceContextLayer,
timeout_layer: TimeoutLayer,
duration_recorder_layer: DurationRecorderLayer<OnRequestLabels, OnResponseLabels, KeyValue>,
in_flight_counter_layer: InFlightCounterLayer<OnRequestLabels>,
}
impl Default for ClientLayer {
@@ -142,26 +194,45 @@ impl Default for ClientLayer {
impl ClientLayer {
#[must_use]
pub fn new() -> Self {
let semaphore = Arc::new(Semaphore::new(10));
Self::with_semaphore(semaphore)
}
#[must_use]
pub fn with_semaphore(semaphore: Arc<Semaphore>) -> Self {
Self {
user_agent_layer: SetRequestHeaderLayer::overriding(
USER_AGENT,
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::new(10),
follow_redirect_layer: FollowRedirectLayer::new(),
trace_layer: TraceLayer::new(MakeSpanForRequest)
trace_layer: TraceLayer::new(MakeSpanForRequest::default())
.on_response(EnrichSpanOnResponse)
.on_error(EnrichSpanOnError),
trace_context_layer: TraceContextLayer::new(),
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
duration_recorder_layer: DurationRecorderLayer::new("http.client.duration")
.on_request(OnRequestLabels::default())
.on_response(OnResponseLabels)
.on_error(KeyValue::new("http.error", true)),
in_flight_counter_layer: InFlightCounterLayer::new("http.client.active_requests")
.on_request(OnRequestLabels::default()),
}
}
#[must_use]
pub fn with_category(mut self, category: &'static str) -> Self {
self.trace_layer = TraceLayer::new(MakeSpanForRequest {
category: Some(category),
})
.on_response(EnrichSpanOnResponse)
.on_error(EnrichSpanOnError);
self.duration_recorder_layer = self.duration_recorder_layer.on_request(OnRequestLabels {
category: Some(category),
});
self.in_flight_counter_layer = self.in_flight_counter_layer.on_request(OnRequestLabels {
category: Some(category),
});
self
}
}
impl<S> Layer<S> for ClientLayer
@@ -173,6 +244,8 @@ where
fn layer(&self, inner: S) -> Self::Service {
(
&self.user_agent_layer,
&self.duration_recorder_layer,
&self.in_flight_counter_layer,
&self.concurrency_limit_layer,
&self.follow_redirect_layer,
&self.trace_layer,

View File

@@ -33,7 +33,7 @@ mod service;
#[cfg(feature = "client")]
pub use self::{
client::{
make_traced_client, make_traced_connector, make_untraced_client, ClientInitError,
make_traced_client, make_traced_connector, make_untraced_client, Client, ClientInitError,
TracedClient, TracedConnector, UntracedClient, UntracedConnector,
},
layers::client::{ClientLayer, ClientService},

View File

@@ -154,8 +154,7 @@ impl HomeserverConnection for SynapseConnection {
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.client("homeserver.query_user")
.response_body_to_bytes()
.json_response();
@@ -218,8 +217,7 @@ impl HomeserverConnection for SynapseConnection {
let mut client = self
.http_client_factory
.client()
.await?
.client("homeserver.provision_user")
.request_bytes_to_body()
.json_request();
@@ -255,8 +253,7 @@ impl HomeserverConnection for SynapseConnection {
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.client("homeserver.create_device")
.request_bytes_to_body()
.json_request();
@@ -284,7 +281,7 @@ impl HomeserverConnection for SynapseConnection {
err(Display),
)]
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
let mut client = self.http_client_factory.client().await?;
let mut client = self.http_client_factory.client("homeserver.delete_device");
let request = self
.delete(&format!(
@@ -314,8 +311,7 @@ impl HomeserverConnection for SynapseConnection {
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.client("homeserver.delete_user")
.request_bytes_to_body()
.json_request();
@@ -345,8 +341,7 @@ impl HomeserverConnection for SynapseConnection {
async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> {
let mut client = self
.http_client_factory
.client()
.await?
.client("homeserver.set_displayname")
.request_bytes_to_body()
.json_request();

View File

@@ -40,7 +40,7 @@ impl InFlightCounterLayer {
pub fn new(name: &'static str) -> Self {
let counter = crate::meter()
.i64_up_down_counter(name)
.with_unit(Unit::new("ms"))
.with_unit(Unit::new("{request}"))
.with_description("The number of in-flight requests")
.init();

View File

@@ -69,6 +69,17 @@ where
}
}
impl<V, T, const N: usize> MetricsAttributes<T> for [V; N]
where
V: MetricsAttributes<T> + 'static,
T: 'static,
{
type Iter<'a> = Box<dyn Iterator<Item = KeyValue> + 'a>;
fn attributes<'a>(&'a self, t: &'a T) -> Self::Iter<'_> {
Box::new(self.iter().flat_map(|v| v.attributes(t)))
}
}
impl<V, T> MetricsAttributes<T> for KV<V>
where
V: Into<Value> + Clone + 'static,