1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Implement a JSON tower layer

This will help requesting JSON APIs
This commit is contained in:
Quentin Gliech
2022-02-11 13:10:24 +01:00
parent 8c36e51176
commit 497a3e006e
11 changed files with 579 additions and 274 deletions

34
Cargo.lock generated
View File

@ -1228,9 +1228,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3dda0b6588335f360afc675d0564c17a77a2bda81ca178a4b6081bd86c7f0b" checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
@ -1238,9 +1238,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-core" name = "futures-core"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3"
[[package]] [[package]]
name = "futures-executor" name = "futures-executor"
@ -1266,15 +1266,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-io" name = "futures-io"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f9d34af5a1aac6fb380f735fe510746c38067c5bf16c7fd250280503c971b2" checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbd947adfffb0efc70599b3ddcf7b5597bb5fa9e245eb99f62b3a5f7bb8bd3c" checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1283,21 +1283,21 @@ dependencies = [
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3055baccb68d74ff6480350f8d6eb8fcfa3aa11bdc1a1ae3afdd0514617d508" checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868"
[[package]] [[package]]
name = "futures-task" name = "futures-task"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ee7c6485c30167ce4dfb83ac568a849fe53274c831081476ee13e0dce1aad72" checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.19" version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b5cf40b47a271f77a8b1bec03ca09044d99d2372c0de244e66430761127164" checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a"
dependencies = [ dependencies = [
"futures 0.1.31", "futures 0.1.31",
"futures-channel", "futures-channel",
@ -1877,6 +1877,7 @@ dependencies = [
"opentelemetry-zipkin", "opentelemetry-zipkin",
"reqwest", "reqwest",
"schemars", "schemars",
"serde_json",
"serde_yaml", "serde_yaml",
"tokio", "tokio",
"tower", "tower",
@ -1994,13 +1995,18 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes 1.1.0", "bytes 1.1.0",
"futures-util",
"http", "http",
"http-body", "http-body",
"hyper", "hyper",
"hyper-rustls 0.23.0", "hyper-rustls 0.23.0",
"opentelemetry", "opentelemetry",
"opentelemetry-http", "opentelemetry-http",
"pin-project-lite",
"rustls 0.20.2", "rustls 0.20.2",
"serde",
"serde_json",
"thiserror",
"tokio", "tokio",
"tower", "tower",
"tower-http", "tower-http",

View File

@ -15,6 +15,7 @@ schemars = { version = "0.8.8", features = ["url", "chrono"] }
tower = { version = "0.4.11", features = ["full"] } tower = { version = "0.4.11", features = ["full"] }
hyper = { version = "0.14.16", features = ["full"] } hyper = { version = "0.14.16", features = ["full"] }
serde_yaml = "0.8.23" serde_yaml = "0.8.23"
serde_json = "1.0.78"
warp = "0.3.2" warp = "0.3.2"
url = "2.2.2" url = "2.2.2"
argon2 = { version = "0.3.3", features = ["password-hash"] } argon2 = { version = "0.3.3", features = ["password-hash"] }

View File

@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
use clap::Parser; use clap::Parser;
use hyper::Uri; use hyper::{Response, Uri};
use mas_http::HttpServiceExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
@ -31,43 +32,81 @@ enum Subcommand {
#[clap(long, short = 'I')] #[clap(long, short = 'I')]
show_headers: bool, show_headers: bool,
/// Parse the response as JSON
#[clap(long, short = 'j')]
json: bool,
/// URI where to perform a GET request /// URI where to perform a GET request
url: Uri, url: Uri,
}, },
} }
fn print_headers(parts: &hyper::http::response::Parts) {
println!(
"{:?} {} {}",
parts.version,
parts.status.as_str(),
parts.status.canonical_reason().unwrap_or_default()
);
for (header, value) in &parts.headers {
println!("{}: {:?}", header, value);
}
println!();
}
impl Options { impl Options {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn run(&self, _root: &super::Options) -> anyhow::Result<()> { pub async fn run(&self, _root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC; use Subcommand as SC;
match &self.subcommand { match &self.subcommand {
SC::Http { show_headers, url } => { SC::Http {
show_headers,
json: false,
url,
} => {
let mut client = mas_http::client("cli-debug-http").await?; let mut client = mas_http::client("cli-debug-http").await?;
let request = hyper::Request::builder() let request = hyper::Request::builder()
.uri(url) .uri(url)
.body(hyper::Body::empty())?; .body(hyper::Body::empty())?;
let mut response = client.ready().await?.call(request).await?; let response = client.ready().await?.call(request).await?;
let (parts, body) = response.into_parts();
if *show_headers { if *show_headers {
let status = response.status(); print_headers(&parts);
println!(
"{:?} {} {}",
response.version(),
status.as_str(),
status.canonical_reason().unwrap_or_default()
);
for (header, value) in response.headers() {
println!("{}: {:?}", header, value);
}
println!();
} }
let mut body = hyper::body::aggregate(response.body_mut()).await?;
let mut body = hyper::body::aggregate(body).await?;
let mut stdout = tokio::io::stdout(); let mut stdout = tokio::io::stdout();
stdout.write_all_buf(&mut body).await?; stdout.write_all_buf(&mut body).await?;
Ok(()) Ok(())
} }
SC::Http {
show_headers,
json: true,
url,
} => {
let mut client = mas_http::client("cli-debug-http").await?.json();
let request = hyper::Request::builder()
.uri(url)
.body(hyper::Body::empty())?;
let response: Response<serde_json::Value> =
client.ready().await?.call(request).await?;
let (parts, body) = response.into_parts();
if *show_headers {
print_headers(&parts);
}
let body = serde_json::to_string_pretty(&body)?;
println!("{}", body);
Ok(())
}
} }
} }
} }

View File

@ -8,14 +8,19 @@ license = "Apache-2.0"
[dependencies] [dependencies]
anyhow = "1.0.53" anyhow = "1.0.53"
bytes = "1.1.0" bytes = "1.1.0"
futures-util = "0.3.21"
http = "0.2.6" http = "0.2.6"
http-body = "0.4.4" http-body = "0.4.4"
hyper = "0.14.16" hyper = "0.14.16"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] } hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
opentelemetry = "0.17.0" opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0" opentelemetry-http = "0.6.0"
pin-project-lite = "0.2.8"
rustls = "0.20.2" rustls = "0.20.2"
tokio = { version = "1.16.1", features = ["sync"] } serde = "1.0.136"
serde_json = "1.0.78"
thiserror = "1.0.30"
tokio = { version = "1.16.1", features = ["sync", "parking_lot"] }
tower = { version = "0.4.11", features = ["timeout", "limit"] } tower = { version = "0.4.11", features = ["timeout", "limit"] }
tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace", "compression-full"] } tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace", "compression-full"] }
tracing = "0.1.30" tracing = "0.1.30"

30
crates/http/src/ext.rs Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::layers::json::Json;
pub trait ServiceExt {
fn json<T>(self) -> Json<Self, T>
where
Self: Sized;
}
impl<S> ServiceExt for S {
fn json<T>(self) -> Json<Self, T>
where
Self: Sized,
{
Json::new(self)
}
}

View File

@ -0,0 +1,96 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{marker::PhantomData, time::Duration};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use http_body::combinators::BoxBody;
use tower::{
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
ServiceBuilder, ServiceExt,
};
use tower_http::{
decompression::{DecompressionBody, DecompressionLayer},
follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer,
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::trace::OtelTraceLayer;
static MAS_USER_AGENT: HeaderValue =
HeaderValue::from_static("matrix-authentication-service/0.0.1");
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
operation: &'static str,
_t: PhantomData<ReqBody>,
}
impl<B> ClientLayer<B> {
pub fn new(operation: &'static str) -> Self {
Self {
operation,
_t: PhantomData,
}
}
}
pub type ClientResponse<B> = Response<
DecompressionBody<BoxBody<<B as http_body::Body>::Data, <B as http_body::Body>::Error>>,
>;
impl<ReqBody, ResBody, S> Layer<S> for ClientLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(DecompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
// A trace that has the whole operation, with all the redirects, retries, rate limits
.layer(OtelTraceLayer::outer_client(self.operation))
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
// A trace for each "real" http request
.layer(OtelTraceLayer::inner_client())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
// Propagate the span context
.map_request(|mut r: Request<_>| {
// TODO: this seems to be broken
let cx = tracing::Span::current().context();
let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut());
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut injector)
});
r
})
.service(inner)
.boxed_clone()
}
}

View File

@ -0,0 +1,116 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{marker::PhantomData, task::Poll};
use futures_util::future::BoxFuture;
use http::{header::ACCEPT, HeaderValue, Request, Response};
use http_body::Body;
use serde::de::DeserializeOwned;
use thiserror::Error;
use tower::{Layer, Service};
#[derive(Debug, Error)]
pub enum Error<Service, Body> {
#[error("service")]
Service { inner: Service },
#[error("body")]
Body { inner: Body },
#[error("json")]
Json { inner: serde_json::Error },
}
impl<S, B> Error<S, B> {
fn service(source: S) -> Self {
Self::Service { inner: source }
}
fn body(source: B) -> Self {
Self::Body { inner: source }
}
fn json(source: serde_json::Error) -> Self {
Self::Json { inner: source }
}
}
pub struct Json<S, T> {
inner: S,
_t: PhantomData<T>,
}
impl<S, T> Json<S, T> {
pub const fn new(inner: S) -> Self {
Self {
inner,
_t: PhantomData,
}
}
}
impl<S, T, B, C> Service<Request<B>> for Json<S, T>
where
S: Service<Request<B>, Response = Response<C>>,
S::Future: Send + 'static,
C: Body + Send + 'static,
C::Data: Send + 'static,
T: DeserializeOwned,
{
type Error = Error<S::Error, C::Error>;
type Response = Response<T>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Error::service)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
req.headers_mut()
.insert(ACCEPT, HeaderValue::from_static("application/json"));
let fut = self.inner.call(req);
let fut = async {
let res = fut.await.map_err(Error::service)?;
let (parts, body) = res.into_parts();
futures_util::pin_mut!(body);
let bytes = hyper::body::to_bytes(&mut body)
.await
.map_err(Error::body)?;
let body = serde_json::from_slice(&bytes.to_vec()).map_err(Error::json)?;
let res = Response::from_parts(parts, body);
Ok(res)
};
Box::pin(fut)
}
}
#[derive(Default, Clone, Copy)]
pub struct JsonResponseLayer<T, ReqBody>(PhantomData<(T, ReqBody)>);
impl<ReqBody, ResBody, S, T> Layer<S> for JsonResponseLayer<T, ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
T: serde::de::DeserializeOwned,
{
type Service = Json<S, T>;
fn layer(&self, inner: S) -> Self::Service {
Json::new(inner)
}
}

View File

@ -0,0 +1,18 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub(crate) mod client;
pub(crate) mod json;
pub(crate) mod server;
pub(crate) mod trace;

View File

@ -0,0 +1,56 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{marker::PhantomData, time::Duration};
use http::{Request, Response};
use http_body::combinators::BoxBody;
use tower::{
timeout::TimeoutLayer, util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt,
};
use tower_http::compression::{CompressionBody, CompressionLayer};
use super::trace::OtelTraceLayer;
use crate::BoxError;
#[derive(Debug, Default)]
pub struct ServerLayer<ReqBody> {
_t: PhantomData<ReqBody>,
}
impl<ReqBody, ResBody, S> Layer<S> for ServerLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<
Request<ReqBody>,
Response<CompressionBody<BoxBody<ResBody::Data, ResBody::Error>>>,
BoxError,
>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(CompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(OtelTraceLayer::server())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(inner)
.boxed_clone()
}
}

View File

@ -0,0 +1,170 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use http::{header::USER_AGENT, Request, Response, Version};
use opentelemetry::trace::TraceContextExt;
use opentelemetry_http::HeaderExtractor;
use tower::Layer;
use tower_http::{
classify::{ServerErrorsAsFailures, SharedClassifier},
trace::{DefaultOnRequest, MakeSpan, OnResponse, Trace},
};
use tracing::{field, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(Debug, Clone, Copy)]
pub enum OtelTraceLayer {
OuterClient(&'static str),
InnerClient,
Server,
}
impl OtelTraceLayer {
pub const fn outer_client(operation: &'static str) -> Self {
Self::OuterClient(operation)
}
pub const fn inner_client() -> Self {
Self::InnerClient
}
pub const fn server() -> Self {
Self::Server
}
}
impl<S> Layer<S> for OtelTraceLayer {
type Service = Trace<
S,
SharedClassifier<ServerErrorsAsFailures>,
MakeOtelSpan,
DefaultOnRequest,
OtelOnResponse,
>;
fn layer(&self, inner: S) -> Self::Service {
let make_span = match self {
Self::OuterClient(o) => MakeOtelSpan::OuterClient(o),
Self::InnerClient => MakeOtelSpan::InnerClient,
Self::Server => MakeOtelSpan::Server,
};
Trace::new_for_http(inner)
.make_span_with(make_span)
.on_response(OtelOnResponse)
}
}
#[derive(Debug, Clone, Copy)]
pub enum MakeOtelSpan {
OuterClient(&'static str),
InnerClient,
Server,
}
impl<B> MakeSpan<B> for MakeOtelSpan {
fn make_span(&mut self, request: &Request<B>) -> Span {
// Extract the context from the headers
let headers = request.headers();
let version = match request.version() {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_ => "",
};
let span = match self {
Self::OuterClient(operation) => {
tracing::info_span!(
"client_request",
otel.name = operation,
otel.kind = "internal",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::InnerClient => {
tracing::info_span!(
"outgoing_request",
otel.kind = "client",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::Server => {
let span = tracing::info_span!(
"incoming_request",
otel.kind = "server",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
);
// Extract the context from the headers for server spans
let headers = request.headers();
let extractor = HeaderExtractor(headers);
let cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&extractor)
});
if cx.span().span_context().is_remote() {
span.set_parent(cx);
}
span
}
};
if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) {
span.record("http.user_agent", &user_agent);
}
span
}
}
#[derive(Debug, Clone, Default)]
pub struct OtelOnResponse;
impl<B> OnResponse<B> for OtelOnResponse {
fn on_response(self, response: &Response<B>, _latency: Duration, span: &Span) {
let s = response.status();
let status = if s.is_success() {
"ok"
} else if s.is_client_error() || s.is_server_error() {
"error"
} else {
"unset"
};
span.record("otel.status_code", &status);
span.record("http.status_code", &s.as_u16());
}
}

View File

@ -12,95 +12,24 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{marker::PhantomData, time::Duration};
use bytes::Bytes; use bytes::Bytes;
use http::{header::USER_AGENT, HeaderValue, Request, Response, Version}; use http::{Request, Response};
use http_body::{combinators::BoxBody, Body}; use http_body::Body;
use hyper::{client::HttpConnector, Client}; use hyper::{client::HttpConnector, Client};
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder}; use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use opentelemetry::trace::TraceContextExt; use layers::client::ClientResponse;
use opentelemetry_http::HeaderExtractor;
use tokio::sync::OnceCell; use tokio::sync::OnceCell;
use tower::{ use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
ServiceBuilder, ServiceExt, mod ext;
mod layers;
pub use self::{
ext::ServiceExt as HttpServiceExt,
layers::{client::ClientLayer, json::JsonResponseLayer, server::ServerLayer},
}; };
use tower_http::{
compression::{CompressionBody, CompressionLayer},
decompression::{DecompressionBody, DecompressionLayer},
follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer,
trace::{MakeSpan, OnResponse, TraceLayer},
};
use tracing::field;
use tracing_opentelemetry::OpenTelemetrySpanExt;
static MAS_USER_AGENT: HeaderValue = pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
HeaderValue::from_static("matrix-authentication-service/0.0.1");
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
operation: &'static str,
_t: PhantomData<ReqBody>,
}
impl<B> ClientLayer<B> {
fn new(operation: &'static str) -> Self {
Self {
operation,
_t: PhantomData,
}
}
}
type ClientResponse<B> = Response<
DecompressionBody<BoxBody<<B as http_body::Body>::Data, <B as http_body::Body>::Error>>,
>;
impl<ReqBody, ResBody, S> Layer<S> for ClientLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(DecompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
// A trace that has the whole operation, with all the redirects, retries, rate limits
.layer(MakeOtelSpan::outer_client(self.operation).http_layer())
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
// A trace for each "real" http request
.layer(MakeOtelSpan::inner_client().http_layer())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
// Propagate the span context
.map_request(|mut r: Request<_>| {
// TODO: this seems to be broken
let cx = tracing::Span::current().context();
let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut());
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut injector)
});
r
})
.service(inner)
.boxed_clone()
}
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new(); static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
@ -159,164 +88,3 @@ where
Ok(client) Ok(client)
} }
#[derive(Debug, Default)]
pub struct ServerLayer<ReqBody>(PhantomData<ReqBody>);
impl<ReqBody, ResBody, S> Layer<S> for ServerLayer<ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
ReqBody: http_body::Body + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
ResBody::Error: std::fmt::Display + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError>,
{
type Service = BoxCloneService<
Request<ReqBody>,
Response<CompressionBody<BoxBody<ResBody::Data, ResBody::Error>>>,
BoxError,
>;
fn layer(&self, inner: S) -> Self::Service {
ServiceBuilder::new()
.layer(CompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
.layer(
TraceLayer::new_for_http()
.make_span_with(MakeOtelSpan::server())
.on_response(OtelOnResponse),
)
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(inner)
.boxed_clone()
}
}
#[derive(Debug, Clone, Copy)]
pub enum MakeOtelSpan {
OuterClient(&'static str),
InnerClient,
Server,
}
impl MakeOtelSpan {
const fn outer_client(operation: &'static str) -> Self {
Self::OuterClient(operation)
}
const fn inner_client() -> Self {
Self::InnerClient
}
const fn server() -> Self {
Self::Server
}
fn http_layer(
self,
) -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
Self,
tower_http::trace::DefaultOnRequest,
OtelOnResponse,
> {
TraceLayer::new_for_http()
.make_span_with(self)
.on_response(OtelOnResponse)
}
}
impl<B> MakeSpan<B> for MakeOtelSpan {
fn make_span(&mut self, request: &Request<B>) -> tracing::Span {
// Extract the context from the headers
let headers = request.headers();
let version = match request.version() {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_ => "",
};
let span = match self {
Self::OuterClient(operation) => {
tracing::info_span!(
"client_request",
otel.name = operation,
otel.kind = "internal",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::InnerClient => {
tracing::info_span!(
"outgoing_request",
otel.kind = "client",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
)
}
Self::Server => {
let span = tracing::info_span!(
"incoming_request",
otel.kind = "server",
otel.status_code = field::Empty,
http.method = %request.method(),
http.target = %request.uri(),
http.flavor = version,
http.status_code = field::Empty,
http.user_agent = field::Empty,
);
// Extract the context from the headers for server spans
let headers = request.headers();
let extractor = HeaderExtractor(headers);
let cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&extractor)
});
if cx.span().span_context().is_remote() {
span.set_parent(cx);
}
span
}
};
if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) {
span.record("http.user_agent", &user_agent);
}
span
}
}
#[derive(Debug, Clone, Default)]
pub struct OtelOnResponse;
impl<B> OnResponse<B> for OtelOnResponse {
fn on_response(self, response: &hyper::Response<B>, _latency: Duration, span: &tracing::Span) {
let s = response.status();
let status = if s.is_success() {
"ok"
} else if s.is_client_error() || s.is_server_error() {
"error"
} else {
"unset"
};
span.record("otel.status_code", &status);
span.record("http.status_code", &s.as_u16());
}
}