You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-06 06:02:40 +03:00
Gate some crates behind features in mas-http
This commit is contained in:
@@ -36,4 +36,4 @@ mas-storage = { path = "../storage" }
|
|||||||
mas-data-model = { path = "../data-model" }
|
mas-data-model = { path = "../data-model" }
|
||||||
mas-jose = { path = "../jose" }
|
mas-jose = { path = "../jose" }
|
||||||
mas-iana = { path = "../iana" }
|
mas-iana = { path = "../iana" }
|
||||||
mas-http = { path = "../http" }
|
mas-http = { path = "../http", features = ["client"] }
|
||||||
|
@@ -35,7 +35,7 @@ opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqw
|
|||||||
mas-config = { path = "../config" }
|
mas-config = { path = "../config" }
|
||||||
mas-email = { path = "../email" }
|
mas-email = { path = "../email" }
|
||||||
mas-handlers = { path = "../handlers" }
|
mas-handlers = { path = "../handlers" }
|
||||||
mas-http = { path = "../http" }
|
mas-http = { path = "../http", features = ["axum"] }
|
||||||
mas-policy = { path = "../policy" }
|
mas-policy = { path = "../policy" }
|
||||||
mas-router = { path = "../router" }
|
mas-router = { path = "../router" }
|
||||||
mas-static-files = { path = "../static-files" }
|
mas-static-files = { path = "../static-files" }
|
||||||
|
@@ -6,14 +6,14 @@ edition = "2021"
|
|||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = "0.5.13"
|
axum = { version = "0.5.13", optional = true }
|
||||||
bytes = "1.2.1"
|
bytes = "1.2.1"
|
||||||
futures-util = "0.3.21"
|
futures-util = "0.3.21"
|
||||||
headers = "0.3.7"
|
headers = "0.3.7"
|
||||||
http = "0.2.8"
|
http = "0.2.8"
|
||||||
http-body = "0.4.5"
|
http-body = "0.4.5"
|
||||||
hyper = "0.14.20"
|
hyper = "0.14.20"
|
||||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false }
|
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true }
|
||||||
once_cell = "1.13.0"
|
once_cell = "1.13.0"
|
||||||
opentelemetry = "0.17.0"
|
opentelemetry = "0.17.0"
|
||||||
opentelemetry-http = "0.6.0"
|
opentelemetry-http = "0.6.0"
|
||||||
@@ -23,14 +23,19 @@ serde = "1.0.142"
|
|||||||
serde_json = "1.0.83"
|
serde_json = "1.0.83"
|
||||||
serde_urlencoded = "0.7.1"
|
serde_urlencoded = "0.7.1"
|
||||||
thiserror = "1.0.32"
|
thiserror = "1.0.32"
|
||||||
tokio = { version = "1.20.1", features = ["sync", "parking_lot"] }
|
tokio = { version = "1.20.1", optional = true }
|
||||||
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
||||||
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
|
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-opentelemetry = "0.17.4"
|
tracing-opentelemetry = "0.17.4"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1.0.62"
|
anyhow = "1.0.62"
|
||||||
serde = { version = "1.0.142", features = ["derive"] }
|
serde = { version = "1.0.142", features = ["derive"] }
|
||||||
tokio = { version = "1.20.1", features = ["macros"] }
|
tokio = { version = "1.20.1", features = ["macros", "rt"] }
|
||||||
tower = { version = "0.4.13", features = ["util"] }
|
tower = { version = "0.4.13", features = ["util"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
axum = ["dep:axum"]
|
||||||
|
client = ["dep:hyper-rustls", "hyper/tcp", "tokio", "tokio/sync", "tokio/parking_lot"]
|
||||||
|
130
crates/http/src/client.rs
Normal file
130
crates/http/src/client.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_util::{FutureExt, TryFutureExt};
|
||||||
|
use http::{Request, Response};
|
||||||
|
use http_body::{combinators::BoxBody, Body};
|
||||||
|
use hyper::{
|
||||||
|
client::{connect::dns::GaiResolver, HttpConnector},
|
||||||
|
Client,
|
||||||
|
};
|
||||||
|
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::{sync::OnceCell, task::JoinError};
|
||||||
|
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
layers::{
|
||||||
|
client::{ClientLayer, ClientResponse},
|
||||||
|
otel::{TraceDns, TraceLayer},
|
||||||
|
},
|
||||||
|
BoxError, FutureService,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
||||||
|
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("failed to initialize HTTPS client")]
|
||||||
|
Init(#[from] ClientInitError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Call(#[from] BoxError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug, Clone)]
|
||||||
|
pub enum ClientInitError {
|
||||||
|
#[error("failed to load system certificates")]
|
||||||
|
CertificateLoad {
|
||||||
|
#[from]
|
||||||
|
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
||||||
|
|
||||||
|
async fn make_base_client<B, E>(
|
||||||
|
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||||
|
where
|
||||||
|
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||||
|
E: Into<BoxError>,
|
||||||
|
{
|
||||||
|
// Trace DNS requests
|
||||||
|
let resolver = ServiceBuilder::new()
|
||||||
|
.layer(TraceLayer::dns())
|
||||||
|
.service(GaiResolver::new());
|
||||||
|
|
||||||
|
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||||
|
http.enforce_http(false);
|
||||||
|
|
||||||
|
let tls_config = TLS_CONFIG
|
||||||
|
.get_or_try_init(|| async move {
|
||||||
|
// Load the TLS config once in a blocking task because loading the system
|
||||||
|
// certificates can take a long time (~200ms) on macOS
|
||||||
|
let span = tracing::info_span!("load_certificates");
|
||||||
|
tokio::task::spawn_blocking(|| {
|
||||||
|
let _span = span.entered();
|
||||||
|
rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_native_roots()
|
||||||
|
.with_no_client_auth()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
||||||
|
|
||||||
|
let https = HttpsConnectorBuilder::new()
|
||||||
|
.with_tls_config(tls_config.clone())
|
||||||
|
.https_or_http()
|
||||||
|
.enable_http1()
|
||||||
|
.enable_http2()
|
||||||
|
.wrap_connector(http);
|
||||||
|
|
||||||
|
// TODO: we should get the remote address here
|
||||||
|
let client = Client::builder().build(https);
|
||||||
|
|
||||||
|
Ok::<_, ClientInitError>(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn client<B, E>(
|
||||||
|
operation: &'static str,
|
||||||
|
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
||||||
|
where
|
||||||
|
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||||
|
E: Into<BoxError> + 'static,
|
||||||
|
{
|
||||||
|
let fut = make_base_client()
|
||||||
|
// Map the error to a ClientError
|
||||||
|
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
||||||
|
// Wrap it in an Shared (Arc) to be able to Clone it
|
||||||
|
.shared();
|
||||||
|
|
||||||
|
let client: FutureService<_, _> = FutureService::new(fut);
|
||||||
|
|
||||||
|
let client = ServiceBuilder::new()
|
||||||
|
// Convert the errors to ClientError to help dealing with them
|
||||||
|
.map_err(ClientError::from)
|
||||||
|
.map_response(|r: ClientResponse<hyper::Body>| {
|
||||||
|
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||||
|
})
|
||||||
|
.layer(ClientLayer::new(operation))
|
||||||
|
.service(client);
|
||||||
|
|
||||||
|
client.boxed_clone()
|
||||||
|
}
|
@@ -12,17 +12,16 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{borrow::Cow, net::SocketAddr};
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
#[cfg(feature = "axum")]
|
||||||
use axum::extract::{ConnectInfo, MatchedPath};
|
use axum::extract::{ConnectInfo, MatchedPath};
|
||||||
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
|
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
|
||||||
use http::{Method, Request, Version};
|
use http::{Method, Request, Version};
|
||||||
|
#[cfg(feature = "client")]
|
||||||
use hyper::client::connect::dns::Name;
|
use hyper::client::connect::dns::Name;
|
||||||
use opentelemetry::trace::{SpanBuilder, SpanKind};
|
use opentelemetry::trace::{SpanBuilder, SpanKind};
|
||||||
use opentelemetry_semantic_conventions::trace::{
|
use opentelemetry_semantic_conventions::trace as SC;
|
||||||
HTTP_FLAVOR, HTTP_HOST, HTTP_METHOD, HTTP_REQUEST_CONTENT_LENGTH, HTTP_ROUTE, HTTP_TARGET,
|
|
||||||
HTTP_USER_AGENT, NET_HOST_NAME, NET_PEER_IP, NET_PEER_PORT, NET_TRANSPORT,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub trait MakeSpanBuilder<R> {
|
pub trait MakeSpanBuilder<R> {
|
||||||
fn make_span_builder(&self, request: &R) -> SpanBuilder;
|
fn make_span_builder(&self, request: &R) -> SpanBuilder;
|
||||||
@@ -117,24 +116,24 @@ impl SpanFromHttpRequest {
|
|||||||
impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
|
impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
|
||||||
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
|
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
|
||||||
let mut attributes = vec![
|
let mut attributes = vec![
|
||||||
HTTP_METHOD.string(http_method_str(request.method())),
|
SC::HTTP_METHOD.string(http_method_str(request.method())),
|
||||||
HTTP_FLAVOR.string(http_flavor(request.version())),
|
SC::HTTP_FLAVOR.string(http_flavor(request.version())),
|
||||||
HTTP_TARGET.string(request.uri().to_string()),
|
SC::HTTP_TARGET.string(request.uri().to_string()),
|
||||||
];
|
];
|
||||||
|
|
||||||
let headers = request.headers();
|
let headers = request.headers();
|
||||||
|
|
||||||
if let Some(host) = headers.typed_get::<Host>() {
|
if let Some(host) = headers.typed_get::<Host>() {
|
||||||
attributes.push(HTTP_HOST.string(host.to_string()));
|
attributes.push(SC::HTTP_HOST.string(host.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
|
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
|
||||||
attributes.push(HTTP_USER_AGENT.string(user_agent.to_string()));
|
attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ContentLength(content_length)) = headers.typed_get() {
|
if let Some(ContentLength(content_length)) = headers.typed_get() {
|
||||||
if let Ok(content_length) = content_length.try_into() {
|
if let Ok(content_length) = content_length.try_into() {
|
||||||
attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,42 +143,47 @@ impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "axum")]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SpanFromAxumRequest;
|
pub struct SpanFromAxumRequest;
|
||||||
|
|
||||||
|
#[cfg(feature = "axum")]
|
||||||
impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
|
impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
|
||||||
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
|
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
|
||||||
let mut attributes = vec![
|
let mut attributes = vec![
|
||||||
HTTP_METHOD.string(http_method_str(request.method())),
|
SC::HTTP_METHOD.string(http_method_str(request.method())),
|
||||||
HTTP_FLAVOR.string(http_flavor(request.version())),
|
SC::HTTP_FLAVOR.string(http_flavor(request.version())),
|
||||||
HTTP_TARGET.string(request.uri().to_string()),
|
SC::HTTP_TARGET.string(request.uri().to_string()),
|
||||||
];
|
];
|
||||||
|
|
||||||
let headers = request.headers();
|
let headers = request.headers();
|
||||||
|
|
||||||
if let Some(host) = headers.typed_get::<Host>() {
|
if let Some(host) = headers.typed_get::<Host>() {
|
||||||
attributes.push(HTTP_HOST.string(host.to_string()));
|
attributes.push(SC::HTTP_HOST.string(host.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
|
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
|
||||||
attributes.push(HTTP_USER_AGENT.string(user_agent.to_string()));
|
attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ContentLength(content_length)) = headers.typed_get() {
|
if let Some(ContentLength(content_length)) = headers.typed_get() {
|
||||||
if let Ok(content_length) = content_length.try_into() {
|
if let Ok(content_length) = content_length.try_into() {
|
||||||
attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
|
if let Some(ConnectInfo(addr)) = request
|
||||||
attributes.push(NET_TRANSPORT.string("ip_tcp"));
|
.extensions()
|
||||||
attributes.push(NET_PEER_IP.string(addr.ip().to_string()));
|
.get::<ConnectInfo<std::net::SocketAddr>>()
|
||||||
attributes.push(NET_PEER_PORT.i64(addr.port().into()));
|
{
|
||||||
|
attributes.push(SC::NET_TRANSPORT.string("ip_tcp"));
|
||||||
|
attributes.push(SC::NET_PEER_IP.string(addr.ip().to_string()));
|
||||||
|
attributes.push(SC::NET_PEER_PORT.i64(addr.port().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let name = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
let name = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
||||||
let path = path.as_str().to_owned();
|
let path = path.as_str().to_owned();
|
||||||
attributes.push(HTTP_ROUTE.string(path.clone()));
|
attributes.push(SC::HTTP_ROUTE.string(path.clone()));
|
||||||
path
|
path
|
||||||
} else {
|
} else {
|
||||||
request.uri().path().to_owned()
|
request.uri().path().to_owned()
|
||||||
@@ -191,12 +195,14 @@ impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
pub struct SpanFromDnsRequest;
|
pub struct SpanFromDnsRequest;
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
impl MakeSpanBuilder<Name> for SpanFromDnsRequest {
|
impl MakeSpanBuilder<Name> for SpanFromDnsRequest {
|
||||||
fn make_span_builder(&self, request: &Name) -> SpanBuilder {
|
fn make_span_builder(&self, request: &Name) -> SpanBuilder {
|
||||||
let attributes = vec![NET_HOST_NAME.string(request.as_str().to_owned())];
|
let attributes = vec![SC::NET_HOST_NAME.string(request.as_str().to_owned())];
|
||||||
|
|
||||||
SpanBuilder::from_name("resolve")
|
SpanBuilder::from_name("resolve")
|
||||||
.with_kind(SpanKind::Client)
|
.with_kind(SpanKind::Client)
|
||||||
|
@@ -37,6 +37,7 @@ pub type TraceHttpServer<S> = Trace<
|
|||||||
S,
|
S,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
#[cfg(feature = "axum")]
|
||||||
pub type TraceAxumServerLayer = TraceLayer<
|
pub type TraceAxumServerLayer = TraceLayer<
|
||||||
ExtractFromHttpRequest,
|
ExtractFromHttpRequest,
|
||||||
DefaultInjectContext,
|
DefaultInjectContext,
|
||||||
@@ -45,6 +46,7 @@ pub type TraceAxumServerLayer = TraceLayer<
|
|||||||
DefaultOnError,
|
DefaultOnError,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
#[cfg(feature = "axum")]
|
||||||
pub type TraceAxumServer<S> = Trace<
|
pub type TraceAxumServer<S> = Trace<
|
||||||
ExtractFromHttpRequest,
|
ExtractFromHttpRequest,
|
||||||
DefaultInjectContext,
|
DefaultInjectContext,
|
||||||
@@ -71,6 +73,7 @@ pub type TraceHttpClient<S> = Trace<
|
|||||||
S,
|
S,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
pub type TraceDnsLayer = TraceLayer<
|
pub type TraceDnsLayer = TraceLayer<
|
||||||
DefaultExtractContext,
|
DefaultExtractContext,
|
||||||
DefaultInjectContext,
|
DefaultInjectContext,
|
||||||
@@ -79,6 +82,7 @@ pub type TraceDnsLayer = TraceLayer<
|
|||||||
DefaultOnError,
|
DefaultOnError,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
pub type TraceDns<S> = Trace<
|
pub type TraceDns<S> = Trace<
|
||||||
DefaultExtractContext,
|
DefaultExtractContext,
|
||||||
DefaultInjectContext,
|
DefaultInjectContext,
|
||||||
@@ -98,6 +102,7 @@ impl TraceHttpServerLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "axum")]
|
||||||
impl TraceAxumServerLayer {
|
impl TraceAxumServerLayer {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn axum() -> Self {
|
pub fn axum() -> Self {
|
||||||
@@ -126,6 +131,7 @@ impl TraceHttpClientLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
impl TraceDnsLayer {
|
impl TraceDnsLayer {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn dns() -> Self {
|
pub fn dns() -> Self {
|
||||||
|
@@ -14,12 +14,10 @@
|
|||||||
|
|
||||||
use headers::{ContentLength, HeaderMapExt};
|
use headers::{ContentLength, HeaderMapExt};
|
||||||
use http::Response;
|
use http::Response;
|
||||||
|
#[cfg(feature = "client")]
|
||||||
use hyper::client::connect::HttpInfo;
|
use hyper::client::connect::HttpInfo;
|
||||||
use opentelemetry::trace::SpanRef;
|
use opentelemetry::trace::SpanRef;
|
||||||
use opentelemetry_semantic_conventions::trace::{
|
use opentelemetry_semantic_conventions::trace as SC;
|
||||||
HTTP_RESPONSE_CONTENT_LENGTH, HTTP_STATUS_CODE, NET_HOST_IP, NET_HOST_PORT, NET_PEER_IP,
|
|
||||||
NET_PEER_PORT,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub trait OnResponse<R> {
|
pub trait OnResponse<R> {
|
||||||
fn on_response(&self, span: &SpanRef<'_>, response: &R);
|
fn on_response(&self, span: &SpanRef<'_>, response: &R);
|
||||||
@@ -37,21 +35,22 @@ pub struct OnHttpResponse;
|
|||||||
|
|
||||||
impl<B> OnResponse<Response<B>> for OnHttpResponse {
|
impl<B> OnResponse<Response<B>> for OnHttpResponse {
|
||||||
fn on_response(&self, span: &SpanRef<'_>, response: &Response<B>) {
|
fn on_response(&self, span: &SpanRef<'_>, response: &Response<B>) {
|
||||||
span.set_attribute(HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16())));
|
span.set_attribute(SC::HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16())));
|
||||||
|
|
||||||
if let Some(ContentLength(content_length)) = response.headers().typed_get() {
|
if let Some(ContentLength(content_length)) = response.headers().typed_get() {
|
||||||
if let Ok(content_length) = content_length.try_into() {
|
if let Ok(content_length) = content_length.try_into() {
|
||||||
span.set_attribute(HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length));
|
span.set_attribute(SC::HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
// Get local and remote address from hyper's HttpInfo injected by the
|
// Get local and remote address from hyper's HttpInfo injected by the
|
||||||
// HttpConnector
|
// HttpConnector
|
||||||
if let Some(info) = response.extensions().get::<HttpInfo>() {
|
if let Some(info) = response.extensions().get::<HttpInfo>() {
|
||||||
span.set_attribute(NET_PEER_IP.string(info.remote_addr().ip().to_string()));
|
span.set_attribute(SC::NET_PEER_IP.string(info.remote_addr().ip().to_string()));
|
||||||
span.set_attribute(NET_PEER_PORT.i64(info.remote_addr().port().into()));
|
span.set_attribute(SC::NET_PEER_PORT.i64(info.remote_addr().port().into()));
|
||||||
span.set_attribute(NET_HOST_IP.string(info.local_addr().ip().to_string()));
|
span.set_attribute(SC::NET_HOST_IP.string(info.local_addr().ip().to_string()));
|
||||||
span.set_attribute(NET_HOST_PORT.i64(info.local_addr().port().into()));
|
span.set_attribute(SC::NET_HOST_PORT.i64(info.local_addr().port().into()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -37,10 +37,14 @@ where
|
|||||||
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
|
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
|
||||||
|
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
ServiceBuilder::new()
|
let builder = ServiceBuilder::new().compression();
|
||||||
.compression()
|
|
||||||
.layer(TraceLayer::axum())
|
#[cfg(feature = "axum")]
|
||||||
.service(inner)
|
let builder = builder.layer(TraceLayer::axum());
|
||||||
.boxed_clone()
|
|
||||||
|
#[cfg(not(feature = "axum"))]
|
||||||
|
let builder = builder.layer(TraceLayer::http_server());
|
||||||
|
|
||||||
|
builder.service(inner).boxed_clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -24,30 +24,14 @@
|
|||||||
#![warn(clippy::pedantic)]
|
#![warn(clippy::pedantic)]
|
||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
use std::sync::Arc;
|
#[cfg(feature = "client")]
|
||||||
|
mod client;
|
||||||
use bytes::Bytes;
|
|
||||||
use futures_util::{FutureExt, TryFutureExt};
|
|
||||||
use http::{Request, Response};
|
|
||||||
use http_body::{combinators::BoxBody, Body};
|
|
||||||
use hyper::{
|
|
||||||
client::{connect::dns::GaiResolver, HttpConnector},
|
|
||||||
Client,
|
|
||||||
};
|
|
||||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::{sync::OnceCell, task::JoinError};
|
|
||||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
|
||||||
|
|
||||||
use self::layers::{
|
|
||||||
client::ClientResponse,
|
|
||||||
otel::{TraceDns, TraceLayer},
|
|
||||||
};
|
|
||||||
|
|
||||||
mod ext;
|
mod ext;
|
||||||
mod future_service;
|
mod future_service;
|
||||||
mod layers;
|
mod layers;
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
pub use self::client::client;
|
||||||
pub use self::{
|
pub use self::{
|
||||||
ext::{
|
ext::{
|
||||||
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
||||||
@@ -67,97 +51,3 @@ pub use self::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||||
|
|
||||||
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
|
||||||
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum ClientError {
|
|
||||||
#[error("failed to initialize HTTPS client")]
|
|
||||||
Init(#[from] ClientInitError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Call(#[from] BoxError),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
|
||||||
pub enum ClientInitError {
|
|
||||||
#[error("failed to load system certificates")]
|
|
||||||
CertificateLoad {
|
|
||||||
#[from]
|
|
||||||
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
|
||||||
|
|
||||||
async fn make_base_client<B, E>(
|
|
||||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
|
||||||
where
|
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
|
||||||
E: Into<BoxError>,
|
|
||||||
{
|
|
||||||
// Trace DNS requests
|
|
||||||
let resolver = ServiceBuilder::new()
|
|
||||||
.layer(TraceLayer::dns())
|
|
||||||
.service(GaiResolver::new());
|
|
||||||
|
|
||||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
|
||||||
http.enforce_http(false);
|
|
||||||
|
|
||||||
let tls_config = TLS_CONFIG
|
|
||||||
.get_or_try_init(|| async move {
|
|
||||||
// Load the TLS config once in a blocking task because loading the system
|
|
||||||
// certificates can take a long time (~200ms) on macOS
|
|
||||||
let span = tracing::info_span!("load_certificates");
|
|
||||||
tokio::task::spawn_blocking(|| {
|
|
||||||
let _span = span.entered();
|
|
||||||
rustls::ClientConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_native_roots()
|
|
||||||
.with_no_client_auth()
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
|
||||||
|
|
||||||
let https = HttpsConnectorBuilder::new()
|
|
||||||
.with_tls_config(tls_config.clone())
|
|
||||||
.https_or_http()
|
|
||||||
.enable_http1()
|
|
||||||
.enable_http2()
|
|
||||||
.wrap_connector(http);
|
|
||||||
|
|
||||||
// TODO: we should get the remote address here
|
|
||||||
let client = Client::builder().build(https);
|
|
||||||
|
|
||||||
Ok::<_, ClientInitError>(client)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn client<B, E>(
|
|
||||||
operation: &'static str,
|
|
||||||
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
|
||||||
where
|
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
|
||||||
E: Into<BoxError> + 'static,
|
|
||||||
{
|
|
||||||
let fut = make_base_client()
|
|
||||||
// Map the error to a ClientError
|
|
||||||
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
|
||||||
// Wrap it in an Shared (Arc) to be able to Clone it
|
|
||||||
.shared();
|
|
||||||
|
|
||||||
let client: FutureService<_, _> = FutureService::new(fut);
|
|
||||||
|
|
||||||
let client = ServiceBuilder::new()
|
|
||||||
// Convert the errors to ClientError to help dealing with them
|
|
||||||
.map_err(ClientError::from)
|
|
||||||
.map_response(|r: ClientResponse<hyper::Body>| {
|
|
||||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
|
||||||
})
|
|
||||||
.layer(ClientLayer::new(operation))
|
|
||||||
.service(client);
|
|
||||||
|
|
||||||
client.boxed_clone()
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user