diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 6df07674..5b22c169 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -36,4 +36,4 @@ mas-storage = { path = "../storage" } mas-data-model = { path = "../data-model" } mas-jose = { path = "../jose" } mas-iana = { path = "../iana" } -mas-http = { path = "../http" } +mas-http = { path = "../http", features = ["client"] } diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 1eafc8f5..616972ca 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -35,7 +35,7 @@ opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqw mas-config = { path = "../config" } mas-email = { path = "../email" } mas-handlers = { path = "../handlers" } -mas-http = { path = "../http" } +mas-http = { path = "../http", features = ["axum"] } mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-static-files = { path = "../static-files" } diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index ea7bd83c..9cd08163 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -6,14 +6,14 @@ edition = "2021" license = "Apache-2.0" [dependencies] -axum = "0.5.13" +axum = { version = "0.5.13", optional = true } bytes = "1.2.1" futures-util = "0.3.21" headers = "0.3.7" http = "0.2.8" http-body = "0.4.5" hyper = "0.14.20" -hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false } +hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true } once_cell = "1.13.0" opentelemetry = "0.17.0" opentelemetry-http = "0.6.0" @@ -23,14 +23,19 @@ serde = "1.0.142" serde_json = "1.0.83" serde_urlencoded = "0.7.1" thiserror = "1.0.32" -tokio = { version = "1.20.1", features = ["sync", "parking_lot"] } +tokio = { version = "1.20.1", optional = true } tower = { version = "0.4.13", features = ["timeout", "limit"] } -tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] } +tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] } tracing = "0.1.36" tracing-opentelemetry = "0.17.4" [dev-dependencies] anyhow = "1.0.62" serde = { version = "1.0.142", features = ["derive"] } -tokio = { version = "1.20.1", features = ["macros"] } +tokio = { version = "1.20.1", features = ["macros", "rt"] } tower = { version = "0.4.13", features = ["util"] } + +[features] +default = [] +axum = ["dep:axum"] +client = ["dep:hyper-rustls", "hyper/tcp", "tokio", "tokio/sync", "tokio/parking_lot"] diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs new file mode 100644 index 00000000..d181e052 --- /dev/null +++ b/crates/http/src/client.rs @@ -0,0 +1,130 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use bytes::Bytes; +use futures_util::{FutureExt, TryFutureExt}; +use http::{Request, Response}; +use http_body::{combinators::BoxBody, Body}; +use hyper::{ + client::{connect::dns::GaiResolver, HttpConnector}, + Client, +}; +use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; +use thiserror::Error; +use tokio::{sync::OnceCell, task::JoinError}; +use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt}; + +use crate::{ + layers::{ + client::{ClientLayer, ClientResponse}, + otel::{TraceDns, TraceLayer}, + }, + BoxError, FutureService, +}; + +/// A wrapper over a boxed error that implements ``std::error::Error``. +/// This is helps converting to ``anyhow::Error`` with the `?` operator +#[derive(Error, Debug)] +pub enum ClientError { + #[error("failed to initialize HTTPS client")] + Init(#[from] ClientInitError), + + #[error(transparent)] + Call(#[from] BoxError), +} + +#[derive(Error, Debug, Clone)] +pub enum ClientInitError { + #[error("failed to load system certificates")] + CertificateLoad { + #[from] + inner: Arc, // That error is in an Arc to have the error implement Clone + }, +} + +static TLS_CONFIG: OnceCell = OnceCell::const_new(); + +async fn make_base_client( +) -> Result>>, B>, ClientInitError> +where + B: http_body::Body + Send + 'static, + E: Into, +{ + // Trace DNS requests + let resolver = ServiceBuilder::new() + .layer(TraceLayer::dns()) + .service(GaiResolver::new()); + + let mut http = HttpConnector::new_with_resolver(resolver); + http.enforce_http(false); + + let tls_config = TLS_CONFIG + .get_or_try_init(|| async move { + // Load the TLS config once in a blocking task because loading the system + // certificates can take a long time (~200ms) on macOS + let span = tracing::info_span!("load_certificates"); + tokio::task::spawn_blocking(|| { + let _span = span.entered(); + rustls::ClientConfig::builder() + .with_safe_defaults() + .with_native_roots() + .with_no_client_auth() + }) + .await + }) + .await + .map_err(|e| ClientInitError::from(Arc::new(e)))?; + + let https = HttpsConnectorBuilder::new() + .with_tls_config(tls_config.clone()) + .https_or_http() + .enable_http1() + .enable_http2() + .wrap_connector(http); + + // TODO: we should get the remote address here + let client = Client::builder().build(https); + + Ok::<_, ClientInitError>(client) +} + +#[must_use] +pub fn client( + operation: &'static str, +) -> BoxCloneService, Response>, ClientError> +where + B: http_body::Body + Default + Send + 'static, + E: Into + 'static, +{ + let fut = make_base_client() + // Map the error to a ClientError + .map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e)))) + // Wrap it in an Shared (Arc) to be able to Clone it + .shared(); + + let client: FutureService<_, _> = FutureService::new(fut); + + let client = ServiceBuilder::new() + // Convert the errors to ClientError to help dealing with them + .map_err(ClientError::from) + .map_response(|r: ClientResponse| { + r.map(|body| body.map_err(ClientError::from).boxed()) + }) + .layer(ClientLayer::new(operation)) + .service(client); + + client.boxed_clone() +} diff --git a/crates/http/src/layers/otel/make_span_builder.rs b/crates/http/src/layers/otel/make_span_builder.rs index 7f772806..12ae324c 100644 --- a/crates/http/src/layers/otel/make_span_builder.rs +++ b/crates/http/src/layers/otel/make_span_builder.rs @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{borrow::Cow, net::SocketAddr}; +use std::borrow::Cow; +#[cfg(feature = "axum")] use axum::extract::{ConnectInfo, MatchedPath}; use headers::{ContentLength, HeaderMapExt, Host, UserAgent}; use http::{Method, Request, Version}; +#[cfg(feature = "client")] use hyper::client::connect::dns::Name; use opentelemetry::trace::{SpanBuilder, SpanKind}; -use opentelemetry_semantic_conventions::trace::{ - HTTP_FLAVOR, HTTP_HOST, HTTP_METHOD, HTTP_REQUEST_CONTENT_LENGTH, HTTP_ROUTE, HTTP_TARGET, - HTTP_USER_AGENT, NET_HOST_NAME, NET_PEER_IP, NET_PEER_PORT, NET_TRANSPORT, -}; +use opentelemetry_semantic_conventions::trace as SC; pub trait MakeSpanBuilder { fn make_span_builder(&self, request: &R) -> SpanBuilder; @@ -117,24 +116,24 @@ impl SpanFromHttpRequest { impl MakeSpanBuilder> for SpanFromHttpRequest { fn make_span_builder(&self, request: &Request) -> SpanBuilder { let mut attributes = vec![ - HTTP_METHOD.string(http_method_str(request.method())), - HTTP_FLAVOR.string(http_flavor(request.version())), - HTTP_TARGET.string(request.uri().to_string()), + SC::HTTP_METHOD.string(http_method_str(request.method())), + SC::HTTP_FLAVOR.string(http_flavor(request.version())), + SC::HTTP_TARGET.string(request.uri().to_string()), ]; let headers = request.headers(); if let Some(host) = headers.typed_get::() { - attributes.push(HTTP_HOST.string(host.to_string())); + attributes.push(SC::HTTP_HOST.string(host.to_string())); } if let Some(user_agent) = headers.typed_get::() { - attributes.push(HTTP_USER_AGENT.string(user_agent.to_string())); + attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string())); } if let Some(ContentLength(content_length)) = headers.typed_get() { if let Ok(content_length) = content_length.try_into() { - attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length)); + attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length)); } } @@ -144,42 +143,47 @@ impl MakeSpanBuilder> for SpanFromHttpRequest { } } +#[cfg(feature = "axum")] #[derive(Debug, Clone)] pub struct SpanFromAxumRequest; +#[cfg(feature = "axum")] impl MakeSpanBuilder> for SpanFromAxumRequest { fn make_span_builder(&self, request: &Request) -> SpanBuilder { let mut attributes = vec![ - HTTP_METHOD.string(http_method_str(request.method())), - HTTP_FLAVOR.string(http_flavor(request.version())), - HTTP_TARGET.string(request.uri().to_string()), + SC::HTTP_METHOD.string(http_method_str(request.method())), + SC::HTTP_FLAVOR.string(http_flavor(request.version())), + SC::HTTP_TARGET.string(request.uri().to_string()), ]; let headers = request.headers(); if let Some(host) = headers.typed_get::() { - attributes.push(HTTP_HOST.string(host.to_string())); + attributes.push(SC::HTTP_HOST.string(host.to_string())); } if let Some(user_agent) = headers.typed_get::() { - attributes.push(HTTP_USER_AGENT.string(user_agent.to_string())); + attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string())); } if let Some(ContentLength(content_length)) = headers.typed_get() { if let Ok(content_length) = content_length.try_into() { - attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length)); + attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length)); } } - if let Some(ConnectInfo(addr)) = request.extensions().get::>() { - attributes.push(NET_TRANSPORT.string("ip_tcp")); - attributes.push(NET_PEER_IP.string(addr.ip().to_string())); - attributes.push(NET_PEER_PORT.i64(addr.port().into())); + if let Some(ConnectInfo(addr)) = request + .extensions() + .get::>() + { + attributes.push(SC::NET_TRANSPORT.string("ip_tcp")); + attributes.push(SC::NET_PEER_IP.string(addr.ip().to_string())); + attributes.push(SC::NET_PEER_PORT.i64(addr.port().into())); } let name = if let Some(path) = request.extensions().get::() { let path = path.as_str().to_owned(); - attributes.push(HTTP_ROUTE.string(path.clone())); + attributes.push(SC::HTTP_ROUTE.string(path.clone())); path } else { request.uri().path().to_owned() @@ -191,12 +195,14 @@ impl MakeSpanBuilder> for SpanFromAxumRequest { } } +#[cfg(feature = "client")] #[derive(Debug, Clone, Copy, Default)] pub struct SpanFromDnsRequest; +#[cfg(feature = "client")] impl MakeSpanBuilder for SpanFromDnsRequest { fn make_span_builder(&self, request: &Name) -> SpanBuilder { - let attributes = vec![NET_HOST_NAME.string(request.as_str().to_owned())]; + let attributes = vec![SC::NET_HOST_NAME.string(request.as_str().to_owned())]; SpanBuilder::from_name("resolve") .with_kind(SpanKind::Client) diff --git a/crates/http/src/layers/otel/mod.rs b/crates/http/src/layers/otel/mod.rs index fdf556bc..56c8b66f 100644 --- a/crates/http/src/layers/otel/mod.rs +++ b/crates/http/src/layers/otel/mod.rs @@ -37,6 +37,7 @@ pub type TraceHttpServer = Trace< S, >; +#[cfg(feature = "axum")] pub type TraceAxumServerLayer = TraceLayer< ExtractFromHttpRequest, DefaultInjectContext, @@ -45,6 +46,7 @@ pub type TraceAxumServerLayer = TraceLayer< DefaultOnError, >; +#[cfg(feature = "axum")] pub type TraceAxumServer = Trace< ExtractFromHttpRequest, DefaultInjectContext, @@ -71,6 +73,7 @@ pub type TraceHttpClient = Trace< S, >; +#[cfg(feature = "client")] pub type TraceDnsLayer = TraceLayer< DefaultExtractContext, DefaultInjectContext, @@ -79,6 +82,7 @@ pub type TraceDnsLayer = TraceLayer< DefaultOnError, >; +#[cfg(feature = "client")] pub type TraceDns = Trace< DefaultExtractContext, DefaultInjectContext, @@ -98,6 +102,7 @@ impl TraceHttpServerLayer { } } +#[cfg(feature = "axum")] impl TraceAxumServerLayer { #[must_use] pub fn axum() -> Self { @@ -126,6 +131,7 @@ impl TraceHttpClientLayer { } } +#[cfg(feature = "client")] impl TraceDnsLayer { #[must_use] pub fn dns() -> Self { diff --git a/crates/http/src/layers/otel/on_response.rs b/crates/http/src/layers/otel/on_response.rs index 5ebf8203..7fd5c5f8 100644 --- a/crates/http/src/layers/otel/on_response.rs +++ b/crates/http/src/layers/otel/on_response.rs @@ -14,12 +14,10 @@ use headers::{ContentLength, HeaderMapExt}; use http::Response; +#[cfg(feature = "client")] use hyper::client::connect::HttpInfo; use opentelemetry::trace::SpanRef; -use opentelemetry_semantic_conventions::trace::{ - HTTP_RESPONSE_CONTENT_LENGTH, HTTP_STATUS_CODE, NET_HOST_IP, NET_HOST_PORT, NET_PEER_IP, - NET_PEER_PORT, -}; +use opentelemetry_semantic_conventions::trace as SC; pub trait OnResponse { fn on_response(&self, span: &SpanRef<'_>, response: &R); @@ -37,21 +35,22 @@ pub struct OnHttpResponse; impl OnResponse> for OnHttpResponse { fn on_response(&self, span: &SpanRef<'_>, response: &Response) { - span.set_attribute(HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16()))); + span.set_attribute(SC::HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16()))); if let Some(ContentLength(content_length)) = response.headers().typed_get() { if let Ok(content_length) = content_length.try_into() { - span.set_attribute(HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length)); + span.set_attribute(SC::HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length)); } } + #[cfg(feature = "client")] // Get local and remote address from hyper's HttpInfo injected by the // HttpConnector if let Some(info) = response.extensions().get::() { - span.set_attribute(NET_PEER_IP.string(info.remote_addr().ip().to_string())); - span.set_attribute(NET_PEER_PORT.i64(info.remote_addr().port().into())); - span.set_attribute(NET_HOST_IP.string(info.local_addr().ip().to_string())); - span.set_attribute(NET_HOST_PORT.i64(info.local_addr().port().into())); + span.set_attribute(SC::NET_PEER_IP.string(info.remote_addr().ip().to_string())); + span.set_attribute(SC::NET_PEER_PORT.i64(info.remote_addr().port().into())); + span.set_attribute(SC::NET_HOST_IP.string(info.local_addr().ip().to_string())); + span.set_attribute(SC::NET_HOST_PORT.i64(info.local_addr().port().into())); } } } diff --git a/crates/http/src/layers/server.rs b/crates/http/src/layers/server.rs index c0a344bd..636cbb1e 100644 --- a/crates/http/src/layers/server.rs +++ b/crates/http/src/layers/server.rs @@ -37,10 +37,14 @@ where type Service = BoxCloneService, Response>, S::Error>; fn layer(&self, inner: S) -> Self::Service { - ServiceBuilder::new() - .compression() - .layer(TraceLayer::axum()) - .service(inner) - .boxed_clone() + let builder = ServiceBuilder::new().compression(); + + #[cfg(feature = "axum")] + let builder = builder.layer(TraceLayer::axum()); + + #[cfg(not(feature = "axum"))] + let builder = builder.layer(TraceLayer::http_server()); + + builder.service(inner).boxed_clone() } } diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index c737ac07..f0e3a7cb 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -24,30 +24,14 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -use std::sync::Arc; - -use bytes::Bytes; -use futures_util::{FutureExt, TryFutureExt}; -use http::{Request, Response}; -use http_body::{combinators::BoxBody, Body}; -use hyper::{ - client::{connect::dns::GaiResolver, HttpConnector}, - Client, -}; -use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; -use thiserror::Error; -use tokio::{sync::OnceCell, task::JoinError}; -use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt}; - -use self::layers::{ - client::ClientResponse, - otel::{TraceDns, TraceLayer}, -}; - +#[cfg(feature = "client")] +mod client; mod ext; mod future_service; mod layers; +#[cfg(feature = "client")] +pub use self::client::client; pub use self::{ ext::{ set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt, @@ -67,97 +51,3 @@ pub use self::{ }; pub(crate) type BoxError = Box; - -/// A wrapper over a boxed error that implements ``std::error::Error``. -/// This is helps converting to ``anyhow::Error`` with the `?` operator -#[derive(Error, Debug)] -pub enum ClientError { - #[error("failed to initialize HTTPS client")] - Init(#[from] ClientInitError), - - #[error(transparent)] - Call(#[from] BoxError), -} - -#[derive(Error, Debug, Clone)] -pub enum ClientInitError { - #[error("failed to load system certificates")] - CertificateLoad { - #[from] - inner: Arc, // That error is in an Arc to have the error implement Clone - }, -} - -static TLS_CONFIG: OnceCell = OnceCell::const_new(); - -async fn make_base_client( -) -> Result>>, B>, ClientInitError> -where - B: http_body::Body + Send + 'static, - E: Into, -{ - // Trace DNS requests - let resolver = ServiceBuilder::new() - .layer(TraceLayer::dns()) - .service(GaiResolver::new()); - - let mut http = HttpConnector::new_with_resolver(resolver); - http.enforce_http(false); - - let tls_config = TLS_CONFIG - .get_or_try_init(|| async move { - // Load the TLS config once in a blocking task because loading the system - // certificates can take a long time (~200ms) on macOS - let span = tracing::info_span!("load_certificates"); - tokio::task::spawn_blocking(|| { - let _span = span.entered(); - rustls::ClientConfig::builder() - .with_safe_defaults() - .with_native_roots() - .with_no_client_auth() - }) - .await - }) - .await - .map_err(|e| ClientInitError::from(Arc::new(e)))?; - - let https = HttpsConnectorBuilder::new() - .with_tls_config(tls_config.clone()) - .https_or_http() - .enable_http1() - .enable_http2() - .wrap_connector(http); - - // TODO: we should get the remote address here - let client = Client::builder().build(https); - - Ok::<_, ClientInitError>(client) -} - -#[must_use] -pub fn client( - operation: &'static str, -) -> BoxCloneService, Response>, ClientError> -where - B: http_body::Body + Default + Send + 'static, - E: Into + 'static, -{ - let fut = make_base_client() - // Map the error to a ClientError - .map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e)))) - // Wrap it in an Shared (Arc) to be able to Clone it - .shared(); - - let client: FutureService<_, _> = FutureService::new(fut); - - let client = ServiceBuilder::new() - // Convert the errors to ClientError to help dealing with them - .map_err(ClientError::from) - .map_response(|r: ClientResponse| { - r.map(|body| body.map_err(ClientError::from).boxed()) - }) - .layer(ClientLayer::new(operation)) - .service(client); - - client.boxed_clone() -}