1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Use new tuple Layer impls instead of ServiceBuilder (#475)

Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
Jonas Platte
2022-10-17 16:48:12 +02:00
committed by GitHub
parent 51515358f7
commit cf6d5a076a
7 changed files with 62 additions and 127 deletions

View File

@ -26,7 +26,10 @@ use hyper::{
}; };
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error; use thiserror::Error;
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt}; use tower::{
util::{BoxCloneService, MapErrLayer, MapResponseLayer},
Layer, Service, ServiceExt,
};
use crate::{ use crate::{
layers::{ layers::{
@ -172,9 +175,7 @@ where
E: Into<BoxError>, E: Into<BoxError>,
{ {
// Trace DNS requests // Trace DNS requests
let resolver = ServiceBuilder::new() let resolver = TraceLayer::dns().layer(GaiResolver::new());
.layer(TraceLayer::dns())
.service(GaiResolver::new());
let roots = tls_roots().await?; let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder() let tls_config = rustls::ClientConfig::builder()
@ -228,15 +229,15 @@ where
{ {
let client = make_base_client().await?; let client = make_base_client().await?;
let client = ServiceBuilder::new() let layer = (
// Convert the errors to ClientError to help dealing with them // Convert the errors to ClientError to help dealing with them
.map_err(ClientError::from) MapErrLayer::new(ClientError::from),
.map_response(|r: ClientResponse<hyper::Body>| { MapResponseLayer::new(|r: ClientResponse<hyper::Body>| {
r.map(|body| body.map_err(ClientError::from).boxed()) r.map(|body| body.map_err(ClientError::from).boxed())
}) }),
.layer(ClientLayer::new(operation)) ClientLayer::new(operation),
.service(client) );
.boxed_clone(); let client = layer.layer(client).boxed_clone();
Ok(client) Ok(client)
} }

View File

@ -16,16 +16,13 @@ use std::ops::RangeBounds;
use http::{header::HeaderName, Request, StatusCode}; use http::{header::HeaderName, Request, StatusCode};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use tower::{layer::util::Stack, Service, ServiceBuilder}; use tower::Service;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use crate::layers::{ use crate::layers::{
body_to_bytes_response::{BodyToBytesResponse, BodyToBytesResponseLayer}, body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
bytes_to_body_request::{BytesToBodyRequest, BytesToBodyRequestLayer}, catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer}, json_request::JsonRequest, json_response::JsonResponse,
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
json_request::{JsonRequest, JsonRequestLayer},
json_response::{JsonResponse, JsonResponseLayer},
}; };
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new(); static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
@ -107,65 +104,3 @@ pub trait ServiceExt<Body>: Sized {
} }
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {} impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}
pub trait ServiceBuilderExt<L>: Sized {
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>>;
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>>;
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>>;
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>>;
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>>;
fn catch_http_code<M>(
self,
status_code: StatusCode,
mapper: M,
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
where
M: Clone,
{
self.catch_http_codes(status_code..=status_code, mapper)
}
fn catch_http_codes<B, M>(
self,
bounds: B,
mapper: M,
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
where
B: RangeBounds<StatusCode>,
M: Clone;
}
impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>> {
self.layer(BytesToBodyRequestLayer::default())
}
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>> {
self.layer(BodyToBytesResponseLayer::default())
}
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>> {
self.layer(JsonResponseLayer::default())
}
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>> {
self.layer(JsonRequestLayer::default())
}
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>> {
self.layer(FormUrlencodedRequestLayer::default())
}
fn catch_http_codes<B, M>(
self,
bounds: B,
mapper: M,
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
where
B: RangeBounds<StatusCode>,
M: Clone,
{
self.layer(CatchHttpCodesLayer::new(bounds, mapper))
}
}

View File

@ -116,15 +116,21 @@ pub struct CatchHttpCodesLayer<M> {
mapper: M, mapper: M,
} }
impl<M> CatchHttpCodesLayer<M> { impl<M> CatchHttpCodesLayer<M>
where
M: Clone,
{
pub fn new<B>(bounds: B, mapper: M) -> Self pub fn new<B>(bounds: B, mapper: M) -> Self
where where
B: RangeBounds<StatusCode>, B: RangeBounds<StatusCode>,
M: Clone,
{ {
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned()); let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
Self { bounds, mapper } Self { bounds, mapper }
} }
pub fn exact(status_code: StatusCode, mapper: M) -> Self {
Self::new(status_code..=status_code, mapper)
}
} }
impl<S, M> Layer<S> for CatchHttpCodesLayer<M> impl<S, M> Layer<S> for CatchHttpCodesLayer<M>

View File

@ -17,7 +17,7 @@ use std::{marker::PhantomData, time::Duration};
use http::{header::USER_AGENT, HeaderValue, Request, Response}; use http::{header::USER_AGENT, HeaderValue, Request, Response};
use tower::{ use tower::{
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service, limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
ServiceBuilder, ServiceExt, ServiceExt,
}; };
use tower_http::{ use tower_http::{
decompression::{DecompressionBody, DecompressionLayer}, decompression::{DecompressionBody, DecompressionLayer},
@ -65,21 +65,19 @@ where
// - the TimeoutLayer // - the TimeoutLayer
// - the DecompressionLayer // - the DecompressionLayer
// Those layers do type erasure of the error. // Those layers do type erasure of the error.
ServiceBuilder::new() (
.layer(DecompressionLayer::new()) DecompressionLayer::new(),
.layer(SetRequestHeaderLayer::overriding( SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
USER_AGENT,
MAS_USER_AGENT.clone(),
))
// A trace that has the whole operation, with all the redirects, timeouts and rate // A trace that has the whole operation, with all the redirects, timeouts and rate
// limits in it // limits in it
.layer(TraceLayer::http_client(self.operation)) TraceLayer::http_client(self.operation),
.layer(ConcurrencyLimitLayer::new(10)) ConcurrencyLimitLayer::new(10),
.layer(FollowRedirectLayer::new()) FollowRedirectLayer::new(),
// A trace for each "real" http request // A trace for each "real" http request
.layer(TraceLayer::inner_http_client()) TraceLayer::inner_http_client(),
.layer(TimeoutLayer::new(Duration::from_secs(10))) TimeoutLayer::new(Duration::from_secs(10)),
.service(inner) )
.layer(inner)
.boxed_clone() .boxed_clone()
} }
} }

View File

@ -16,8 +16,8 @@ use std::marker::PhantomData;
use http::{Request, Response}; use http::{Request, Response};
use opentelemetry::KeyValue; use opentelemetry::KeyValue;
use tower::{util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt}; use tower::{util::BoxCloneService, Layer, Service, ServiceExt};
use tower_http::{compression::CompressionBody, ServiceBuilderExt}; use tower_http::compression::{CompressionBody, CompressionLayer};
use super::otel::TraceLayer; use super::otel::TraceLayer;
@ -49,21 +49,18 @@ where
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>; type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
let builder = ServiceBuilder::new().compression(); let compression = CompressionLayer::new();
#[cfg(feature = "axum")] #[cfg(feature = "axum")]
let mut trace_layer = TraceLayer::axum(); let mut trace = TraceLayer::axum();
#[cfg(not(feature = "axum"))] #[cfg(not(feature = "axum"))]
let mut trace_layer = TraceLayer::http_server(); let mut trace = TraceLayer::http_server();
if let Some(name) = &self.listener_name { if let Some(name) = &self.listener_name {
trace_layer = trace = trace.with_static_attribute(KeyValue::new("listener", name.clone()));
trace_layer.with_static_attribute(KeyValue::new("listener", name.clone()));
} }
let builder = builder.layer(trace_layer); (compression, trace).layer(inner).boxed_clone()
builder.service(inner).boxed_clone()
} }
} }

View File

@ -32,10 +32,7 @@ mod layers;
#[cfg(feature = "client")] #[cfg(feature = "client")]
pub use self::client::{client, make_untraced_client}; pub use self::client::{client, make_untraced_client};
pub use self::{ pub use self::{
ext::{ ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
ServiceExt as HttpServiceExt,
},
layers::{ layers::{
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer}, bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},

View File

@ -18,10 +18,13 @@ use anyhow::{bail, Context};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use headers::{ContentType, HeaderMapExt}; use headers::{ContentType, HeaderMapExt};
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode}; use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
use mas_http::HttpServiceBuilderExt; use mas_http::{
BodyToBytesResponseLayer, BytesToBodyRequestLayer, CatchHttpCodesLayer,
FormUrlencodedRequestLayer, JsonRequestLayer, JsonResponseLayer,
};
use serde::Deserialize; use serde::Deserialize;
use thiserror::Error; use thiserror::Error;
use tower::{ServiceBuilder, ServiceExt}; use tower::{service_fn, Layer, ServiceExt};
#[derive(Debug, Error, Deserialize)] #[derive(Debug, Error, Deserialize)]
#[error("Error code in response: {error}")] #[error("Error code in response: {error}")]
@ -42,10 +45,11 @@ async fn test_http_errors() {
serde_json::from_reader(response.into_body().reader()).unwrap() serde_json::from_reader(response.into_body().reader()).unwrap()
} }
let svc = ServiceBuilder::new() let layer = (
.catch_http_code(StatusCode::BAD_REQUEST, mapper) CatchHttpCodesLayer::exact(StatusCode::BAD_REQUEST, mapper),
.response_body_to_bytes() BodyToBytesResponseLayer,
.service_fn(handle); );
let svc = layer.layer(service_fn(handle));
let request = Request::new(hyper::Body::empty()); let request = Request::new(hyper::Body::empty());
@ -79,10 +83,8 @@ async fn test_json_request_body() {
Ok(res) Ok(res)
} }
let svc = ServiceBuilder::new() let layer = (JsonRequestLayer::default(), BytesToBodyRequestLayer);
.json_request() let svc = layer.layer(service_fn(handle));
.request_bytes_to_body()
.service_fn(handle);
let request = Request::new(serde_json::json!({"hello": "world"})); let request = Request::new(serde_json::json!({"hello": "world"}));
@ -106,10 +108,8 @@ async fn test_json_response_body() {
Ok(res) Ok(res)
} }
let svc = ServiceBuilder::new() let layer = (JsonResponseLayer::default(), BodyToBytesResponseLayer);
.json_response() let svc = layer.layer(service_fn(handle));
.response_body_to_bytes()
.service_fn(handle);
let request = Request::new(hyper::Body::empty()); let request = Request::new(hyper::Body::empty());
@ -142,10 +142,11 @@ async fn test_urlencoded_request_body() {
Ok(res) Ok(res)
} }
let svc = ServiceBuilder::new() let layer = (
.form_urlencoded_request() FormUrlencodedRequestLayer::default(),
.request_bytes_to_body() BytesToBodyRequestLayer,
.service_fn(handle); );
let svc = layer.layer(service_fn(handle));
let request = Request::new(serde_json::json!({"hello": "world"})); let request = Request::new(serde_json::json!({"hello": "world"}));