1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Proper HTTP client

This commit is contained in:
Quentin Gliech
2022-02-10 15:33:19 +01:00
parent 2df40762a2
commit 8c36e51176
6 changed files with 306 additions and 100 deletions

4
Cargo.lock generated
View File

@ -1992,6 +1992,7 @@ dependencies = [
name = "mas-http" name = "mas-http"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"bytes 1.1.0", "bytes 1.1.0",
"http", "http",
"http-body", "http-body",
@ -1999,9 +2000,12 @@ dependencies = [
"hyper-rustls 0.23.0", "hyper-rustls 0.23.0",
"opentelemetry", "opentelemetry",
"opentelemetry-http", "opentelemetry-http",
"rustls 0.20.2",
"tokio",
"tower", "tower",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-opentelemetry",
] ]
[[package]] [[package]]

View File

@ -0,0 +1,73 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use clap::Parser;
use hyper::Uri;
use tokio::io::AsyncWriteExt;
use tower::{Service, ServiceExt};
#[derive(Parser, Debug)]
pub(super) struct Options {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Perform an HTTP request with the default HTTP client
Http {
/// Show response headers
#[clap(long, short = 'I')]
show_headers: bool,
/// URI where to perform a GET request
url: Uri,
},
}
impl Options {
#[tracing::instrument(skip_all)]
pub async fn run(&self, _root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC;
match &self.subcommand {
SC::Http { show_headers, url } => {
let mut client = mas_http::client("cli-debug-http").await?;
let request = hyper::Request::builder()
.uri(url)
.body(hyper::Body::empty())?;
let mut response = client.ready().await?.call(request).await?;
if *show_headers {
let status = response.status();
println!(
"{:?} {} {}",
response.version(),
status.as_str(),
status.canonical_reason().unwrap_or_default()
);
for (header, value) in response.headers() {
println!("{}: {:?}", header, value);
}
println!();
}
let mut body = hyper::body::aggregate(response.body_mut()).await?;
let mut stdout = tokio::io::stdout();
stdout.write_all_buf(&mut body).await?;
Ok(())
}
}
}
}

View File

@ -20,6 +20,7 @@ use mas_config::ConfigurationSection;
mod config; mod config;
mod database; mod database;
mod debug;
mod manage; mod manage;
mod server; mod server;
mod templates; mod templates;
@ -40,6 +41,9 @@ enum Subcommand {
/// Templates-related commands /// Templates-related commands
Templates(self::templates::Options), Templates(self::templates::Options),
/// Debug utilities
Debug(self::debug::Options),
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -67,6 +71,7 @@ impl Options {
Some(S::Server(c)) => c.run(self).await, Some(S::Server(c)) => c.run(self).await,
Some(S::Manage(c)) => c.run(self).await, Some(S::Manage(c)) => c.run(self).await,
Some(S::Templates(c)) => c.run(self).await, Some(S::Templates(c)) => c.run(self).await,
Some(S::Debug(c)) => c.run(self).await,
None => self::server::Options::default().run(self).await, None => self::server::Options::default().run(self).await,
} }
} }

View File

@ -27,7 +27,7 @@ use mas_email::{MailTransport, Mailer};
use mas_storage::MIGRATOR; use mas_storage::MIGRATOR;
use mas_tasks::TaskQueue; use mas_tasks::TaskQueue;
use mas_templates::Templates; use mas_templates::Templates;
use tower::make::Shared; use tower::{make::Shared, Layer};
use tracing::{error, info}; use tracing::{error, info};
#[derive(Parser, Debug, Default)] #[derive(Parser, Debug, Default)]
@ -211,7 +211,7 @@ impl Options {
let warp_service = warp::service(root); let warp_service = warp::service(root);
let service = mas_http::server(warp_service); let service = mas_http::ServerLayer::default().layer(warp_service);
info!("Listening on http://{}", listener.local_addr().unwrap()); info!("Listening on http://{}", listener.local_addr().unwrap());

View File

@ -6,6 +6,7 @@ edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
[dependencies] [dependencies]
anyhow = "1.0.53"
bytes = "1.1.0" bytes = "1.1.0"
http = "0.2.6" http = "0.2.6"
http-body = "0.4.4" http-body = "0.4.4"
@ -13,6 +14,9 @@ hyper = "0.14.16"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] } hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
opentelemetry = "0.17.0" opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0" opentelemetry-http = "0.6.0"
rustls = "0.20.2"
tokio = { version = "1.16.1", features = ["sync"] }
tower = { version = "0.4.11", features = ["timeout", "limit"] } tower = { version = "0.4.11", features = ["timeout", "limit"] }
tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace"] } tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace", "compression-full"] }
tracing = "0.1.30" tracing = "0.1.30"
tracing-opentelemetry = "0.17.0"

View File

@ -12,141 +12,223 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::time::Duration; use std::{marker::PhantomData, time::Duration};
use bytes::Bytes;
use http::{header::USER_AGENT, HeaderValue, Request, Response, Version}; use http::{header::USER_AGENT, HeaderValue, Request, Response, Version};
use http_body::combinators::BoxBody; use http_body::{combinators::BoxBody, Body};
use hyper::{client::HttpConnector, Client}; use hyper::{client::HttpConnector, Client};
use hyper_rustls::HttpsConnectorBuilder; use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use opentelemetry::trace::TraceContextExt; use opentelemetry::trace::TraceContextExt;
use opentelemetry_http::HeaderExtractor; use opentelemetry_http::HeaderExtractor;
use tokio::sync::OnceCell;
use tower::{ use tower::{
limit::ConcurrencyLimitLayer, limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
timeout::TimeoutLayer, ServiceBuilder, ServiceExt,
util::{BoxCloneService, BoxService},
BoxError, Service, ServiceBuilder, ServiceExt,
}; };
use tower_http::{ use tower_http::{
compression::{CompressionBody, CompressionLayer},
decompression::{DecompressionBody, DecompressionLayer},
follow_redirect::FollowRedirectLayer, follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer, set_header::SetRequestHeaderLayer,
trace::{MakeSpan, OnResponse, TraceLayer}, trace::{MakeSpan, OnResponse, TraceLayer},
}; };
use tracing::field; use tracing::field;
use tracing_opentelemetry::OpenTelemetrySpanExt;
static MAS_USER_AGENT: HeaderValue = static MAS_USER_AGENT: HeaderValue =
HeaderValue::from_static("matrix-authentication-service/0.0.1"); HeaderValue::from_static("matrix-authentication-service/0.0.1");
type Body = BoxBody<bytes::Bytes, BoxError>; type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub fn client( #[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
operation: &'static str, operation: &'static str,
) -> BoxService< _t: PhantomData<ReqBody>,
Request<Body>, }
Response<impl http_body::Body<Data = bytes::Bytes, Error = hyper::Error>>,
BoxError, impl<B> ClientLayer<B> {
> { fn new(operation: &'static str) -> Self {
Self {
operation,
_t: PhantomData,
}
}
}
type ClientResponse<B> = Response<
DecompressionBody<BoxBody<<B as http_body::Body>::Data, <B as http_body::Body>::Error>>,
>;
impl<ReqBody, ResBody, S> Layer<S> for ClientLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(DecompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
// A trace that has the whole operation, with all the redirects, retries, rate limits
.layer(MakeOtelSpan::outer_client(self.operation).http_layer())
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
// A trace for each "real" http request
.layer(MakeOtelSpan::inner_client().http_layer())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
// Propagate the span context
.map_request(|mut r: Request<_>| {
// TODO: this seems to be broken
let cx = tracing::Span::current().context();
let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut());
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut injector)
});
r
})
.service(inner)
.boxed_clone()
}
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
pub async fn client<B, E>(
operation: &'static str,
) -> anyhow::Result<
BoxCloneService<
Request<B>,
Response<impl http_body::Body<Data = bytes::Bytes, Error = anyhow::Error>>,
anyhow::Error,
>,
>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError>,
{
// TODO: we could probably hook a tracing DNS resolver there
let mut http = HttpConnector::new(); let mut http = HttpConnector::new();
http.enforce_http(false); http.enforce_http(false);
let tls_config = TLS_CONFIG
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_certificates");
tokio::task::spawn_blocking(|| {
let _span = span.entered();
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth()
})
.await
})
.await?;
let https = HttpsConnectorBuilder::new() let https = HttpsConnectorBuilder::new()
.with_native_roots() .with_tls_config(tls_config.clone())
.https_or_http() .https_or_http()
.enable_http1() .enable_http1()
.enable_http2() .enable_http2()
.wrap_connector(http); .wrap_connector(http);
// TODO: we should get the remote address here
let client = Client::builder().build(https); let client = Client::builder().build(https);
ServiceBuilder::new() let client = ServiceBuilder::new()
.layer( // Convert the errors to anyhow::Error for convenience
TraceLayer::new_for_http() .map_err(|e: BoxError| anyhow::anyhow!(e))
.make_span_with(MakeOtelSpan::client(operation)) .map_response(|r: ClientResponse<hyper::Body>| {
.on_response(OtelOnResponse), r.map(|body| body.map_err(|e: BoxError| anyhow::anyhow!(e)))
) })
.layer(TimeoutLayer::new(Duration::from_secs(10))) .layer(ClientLayer::new(operation))
.layer(FollowRedirectLayer::new())
.layer(ConcurrencyLimitLayer::new(10))
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
.service(client) .service(client)
.boxed() .boxed_clone();
Ok(client)
} }
#[allow(clippy::type_complexity)] #[derive(Debug, Default)]
pub fn server<ReqBody, ResBody, S>( pub struct ServerLayer<ReqBody>(PhantomData<ReqBody>);
service: S,
) -> BoxCloneService<Request<ReqBody>, Response<BoxBody<ResBody::Data, ResBody::Error>>, BoxError> impl<ReqBody, ResBody, S> Layer<S> for ServerLayer<ReqBody>
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + 'static, ReqBody: http_body::Body + 'static,
ResBody: http_body::Body + Sync + Send + 'static, ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static, ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static, S::Future: Send + 'static,
S::Error: Into<BoxError> + 'static, S::Error: Into<BoxError>,
{ {
ServiceBuilder::new() type Service = BoxCloneService<
.map_response(|r: Response<_>| r.map(BoxBody::new)) Request<ReqBody>,
.layer( Response<CompressionBody<BoxBody<ResBody::Data, ResBody::Error>>>,
TraceLayer::new_for_http() BoxError,
.make_span_with(MakeOtelSpan::server()) >;
.on_response(OtelOnResponse),
) fn layer(&self, inner: S) -> Self::Service {
.layer(TimeoutLayer::new(Duration::from_secs(10))) ServiceBuilder::new()
.service(service) .layer(CompressionLayer::new())
.boxed_clone() .map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(
TraceLayer::new_for_http()
.make_span_with(MakeOtelSpan::server())
.on_response(OtelOnResponse),
)
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(inner)
.boxed_clone()
}
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy)]
pub struct MakeOtelSpan { pub enum MakeOtelSpan {
operation: Option<&'static str>, OuterClient(&'static str),
kind: &'static str, InnerClient,
extract: bool, Server,
} }
impl MakeOtelSpan { impl MakeOtelSpan {
fn client(operation: &'static str) -> Self { const fn outer_client(operation: &'static str) -> Self {
Self { Self::OuterClient(operation)
operation: Some(operation),
extract: false,
kind: "client",
}
} }
fn server() -> Self { const fn inner_client() -> Self {
Self { Self::InnerClient
operation: None, }
extract: true,
kind: "server", const fn server() -> Self {
} Self::Server
}
fn http_layer(
self,
) -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
Self,
tower_http::trace::DefaultOnRequest,
OtelOnResponse,
> {
TraceLayer::new_for_http()
.make_span_with(self)
.on_response(OtelOnResponse)
} }
} }
impl<B> MakeSpan<B> for MakeOtelSpan { impl<B> MakeSpan<B> for MakeOtelSpan {
fn make_span(&mut self, request: &Request<B>) -> tracing::Span { fn make_span(&mut self, request: &Request<B>) -> tracing::Span {
let cx = if self.extract {
// Extract the context from the headers
let headers = request.headers();
let extractor = HeaderExtractor(headers);
let cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&extractor)
});
if cx.span().span_context().is_remote() {
cx
} else {
opentelemetry::Context::new()
}
} else {
opentelemetry::Context::current()
};
// Attach the context so when the request span is created it gets properly
// parented
let _guard = cx.attach();
// Extract the context from the headers // Extract the context from the headers
let headers = request.headers(); let headers = request.headers();
@ -159,21 +241,59 @@ impl<B> MakeSpan<B> for MakeOtelSpan {
_ => "", _ => "",
}; };
let span = tracing::info_span!( let span = match self {
"request", Self::OuterClient(operation) => {
otel.name = field::Empty, tracing::info_span!(
otel.kind = self.kind, "client_request",
otel.status_code = field::Empty, otel.name = operation,
http.method = %request.method(), otel.kind = "internal",
http.target = %request.uri(), otel.status_code = field::Empty,
http.flavor = version, http.method = %request.method(),
http.status_code = field::Empty, http.target = %request.uri(),
http.user_agent = field::Empty, http.flavor = version,
); http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::InnerClient => {
tracing::info_span!(
"outgoing_request",
otel.kind = "client",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::Server => {
let span = tracing::info_span!(
"incoming_request",
otel.kind = "server",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
);
if let Some(operation) = &self.operation { // Extract the context from the headers for server spans
span.record("otel.name", operation); let headers = request.headers();
} let extractor = HeaderExtractor(headers);
let cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&extractor)
});
if cx.span().span_context().is_remote() {
span.set_parent(cx);
}
span
}
};
if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) { if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) {
span.record("http.user_agent", &user_agent); span.record("http.user_agent", &user_agent);