You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Make the HTTP client factory reuse the underlying client
This avoids duplicating clients, and makes it so that they all share the same connection pool.
This commit is contained in:
@@ -191,8 +191,7 @@ async fn fetch_jwks(
|
||||
.unwrap();
|
||||
|
||||
let mut client = http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("client.fetch_jwks")
|
||||
.response_body_to_bytes()
|
||||
.json_response::<PublicJsonWebKeySet>();
|
||||
|
||||
|
@@ -12,14 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Full;
|
||||
use mas_http::{
|
||||
BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService,
|
||||
TracedClient,
|
||||
make_traced_connector, BodyToBytesResponseLayer, Client, ClientInitError, ClientLayer,
|
||||
ClientService, HttpService, TracedClient, TracedConnector,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
use tower::{
|
||||
util::{MapErrLayer, MapRequestLayer},
|
||||
BoxError, Layer,
|
||||
@@ -27,15 +24,16 @@ use tower::{
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HttpClientFactory {
|
||||
semaphore: Arc<Semaphore>,
|
||||
traced_connector: TracedConnector,
|
||||
client_layer: ClientLayer,
|
||||
}
|
||||
|
||||
impl HttpClientFactory {
|
||||
#[must_use]
|
||||
pub fn new(concurrency_limit: usize) -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(Semaphore::new(concurrency_limit)),
|
||||
}
|
||||
pub async fn new() -> Result<Self, ClientInitError> {
|
||||
Ok(Self {
|
||||
traced_connector: make_traced_connector().await?,
|
||||
client_layer: ClientLayer::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Constructs a new HTTP client
|
||||
@@ -43,14 +41,16 @@ impl HttpClientFactory {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the client failed to initialise
|
||||
pub async fn client<B>(&self) -> Result<ClientService<TracedClient<B>>, ClientInitError>
|
||||
pub fn client<B>(&self, category: &'static str) -> ClientService<TracedClient<B>>
|
||||
where
|
||||
B: axum::body::HttpBody + Send,
|
||||
B::Data: Send,
|
||||
{
|
||||
let client = mas_http::make_traced_client::<B>().await?;
|
||||
let layer = ClientLayer::with_semaphore(self.semaphore.clone());
|
||||
Ok(layer.layer(client))
|
||||
let client = Client::builder().build(self.traced_connector.clone());
|
||||
self.client_layer
|
||||
.clone()
|
||||
.with_category(category)
|
||||
.layer(client)
|
||||
}
|
||||
|
||||
/// Constructs a new [`HttpService`], suitable for `mas-oidc-client`
|
||||
@@ -58,8 +58,8 @@ impl HttpClientFactory {
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the client failed to initialise
|
||||
pub async fn http_service(&self) -> Result<HttpService, ClientInitError> {
|
||||
let client = self.client().await?;
|
||||
pub fn http_service(&self, category: &'static str) -> HttpService {
|
||||
let client = self.client(category);
|
||||
let client = (
|
||||
MapErrLayer::new(BoxError::from),
|
||||
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
|
||||
@@ -67,6 +67,6 @@ impl HttpClientFactory {
|
||||
)
|
||||
.layer(client);
|
||||
|
||||
Ok(HttpService::new(client))
|
||||
HttpService::new(client)
|
||||
}
|
||||
}
|
||||
|
@@ -67,7 +67,7 @@ impl Options {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
let http_client_factory = HttpClientFactory::new(10);
|
||||
let http_client_factory = HttpClientFactory::new().await?;
|
||||
match self.subcommand {
|
||||
SC::Http {
|
||||
show_headers,
|
||||
@@ -75,7 +75,7 @@ impl Options {
|
||||
url,
|
||||
} => {
|
||||
let _span = info_span!("cli.debug.http").entered();
|
||||
let mut client = http_client_factory.client().await?;
|
||||
let mut client = http_client_factory.client("debug");
|
||||
let request = hyper::Request::builder()
|
||||
.uri(url)
|
||||
.body(hyper::Body::empty())?;
|
||||
@@ -99,8 +99,7 @@ impl Options {
|
||||
} => {
|
||||
let _span = info_span!("cli.debug.http").entered();
|
||||
let mut client = http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("debug")
|
||||
.response_body_to_bytes()
|
||||
.json_response();
|
||||
let request = hyper::Request::builder()
|
||||
|
@@ -97,6 +97,8 @@ impl Options {
|
||||
// Load and compile the templates
|
||||
let templates = templates_from_config(&config.templates, &url_builder).await?;
|
||||
|
||||
let http_client_factory = HttpClientFactory::new().await?;
|
||||
|
||||
if !self.no_worker {
|
||||
let mailer = mailer_from_config(&config.email, &templates)?;
|
||||
mailer.test_connection().await?;
|
||||
@@ -105,15 +107,12 @@ impl Options {
|
||||
let mut rng = thread_rng();
|
||||
let worker_name = Alphanumeric.sample_string(&mut rng, 10);
|
||||
|
||||
// Maximum 50 outgoing HTTP requests at a time
|
||||
let http_client_factory = HttpClientFactory::new(50);
|
||||
|
||||
info!(worker_name, "Starting task worker");
|
||||
let conn = SynapseConnection::new(
|
||||
config.matrix.homeserver.clone(),
|
||||
config.matrix.endpoint.clone(),
|
||||
config.matrix.secret.clone(),
|
||||
http_client_factory,
|
||||
http_client_factory.clone(),
|
||||
);
|
||||
let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?;
|
||||
// TODO: grab the handle
|
||||
@@ -126,9 +125,6 @@ impl Options {
|
||||
|
||||
let password_manager = password_manager_from_config(&config.passwords).await?;
|
||||
|
||||
// Maximum 50 outgoing HTTP requests at a time
|
||||
let http_client_factory = HttpClientFactory::new(50);
|
||||
|
||||
// The upstream OIDC metadata cache
|
||||
let metadata_cache = MetadataCache::new();
|
||||
|
||||
|
@@ -49,7 +49,7 @@ impl Options {
|
||||
let mailer = mailer_from_config(&config.email, &templates)?;
|
||||
mailer.test_connection().await?;
|
||||
|
||||
let http_client_factory = HttpClientFactory::new(50);
|
||||
let http_client_factory = HttpClientFactory::new().await?;
|
||||
let conn = SynapseConnection::new(
|
||||
config.matrix.homeserver.clone(),
|
||||
config.matrix.endpoint.clone(),
|
||||
|
@@ -122,9 +122,7 @@ impl AppState {
|
||||
|
||||
let http_service = self
|
||||
.http_client_factory
|
||||
.http_service()
|
||||
.await
|
||||
.expect("Failed to create the HTTP service");
|
||||
.http_service("upstream_oauth2.metadata");
|
||||
|
||||
self.metadata_cache
|
||||
.warm_up_and_run(
|
||||
|
@@ -142,7 +142,7 @@ impl TestState {
|
||||
|
||||
let homeserver_connection = MockHomeserverConnection::new("example.com");
|
||||
|
||||
let http_client_factory = HttpClientFactory::new(10);
|
||||
let http_client_factory = HttpClientFactory::new().await?;
|
||||
|
||||
let site_config = SiteConfig::default();
|
||||
|
||||
|
@@ -84,7 +84,7 @@ pub(crate) async fn get(
|
||||
.await?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let http_service = http_client_factory.http_service().await?;
|
||||
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
|
||||
|
||||
// First, discover the provider
|
||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
||||
|
@@ -188,7 +188,7 @@ pub(crate) async fn get(
|
||||
CodeOrError::Code { code } => code,
|
||||
};
|
||||
|
||||
let http_service = http_client_factory.http_service().await?;
|
||||
let http_service = http_client_factory.http_service("upstream_oauth2.callback");
|
||||
|
||||
// Discover the provider
|
||||
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
|
||||
|
@@ -14,13 +14,11 @@
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use hyper::{
|
||||
client::{
|
||||
use hyper::client::{
|
||||
connect::dns::{GaiResolver, Name},
|
||||
HttpConnector,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
pub use hyper::Client;
|
||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||
use mas_tower::{
|
||||
DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer,
|
||||
|
@@ -12,15 +12,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
|
||||
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
||||
use hyper::client::connect::HttpInfo;
|
||||
use mas_tower::{
|
||||
EnrichSpan, MakeSpan, TraceContextLayer, TraceContextService, TraceLayer, TraceService,
|
||||
DurationRecorderLayer, DurationRecorderService, EnrichSpan, InFlightCounterLayer,
|
||||
InFlightCounterService, MakeSpan, MetricsAttributes, TraceContextLayer, TraceContextService,
|
||||
TraceLayer, TraceService,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
use opentelemetry::KeyValue;
|
||||
use tower::{
|
||||
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
|
||||
Layer,
|
||||
@@ -33,6 +35,8 @@ use tower_http::{
|
||||
use tracing::Span;
|
||||
|
||||
pub type ClientService<S> = SetRequestHeader<
|
||||
DurationRecorderService<
|
||||
InFlightCounterService<
|
||||
ConcurrencyLimit<
|
||||
FollowRedirect<
|
||||
TraceService<
|
||||
@@ -43,11 +47,19 @@ pub type ClientService<S> = SetRequestHeader<
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
OnRequestLabels,
|
||||
>,
|
||||
OnRequestLabels,
|
||||
OnResponseLabels,
|
||||
KeyValue,
|
||||
>,
|
||||
HeaderValue,
|
||||
>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MakeSpanForRequest;
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MakeSpanForRequest {
|
||||
category: Option<&'static str>,
|
||||
}
|
||||
|
||||
impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
|
||||
fn make_span(&self, request: &Request<B>) -> Span {
|
||||
@@ -58,6 +70,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
|
||||
.map(tracing::field::display);
|
||||
let content_length = headers.typed_get().map(|ContentLength(len)| len);
|
||||
let net_sock_peer_name = request.uri().host();
|
||||
let category = self.category.unwrap_or("UNSET");
|
||||
|
||||
tracing::info_span!(
|
||||
"http.client.request",
|
||||
@@ -78,6 +91,7 @@ impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
|
||||
"net.sock.host.port" = tracing::field::Empty,
|
||||
"user_agent.original" = user_agent,
|
||||
"rust.error" = tracing::field::Empty,
|
||||
"mas.category" = category,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -123,6 +137,42 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OnRequestLabels {
|
||||
category: Option<&'static str>,
|
||||
}
|
||||
|
||||
impl<B> MetricsAttributes<Request<B>> for OnRequestLabels
|
||||
where
|
||||
B: 'static,
|
||||
{
|
||||
type Iter<'a> = std::array::IntoIter<KeyValue, 3>;
|
||||
fn attributes<'a>(&'a self, t: &'a Request<B>) -> Self::Iter<'a> {
|
||||
[
|
||||
KeyValue::new("http.request.method", t.method().as_str().to_owned()),
|
||||
KeyValue::new("network.protocol.name", "http"),
|
||||
KeyValue::new("mas.category", self.category.unwrap_or("UNSET")),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OnResponseLabels;
|
||||
|
||||
impl<B> MetricsAttributes<Response<B>> for OnResponseLabels
|
||||
where
|
||||
B: 'static,
|
||||
{
|
||||
type Iter<'a> = std::iter::Once<KeyValue>;
|
||||
fn attributes<'a>(&'a self, t: &'a Response<B>) -> Self::Iter<'a> {
|
||||
std::iter::once(KeyValue::new(
|
||||
"http.response.status_code",
|
||||
i64::from(t.status().as_u16()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientLayer {
|
||||
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
|
||||
@@ -131,6 +181,8 @@ pub struct ClientLayer {
|
||||
trace_layer: TraceLayer<MakeSpanForRequest, EnrichSpanOnResponse, EnrichSpanOnError>,
|
||||
trace_context_layer: TraceContextLayer,
|
||||
timeout_layer: TimeoutLayer,
|
||||
duration_recorder_layer: DurationRecorderLayer<OnRequestLabels, OnResponseLabels, KeyValue>,
|
||||
in_flight_counter_layer: InFlightCounterLayer<OnRequestLabels>,
|
||||
}
|
||||
|
||||
impl Default for ClientLayer {
|
||||
@@ -142,26 +194,45 @@ impl Default for ClientLayer {
|
||||
impl ClientLayer {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
let semaphore = Arc::new(Semaphore::new(10));
|
||||
Self::with_semaphore(semaphore)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_semaphore(semaphore: Arc<Semaphore>) -> Self {
|
||||
Self {
|
||||
user_agent_layer: SetRequestHeaderLayer::overriding(
|
||||
USER_AGENT,
|
||||
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
|
||||
),
|
||||
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
|
||||
concurrency_limit_layer: GlobalConcurrencyLimitLayer::new(10),
|
||||
follow_redirect_layer: FollowRedirectLayer::new(),
|
||||
trace_layer: TraceLayer::new(MakeSpanForRequest)
|
||||
trace_layer: TraceLayer::new(MakeSpanForRequest::default())
|
||||
.on_response(EnrichSpanOnResponse)
|
||||
.on_error(EnrichSpanOnError),
|
||||
trace_context_layer: TraceContextLayer::new(),
|
||||
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
|
||||
duration_recorder_layer: DurationRecorderLayer::new("http.client.duration")
|
||||
.on_request(OnRequestLabels::default())
|
||||
.on_response(OnResponseLabels)
|
||||
.on_error(KeyValue::new("http.error", true)),
|
||||
in_flight_counter_layer: InFlightCounterLayer::new("http.client.active_requests")
|
||||
.on_request(OnRequestLabels::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_category(mut self, category: &'static str) -> Self {
|
||||
self.trace_layer = TraceLayer::new(MakeSpanForRequest {
|
||||
category: Some(category),
|
||||
})
|
||||
.on_response(EnrichSpanOnResponse)
|
||||
.on_error(EnrichSpanOnError);
|
||||
|
||||
self.duration_recorder_layer = self.duration_recorder_layer.on_request(OnRequestLabels {
|
||||
category: Some(category),
|
||||
});
|
||||
|
||||
self.in_flight_counter_layer = self.in_flight_counter_layer.on_request(OnRequestLabels {
|
||||
category: Some(category),
|
||||
});
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for ClientLayer
|
||||
@@ -173,6 +244,8 @@ where
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
(
|
||||
&self.user_agent_layer,
|
||||
&self.duration_recorder_layer,
|
||||
&self.in_flight_counter_layer,
|
||||
&self.concurrency_limit_layer,
|
||||
&self.follow_redirect_layer,
|
||||
&self.trace_layer,
|
||||
|
@@ -33,7 +33,7 @@ mod service;
|
||||
#[cfg(feature = "client")]
|
||||
pub use self::{
|
||||
client::{
|
||||
make_traced_client, make_traced_connector, make_untraced_client, ClientInitError,
|
||||
make_traced_client, make_traced_connector, make_untraced_client, Client, ClientInitError,
|
||||
TracedClient, TracedConnector, UntracedClient, UntracedConnector,
|
||||
},
|
||||
layers::client::{ClientLayer, ClientService},
|
||||
|
@@ -154,8 +154,7 @@ impl HomeserverConnection for SynapseConnection {
|
||||
async fn query_user(&self, mxid: &str) -> Result<MatrixUser, Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("homeserver.query_user")
|
||||
.response_body_to_bytes()
|
||||
.json_response();
|
||||
|
||||
@@ -218,8 +217,7 @@ impl HomeserverConnection for SynapseConnection {
|
||||
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("homeserver.provision_user")
|
||||
.request_bytes_to_body()
|
||||
.json_request();
|
||||
|
||||
@@ -255,8 +253,7 @@ impl HomeserverConnection for SynapseConnection {
|
||||
async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("homeserver.create_device")
|
||||
.request_bytes_to_body()
|
||||
.json_request();
|
||||
|
||||
@@ -284,7 +281,7 @@ impl HomeserverConnection for SynapseConnection {
|
||||
err(Display),
|
||||
)]
|
||||
async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> {
|
||||
let mut client = self.http_client_factory.client().await?;
|
||||
let mut client = self.http_client_factory.client("homeserver.delete_device");
|
||||
|
||||
let request = self
|
||||
.delete(&format!(
|
||||
@@ -314,8 +311,7 @@ impl HomeserverConnection for SynapseConnection {
|
||||
async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("homeserver.delete_user")
|
||||
.request_bytes_to_body()
|
||||
.json_request();
|
||||
|
||||
@@ -345,8 +341,7 @@ impl HomeserverConnection for SynapseConnection {
|
||||
async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> {
|
||||
let mut client = self
|
||||
.http_client_factory
|
||||
.client()
|
||||
.await?
|
||||
.client("homeserver.set_displayname")
|
||||
.request_bytes_to_body()
|
||||
.json_request();
|
||||
|
||||
|
@@ -40,7 +40,7 @@ impl InFlightCounterLayer {
|
||||
pub fn new(name: &'static str) -> Self {
|
||||
let counter = crate::meter()
|
||||
.i64_up_down_counter(name)
|
||||
.with_unit(Unit::new("ms"))
|
||||
.with_unit(Unit::new("{request}"))
|
||||
.with_description("The number of in-flight requests")
|
||||
.init();
|
||||
|
||||
|
@@ -69,6 +69,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<V, T, const N: usize> MetricsAttributes<T> for [V; N]
|
||||
where
|
||||
V: MetricsAttributes<T> + 'static,
|
||||
T: 'static,
|
||||
{
|
||||
type Iter<'a> = Box<dyn Iterator<Item = KeyValue> + 'a>;
|
||||
fn attributes<'a>(&'a self, t: &'a T) -> Self::Iter<'_> {
|
||||
Box::new(self.iter().flat_map(|v| v.attributes(t)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<V, T> MetricsAttributes<T> for KV<V>
|
||||
where
|
||||
V: Into<Value> + Clone + 'static,
|
||||
|
Reference in New Issue
Block a user