You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Use new tuple Layer impls instead of ServiceBuilder (#475)
Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
@ -26,7 +26,10 @@ use hyper::{
|
|||||||
};
|
};
|
||||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt};
|
use tower::{
|
||||||
|
util::{BoxCloneService, MapErrLayer, MapResponseLayer},
|
||||||
|
Layer, Service, ServiceExt,
|
||||||
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
layers::{
|
layers::{
|
||||||
@ -172,9 +175,7 @@ where
|
|||||||
E: Into<BoxError>,
|
E: Into<BoxError>,
|
||||||
{
|
{
|
||||||
// Trace DNS requests
|
// Trace DNS requests
|
||||||
let resolver = ServiceBuilder::new()
|
let resolver = TraceLayer::dns().layer(GaiResolver::new());
|
||||||
.layer(TraceLayer::dns())
|
|
||||||
.service(GaiResolver::new());
|
|
||||||
|
|
||||||
let roots = tls_roots().await?;
|
let roots = tls_roots().await?;
|
||||||
let tls_config = rustls::ClientConfig::builder()
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
@ -228,15 +229,15 @@ where
|
|||||||
{
|
{
|
||||||
let client = make_base_client().await?;
|
let client = make_base_client().await?;
|
||||||
|
|
||||||
let client = ServiceBuilder::new()
|
let layer = (
|
||||||
// Convert the errors to ClientError to help dealing with them
|
// Convert the errors to ClientError to help dealing with them
|
||||||
.map_err(ClientError::from)
|
MapErrLayer::new(ClientError::from),
|
||||||
.map_response(|r: ClientResponse<hyper::Body>| {
|
MapResponseLayer::new(|r: ClientResponse<hyper::Body>| {
|
||||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||||
})
|
}),
|
||||||
.layer(ClientLayer::new(operation))
|
ClientLayer::new(operation),
|
||||||
.service(client)
|
);
|
||||||
.boxed_clone();
|
let client = layer.layer(client).boxed_clone();
|
||||||
|
|
||||||
Ok(client)
|
Ok(client)
|
||||||
}
|
}
|
||||||
|
@ -16,16 +16,13 @@ use std::ops::RangeBounds;
|
|||||||
|
|
||||||
use http::{header::HeaderName, Request, StatusCode};
|
use http::{header::HeaderName, Request, StatusCode};
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use tower::{layer::util::Stack, Service, ServiceBuilder};
|
use tower::Service;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
|
|
||||||
use crate::layers::{
|
use crate::layers::{
|
||||||
body_to_bytes_response::{BodyToBytesResponse, BodyToBytesResponseLayer},
|
body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
|
||||||
bytes_to_body_request::{BytesToBodyRequest, BytesToBodyRequestLayer},
|
catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
|
||||||
catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer},
|
json_request::JsonRequest, json_response::JsonResponse,
|
||||||
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
|
|
||||||
json_request::{JsonRequest, JsonRequestLayer},
|
|
||||||
json_response::{JsonResponse, JsonResponseLayer},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
|
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
|
||||||
@ -107,65 +104,3 @@ pub trait ServiceExt<Body>: Sized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}
|
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}
|
||||||
|
|
||||||
pub trait ServiceBuilderExt<L>: Sized {
|
|
||||||
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>>;
|
|
||||||
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>>;
|
|
||||||
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>>;
|
|
||||||
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>>;
|
|
||||||
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>>;
|
|
||||||
|
|
||||||
fn catch_http_code<M>(
|
|
||||||
self,
|
|
||||||
status_code: StatusCode,
|
|
||||||
mapper: M,
|
|
||||||
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
|
||||||
where
|
|
||||||
M: Clone,
|
|
||||||
{
|
|
||||||
self.catch_http_codes(status_code..=status_code, mapper)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn catch_http_codes<B, M>(
|
|
||||||
self,
|
|
||||||
bounds: B,
|
|
||||||
mapper: M,
|
|
||||||
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
|
||||||
where
|
|
||||||
B: RangeBounds<StatusCode>,
|
|
||||||
M: Clone;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
|
|
||||||
fn request_bytes_to_body(self) -> ServiceBuilder<Stack<BytesToBodyRequestLayer, L>> {
|
|
||||||
self.layer(BytesToBodyRequestLayer::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesResponseLayer, L>> {
|
|
||||||
self.layer(BodyToBytesResponseLayer::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>> {
|
|
||||||
self.layer(JsonResponseLayer::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>> {
|
|
||||||
self.layer(JsonRequestLayer::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>> {
|
|
||||||
self.layer(FormUrlencodedRequestLayer::default())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn catch_http_codes<B, M>(
|
|
||||||
self,
|
|
||||||
bounds: B,
|
|
||||||
mapper: M,
|
|
||||||
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
|
||||||
where
|
|
||||||
B: RangeBounds<StatusCode>,
|
|
||||||
M: Clone,
|
|
||||||
{
|
|
||||||
self.layer(CatchHttpCodesLayer::new(bounds, mapper))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -116,15 +116,21 @@ pub struct CatchHttpCodesLayer<M> {
|
|||||||
mapper: M,
|
mapper: M,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M> CatchHttpCodesLayer<M> {
|
impl<M> CatchHttpCodesLayer<M>
|
||||||
|
where
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
pub fn new<B>(bounds: B, mapper: M) -> Self
|
pub fn new<B>(bounds: B, mapper: M) -> Self
|
||||||
where
|
where
|
||||||
B: RangeBounds<StatusCode>,
|
B: RangeBounds<StatusCode>,
|
||||||
M: Clone,
|
|
||||||
{
|
{
|
||||||
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
|
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
|
||||||
Self { bounds, mapper }
|
Self { bounds, mapper }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn exact(status_code: StatusCode, mapper: M) -> Self {
|
||||||
|
Self::new(status_code..=status_code, mapper)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, M> Layer<S> for CatchHttpCodesLayer<M>
|
impl<S, M> Layer<S> for CatchHttpCodesLayer<M>
|
||||||
|
@ -17,7 +17,7 @@ use std::{marker::PhantomData, time::Duration};
|
|||||||
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
||||||
use tower::{
|
use tower::{
|
||||||
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
|
limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::BoxCloneService, Layer, Service,
|
||||||
ServiceBuilder, ServiceExt,
|
ServiceExt,
|
||||||
};
|
};
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
decompression::{DecompressionBody, DecompressionLayer},
|
decompression::{DecompressionBody, DecompressionLayer},
|
||||||
@ -65,21 +65,19 @@ where
|
|||||||
// - the TimeoutLayer
|
// - the TimeoutLayer
|
||||||
// - the DecompressionLayer
|
// - the DecompressionLayer
|
||||||
// Those layers do type erasure of the error.
|
// Those layers do type erasure of the error.
|
||||||
ServiceBuilder::new()
|
(
|
||||||
.layer(DecompressionLayer::new())
|
DecompressionLayer::new(),
|
||||||
.layer(SetRequestHeaderLayer::overriding(
|
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
|
||||||
USER_AGENT,
|
|
||||||
MAS_USER_AGENT.clone(),
|
|
||||||
))
|
|
||||||
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
||||||
// limits in it
|
// limits in it
|
||||||
.layer(TraceLayer::http_client(self.operation))
|
TraceLayer::http_client(self.operation),
|
||||||
.layer(ConcurrencyLimitLayer::new(10))
|
ConcurrencyLimitLayer::new(10),
|
||||||
.layer(FollowRedirectLayer::new())
|
FollowRedirectLayer::new(),
|
||||||
// A trace for each "real" http request
|
// A trace for each "real" http request
|
||||||
.layer(TraceLayer::inner_http_client())
|
TraceLayer::inner_http_client(),
|
||||||
.layer(TimeoutLayer::new(Duration::from_secs(10)))
|
TimeoutLayer::new(Duration::from_secs(10)),
|
||||||
.service(inner)
|
)
|
||||||
|
.layer(inner)
|
||||||
.boxed_clone()
|
.boxed_clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,8 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
use opentelemetry::KeyValue;
|
use opentelemetry::KeyValue;
|
||||||
use tower::{util::BoxCloneService, Layer, Service, ServiceBuilder, ServiceExt};
|
use tower::{util::BoxCloneService, Layer, Service, ServiceExt};
|
||||||
use tower_http::{compression::CompressionBody, ServiceBuilderExt};
|
use tower_http::compression::{CompressionBody, CompressionLayer};
|
||||||
|
|
||||||
use super::otel::TraceLayer;
|
use super::otel::TraceLayer;
|
||||||
|
|
||||||
@ -49,21 +49,18 @@ where
|
|||||||
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
|
type Service = BoxCloneService<Request<ReqBody>, Response<CompressionBody<ResBody>>, S::Error>;
|
||||||
|
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
let builder = ServiceBuilder::new().compression();
|
let compression = CompressionLayer::new();
|
||||||
|
|
||||||
#[cfg(feature = "axum")]
|
#[cfg(feature = "axum")]
|
||||||
let mut trace_layer = TraceLayer::axum();
|
let mut trace = TraceLayer::axum();
|
||||||
|
|
||||||
#[cfg(not(feature = "axum"))]
|
#[cfg(not(feature = "axum"))]
|
||||||
let mut trace_layer = TraceLayer::http_server();
|
let mut trace = TraceLayer::http_server();
|
||||||
|
|
||||||
if let Some(name) = &self.listener_name {
|
if let Some(name) = &self.listener_name {
|
||||||
trace_layer =
|
trace = trace.with_static_attribute(KeyValue::new("listener", name.clone()));
|
||||||
trace_layer.with_static_attribute(KeyValue::new("listener", name.clone()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let builder = builder.layer(trace_layer);
|
(compression, trace).layer(inner).boxed_clone()
|
||||||
|
|
||||||
builder.service(inner).boxed_clone()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,10 +32,7 @@ mod layers;
|
|||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
pub use self::client::{client, make_untraced_client};
|
pub use self::client::{client, make_untraced_client};
|
||||||
pub use self::{
|
pub use self::{
|
||||||
ext::{
|
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
|
||||||
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
|
||||||
ServiceExt as HttpServiceExt,
|
|
||||||
},
|
|
||||||
layers::{
|
layers::{
|
||||||
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
|
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
|
||||||
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
|
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
|
||||||
|
@ -18,10 +18,13 @@ use anyhow::{bail, Context};
|
|||||||
use bytes::{Buf, Bytes};
|
use bytes::{Buf, Bytes};
|
||||||
use headers::{ContentType, HeaderMapExt};
|
use headers::{ContentType, HeaderMapExt};
|
||||||
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
|
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
|
||||||
use mas_http::HttpServiceBuilderExt;
|
use mas_http::{
|
||||||
|
BodyToBytesResponseLayer, BytesToBodyRequestLayer, CatchHttpCodesLayer,
|
||||||
|
FormUrlencodedRequestLayer, JsonRequestLayer, JsonResponseLayer,
|
||||||
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tower::{ServiceBuilder, ServiceExt};
|
use tower::{service_fn, Layer, ServiceExt};
|
||||||
|
|
||||||
#[derive(Debug, Error, Deserialize)]
|
#[derive(Debug, Error, Deserialize)]
|
||||||
#[error("Error code in response: {error}")]
|
#[error("Error code in response: {error}")]
|
||||||
@ -42,10 +45,11 @@ async fn test_http_errors() {
|
|||||||
serde_json::from_reader(response.into_body().reader()).unwrap()
|
serde_json::from_reader(response.into_body().reader()).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
let svc = ServiceBuilder::new()
|
let layer = (
|
||||||
.catch_http_code(StatusCode::BAD_REQUEST, mapper)
|
CatchHttpCodesLayer::exact(StatusCode::BAD_REQUEST, mapper),
|
||||||
.response_body_to_bytes()
|
BodyToBytesResponseLayer,
|
||||||
.service_fn(handle);
|
);
|
||||||
|
let svc = layer.layer(service_fn(handle));
|
||||||
|
|
||||||
let request = Request::new(hyper::Body::empty());
|
let request = Request::new(hyper::Body::empty());
|
||||||
|
|
||||||
@ -79,10 +83,8 @@ async fn test_json_request_body() {
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
let svc = ServiceBuilder::new()
|
let layer = (JsonRequestLayer::default(), BytesToBodyRequestLayer);
|
||||||
.json_request()
|
let svc = layer.layer(service_fn(handle));
|
||||||
.request_bytes_to_body()
|
|
||||||
.service_fn(handle);
|
|
||||||
|
|
||||||
let request = Request::new(serde_json::json!({"hello": "world"}));
|
let request = Request::new(serde_json::json!({"hello": "world"}));
|
||||||
|
|
||||||
@ -106,10 +108,8 @@ async fn test_json_response_body() {
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
let svc = ServiceBuilder::new()
|
let layer = (JsonResponseLayer::default(), BodyToBytesResponseLayer);
|
||||||
.json_response()
|
let svc = layer.layer(service_fn(handle));
|
||||||
.response_body_to_bytes()
|
|
||||||
.service_fn(handle);
|
|
||||||
|
|
||||||
let request = Request::new(hyper::Body::empty());
|
let request = Request::new(hyper::Body::empty());
|
||||||
|
|
||||||
@ -142,10 +142,11 @@ async fn test_urlencoded_request_body() {
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
let svc = ServiceBuilder::new()
|
let layer = (
|
||||||
.form_urlencoded_request()
|
FormUrlencodedRequestLayer::default(),
|
||||||
.request_bytes_to_body()
|
BytesToBodyRequestLayer,
|
||||||
.service_fn(handle);
|
);
|
||||||
|
let svc = layer.layer(service_fn(handle));
|
||||||
|
|
||||||
let request = Request::new(serde_json::json!({"hello": "world"}));
|
let request = Request::new(serde_json::json!({"hello": "world"}));
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user