diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs index c280566e..5190a28e 100644 --- a/crates/http/src/client.rs +++ b/crates/http/src/client.rs @@ -26,7 +26,10 @@ use hyper::{ }; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use thiserror::Error; -use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt}; +use tower::{ + util::{BoxCloneService, MapErrLayer, MapResponseLayer}, + Layer, Service, ServiceExt, +}; use crate::{ layers::{ @@ -172,9 +175,7 @@ where E: Into, { // Trace DNS requests - let resolver = ServiceBuilder::new() - .layer(TraceLayer::dns()) - .service(GaiResolver::new()); + let resolver = TraceLayer::dns().layer(GaiResolver::new()); let roots = tls_roots().await?; let tls_config = rustls::ClientConfig::builder() @@ -228,15 +229,15 @@ where { let client = make_base_client().await?; - let client = ServiceBuilder::new() + let layer = ( // Convert the errors to ClientError to help dealing with them - .map_err(ClientError::from) - .map_response(|r: ClientResponse| { + MapErrLayer::new(ClientError::from), + MapResponseLayer::new(|r: ClientResponse| { r.map(|body| body.map_err(ClientError::from).boxed()) - }) - .layer(ClientLayer::new(operation)) - .service(client) - .boxed_clone(); + }), + ClientLayer::new(operation), + ); + let client = layer.layer(client).boxed_clone(); Ok(client) } diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs index 034e991b..831515db 100644 --- a/crates/http/src/ext.rs +++ b/crates/http/src/ext.rs @@ -16,16 +16,13 @@ use std::ops::RangeBounds; use http::{header::HeaderName, Request, StatusCode}; use once_cell::sync::OnceCell; -use tower::{layer::util::Stack, Service, ServiceBuilder}; +use tower::Service; use tower_http::cors::CorsLayer; use crate::layers::{ - body_to_bytes_response::{BodyToBytesResponse, BodyToBytesResponseLayer}, - bytes_to_body_request::{BytesToBodyRequest, BytesToBodyRequestLayer}, - catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer}, - form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer}, - json_request::{JsonRequest, JsonRequestLayer}, - json_response::{JsonResponse, JsonResponseLayer}, + body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest, + catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest, + json_request::JsonRequest, json_response::JsonResponse, }; static PROPAGATOR_HEADERS: OnceCell> = OnceCell::new(); @@ -107,65 +104,3 @@ pub trait ServiceExt: Sized { } impl ServiceExt for S where S: Service> {} - -pub trait ServiceBuilderExt: Sized { - fn request_bytes_to_body(self) -> ServiceBuilder>; - fn response_body_to_bytes(self) -> ServiceBuilder>; - fn json_response(self) -> ServiceBuilder, L>>; - fn json_request(self) -> ServiceBuilder, L>>; - fn form_urlencoded_request(self) -> ServiceBuilder, L>>; - - fn catch_http_code( - self, - status_code: StatusCode, - mapper: M, - ) -> ServiceBuilder, L>> - where - M: Clone, - { - self.catch_http_codes(status_code..=status_code, mapper) - } - - fn catch_http_codes( - self, - bounds: B, - mapper: M, - ) -> ServiceBuilder, L>> - where - B: RangeBounds, - M: Clone; -} - -impl ServiceBuilderExt for ServiceBuilder { - fn request_bytes_to_body(self) -> ServiceBuilder> { - self.layer(BytesToBodyRequestLayer::default()) - } - - fn response_body_to_bytes(self) -> ServiceBuilder> { - self.layer(BodyToBytesResponseLayer::default()) - } - - fn json_response(self) -> ServiceBuilder, L>> { - self.layer(JsonResponseLayer::default()) - } - - fn json_request(self) -> ServiceBuilder, L>> { - self.layer(JsonRequestLayer::default()) - } - - fn form_urlencoded_request(self) -> ServiceBuilder, L>> { - self.layer(FormUrlencodedRequestLayer::default()) - } - - fn catch_http_codes( - self, - bounds: B, - mapper: M, - ) -> ServiceBuilder, L>> - where - B: RangeBounds, - M: Clone, - { - self.layer(CatchHttpCodesLayer::new(bounds, mapper)) - } -} diff --git a/crates/http/src/layers/catch_http_codes.rs b/crates/http/src/layers/catch_http_codes.rs index 6618be79..aaeb0a46 100644 --- a/crates/http/src/layers/catch_http_codes.rs +++ b/crates/http/src/layers/catch_http_codes.rs @@ -116,15 +116,21 @@ pub struct CatchHttpCodesLayer { mapper: M, } -impl CatchHttpCodesLayer { +impl CatchHttpCodesLayer +where + M: Clone, +{ pub fn new(bounds: B, mapper: M) -> Self where B: RangeBounds, - M: Clone, { let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned()); Self { bounds, mapper } } + + pub fn exact(status_code: StatusCode, mapper: M) -> Self { + Self::new(status_code..=status_code, mapper) + } } impl Layer for CatchHttpCodesLayer diff --git a/crates/http/src/layers/client.rs b/crates/http/src/layers/client.rs index 8d0f429d..c6082d9a 100644 --- a/crates/http/src/layers/client.rs +++ b/crates/http/src/layers/client.rs @@ -17,7 +17,7 @@ use std::{marker::PhantomData, time::Duration}; use http::{header::USER_AGENT, HeaderValue, Request, Response}; use tower::{ limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service, - ServiceBuilder, ServiceExt, + ServiceExt, }; use tower_http::{ decompression::{DecompressionBody, DecompressionLayer}, @@ -65,21 +65,19 @@ where // - the TimeoutLayer // - the DecompressionLayer // Those layers do type erasure of the error. - ServiceBuilder::new() - .layer(DecompressionLayer::new()) - .layer(SetRequestHeaderLayer::overriding( - USER_AGENT, - MAS_USER_AGENT.clone(), - )) + ( + DecompressionLayer::new(), + SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()), // A trace that has the whole operation, with all the redirects, timeouts and rate // limits in it - .layer(TraceLayer::http_client(self.operation)) - .layer(ConcurrencyLimitLayer::new(10)) - .layer(FollowRedirectLayer::new()) + TraceLayer::http_client(self.operation), + ConcurrencyLimitLayer::new(10), + FollowRedirectLayer::new(), // A trace for each "real" http request - .layer(TraceLayer::inner_http_client()) - .layer(TimeoutLayer::new(Duration::from_secs(10))) - .service(inner) + TraceLayer::inner_http_client(), + TimeoutLayer::new(Duration::from_secs(10)), + ) + .layer(inner) .boxed_clone() } } diff --git a/crates/http/src/layers/server.rs b/crates/http/src/layers/server.rs index c4d0e9cb..74617ea6 100644 --- a/crates/http/src/layers/server.rs +++ b/crates/http/src/layers/server.rs @@ -16,8 +16,8 @@ use std::marker::PhantomData; use http::{Request, Response}; use opentelemetry::KeyValue; -use tower::{util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt}; -use tower_http::{compression::CompressionBody, ServiceBuilderExt}; +use tower::{util::BoxCloneService, Layer, Service, ServiceExt}; +use tower_http::compression::{CompressionBody, CompressionLayer}; use super::otel::TraceLayer; @@ -49,21 +49,18 @@ where type Service = BoxCloneService, Response>, S::Error>; fn layer(&self, inner: S) -> Self::Service { - let builder = ServiceBuilder::new().compression(); + let compression = CompressionLayer::new(); #[cfg(feature = "axum")] - let mut trace_layer = TraceLayer::axum(); + let mut trace = TraceLayer::axum(); #[cfg(not(feature = "axum"))] - let mut trace_layer = TraceLayer::http_server(); + let mut trace = TraceLayer::http_server(); if let Some(name) = &self.listener_name { - trace_layer = - trace_layer.with_static_attribute(KeyValue::new("listener", name.clone())); + trace = trace.with_static_attribute(KeyValue::new("listener", name.clone())); } - let builder = builder.layer(trace_layer); - - builder.service(inner).boxed_clone() + (compression, trace).layer(inner).boxed_clone() } } diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index dd0b71fb..7f05f83f 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -32,10 +32,7 @@ mod layers; #[cfg(feature = "client")] pub use self::client::{client, make_untraced_client}; pub use self::{ - ext::{ - set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt, - ServiceExt as HttpServiceExt, - }, + ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, layers::{ body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer}, diff --git a/crates/http/tests/client_layers.rs b/crates/http/tests/client_layers.rs index e0ecde3d..071f5da1 100644 --- a/crates/http/tests/client_layers.rs +++ b/crates/http/tests/client_layers.rs @@ -18,10 +18,13 @@ use anyhow::{bail, Context}; use bytes::{Buf, Bytes}; use headers::{ContentType, HeaderMapExt}; use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode}; -use mas_http::HttpServiceBuilderExt; +use mas_http::{ + BodyToBytesResponseLayer, BytesToBodyRequestLayer, CatchHttpCodesLayer, + FormUrlencodedRequestLayer, JsonRequestLayer, JsonResponseLayer, +}; use serde::Deserialize; use thiserror::Error; -use tower::{ServiceBuilder, ServiceExt}; +use tower::{service_fn, Layer, ServiceExt}; #[derive(Debug, Error, Deserialize)] #[error("Error code in response: {error}")] @@ -42,10 +45,11 @@ async fn test_http_errors() { serde_json::from_reader(response.into_body().reader()).unwrap() } - let svc = ServiceBuilder::new() - .catch_http_code(StatusCode::BAD_REQUEST, mapper) - .response_body_to_bytes() - .service_fn(handle); + let layer = ( + CatchHttpCodesLayer::exact(StatusCode::BAD_REQUEST, mapper), + BodyToBytesResponseLayer, + ); + let svc = layer.layer(service_fn(handle)); let request = Request::new(hyper::Body::empty()); @@ -79,10 +83,8 @@ async fn test_json_request_body() { Ok(res) } - let svc = ServiceBuilder::new() - .json_request() - .request_bytes_to_body() - .service_fn(handle); + let layer = (JsonRequestLayer::default(), BytesToBodyRequestLayer); + let svc = layer.layer(service_fn(handle)); let request = Request::new(serde_json::json!({"hello": "world"})); @@ -106,10 +108,8 @@ async fn test_json_response_body() { Ok(res) } - let svc = ServiceBuilder::new() - .json_response() - .response_body_to_bytes() - .service_fn(handle); + let layer = (JsonResponseLayer::default(), BodyToBytesResponseLayer); + let svc = layer.layer(service_fn(handle)); let request = Request::new(hyper::Body::empty()); @@ -142,10 +142,11 @@ async fn test_urlencoded_request_body() { Ok(res) } - let svc = ServiceBuilder::new() - .form_urlencoded_request() - .request_bytes_to_body() - .service_fn(handle); + let layer = ( + FormUrlencodedRequestLayer::default(), + BytesToBodyRequestLayer, + ); + let svc = layer.layer(service_fn(handle)); let request = Request::new(serde_json::json!({"hello": "world"}));