1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Use new tuple Layer impls instead of ServiceBuilder (#475)

Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
Jonas Platte
2022-10-17 16:48:12 +02:00
committed by GitHub
parent 51515358f7
commit cf6d5a076a
7 changed files with 62 additions and 127 deletions

View File

@ -26,7 +26,10 @@ use hyper::{
};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error;
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt};
use tower::{
util::{BoxCloneService, MapErrLayer, MapResponseLayer},
Layer, Service, ServiceExt,
};
use crate::{
layers::{
@ -172,9 +175,7 @@ where
E: Into<BoxError>,
{
// Trace DNS requests
let resolver = ServiceBuilder::new()
.layer(TraceLayer::dns())
.service(GaiResolver::new());
let resolver = TraceLayer::dns().layer(GaiResolver::new());
let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder()
@ -228,15 +229,15 @@ where
{
let client = make_base_client().await?;
let client = ServiceBuilder::new()
let layer = (
// Convert the errors to ClientError to help dealing with them
.map_err(ClientError::from)
.map_response(|r: ClientResponse<hyper::Body>| {
MapErrLayer::new(ClientError::from),
MapResponseLayer::new(|r: ClientResponse<hyper::Body>| {
r.map(|body| body.map_err(ClientError::from).boxed())
})
.layer(ClientLayer::new(operation))
.service(client)
.boxed_clone();
}),
ClientLayer::new(operation),
);
let client = layer.layer(client).boxed_clone();
Ok(client)
}

View File

@ -16,16 +16,13 @@ use std::ops::RangeBounds;
use http::{header::HeaderName, Request, StatusCode};
use once_cell::sync::OnceCell;
use tower::{layer::util::Stack, Service, ServiceBuilder};
use tower::Service;
use tower_http::cors::CorsLayer;
use crate::layers::{
body_to_bytes_response::{BodyToBytesResponse, BodyToBytesResponseLayer},
bytes_to_body_request::{BytesToBodyRequest, BytesToBodyRequestLayer},
catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer},
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
json_request::{JsonRequest, JsonRequestLayer},
json_response::{JsonResponse, JsonResponseLayer},
body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
json_request::JsonRequest, json_response::JsonResponse,
};
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
@ -107,65 +104,3 @@ pub trait ServiceExt<Body>: Sized {
}
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}
pub trait ServiceBuilderExt<L>: Sized {
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>>;
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>>;
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>>;
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>>;
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>>;
fn catch_http_code<M>(
self,
status_code: StatusCode,
mapper: M,
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
where
M: Clone,
{
self.catch_http_codes(status_code..=status_code, mapper)
}
fn catch_http_codes<B, M>(
self,
bounds: B,
mapper: M,
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
where
B: RangeBounds<StatusCode>,
M: Clone;
}
impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>> {
self.layer(BytesToBodyRequestLayer::default())
}
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>> {
self.layer(BodyToBytesResponseLayer::default())
}
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>> {
self.layer(JsonResponseLayer::default())
}
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>> {
self.layer(JsonRequestLayer::default())
}
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>> {
self.layer(FormUrlencodedRequestLayer::default())
}
fn catch_http_codes<B, M>(
self,
bounds: B,
mapper: M,
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
where
B: RangeBounds<StatusCode>,
M: Clone,
{
self.layer(CatchHttpCodesLayer::new(bounds, mapper))
}
}

View File

@ -116,15 +116,21 @@ pub struct CatchHttpCodesLayer<M> {
mapper: M,
}
impl<M> CatchHttpCodesLayer<M> {
impl<M> CatchHttpCodesLayer<M>
where
M: Clone,
{
pub fn new<B>(bounds: B, mapper: M) -> Self
where
B: RangeBounds<StatusCode>,
M: Clone,
{
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
Self { bounds, mapper }
}
pub fn exact(status_code: StatusCode, mapper: M) -> Self {
Self::new(status_code..=status_code, mapper)
}
}
impl<S, M> Layer<S> for CatchHttpCodesLayer<M>

View File

@ -17,7 +17,7 @@ use std::{marker::PhantomData, time::Duration};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use tower::{
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
ServiceBuilder, ServiceExt,
ServiceExt,
};
use tower_http::{
decompression::{DecompressionBody, DecompressionLayer},
@ -65,21 +65,19 @@ where
// - the TimeoutLayer
// - the DecompressionLayer
// Those layers do type erasure of the error.
ServiceBuilder::new()
.layer(DecompressionLayer::new())
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
(
DecompressionLayer::new(),
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
// A trace that has the whole operation, with all the redirects, timeouts and rate
// limits in it
.layer(TraceLayer::http_client(self.operation))
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
TraceLayer::http_client(self.operation),
ConcurrencyLimitLayer::new(10),
FollowRedirectLayer::new(),
// A trace for each "real" http request
.layer(TraceLayer::inner_http_client())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(inner)
TraceLayer::inner_http_client(),
TimeoutLayer::new(Duration::from_secs(10)),
)
.layer(inner)
.boxed_clone()
}
}

View File

@ -16,8 +16,8 @@ use std::marker::PhantomData;
use http::{Request, Response};
use opentelemetry::KeyValue;
use tower::{util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt};
use tower_http::{compression::CompressionBody, ServiceBuilderExt};
use tower::{util::BoxCloneService, Layer, Service, ServiceExt};
use tower_http::compression::{CompressionBody, CompressionLayer};
use super::otel::TraceLayer;
@ -49,21 +49,18 @@ where
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
fn layer(&self, inner: S) -> Self::Service {
let builder = ServiceBuilder::new().compression();
let compression = CompressionLayer::new();
#[cfg(feature = "axum")]
let mut trace_layer = TraceLayer::axum();
let mut trace = TraceLayer::axum();
#[cfg(not(feature = "axum"))]
let mut trace_layer = TraceLayer::http_server();
let mut trace = TraceLayer::http_server();
if let Some(name) = &self.listener_name {
trace_layer =
trace_layer.with_static_attribute(KeyValue::new("listener", name.clone()));
trace = trace.with_static_attribute(KeyValue::new("listener", name.clone()));
}
let builder = builder.layer(trace_layer);
builder.service(inner).boxed_clone()
(compression, trace).layer(inner).boxed_clone()
}
}

View File

@ -32,10 +32,7 @@ mod layers;
#[cfg(feature = "client")]
pub use self::client::{client, make_untraced_client};
pub use self::{
ext::{
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
ServiceExt as HttpServiceExt,
},
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
layers::{
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},

View File

@ -18,10 +18,13 @@ use anyhow::{bail, Context};
use bytes::{Buf, Bytes};
use headers::{ContentType, HeaderMapExt};
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
use mas_http::HttpServiceBuilderExt;
use mas_http::{
BodyToBytesResponseLayer, BytesToBodyRequestLayer, CatchHttpCodesLayer,
FormUrlencodedRequestLayer, JsonRequestLayer, JsonResponseLayer,
};
use serde::Deserialize;
use thiserror::Error;
use tower::{ServiceBuilder, ServiceExt};
use tower::{service_fn, Layer, ServiceExt};
#[derive(Debug, Error, Deserialize)]
#[error("Error code in response: {error}")]
@ -42,10 +45,11 @@ async fn test_http_errors() {
serde_json::from_reader(response.into_body().reader()).unwrap()
}
let svc = ServiceBuilder::new()
.catch_http_code(StatusCode::BAD_REQUEST, mapper)
.response_body_to_bytes()
.service_fn(handle);
let layer = (
CatchHttpCodesLayer::exact(StatusCode::BAD_REQUEST, mapper),
BodyToBytesResponseLayer,
);
let svc = layer.layer(service_fn(handle));
let request = Request::new(hyper::Body::empty());
@ -79,10 +83,8 @@ async fn test_json_request_body() {
Ok(res)
}
let svc = ServiceBuilder::new()
.json_request()
.request_bytes_to_body()
.service_fn(handle);
let layer = (JsonRequestLayer::default(), BytesToBodyRequestLayer);
let svc = layer.layer(service_fn(handle));
let request = Request::new(serde_json::json!({"hello": "world"}));
@ -106,10 +108,8 @@ async fn test_json_response_body() {
Ok(res)
}
let svc = ServiceBuilder::new()
.json_response()
.response_body_to_bytes()
.service_fn(handle);
let layer = (JsonResponseLayer::default(), BodyToBytesResponseLayer);
let svc = layer.layer(service_fn(handle));
let request = Request::new(hyper::Body::empty());
@ -142,10 +142,11 @@ async fn test_urlencoded_request_body() {
Ok(res)
}
let svc = ServiceBuilder::new()
.form_urlencoded_request()
.request_bytes_to_body()
.service_fn(handle);
let layer = (
FormUrlencodedRequestLayer::default(),
BytesToBodyRequestLayer,
);
let svc = layer.layer(service_fn(handle));
let request = Request::new(serde_json::json!({"hello": "world"}));