1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Gate some crates behind features in mas-http

This commit is contained in:
Quentin Gliech
2022-08-17 12:20:09 +02:00
parent 9fe541f7b6
commit 185ff622f9
9 changed files with 199 additions and 159 deletions

View File

@ -36,4 +36,4 @@ mas-storage = { path = "../storage" }
mas-data-model = { path = "../data-model" }
mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }
mas-http = { path = "../http" }
mas-http = { path = "../http", features = ["client"] }

View File

@ -35,7 +35,7 @@ opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqw
mas-config = { path = "../config" }
mas-email = { path = "../email" }
mas-handlers = { path = "../handlers" }
mas-http = { path = "../http" }
mas-http = { path = "../http", features = ["axum"] }
mas-policy = { path = "../policy" }
mas-router = { path = "../router" }
mas-static-files = { path = "../static-files" }

View File

@ -6,14 +6,14 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
axum = "0.5.13"
axum = { version = "0.5.13", optional = true }
bytes = "1.2.1"
futures-util = "0.3.21"
headers = "0.3.7"
http = "0.2.8"
http-body = "0.4.5"
hyper = "0.14.20"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false }
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true }
once_cell = "1.13.0"
opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0"
@ -23,14 +23,19 @@ serde = "1.0.142"
serde_json = "1.0.83"
serde_urlencoded = "0.7.1"
thiserror = "1.0.32"
tokio = { version = "1.20.1", features = ["sync", "parking_lot"] }
tokio = { version = "1.20.1", optional = true }
tower = { version = "0.4.13", features = ["timeout", "limit"] }
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
tracing = "0.1.36"
tracing-opentelemetry = "0.17.4"
[dev-dependencies]
anyhow = "1.0.62"
serde = { version = "1.0.142", features = ["derive"] }
tokio = { version = "1.20.1", features = ["macros"] }
tokio = { version = "1.20.1", features = ["macros", "rt"] }
tower = { version = "0.4.13", features = ["util"] }
[features]
default = []
axum = ["dep:axum"]
client = ["dep:hyper-rustls", "hyper/tcp", "tokio", "tokio/sync", "tokio/parking_lot"]

130
crates/http/src/client.rs Normal file
View File

@ -0,0 +1,130 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use bytes::Bytes;
use futures_util::{FutureExt, TryFutureExt};
use http::{Request, Response};
use http_body::{combinators::BoxBody, Body};
use hyper::{
client::{connect::dns::GaiResolver, HttpConnector},
Client,
};
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error;
use tokio::{sync::OnceCell, task::JoinError};
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
use crate::{
layers::{
client::{ClientLayer, ClientResponse},
otel::{TraceDns, TraceLayer},
},
BoxError, FutureService,
};
/// A wrapper over a boxed error that implements ``std::error::Error``.
/// This is helps converting to ``anyhow::Error`` with the `?` operator
#[derive(Error, Debug)]
pub enum ClientError {
#[error("failed to initialize HTTPS client")]
Init(#[from] ClientInitError),
#[error(transparent)]
Call(#[from] BoxError),
}
#[derive(Error, Debug, Clone)]
pub enum ClientInitError {
#[error("failed to load system certificates")]
CertificateLoad {
#[from]
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
},
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
async fn make_base_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
{
// Trace DNS requests
let resolver = ServiceBuilder::new()
.layer(TraceLayer::dns())
.service(GaiResolver::new());
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
let tls_config = TLS_CONFIG
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_certificates");
tokio::task::spawn_blocking(|| {
let _span = span.entered();
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth()
})
.await
})
.await
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http);
// TODO: we should get the remote address here
let client = Client::builder().build(https);
Ok::<_, ClientInitError>(client)
}
#[must_use]
pub fn client<B, E>(
operation: &'static str,
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError> + 'static,
{
let fut = make_base_client()
// Map the error to a ClientError
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
// Wrap it in an Shared (Arc) to be able to Clone it
.shared();
let client: FutureService<_, _> = FutureService::new(fut);
let client = ServiceBuilder::new()
// Convert the errors to ClientError to help dealing with them
.map_err(ClientError::from)
.map_response(|r: ClientResponse<hyper::Body>| {
r.map(|body| body.map_err(ClientError::from).boxed())
})
.layer(ClientLayer::new(operation))
.service(client);
client.boxed_clone()
}

View File

@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{borrow::Cow, net::SocketAddr};
use std::borrow::Cow;
#[cfg(feature = "axum")]
use axum::extract::{ConnectInfo, MatchedPath};
use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
use http::{Method, Request, Version};
#[cfg(feature = "client")]
use hyper::client::connect::dns::Name;
use opentelemetry::trace::{SpanBuilder, SpanKind};
use opentelemetry_semantic_conventions::trace::{
HTTP_FLAVOR, HTTP_HOST, HTTP_METHOD, HTTP_REQUEST_CONTENT_LENGTH, HTTP_ROUTE, HTTP_TARGET,
HTTP_USER_AGENT, NET_HOST_NAME, NET_PEER_IP, NET_PEER_PORT, NET_TRANSPORT,
};
use opentelemetry_semantic_conventions::trace as SC;
pub trait MakeSpanBuilder<R> {
fn make_span_builder(&self, request: &R) -> SpanBuilder;
@ -117,24 +116,24 @@ impl SpanFromHttpRequest {
impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
let mut attributes = vec![
HTTP_METHOD.string(http_method_str(request.method())),
HTTP_FLAVOR.string(http_flavor(request.version())),
HTTP_TARGET.string(request.uri().to_string()),
SC::HTTP_METHOD.string(http_method_str(request.method())),
SC::HTTP_FLAVOR.string(http_flavor(request.version())),
SC::HTTP_TARGET.string(request.uri().to_string()),
];
let headers = request.headers();
if let Some(host) = headers.typed_get::<Host>() {
attributes.push(HTTP_HOST.string(host.to_string()));
attributes.push(SC::HTTP_HOST.string(host.to_string()));
}
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
attributes.push(HTTP_USER_AGENT.string(user_agent.to_string()));
attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string()));
}
if let Some(ContentLength(content_length)) = headers.typed_get() {
if let Ok(content_length) = content_length.try_into() {
attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
}
}
@ -144,42 +143,47 @@ impl<B> MakeSpanBuilder<Request<B>> for SpanFromHttpRequest {
}
}
#[cfg(feature = "axum")]
#[derive(Debug, Clone)]
pub struct SpanFromAxumRequest;
#[cfg(feature = "axum")]
impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
fn make_span_builder(&self, request: &Request<B>) -> SpanBuilder {
let mut attributes = vec![
HTTP_METHOD.string(http_method_str(request.method())),
HTTP_FLAVOR.string(http_flavor(request.version())),
HTTP_TARGET.string(request.uri().to_string()),
SC::HTTP_METHOD.string(http_method_str(request.method())),
SC::HTTP_FLAVOR.string(http_flavor(request.version())),
SC::HTTP_TARGET.string(request.uri().to_string()),
];
let headers = request.headers();
if let Some(host) = headers.typed_get::<Host>() {
attributes.push(HTTP_HOST.string(host.to_string()));
attributes.push(SC::HTTP_HOST.string(host.to_string()));
}
if let Some(user_agent) = headers.typed_get::<UserAgent>() {
attributes.push(HTTP_USER_AGENT.string(user_agent.to_string()));
attributes.push(SC::HTTP_USER_AGENT.string(user_agent.to_string()));
}
if let Some(ContentLength(content_length)) = headers.typed_get() {
if let Ok(content_length) = content_length.try_into() {
attributes.push(HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
attributes.push(SC::HTTP_REQUEST_CONTENT_LENGTH.i64(content_length));
}
}
if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
attributes.push(NET_TRANSPORT.string("ip_tcp"));
attributes.push(NET_PEER_IP.string(addr.ip().to_string()));
attributes.push(NET_PEER_PORT.i64(addr.port().into()));
if let Some(ConnectInfo(addr)) = request
.extensions()
.get::<ConnectInfo<std::net::SocketAddr>>()
{
attributes.push(SC::NET_TRANSPORT.string("ip_tcp"));
attributes.push(SC::NET_PEER_IP.string(addr.ip().to_string()));
attributes.push(SC::NET_PEER_PORT.i64(addr.port().into()));
}
let name = if let Some(path) = request.extensions().get::<MatchedPath>() {
let path = path.as_str().to_owned();
attributes.push(HTTP_ROUTE.string(path.clone()));
attributes.push(SC::HTTP_ROUTE.string(path.clone()));
path
} else {
request.uri().path().to_owned()
@ -191,12 +195,14 @@ impl<B> MakeSpanBuilder<Request<B>> for SpanFromAxumRequest {
}
}
#[cfg(feature = "client")]
#[derive(Debug, Clone, Copy, Default)]
pub struct SpanFromDnsRequest;
#[cfg(feature = "client")]
impl MakeSpanBuilder<Name> for SpanFromDnsRequest {
fn make_span_builder(&self, request: &Name) -> SpanBuilder {
let attributes = vec![NET_HOST_NAME.string(request.as_str().to_owned())];
let attributes = vec![SC::NET_HOST_NAME.string(request.as_str().to_owned())];
SpanBuilder::from_name("resolve")
.with_kind(SpanKind::Client)

View File

@ -37,6 +37,7 @@ pub type TraceHttpServer<S> = Trace<
S,
>;
#[cfg(feature = "axum")]
pub type TraceAxumServerLayer = TraceLayer<
ExtractFromHttpRequest,
DefaultInjectContext,
@ -45,6 +46,7 @@ pub type TraceAxumServerLayer = TraceLayer<
DefaultOnError,
>;
#[cfg(feature = "axum")]
pub type TraceAxumServer<S> = Trace<
ExtractFromHttpRequest,
DefaultInjectContext,
@ -71,6 +73,7 @@ pub type TraceHttpClient<S> = Trace<
S,
>;
#[cfg(feature = "client")]
pub type TraceDnsLayer = TraceLayer<
DefaultExtractContext,
DefaultInjectContext,
@ -79,6 +82,7 @@ pub type TraceDnsLayer = TraceLayer<
DefaultOnError,
>;
#[cfg(feature = "client")]
pub type TraceDns<S> = Trace<
DefaultExtractContext,
DefaultInjectContext,
@ -98,6 +102,7 @@ impl TraceHttpServerLayer {
}
}
#[cfg(feature = "axum")]
impl TraceAxumServerLayer {
#[must_use]
pub fn axum() -> Self {
@ -126,6 +131,7 @@ impl TraceHttpClientLayer {
}
}
#[cfg(feature = "client")]
impl TraceDnsLayer {
#[must_use]
pub fn dns() -> Self {

View File

@ -14,12 +14,10 @@
use headers::{ContentLength, HeaderMapExt};
use http::Response;
#[cfg(feature = "client")]
use hyper::client::connect::HttpInfo;
use opentelemetry::trace::SpanRef;
use opentelemetry_semantic_conventions::trace::{
HTTP_RESPONSE_CONTENT_LENGTH, HTTP_STATUS_CODE, NET_HOST_IP, NET_HOST_PORT, NET_PEER_IP,
NET_PEER_PORT,
};
use opentelemetry_semantic_conventions::trace as SC;
pub trait OnResponse<R> {
fn on_response(&self, span: &SpanRef<'_>, response: &R);
@ -37,21 +35,22 @@ pub struct OnHttpResponse;
impl<B> OnResponse<Response<B>> for OnHttpResponse {
fn on_response(&self, span: &SpanRef<'_>, response: &Response<B>) {
span.set_attribute(HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16())));
span.set_attribute(SC::HTTP_STATUS_CODE.i64(i64::from(response.status().as_u16())));
if let Some(ContentLength(content_length)) = response.headers().typed_get() {
if let Ok(content_length) = content_length.try_into() {
span.set_attribute(HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length));
span.set_attribute(SC::HTTP_RESPONSE_CONTENT_LENGTH.i64(content_length));
}
}
#[cfg(feature = "client")]
// Get local and remote address from hyper's HttpInfo injected by the
// HttpConnector
if let Some(info) = response.extensions().get::<HttpInfo>() {
span.set_attribute(NET_PEER_IP.string(info.remote_addr().ip().to_string()));
span.set_attribute(NET_PEER_PORT.i64(info.remote_addr().port().into()));
span.set_attribute(NET_HOST_IP.string(info.local_addr().ip().to_string()));
span.set_attribute(NET_HOST_PORT.i64(info.local_addr().port().into()));
span.set_attribute(SC::NET_PEER_IP.string(info.remote_addr().ip().to_string()));
span.set_attribute(SC::NET_PEER_PORT.i64(info.remote_addr().port().into()));
span.set_attribute(SC::NET_HOST_IP.string(info.local_addr().ip().to_string()));
span.set_attribute(SC::NET_HOST_PORT.i64(info.local_addr().port().into()));
}
}
}

View File

@ -37,10 +37,14 @@ where
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.compression()
.layer(TraceLayer::axum())
.service(inner)
.boxed_clone()
let builder = ServiceBuilder::new().compression();
#[cfg(feature = "axum")]
let builder = builder.layer(TraceLayer::axum());
#[cfg(not(feature = "axum"))]
let builder = builder.layer(TraceLayer::http_server());
builder.service(inner).boxed_clone()
}
}

View File

@ -24,30 +24,14 @@
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
use std::sync::Arc;
use bytes::Bytes;
use futures_util::{FutureExt, TryFutureExt};
use http::{Request, Response};
use http_body::{combinators::BoxBody, Body};
use hyper::{
client::{connect::dns::GaiResolver, HttpConnector},
Client,
};
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error;
use tokio::{sync::OnceCell, task::JoinError};
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
use self::layers::{
client::ClientResponse,
otel::{TraceDns, TraceLayer},
};
#[cfg(feature = "client")]
mod client;
mod ext;
mod future_service;
mod layers;
#[cfg(feature = "client")]
pub use self::client::client;
pub use self::{
ext::{
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
@ -67,97 +51,3 @@ pub use self::{
};
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
/// A wrapper over a boxed error that implements ``std::error::Error``.
/// This is helps converting to ``anyhow::Error`` with the `?` operator
#[derive(Error, Debug)]
pub enum ClientError {
#[error("failed to initialize HTTPS client")]
Init(#[from] ClientInitError),
#[error(transparent)]
Call(#[from] BoxError),
}
#[derive(Error, Debug, Clone)]
pub enum ClientInitError {
#[error("failed to load system certificates")]
CertificateLoad {
#[from]
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
},
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
async fn make_base_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
{
// Trace DNS requests
let resolver = ServiceBuilder::new()
.layer(TraceLayer::dns())
.service(GaiResolver::new());
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
let tls_config = TLS_CONFIG
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_certificates");
tokio::task::spawn_blocking(|| {
let _span = span.entered();
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth()
})
.await
})
.await
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http);
// TODO: we should get the remote address here
let client = Client::builder().build(https);
Ok::<_, ClientInitError>(client)
}
#[must_use]
pub fn client<B, E>(
operation: &'static str,
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError> + 'static,
{
let fut = make_base_client()
// Map the error to a ClientError
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
// Wrap it in an Shared (Arc) to be able to Clone it
.shared();
let client: FutureService<_, _> = FutureService::new(fut);
let client = ServiceBuilder::new()
// Convert the errors to ClientError to help dealing with them
.map_err(ClientError::from)
.map_response(|r: ClientResponse<hyper::Body>| {
r.map(|body| body.map_err(ClientError::from).boxed())
})
.layer(ClientLayer::new(operation))
.service(client);
client.boxed_clone()
}