diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index d281709d..43ec86c1 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -191,8 +191,7 @@ async fn fetch_jwks( .unwrap(); let mut client = http_client_factory - .client() - .await? + .client("client.fetch_jwks") .response_body_to_bytes() .json_response::(); diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs index e6644407..16337a1a 100644 --- a/crates/axum-utils/src/http_client_factory.rs +++ b/crates/axum-utils/src/http_client_factory.rs @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::body::Full; use mas_http::{ - BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService, - TracedClient, + make_traced_connector, BodyToBytesResponseLayer, Client, ClientInitError, ClientLayer, + ClientService, HttpService, TracedClient, TracedConnector, }; -use tokio::sync::Semaphore; use tower::{ util::{MapErrLayer, MapRequestLayer}, BoxError, Layer, @@ -27,15 +24,16 @@ use tower::{ #[derive(Debug, Clone)] pub struct HttpClientFactory { - semaphore: Arc, + traced_connector: TracedConnector, + client_layer: ClientLayer, } impl HttpClientFactory { - #[must_use] - pub fn new(concurrency_limit: usize) -> Self { - Self { - semaphore: Arc::new(Semaphore::new(concurrency_limit)), - } + pub async fn new() -> Result { + Ok(Self { + traced_connector: make_traced_connector().await?, + client_layer: ClientLayer::new(), + }) } /// Constructs a new HTTP client @@ -43,14 +41,16 @@ impl HttpClientFactory { /// # Errors /// /// Returns an error if the client failed to initialise - pub async fn client(&self) -> Result>, ClientInitError> + pub fn client(&self, category: &'static str) -> ClientService> where B: axum::body::HttpBody + Send, B::Data: Send, { - let client = mas_http::make_traced_client::().await?; - let layer = ClientLayer::with_semaphore(self.semaphore.clone()); - Ok(layer.layer(client)) + let client = Client::builder().build(self.traced_connector.clone()); + self.client_layer + .clone() + .with_category(category) + .layer(client) } /// Constructs a new [`HttpService`], suitable for `mas-oidc-client` @@ -58,8 +58,8 @@ impl HttpClientFactory { /// # Errors /// /// Returns an error if the client failed to initialise - pub async fn http_service(&self) -> Result { - let client = self.client().await?; + pub fn http_service(&self, category: &'static str) -> HttpService { + let client = self.client(category); let client = ( MapErrLayer::new(BoxError::from), MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), @@ -67,6 +67,6 @@ impl HttpClientFactory { ) .layer(client); - Ok(HttpService::new(client)) + HttpService::new(client) } } diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 127a541b..d435b01c 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -67,7 +67,7 @@ impl Options { #[tracing::instrument(skip_all)] pub async fn run(self, root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; - let http_client_factory = HttpClientFactory::new(10); + let http_client_factory = HttpClientFactory::new().await?; match self.subcommand { SC::Http { show_headers, @@ -75,7 +75,7 @@ impl Options { url, } => { let _span = info_span!("cli.debug.http").entered(); - let mut client = http_client_factory.client().await?; + let mut client = http_client_factory.client("debug"); let request = hyper::Request::builder() .uri(url) .body(hyper::Body::empty())?; @@ -99,8 +99,7 @@ impl Options { } => { let _span = info_span!("cli.debug.http").entered(); let mut client = http_client_factory - .client() - .await? + .client("debug") .response_body_to_bytes() .json_response(); let request = hyper::Request::builder() diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index f005c452..4766d7e9 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -97,6 +97,8 @@ impl Options { // Load and compile the templates let templates = templates_from_config(&config.templates, &url_builder).await?; + let http_client_factory = HttpClientFactory::new().await?; + if !self.no_worker { let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; @@ -105,15 +107,12 @@ impl Options { let mut rng = thread_rng(); let worker_name = Alphanumeric.sample_string(&mut rng, 10); - // Maximum 50 outgoing HTTP requests at a time - let http_client_factory = HttpClientFactory::new(50); - info!(worker_name, "Starting task worker"); let conn = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), config.matrix.secret.clone(), - http_client_factory, + http_client_factory.clone(), ); let monitor = mas_tasks::init(&worker_name, &pool, &mailer, conn).await?; // TODO: grab the handle @@ -126,9 +125,6 @@ impl Options { let password_manager = password_manager_from_config(&config.passwords).await?; - // Maximum 50 outgoing HTTP requests at a time - let http_client_factory = HttpClientFactory::new(50); - // The upstream OIDC metadata cache let metadata_cache = MetadataCache::new(); diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index e55dc27f..25e2bce7 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -49,7 +49,7 @@ impl Options { let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; - let http_client_factory = HttpClientFactory::new(50); + let http_client_factory = HttpClientFactory::new().await?; let conn = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index b887af94..d53e3188 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -122,9 +122,7 @@ impl AppState { let http_service = self .http_client_factory - .http_service() - .await - .expect("Failed to create the HTTP service"); + .http_service("upstream_oauth2.metadata"); self.metadata_cache .warm_up_and_run( diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index e5c0bad7..ec75e290 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -142,7 +142,7 @@ impl TestState { let homeserver_connection = MockHomeserverConnection::new("example.com"); - let http_client_factory = HttpClientFactory::new(10); + let http_client_factory = HttpClientFactory::new().await?; let site_config = SiteConfig::default(); diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index abf40513..580afcfe 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -84,7 +84,7 @@ pub(crate) async fn get( .await? .ok_or(RouteError::ProviderNotFound)?; - let http_service = http_client_factory.http_service().await?; + let http_service = http_client_factory.http_service("upstream_oauth2.authorize"); // First, discover the provider let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 58129712..41d90ab3 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -188,7 +188,7 @@ pub(crate) async fn get( CodeOrError::Code { code } => code, }; - let http_service = http_client_factory.http_service().await?; + let http_service = http_client_factory.http_service("upstream_oauth2.callback"); // Discover the provider let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs index c63ce8c0..ee23b77c 100644 --- a/crates/http/src/client.rs +++ b/crates/http/src/client.rs @@ -14,13 +14,11 @@ use std::convert::Infallible; -use hyper::{ - client::{ - connect::dns::{GaiResolver, Name}, - HttpConnector, - }, - Client, +use hyper::client::{ + connect::dns::{GaiResolver, Name}, + HttpConnector, }; +pub use hyper::Client; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use mas_tower::{ DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer, diff --git a/crates/http/src/layers/client.rs b/crates/http/src/layers/client.rs index a44eff7e..19ace214 100644 --- a/crates/http/src/layers/client.rs +++ b/crates/http/src/layers/client.rs @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use headers::{ContentLength, HeaderMapExt, Host, UserAgent}; use http::{header::USER_AGENT, HeaderValue, Request, Response}; use hyper::client::connect::HttpInfo; use mas_tower::{ - EnrichSpan, MakeSpan, TraceContextLayer, TraceContextService, TraceLayer, TraceService, + DurationRecorderLayer, DurationRecorderService, EnrichSpan, InFlightCounterLayer, + InFlightCounterService, MakeSpan, MetricsAttributes, TraceContextLayer, TraceContextService, + TraceLayer, TraceService, }; -use tokio::sync::Semaphore; +use opentelemetry::KeyValue; use tower::{ limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer}, Layer, @@ -33,21 +35,31 @@ use tower_http::{ use tracing::Span; pub type ClientService = SetRequestHeader< - ConcurrencyLimit< - FollowRedirect< - TraceService< - TraceContextService>, - MakeSpanForRequest, - EnrichSpanOnResponse, - EnrichSpanOnError, + DurationRecorderService< + InFlightCounterService< + ConcurrencyLimit< + FollowRedirect< + TraceService< + TraceContextService>, + MakeSpanForRequest, + EnrichSpanOnResponse, + EnrichSpanOnError, + >, + >, >, + OnRequestLabels, >, + OnRequestLabels, + OnResponseLabels, + KeyValue, >, HeaderValue, >; -#[derive(Debug, Clone)] -pub struct MakeSpanForRequest; +#[derive(Debug, Clone, Default)] +pub struct MakeSpanForRequest { + category: Option<&'static str>, +} impl MakeSpan> for MakeSpanForRequest { fn make_span(&self, request: &Request) -> Span { @@ -58,6 +70,7 @@ impl MakeSpan> for MakeSpanForRequest { .map(tracing::field::display); let content_length = headers.typed_get().map(|ContentLength(len)| len); let net_sock_peer_name = request.uri().host(); + let category = self.category.unwrap_or("UNSET"); tracing::info_span!( "http.client.request", @@ -78,6 +91,7 @@ impl MakeSpan> for MakeSpanForRequest { "net.sock.host.port" = tracing::field::Empty, "user_agent.original" = user_agent, "rust.error" = tracing::field::Empty, + "mas.category" = category, ) } } @@ -123,6 +137,42 @@ where } } +#[derive(Debug, Clone, Default)] +pub struct OnRequestLabels { + category: Option<&'static str>, +} + +impl MetricsAttributes> for OnRequestLabels +where + B: 'static, +{ + type Iter<'a> = std::array::IntoIter; + fn attributes<'a>(&'a self, t: &'a Request) -> Self::Iter<'a> { + [ + KeyValue::new("http.request.method", t.method().as_str().to_owned()), + KeyValue::new("network.protocol.name", "http"), + KeyValue::new("mas.category", self.category.unwrap_or("UNSET")), + ] + .into_iter() + } +} + +#[derive(Debug, Clone, Default)] +pub struct OnResponseLabels; + +impl MetricsAttributes> for OnResponseLabels +where + B: 'static, +{ + type Iter<'a> = std::iter::Once; + fn attributes<'a>(&'a self, t: &'a Response) -> Self::Iter<'a> { + std::iter::once(KeyValue::new( + "http.response.status_code", + i64::from(t.status().as_u16()), + )) + } +} + #[derive(Debug, Clone)] pub struct ClientLayer { user_agent_layer: SetRequestHeaderLayer, @@ -131,6 +181,8 @@ pub struct ClientLayer { trace_layer: TraceLayer, trace_context_layer: TraceContextLayer, timeout_layer: TimeoutLayer, + duration_recorder_layer: DurationRecorderLayer, + in_flight_counter_layer: InFlightCounterLayer, } impl Default for ClientLayer { @@ -142,26 +194,45 @@ impl Default for ClientLayer { impl ClientLayer { #[must_use] pub fn new() -> Self { - let semaphore = Arc::new(Semaphore::new(10)); - Self::with_semaphore(semaphore) - } - - #[must_use] - pub fn with_semaphore(semaphore: Arc) -> Self { Self { user_agent_layer: SetRequestHeaderLayer::overriding( USER_AGENT, HeaderValue::from_static("matrix-authentication-service/0.0.1"), ), - concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore), + concurrency_limit_layer: GlobalConcurrencyLimitLayer::new(10), follow_redirect_layer: FollowRedirectLayer::new(), - trace_layer: TraceLayer::new(MakeSpanForRequest) + trace_layer: TraceLayer::new(MakeSpanForRequest::default()) .on_response(EnrichSpanOnResponse) .on_error(EnrichSpanOnError), trace_context_layer: TraceContextLayer::new(), timeout_layer: TimeoutLayer::new(Duration::from_secs(10)), + duration_recorder_layer: DurationRecorderLayer::new("http.client.duration") + .on_request(OnRequestLabels::default()) + .on_response(OnResponseLabels) + .on_error(KeyValue::new("http.error", true)), + in_flight_counter_layer: InFlightCounterLayer::new("http.client.active_requests") + .on_request(OnRequestLabels::default()), } } + + #[must_use] + pub fn with_category(mut self, category: &'static str) -> Self { + self.trace_layer = TraceLayer::new(MakeSpanForRequest { + category: Some(category), + }) + .on_response(EnrichSpanOnResponse) + .on_error(EnrichSpanOnError); + + self.duration_recorder_layer = self.duration_recorder_layer.on_request(OnRequestLabels { + category: Some(category), + }); + + self.in_flight_counter_layer = self.in_flight_counter_layer.on_request(OnRequestLabels { + category: Some(category), + }); + + self + } } impl Layer for ClientLayer @@ -173,6 +244,8 @@ where fn layer(&self, inner: S) -> Self::Service { ( &self.user_agent_layer, + &self.duration_recorder_layer, + &self.in_flight_counter_layer, &self.concurrency_limit_layer, &self.follow_redirect_layer, &self.trace_layer, diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 83c11c9f..9179e034 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -33,7 +33,7 @@ mod service; #[cfg(feature = "client")] pub use self::{ client::{ - make_traced_client, make_traced_connector, make_untraced_client, ClientInitError, + make_traced_client, make_traced_connector, make_untraced_client, Client, ClientInitError, TracedClient, TracedConnector, UntracedClient, UntracedConnector, }, layers::client::{ClientLayer, ClientService}, diff --git a/crates/matrix-synapse/src/lib.rs b/crates/matrix-synapse/src/lib.rs index e2a62020..0f8fe46a 100644 --- a/crates/matrix-synapse/src/lib.rs +++ b/crates/matrix-synapse/src/lib.rs @@ -154,8 +154,7 @@ impl HomeserverConnection for SynapseConnection { async fn query_user(&self, mxid: &str) -> Result { let mut client = self .http_client_factory - .client() - .await? + .client("homeserver.query_user") .response_body_to_bytes() .json_response(); @@ -218,8 +217,7 @@ impl HomeserverConnection for SynapseConnection { let mut client = self .http_client_factory - .client() - .await? + .client("homeserver.provision_user") .request_bytes_to_body() .json_request(); @@ -255,8 +253,7 @@ impl HomeserverConnection for SynapseConnection { async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { let mut client = self .http_client_factory - .client() - .await? + .client("homeserver.create_device") .request_bytes_to_body() .json_request(); @@ -284,7 +281,7 @@ impl HomeserverConnection for SynapseConnection { err(Display), )] async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { - let mut client = self.http_client_factory.client().await?; + let mut client = self.http_client_factory.client("homeserver.delete_device"); let request = self .delete(&format!( @@ -314,8 +311,7 @@ impl HomeserverConnection for SynapseConnection { async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { let mut client = self .http_client_factory - .client() - .await? + .client("homeserver.delete_user") .request_bytes_to_body() .json_request(); @@ -345,8 +341,7 @@ impl HomeserverConnection for SynapseConnection { async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { let mut client = self .http_client_factory - .client() - .await? + .client("homeserver.set_displayname") .request_bytes_to_body() .json_request(); diff --git a/crates/tower/src/metrics/in_flight.rs b/crates/tower/src/metrics/in_flight.rs index 82b3c90f..2aa1408c 100644 --- a/crates/tower/src/metrics/in_flight.rs +++ b/crates/tower/src/metrics/in_flight.rs @@ -40,7 +40,7 @@ impl InFlightCounterLayer { pub fn new(name: &'static str) -> Self { let counter = crate::meter() .i64_up_down_counter(name) - .with_unit(Unit::new("ms")) + .with_unit(Unit::new("{request}")) .with_description("The number of in-flight requests") .init(); diff --git a/crates/tower/src/metrics/make_attributes.rs b/crates/tower/src/metrics/make_attributes.rs index ede398b1..2df48e0b 100644 --- a/crates/tower/src/metrics/make_attributes.rs +++ b/crates/tower/src/metrics/make_attributes.rs @@ -69,6 +69,17 @@ where } } +impl MetricsAttributes for [V; N] +where + V: MetricsAttributes + 'static, + T: 'static, +{ + type Iter<'a> = Box + 'a>; + fn attributes<'a>(&'a self, t: &'a T) -> Self::Iter<'_> { + Box::new(self.iter().flat_map(|v| v.attributes(t))) + } +} + impl MetricsAttributes for KV where V: Into + Clone + 'static,