1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Proper HTTP client

This commit is contained in:
Quentin Gliech
2022-02-10 15:33:19 +01:00
parent 2df40762a2
commit 8c36e51176
6 changed files with 306 additions and 100 deletions

4
Cargo.lock generated
View File

@ -1992,6 +1992,7 @@ dependencies = [
name = "mas-http"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes 1.1.0",
"http",
"http-body",
@ -1999,9 +2000,12 @@ dependencies = [
"hyper-rustls 0.23.0",
"opentelemetry",
"opentelemetry-http",
"rustls 0.20.2",
"tokio",
"tower",
"tower-http",
"tracing",
"tracing-opentelemetry",
]
[[package]]

View File

@ -0,0 +1,73 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use clap::Parser;
use hyper::Uri;
use tokio::io::AsyncWriteExt;
use tower::{Service, ServiceExt};
#[derive(Parser, Debug)]
pub(super) struct Options {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Perform an HTTP request with the default HTTP client
Http {
/// Show response headers
#[clap(long, short = 'I')]
show_headers: bool,
/// URI where to perform a GET request
url: Uri,
},
}
impl Options {
#[tracing::instrument(skip_all)]
pub async fn run(&self, _root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC;
match &self.subcommand {
SC::Http { show_headers, url } => {
let mut client = mas_http::client("cli-debug-http").await?;
let request = hyper::Request::builder()
.uri(url)
.body(hyper::Body::empty())?;
let mut response = client.ready().await?.call(request).await?;
if *show_headers {
let status = response.status();
println!(
"{:?} {} {}",
response.version(),
status.as_str(),
status.canonical_reason().unwrap_or_default()
);
for (header, value) in response.headers() {
println!("{}: {:?}", header, value);
}
println!();
}
let mut body = hyper::body::aggregate(response.body_mut()).await?;
let mut stdout = tokio::io::stdout();
stdout.write_all_buf(&mut body).await?;
Ok(())
}
}
}
}

View File

@ -20,6 +20,7 @@ use mas_config::ConfigurationSection;
mod config;
mod database;
mod debug;
mod manage;
mod server;
mod templates;
@ -40,6 +41,9 @@ enum Subcommand {
/// Templates-related commands
Templates(self::templates::Options),
/// Debug utilities
Debug(self::debug::Options),
}
#[derive(Parser, Debug)]
@ -67,6 +71,7 @@ impl Options {
Some(S::Server(c)) => c.run(self).await,
Some(S::Manage(c)) => c.run(self).await,
Some(S::Templates(c)) => c.run(self).await,
Some(S::Debug(c)) => c.run(self).await,
None => self::server::Options::default().run(self).await,
}
}

View File

@ -27,7 +27,7 @@ use mas_email::{MailTransport, Mailer};
use mas_storage::MIGRATOR;
use mas_tasks::TaskQueue;
use mas_templates::Templates;
use tower::make::Shared;
use tower::{make::Shared, Layer};
use tracing::{error, info};
#[derive(Parser, Debug, Default)]
@ -211,7 +211,7 @@ impl Options {
let warp_service = warp::service(root);
let service = mas_http::server(warp_service);
let service = mas_http::ServerLayer::default().layer(warp_service);
info!("Listening on http://{}", listener.local_addr().unwrap());

View File

@ -6,6 +6,7 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
anyhow = "1.0.53"
bytes = "1.1.0"
http = "0.2.6"
http-body = "0.4.4"
@ -13,6 +14,9 @@ hyper = "0.14.16"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0"
rustls = "0.20.2"
tokio = { version = "1.16.1", features = ["sync"] }
tower = { version = "0.4.11", features = ["timeout", "limit"] }
tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace"] }
tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace", "compression-full"] }
tracing = "0.1.30"
tracing-opentelemetry = "0.17.0"

View File

@ -12,81 +12,175 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use std::{marker::PhantomData, time::Duration};
use bytes::Bytes;
use http::{header::USER_AGENT, HeaderValue, Request, Response, Version};
use http_body::combinators::BoxBody;
use http_body::{combinators::BoxBody, Body};
use hyper::{client::HttpConnector, Client};
use hyper_rustls::HttpsConnectorBuilder;
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use opentelemetry::trace::TraceContextExt;
use opentelemetry_http::HeaderExtractor;
use tokio::sync::OnceCell;
use tower::{
limit::ConcurrencyLimitLayer,
timeout::TimeoutLayer,
util::{BoxCloneService, BoxService},
BoxError, Service, ServiceBuilder, ServiceExt,
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
ServiceBuilder, ServiceExt,
};
use tower_http::{
compression::{CompressionBody, CompressionLayer},
decompression::{DecompressionBody, DecompressionLayer},
follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer,
trace::{MakeSpan, OnResponse, TraceLayer},
};
use tracing::field;
use tracing_opentelemetry::OpenTelemetrySpanExt;
static MAS_USER_AGENT: HeaderValue =
HeaderValue::from_static("matrix-authentication-service/0.0.1");
type Body = BoxBody<bytes::Bytes, BoxError>;
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub fn client(
#[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
operation: &'static str,
) -> BoxService<
Request<Body>,
Response<impl http_body::Body<Data = bytes::Bytes, Error = hyper::Error>>,
BoxError,
> {
_t: PhantomData<ReqBody>,
}
impl<B> ClientLayer<B> {
fn new(operation: &'static str) -> Self {
Self {
operation,
_t: PhantomData,
}
}
}
type ClientResponse<B> = Response<
DecompressionBody<BoxBody<<B as http_body::Body>::Data, <B as http_body::Body>::Error>>,
>;
impl<ReqBody, ResBody, S> Layer<S> for ClientLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(DecompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
// A trace that has the whole operation, with all the redirects, retries, rate limits
.layer(MakeOtelSpan::outer_client(self.operation).http_layer())
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
// A trace for each "real" http request
.layer(MakeOtelSpan::inner_client().http_layer())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
// Propagate the span context
.map_request(|mut r: Request<_>| {
// TODO: this seems to be broken
let cx = tracing::Span::current().context();
let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut());
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut injector)
});
r
})
.service(inner)
.boxed_clone()
}
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
pub async fn client<B, E>(
operation: &'static str,
) -> anyhow::Result<
BoxCloneService<
Request<B>,
Response<impl http_body::Body<Data = bytes::Bytes, Error = anyhow::Error>>,
anyhow::Error,
>,
>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError>,
{
// TODO: we could probably hook a tracing DNS resolver there
let mut http = HttpConnector::new();
http.enforce_http(false);
let https = HttpsConnectorBuilder::new()
let tls_config = TLS_CONFIG
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_certificates");
tokio::task::spawn_blocking(|| {
let _span = span.entered();
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth()
})
.await
})
.await?;
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http);
// TODO: we should get the remote address here
let client = Client::builder().build(https);
ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.make_span_with(MakeOtelSpan::client(operation))
.on_response(OtelOnResponse),
)
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.layer(FollowRedirectLayer::new())
.layer(ConcurrencyLimitLayer::new(10))
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
let client = ServiceBuilder::new()
// Convert the errors to anyhow::Error for convenience
.map_err(|e: BoxError| anyhow::anyhow!(e))
.map_response(|r: ClientResponse<hyper::Body>| {
r.map(|body| body.map_err(|e: BoxError| anyhow::anyhow!(e)))
})
.layer(ClientLayer::new(operation))
.service(client)
.boxed()
.boxed_clone();
Ok(client)
}
#[allow(clippy::type_complexity)]
pub fn server<ReqBody, ResBody, S>(
service: S,
) -> BoxCloneService<Request<ReqBody>, Response<BoxBody<ResBody::Data, ResBody::Error>>, BoxError>
#[derive(Debug, Default)]
pub struct ServerLayer<ReqBody>(PhantomData<ReqBody>);
impl<ReqBody, ResBody, S> Layer<S> for ServerLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError> + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<
Request<ReqBody>,
Response<CompressionBody<BoxBody<ResBody::Data, ResBody::Error>>>,
BoxError,
>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(CompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(
TraceLayer::new_for_http()
@ -94,59 +188,47 @@ where
.on_response(OtelOnResponse),
)
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(service)
.service(inner)
.boxed_clone()
}
}
#[derive(Debug, Clone, Default)]
pub struct MakeOtelSpan {
operation: Option<&'static str>,
kind: &'static str,
extract: bool,
#[derive(Debug, Clone, Copy)]
pub enum MakeOtelSpan {
OuterClient(&'static str),
InnerClient,
Server,
}
impl MakeOtelSpan {
fn client(operation: &'static str) -> Self {
Self {
operation: Some(operation),
extract: false,
kind: "client",
}
const fn outer_client(operation: &'static str) -> Self {
Self::OuterClient(operation)
}
fn server() -> Self {
Self {
operation: None,
extract: true,
kind: "server",
const fn inner_client() -> Self {
Self::InnerClient
}
const fn server() -> Self {
Self::Server
}
fn http_layer(
self,
) -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
Self,
tower_http::trace::DefaultOnRequest,
OtelOnResponse,
> {
TraceLayer::new_for_http()
.make_span_with(self)
.on_response(OtelOnResponse)
}
}
impl<B> MakeSpan<B> for MakeOtelSpan {
fn make_span(&mut self, request: &Request<B>) -> tracing::Span {
let cx = if self.extract {
// Extract the context from the headers
let headers = request.headers();
let extractor = HeaderExtractor(headers);
let cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&extractor)
});
if cx.span().span_context().is_remote() {
cx
} else {
opentelemetry::Context::new()
}
} else {
opentelemetry::Context::current()
};
// Attach the context so when the request span is created it gets properly
// parented
let _guard = cx.attach();
// Extract the context from the headers
let headers = request.headers();
@ -159,10 +241,36 @@ impl<B> MakeSpan<B> for MakeOtelSpan {
_ => "",
};
let span = match self {
Self::OuterClient(operation) => {
tracing::info_span!(
"client_request",
otel.name = operation,
otel.kind = "internal",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::InnerClient => {
tracing::info_span!(
"outgoing_request",
otel.kind = "client",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::Server => {
let span = tracing::info_span!(
"request",
otel.name = field::Empty,
otel.kind = self.kind,
"incoming_request",
otel.kind = "server",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
@ -171,10 +279,22 @@ impl<B> MakeSpan<B> for MakeOtelSpan {
http.user_agent = field::Empty,
);
if let Some(operation) = &self.operation {
span.record("otel.name", operation);
// Extract the context from the headers for server spans
let headers = request.headers();
let extractor = HeaderExtractor(headers);
let cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&extractor)
});
if cx.span().span_context().is_remote() {
span.set_parent(cx);
}
span
}
};
if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) {
span.record("http.user_agent", &user_agent);
}