From 497a3e006ebbb9949f57c505c3ab0a1c8f616fa4 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 11 Feb 2022 13:10:24 +0100 Subject: [PATCH] Implement a JSON tower layer This will help requesting JSON APIs --- Cargo.lock | 34 ++-- crates/cli/Cargo.toml | 1 + crates/cli/src/commands/debug.rs | 69 +++++++-- crates/http/Cargo.toml | 7 +- crates/http/src/ext.rs | 30 ++++ crates/http/src/layers/client.rs | 96 ++++++++++++ crates/http/src/layers/json.rs | 116 ++++++++++++++ crates/http/src/layers/mod.rs | 18 +++ crates/http/src/layers/server.rs | 56 +++++++ crates/http/src/layers/trace.rs | 170 ++++++++++++++++++++ crates/http/src/lib.rs | 256 ++----------------------------- 11 files changed, 579 insertions(+), 274 deletions(-) create mode 100644 crates/http/src/ext.rs create mode 100644 crates/http/src/layers/client.rs create mode 100644 crates/http/src/layers/json.rs create mode 100644 crates/http/src/layers/mod.rs create mode 100644 crates/http/src/layers/server.rs create mode 100644 crates/http/src/layers/trace.rs diff --git a/Cargo.lock b/Cargo.lock index d3414d60..de5f4b21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1228,9 +1228,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3dda0b6588335f360afc675d0564c17a77a2bda81ca178a4b6081bd86c7f0b" +checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" dependencies = [ "futures-core", "futures-sink", @@ -1238,9 +1238,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" +checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" [[package]] name = "futures-executor" @@ -1266,15 +1266,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f9d34af5a1aac6fb380f735fe510746c38067c5bf16c7fd250280503c971b2" +checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" [[package]] name = "futures-macro" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbd947adfffb0efc70599b3ddcf7b5597bb5fa9e245eb99f62b3a5f7bb8bd3c" +checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" dependencies = [ "proc-macro2", "quote", @@ -1283,21 +1283,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3055baccb68d74ff6480350f8d6eb8fcfa3aa11bdc1a1ae3afdd0514617d508" +checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" [[package]] name = "futures-task" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee7c6485c30167ce4dfb83ac568a849fe53274c831081476ee13e0dce1aad72" +checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" [[package]] name = "futures-util" -version = "0.3.19" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b5cf40b47a271f77a8b1bec03ca09044d99d2372c0de244e66430761127164" +checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ "futures 0.1.31", "futures-channel", @@ -1877,6 +1877,7 @@ dependencies = [ "opentelemetry-zipkin", "reqwest", "schemars", + "serde_json", "serde_yaml", "tokio", "tower", @@ -1994,13 +1995,18 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes 1.1.0", + "futures-util", "http", "http-body", "hyper", "hyper-rustls 0.23.0", "opentelemetry", "opentelemetry-http", + "pin-project-lite", "rustls 0.20.2", + "serde", + "serde_json", + "thiserror", "tokio", "tower", "tower-http", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 17549369..d947fa1e 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -15,6 +15,7 @@ schemars = { version = "0.8.8", features = ["url", "chrono"] } tower = { version = "0.4.11", features = ["full"] } hyper = { version = "0.14.16", features = ["full"] } serde_yaml = "0.8.23" +serde_json = "1.0.78" warp = "0.3.2" url = "2.2.2" argon2 = { version = "0.3.3", features = ["password-hash"] } diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 33aaf755..1f369949 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -13,7 +13,8 @@ // limitations under the License. use clap::Parser; -use hyper::Uri; +use hyper::{Response, Uri}; +use mas_http::HttpServiceExt; use tokio::io::AsyncWriteExt; use tower::{Service, ServiceExt}; @@ -31,43 +32,81 @@ enum Subcommand { #[clap(long, short = 'I')] show_headers: bool, + /// Parse the response as JSON + #[clap(long, short = 'j')] + json: bool, + /// URI where to perform a GET request url: Uri, }, } +fn print_headers(parts: &hyper::http::response::Parts) { + println!( + "{:?} {} {}", + parts.version, + parts.status.as_str(), + parts.status.canonical_reason().unwrap_or_default() + ); + + for (header, value) in &parts.headers { + println!("{}: {:?}", header, value); + } + println!(); +} + impl Options { #[tracing::instrument(skip_all)] pub async fn run(&self, _root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; match &self.subcommand { - SC::Http { show_headers, url } => { + SC::Http { + show_headers, + json: false, + url, + } => { let mut client = mas_http::client("cli-debug-http").await?; let request = hyper::Request::builder() .uri(url) .body(hyper::Body::empty())?; - let mut response = client.ready().await?.call(request).await?; + let response = client.ready().await?.call(request).await?; + let (parts, body) = response.into_parts(); if *show_headers { - let status = response.status(); - println!( - "{:?} {} {}", - response.version(), - status.as_str(), - status.canonical_reason().unwrap_or_default() - ); - for (header, value) in response.headers() { - println!("{}: {:?}", header, value); - } - println!(); + print_headers(&parts); } - let mut body = hyper::body::aggregate(response.body_mut()).await?; + + let mut body = hyper::body::aggregate(body).await?; let mut stdout = tokio::io::stdout(); stdout.write_all_buf(&mut body).await?; Ok(()) } + + SC::Http { + show_headers, + json: true, + url, + } => { + let mut client = mas_http::client("cli-debug-http").await?.json(); + let request = hyper::Request::builder() + .uri(url) + .body(hyper::Body::empty())?; + + let response: Response = + client.ready().await?.call(request).await?; + let (parts, body) = response.into_parts(); + + if *show_headers { + print_headers(&parts); + } + + let body = serde_json::to_string_pretty(&body)?; + println!("{}", body); + + Ok(()) + } } } } diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 0d87273f..87a3676c 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -8,14 +8,19 @@ license = "Apache-2.0" [dependencies] anyhow = "1.0.53" bytes = "1.1.0" +futures-util = "0.3.21" http = "0.2.6" http-body = "0.4.4" hyper = "0.14.16" hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] } opentelemetry = "0.17.0" opentelemetry-http = "0.6.0" +pin-project-lite = "0.2.8" rustls = "0.20.2" -tokio = { version = "1.16.1", features = ["sync"] } +serde = "1.0.136" +serde_json = "1.0.78" +thiserror = "1.0.30" +tokio = { version = "1.16.1", features = ["sync", "parking_lot"] } tower = { version = "0.4.11", features = ["timeout", "limit"] } tower-http = { version = "0.2.1", features = ["follow-redirect", "decompression-full", "set-header", "trace", "compression-full"] } tracing = "0.1.30" diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs new file mode 100644 index 00000000..bffb976f --- /dev/null +++ b/crates/http/src/ext.rs @@ -0,0 +1,30 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::layers::json::Json; + +pub trait ServiceExt { + fn json(self) -> Json + where + Self: Sized; +} + +impl ServiceExt for S { + fn json(self) -> Json + where + Self: Sized, + { + Json::new(self) + } +} diff --git a/crates/http/src/layers/client.rs b/crates/http/src/layers/client.rs new file mode 100644 index 00000000..90b5bf97 --- /dev/null +++ b/crates/http/src/layers/client.rs @@ -0,0 +1,96 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{marker::PhantomData, time::Duration}; + +use http::{header::USER_AGENT, HeaderValue, Request, Response}; +use http_body::combinators::BoxBody; +use tower::{ + limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service, + ServiceBuilder, ServiceExt, +}; +use tower_http::{ + decompression::{DecompressionBody, DecompressionLayer}, + follow_redirect::FollowRedirectLayer, + set_header::SetRequestHeaderLayer, +}; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +use super::trace::OtelTraceLayer; + +static MAS_USER_AGENT: HeaderValue = + HeaderValue::from_static("matrix-authentication-service/0.0.1"); + +type BoxError = Box; + +#[derive(Debug, Clone)] +pub struct ClientLayer { + operation: &'static str, + _t: PhantomData, +} + +impl ClientLayer { + pub fn new(operation: &'static str) -> Self { + Self { + operation, + _t: PhantomData, + } + } +} + +pub type ClientResponse = Response< + DecompressionBody::Data, ::Error>>, +>; + +impl Layer for ClientLayer +where + S: Service, Response = Response> + Clone + Send + 'static, + ReqBody: http_body::Body + Default + Send + 'static, + ResBody: http_body::Body + Sync + Send + 'static, + ResBody::Error: std::fmt::Display + 'static, + S::Future: Send + 'static, + S::Error: Into, +{ + type Service = BoxCloneService, ClientResponse, BoxError>; + + fn layer(&self, inner: S) -> Self::Service { + ServiceBuilder::new() + .layer(DecompressionLayer::new()) + .map_response(|r: Response<_>| r.map(BoxBody::new)) + .layer(SetRequestHeaderLayer::overriding( + USER_AGENT, + MAS_USER_AGENT.clone(), + )) + // A trace that has the whole operation, with all the redirects, retries, rate limits + .layer(OtelTraceLayer::outer_client(self.operation)) + .layer(ConcurrencyLimitLayer::new(10)) + .layer(FollowRedirectLayer::new()) + // A trace for each "real" http request + .layer(OtelTraceLayer::inner_client()) + .layer(TimeoutLayer::new(Duration::from_secs(10))) + // Propagate the span context + .map_request(|mut r: Request<_>| { + // TODO: this seems to be broken + let cx = tracing::Span::current().context(); + let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut()); + opentelemetry::global::get_text_map_propagator(|propagator| { + propagator.inject_context(&cx, &mut injector) + }); + + r + }) + .service(inner) + .boxed_clone() + } +} diff --git a/crates/http/src/layers/json.rs b/crates/http/src/layers/json.rs new file mode 100644 index 00000000..57bdb622 --- /dev/null +++ b/crates/http/src/layers/json.rs @@ -0,0 +1,116 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{marker::PhantomData, task::Poll}; + +use futures_util::future::BoxFuture; +use http::{header::ACCEPT, HeaderValue, Request, Response}; +use http_body::Body; +use serde::de::DeserializeOwned; +use thiserror::Error; +use tower::{Layer, Service}; + +#[derive(Debug, Error)] +pub enum Error { + #[error("service")] + Service { inner: Service }, + #[error("body")] + Body { inner: Body }, + #[error("json")] + Json { inner: serde_json::Error }, +} + +impl Error { + fn service(source: S) -> Self { + Self::Service { inner: source } + } + + fn body(source: B) -> Self { + Self::Body { inner: source } + } + + fn json(source: serde_json::Error) -> Self { + Self::Json { inner: source } + } +} + +pub struct Json { + inner: S, + _t: PhantomData, +} + +impl Json { + pub const fn new(inner: S) -> Self { + Self { + inner, + _t: PhantomData, + } + } +} + +impl Service> for Json +where + S: Service, Response = Response>, + S::Future: Send + 'static, + C: Body + Send + 'static, + C::Data: Send + 'static, + T: DeserializeOwned, +{ + type Error = Error; + type Response = Response; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Error::service) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + req.headers_mut() + .insert(ACCEPT, HeaderValue::from_static("application/json")); + + let fut = self.inner.call(req); + + let fut = async { + let res = fut.await.map_err(Error::service)?; + let (parts, body) = res.into_parts(); + + futures_util::pin_mut!(body); + let bytes = hyper::body::to_bytes(&mut body) + .await + .map_err(Error::body)?; + + let body = serde_json::from_slice(&bytes.to_vec()).map_err(Error::json)?; + + let res = Response::from_parts(parts, body); + Ok(res) + }; + + Box::pin(fut) + } +} + +#[derive(Default, Clone, Copy)] +pub struct JsonResponseLayer(PhantomData<(T, ReqBody)>); + +impl Layer for JsonResponseLayer +where + S: Service, Response = Response>, + T: serde::de::DeserializeOwned, +{ + type Service = Json; + + fn layer(&self, inner: S) -> Self::Service { + Json::new(inner) + } +} diff --git a/crates/http/src/layers/mod.rs b/crates/http/src/layers/mod.rs new file mode 100644 index 00000000..2fd44508 --- /dev/null +++ b/crates/http/src/layers/mod.rs @@ -0,0 +1,18 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub(crate) mod client; +pub(crate) mod json; +pub(crate) mod server; +pub(crate) mod trace; diff --git a/crates/http/src/layers/server.rs b/crates/http/src/layers/server.rs new file mode 100644 index 00000000..1528564b --- /dev/null +++ b/crates/http/src/layers/server.rs @@ -0,0 +1,56 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{marker::PhantomData, time::Duration}; + +use http::{Request, Response}; +use http_body::combinators::BoxBody; +use tower::{ + timeout::TimeoutLayer, util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt, +}; +use tower_http::compression::{CompressionBody, CompressionLayer}; + +use super::trace::OtelTraceLayer; +use crate::BoxError; + +#[derive(Debug, Default)] +pub struct ServerLayer { + _t: PhantomData, +} + +impl Layer for ServerLayer +where + S: Service, Response = Response> + Clone + Send + 'static, + ReqBody: http_body::Body + 'static, + ResBody: http_body::Body + Sync + Send + 'static, + ResBody::Error: std::fmt::Display + 'static, + S::Future: Send + 'static, + S::Error: Into, +{ + type Service = BoxCloneService< + Request, + Response>>, + BoxError, + >; + + fn layer(&self, inner: S) -> Self::Service { + ServiceBuilder::new() + .layer(CompressionLayer::new()) + .map_response(|r: Response<_>| r.map(BoxBody::new)) + .layer(OtelTraceLayer::server()) + .layer(TimeoutLayer::new(Duration::from_secs(10))) + .service(inner) + .boxed_clone() + } +} diff --git a/crates/http/src/layers/trace.rs b/crates/http/src/layers/trace.rs new file mode 100644 index 00000000..84aea725 --- /dev/null +++ b/crates/http/src/layers/trace.rs @@ -0,0 +1,170 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use http::{header::USER_AGENT, Request, Response, Version}; +use opentelemetry::trace::TraceContextExt; +use opentelemetry_http::HeaderExtractor; +use tower::Layer; +use tower_http::{ + classify::{ServerErrorsAsFailures, SharedClassifier}, + trace::{DefaultOnRequest, MakeSpan, OnResponse, Trace}, +}; +use tracing::{field, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +#[derive(Debug, Clone, Copy)] +pub enum OtelTraceLayer { + OuterClient(&'static str), + InnerClient, + Server, +} + +impl OtelTraceLayer { + pub const fn outer_client(operation: &'static str) -> Self { + Self::OuterClient(operation) + } + + pub const fn inner_client() -> Self { + Self::InnerClient + } + + pub const fn server() -> Self { + Self::Server + } +} + +impl Layer for OtelTraceLayer { + type Service = Trace< + S, + SharedClassifier, + MakeOtelSpan, + DefaultOnRequest, + OtelOnResponse, + >; + + fn layer(&self, inner: S) -> Self::Service { + let make_span = match self { + Self::OuterClient(o) => MakeOtelSpan::OuterClient(o), + Self::InnerClient => MakeOtelSpan::InnerClient, + Self::Server => MakeOtelSpan::Server, + }; + + Trace::new_for_http(inner) + .make_span_with(make_span) + .on_response(OtelOnResponse) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum MakeOtelSpan { + OuterClient(&'static str), + InnerClient, + Server, +} + +impl MakeSpan for MakeOtelSpan { + fn make_span(&mut self, request: &Request) -> Span { + // Extract the context from the headers + let headers = request.headers(); + + let version = match request.version() { + Version::HTTP_09 => "0.9", + Version::HTTP_10 => "1.0", + Version::HTTP_11 => "1.1", + Version::HTTP_2 => "2.0", + Version::HTTP_3 => "3.0", + _ => "", + }; + + let span = match self { + Self::OuterClient(operation) => { + tracing::info_span!( + "client_request", + otel.name = operation, + otel.kind = "internal", + otel.status_code = field::Empty, + http.method = %request.method(), + http.target = %request.uri(), + http.flavor = version, + http.status_code = field::Empty, + http.user_agent = field::Empty, + ) + } + Self::InnerClient => { + tracing::info_span!( + "outgoing_request", + otel.kind = "client", + otel.status_code = field::Empty, + http.method = %request.method(), + http.target = %request.uri(), + http.flavor = version, + http.status_code = field::Empty, + http.user_agent = field::Empty, + ) + } + Self::Server => { + let span = tracing::info_span!( + "incoming_request", + otel.kind = "server", + otel.status_code = field::Empty, + http.method = %request.method(), + http.target = %request.uri(), + http.flavor = version, + http.status_code = field::Empty, + http.user_agent = field::Empty, + ); + + // Extract the context from the headers for server spans + let headers = request.headers(); + let extractor = HeaderExtractor(headers); + + let cx = opentelemetry::global::get_text_map_propagator(|propagator| { + propagator.extract(&extractor) + }); + + if cx.span().span_context().is_remote() { + span.set_parent(cx); + } + + span + } + }; + + if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) { + span.record("http.user_agent", &user_agent); + } + + span + } +} + +#[derive(Debug, Clone, Default)] +pub struct OtelOnResponse; + +impl OnResponse for OtelOnResponse { + fn on_response(self, response: &Response, _latency: Duration, span: &Span) { + let s = response.status(); + let status = if s.is_success() { + "ok" + } else if s.is_client_error() || s.is_server_error() { + "error" + } else { + "unset" + }; + span.record("otel.status_code", &status); + span.record("http.status_code", &s.as_u16()); + } +} diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 5848b778..20fab110 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -12,95 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{marker::PhantomData, time::Duration}; - use bytes::Bytes; -use http::{header::USER_AGENT, HeaderValue, Request, Response, Version}; -use http_body::{combinators::BoxBody, Body}; +use http::{Request, Response}; +use http_body::Body; use hyper::{client::HttpConnector, Client}; use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder}; -use opentelemetry::trace::TraceContextExt; -use opentelemetry_http::HeaderExtractor; +use layers::client::ClientResponse; use tokio::sync::OnceCell; -use tower::{ - limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service, - ServiceBuilder, ServiceExt, +use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt}; + +mod ext; +mod layers; + +pub use self::{ + ext::ServiceExt as HttpServiceExt, + layers::{client::ClientLayer, json::JsonResponseLayer, server::ServerLayer}, }; -use tower_http::{ - compression::{CompressionBody, CompressionLayer}, - decompression::{DecompressionBody, DecompressionLayer}, - follow_redirect::FollowRedirectLayer, - set_header::SetRequestHeaderLayer, - trace::{MakeSpan, OnResponse, TraceLayer}, -}; -use tracing::field; -use tracing_opentelemetry::OpenTelemetrySpanExt; -static MAS_USER_AGENT: HeaderValue = - HeaderValue::from_static("matrix-authentication-service/0.0.1"); - -type BoxError = Box; - -#[derive(Debug, Clone)] -pub struct ClientLayer { - operation: &'static str, - _t: PhantomData, -} - -impl ClientLayer { - fn new(operation: &'static str) -> Self { - Self { - operation, - _t: PhantomData, - } - } -} - -type ClientResponse = Response< - DecompressionBody::Data, ::Error>>, ->; - -impl Layer for ClientLayer -where - S: Service, Response = Response> + Clone + Send + 'static, - ReqBody: http_body::Body + Default + Send + 'static, - ResBody: http_body::Body + Sync + Send + 'static, - ResBody::Error: std::fmt::Display + 'static, - S::Future: Send + 'static, - S::Error: Into, -{ - type Service = BoxCloneService, ClientResponse, BoxError>; - - fn layer(&self, inner: S) -> Self::Service { - ServiceBuilder::new() - .layer(DecompressionLayer::new()) - .map_response(|r: Response<_>| r.map(BoxBody::new)) - .layer(SetRequestHeaderLayer::overriding( - USER_AGENT, - MAS_USER_AGENT.clone(), - )) - // A trace that has the whole operation, with all the redirects, retries, rate limits - .layer(MakeOtelSpan::outer_client(self.operation).http_layer()) - .layer(ConcurrencyLimitLayer::new(10)) - .layer(FollowRedirectLayer::new()) - // A trace for each "real" http request - .layer(MakeOtelSpan::inner_client().http_layer()) - .layer(TimeoutLayer::new(Duration::from_secs(10))) - // Propagate the span context - .map_request(|mut r: Request<_>| { - // TODO: this seems to be broken - let cx = tracing::Span::current().context(); - let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut()); - opentelemetry::global::get_text_map_propagator(|propagator| { - propagator.inject_context(&cx, &mut injector) - }); - - r - }) - .service(inner) - .boxed_clone() - } -} +pub(crate) type BoxError = Box; static TLS_CONFIG: OnceCell = OnceCell::const_new(); @@ -159,164 +88,3 @@ where Ok(client) } - -#[derive(Debug, Default)] -pub struct ServerLayer(PhantomData); - -impl Layer for ServerLayer -where - S: Service, Response = Response> + Clone + Send + 'static, - ReqBody: http_body::Body + 'static, - ResBody: http_body::Body + Sync + Send + 'static, - ResBody::Error: std::fmt::Display + 'static, - S::Future: Send + 'static, - S::Error: Into, -{ - type Service = BoxCloneService< - Request, - Response>>, - BoxError, - >; - - fn layer(&self, inner: S) -> Self::Service { - ServiceBuilder::new() - .layer(CompressionLayer::new()) - .map_response(|r: Response<_>| r.map(BoxBody::new)) - .layer( - TraceLayer::new_for_http() - .make_span_with(MakeOtelSpan::server()) - .on_response(OtelOnResponse), - ) - .layer(TimeoutLayer::new(Duration::from_secs(10))) - .service(inner) - .boxed_clone() - } -} - -#[derive(Debug, Clone, Copy)] -pub enum MakeOtelSpan { - OuterClient(&'static str), - InnerClient, - Server, -} - -impl MakeOtelSpan { - const fn outer_client(operation: &'static str) -> Self { - Self::OuterClient(operation) - } - - const fn inner_client() -> Self { - Self::InnerClient - } - - const fn server() -> Self { - Self::Server - } - - fn http_layer( - self, - ) -> TraceLayer< - tower_http::classify::SharedClassifier, - Self, - tower_http::trace::DefaultOnRequest, - OtelOnResponse, - > { - TraceLayer::new_for_http() - .make_span_with(self) - .on_response(OtelOnResponse) - } -} - -impl MakeSpan for MakeOtelSpan { - fn make_span(&mut self, request: &Request) -> tracing::Span { - // Extract the context from the headers - let headers = request.headers(); - - let version = match request.version() { - Version::HTTP_09 => "0.9", - Version::HTTP_10 => "1.0", - Version::HTTP_11 => "1.1", - Version::HTTP_2 => "2.0", - Version::HTTP_3 => "3.0", - _ => "", - }; - - let span = match self { - Self::OuterClient(operation) => { - tracing::info_span!( - "client_request", - otel.name = operation, - otel.kind = "internal", - otel.status_code = field::Empty, - http.method = %request.method(), - http.target = %request.uri(), - http.flavor = version, - http.status_code = field::Empty, - http.user_agent = field::Empty, - ) - } - Self::InnerClient => { - tracing::info_span!( - "outgoing_request", - otel.kind = "client", - otel.status_code = field::Empty, - http.method = %request.method(), - http.target = %request.uri(), - http.flavor = version, - http.status_code = field::Empty, - http.user_agent = field::Empty, - ) - } - Self::Server => { - let span = tracing::info_span!( - "incoming_request", - otel.kind = "server", - otel.status_code = field::Empty, - http.method = %request.method(), - http.target = %request.uri(), - http.flavor = version, - http.status_code = field::Empty, - http.user_agent = field::Empty, - ); - - // Extract the context from the headers for server spans - let headers = request.headers(); - let extractor = HeaderExtractor(headers); - - let cx = opentelemetry::global::get_text_map_propagator(|propagator| { - propagator.extract(&extractor) - }); - - if cx.span().span_context().is_remote() { - span.set_parent(cx); - } - - span - } - }; - - if let Some(user_agent) = headers.get(USER_AGENT).and_then(|s| s.to_str().ok()) { - span.record("http.user_agent", &user_agent); - } - - span - } -} - -#[derive(Debug, Clone, Default)] -pub struct OtelOnResponse; - -impl OnResponse for OtelOnResponse { - fn on_response(self, response: &hyper::Response, _latency: Duration, span: &tracing::Span) { - let s = response.status(); - let status = if s.is_success() { - "ok" - } else if s.is_client_error() || s.is_server_error() { - "error" - } else { - "unset" - }; - span.record("otel.status_code", &status); - span.record("http.status_code", &s.as_u16()); - } -}