You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Use new tuple Layer impls instead of ServiceBuilder (#475)
Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
@ -26,7 +26,10 @@ use hyper::{
|
||||
};
|
||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||
use thiserror::Error;
|
||||
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt};
|
||||
use tower::{
|
||||
util::{BoxCloneService, MapErrLayer, MapResponseLayer},
|
||||
Layer, Service, ServiceExt,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
layers::{
|
||||
@ -172,9 +175,7 @@ where
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
// Trace DNS requests
|
||||
let resolver = ServiceBuilder::new()
|
||||
.layer(TraceLayer::dns())
|
||||
.service(GaiResolver::new());
|
||||
let resolver = TraceLayer::dns().layer(GaiResolver::new());
|
||||
|
||||
let roots = tls_roots().await?;
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
@ -228,15 +229,15 @@ where
|
||||
{
|
||||
let client = make_base_client().await?;
|
||||
|
||||
let client = ServiceBuilder::new()
|
||||
let layer = (
|
||||
// Convert the errors to ClientError to help dealing with them
|
||||
.map_err(ClientError::from)
|
||||
.map_response(|r: ClientResponse<hyper::Body>| {
|
||||
MapErrLayer::new(ClientError::from),
|
||||
MapResponseLayer::new(|r: ClientResponse<hyper::Body>| {
|
||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||
})
|
||||
.layer(ClientLayer::new(operation))
|
||||
.service(client)
|
||||
.boxed_clone();
|
||||
}),
|
||||
ClientLayer::new(operation),
|
||||
);
|
||||
let client = layer.layer(client).boxed_clone();
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
@ -16,16 +16,13 @@ use std::ops::RangeBounds;
|
||||
|
||||
use http::{header::HeaderName, Request, StatusCode};
|
||||
use once_cell::sync::OnceCell;
|
||||
use tower::{layer::util::Stack, Service, ServiceBuilder};
|
||||
use tower::Service;
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
use crate::layers::{
|
||||
body_to_bytes_response::{BodyToBytesResponse, BodyToBytesResponseLayer},
|
||||
bytes_to_body_request::{BytesToBodyRequest, BytesToBodyRequestLayer},
|
||||
catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer},
|
||||
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
|
||||
json_request::{JsonRequest, JsonRequestLayer},
|
||||
json_response::{JsonResponse, JsonResponseLayer},
|
||||
body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
|
||||
catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
|
||||
json_request::JsonRequest, json_response::JsonResponse,
|
||||
};
|
||||
|
||||
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
|
||||
@ -107,65 +104,3 @@ pub trait ServiceExt<Body>: Sized {
|
||||
}
|
||||
|
||||
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}
|
||||
|
||||
pub trait ServiceBuilderExt<L>: Sized {
|
||||
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>>;
|
||||
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>>;
|
||||
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>>;
|
||||
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>>;
|
||||
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>>;
|
||||
|
||||
fn catch_http_code<M>(
|
||||
self,
|
||||
status_code: StatusCode,
|
||||
mapper: M,
|
||||
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
||||
where
|
||||
M: Clone,
|
||||
{
|
||||
self.catch_http_codes(status_code..=status_code, mapper)
|
||||
}
|
||||
|
||||
fn catch_http_codes<B, M>(
|
||||
self,
|
||||
bounds: B,
|
||||
mapper: M,
|
||||
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
||||
where
|
||||
B: RangeBounds<StatusCode>,
|
||||
M: Clone;
|
||||
}
|
||||
|
||||
impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
|
||||
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>> {
|
||||
self.layer(BytesToBodyRequestLayer::default())
|
||||
}
|
||||
|
||||
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>> {
|
||||
self.layer(BodyToBytesResponseLayer::default())
|
||||
}
|
||||
|
||||
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>> {
|
||||
self.layer(JsonResponseLayer::default())
|
||||
}
|
||||
|
||||
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>> {
|
||||
self.layer(JsonRequestLayer::default())
|
||||
}
|
||||
|
||||
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>> {
|
||||
self.layer(FormUrlencodedRequestLayer::default())
|
||||
}
|
||||
|
||||
fn catch_http_codes<B, M>(
|
||||
self,
|
||||
bounds: B,
|
||||
mapper: M,
|
||||
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
||||
where
|
||||
B: RangeBounds<StatusCode>,
|
||||
M: Clone,
|
||||
{
|
||||
self.layer(CatchHttpCodesLayer::new(bounds, mapper))
|
||||
}
|
||||
}
|
||||
|
@ -116,15 +116,21 @@ pub struct CatchHttpCodesLayer<M> {
|
||||
mapper: M,
|
||||
}
|
||||
|
||||
impl<M> CatchHttpCodesLayer<M> {
|
||||
impl<M> CatchHttpCodesLayer<M>
|
||||
where
|
||||
M: Clone,
|
||||
{
|
||||
pub fn new<B>(bounds: B, mapper: M) -> Self
|
||||
where
|
||||
B: RangeBounds<StatusCode>,
|
||||
M: Clone,
|
||||
{
|
||||
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
|
||||
Self { bounds, mapper }
|
||||
}
|
||||
|
||||
pub fn exact(status_code: StatusCode, mapper: M) -> Self {
|
||||
Self::new(status_code..=status_code, mapper)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, M> Layer<S> for CatchHttpCodesLayer<M>
|
||||
|
@ -17,7 +17,7 @@ use std::{marker::PhantomData, time::Duration};
|
||||
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
||||
use tower::{
|
||||
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
|
||||
ServiceBuilder, ServiceExt,
|
||||
ServiceExt,
|
||||
};
|
||||
use tower_http::{
|
||||
decompression::{DecompressionBody, DecompressionLayer},
|
||||
@ -65,21 +65,19 @@ where
|
||||
// - the TimeoutLayer
|
||||
// - the DecompressionLayer
|
||||
// Those layers do type erasure of the error.
|
||||
ServiceBuilder::new()
|
||||
.layer(DecompressionLayer::new())
|
||||
.layer(SetRequestHeaderLayer::overriding(
|
||||
USER_AGENT,
|
||||
MAS_USER_AGENT.clone(),
|
||||
))
|
||||
(
|
||||
DecompressionLayer::new(),
|
||||
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
|
||||
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
||||
// limits in it
|
||||
.layer(TraceLayer::http_client(self.operation))
|
||||
.layer(ConcurrencyLimitLayer::new(10))
|
||||
.layer(FollowRedirectLayer::new())
|
||||
TraceLayer::http_client(self.operation),
|
||||
ConcurrencyLimitLayer::new(10),
|
||||
FollowRedirectLayer::new(),
|
||||
// A trace for each "real" http request
|
||||
.layer(TraceLayer::inner_http_client())
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(10)))
|
||||
.service(inner)
|
||||
TraceLayer::inner_http_client(),
|
||||
TimeoutLayer::new(Duration::from_secs(10)),
|
||||
)
|
||||
.layer(inner)
|
||||
.boxed_clone()
|
||||
}
|
||||
}
|
||||
|
@ -16,8 +16,8 @@ use std::marker::PhantomData;
|
||||
|
||||
use http::{Request, Response};
|
||||
use opentelemetry::KeyValue;
|
||||
use tower::{util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt};
|
||||
use tower_http::{compression::CompressionBody, ServiceBuilderExt};
|
||||
use tower::{util::BoxCloneService, Layer, Service, ServiceExt};
|
||||
use tower_http::compression::{CompressionBody, CompressionLayer};
|
||||
|
||||
use super::otel::TraceLayer;
|
||||
|
||||
@ -49,21 +49,18 @@ where
|
||||
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
let builder = ServiceBuilder::new().compression();
|
||||
let compression = CompressionLayer::new();
|
||||
|
||||
#[cfg(feature = "axum")]
|
||||
let mut trace_layer = TraceLayer::axum();
|
||||
let mut trace = TraceLayer::axum();
|
||||
|
||||
#[cfg(not(feature = "axum"))]
|
||||
let mut trace_layer = TraceLayer::http_server();
|
||||
let mut trace = TraceLayer::http_server();
|
||||
|
||||
if let Some(name) = &self.listener_name {
|
||||
trace_layer =
|
||||
trace_layer.with_static_attribute(KeyValue::new("listener", name.clone()));
|
||||
trace = trace.with_static_attribute(KeyValue::new("listener", name.clone()));
|
||||
}
|
||||
|
||||
let builder = builder.layer(trace_layer);
|
||||
|
||||
builder.service(inner).boxed_clone()
|
||||
(compression, trace).layer(inner).boxed_clone()
|
||||
}
|
||||
}
|
||||
|
@ -32,10 +32,7 @@ mod layers;
|
||||
#[cfg(feature = "client")]
|
||||
pub use self::client::{client, make_untraced_client};
|
||||
pub use self::{
|
||||
ext::{
|
||||
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
||||
ServiceExt as HttpServiceExt,
|
||||
},
|
||||
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
|
||||
layers::{
|
||||
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
|
||||
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
|
||||
|
@ -18,10 +18,13 @@ use anyhow::{bail, Context};
|
||||
use bytes::{Buf, Bytes};
|
||||
use headers::{ContentType, HeaderMapExt};
|
||||
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
|
||||
use mas_http::HttpServiceBuilderExt;
|
||||
use mas_http::{
|
||||
BodyToBytesResponseLayer, BytesToBodyRequestLayer, CatchHttpCodesLayer,
|
||||
FormUrlencodedRequestLayer, JsonRequestLayer, JsonResponseLayer,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tower::{ServiceBuilder, ServiceExt};
|
||||
use tower::{service_fn, Layer, ServiceExt};
|
||||
|
||||
#[derive(Debug, Error, Deserialize)]
|
||||
#[error("Error code in response: {error}")]
|
||||
@ -42,10 +45,11 @@ async fn test_http_errors() {
|
||||
serde_json::from_reader(response.into_body().reader()).unwrap()
|
||||
}
|
||||
|
||||
let svc = ServiceBuilder::new()
|
||||
.catch_http_code(StatusCode::BAD_REQUEST, mapper)
|
||||
.response_body_to_bytes()
|
||||
.service_fn(handle);
|
||||
let layer = (
|
||||
CatchHttpCodesLayer::exact(StatusCode::BAD_REQUEST, mapper),
|
||||
BodyToBytesResponseLayer,
|
||||
);
|
||||
let svc = layer.layer(service_fn(handle));
|
||||
|
||||
let request = Request::new(hyper::Body::empty());
|
||||
|
||||
@ -79,10 +83,8 @@ async fn test_json_request_body() {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
let svc = ServiceBuilder::new()
|
||||
.json_request()
|
||||
.request_bytes_to_body()
|
||||
.service_fn(handle);
|
||||
let layer = (JsonRequestLayer::default(), BytesToBodyRequestLayer);
|
||||
let svc = layer.layer(service_fn(handle));
|
||||
|
||||
let request = Request::new(serde_json::json!({"hello": "world"}));
|
||||
|
||||
@ -106,10 +108,8 @@ async fn test_json_response_body() {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
let svc = ServiceBuilder::new()
|
||||
.json_response()
|
||||
.response_body_to_bytes()
|
||||
.service_fn(handle);
|
||||
let layer = (JsonResponseLayer::default(), BodyToBytesResponseLayer);
|
||||
let svc = layer.layer(service_fn(handle));
|
||||
|
||||
let request = Request::new(hyper::Body::empty());
|
||||
|
||||
@ -142,10 +142,11 @@ async fn test_urlencoded_request_body() {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
let svc = ServiceBuilder::new()
|
||||
.form_urlencoded_request()
|
||||
.request_bytes_to_body()
|
||||
.service_fn(handle);
|
||||
let layer = (
|
||||
FormUrlencodedRequestLayer::default(),
|
||||
BytesToBodyRequestLayer,
|
||||
);
|
||||
let svc = layer.layer(service_fn(handle));
|
||||
|
||||
let request = Request::new(serde_json::json!({"hello": "world"}));
|
||||
|
||||
|
Reference in New Issue
Block a user