diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs index cab40b40..034e991b 100644 --- a/crates/http/src/ext.rs +++ b/crates/http/src/ext.rs @@ -20,7 +20,8 @@ use tower::{layer::util::Stack, Service, ServiceBuilder}; use tower_http::cors::CorsLayer; use crate::layers::{ - body_to_bytes::{BodyToBytes, BodyToBytesLayer}, + body_to_bytes_response::{BodyToBytesResponse, BodyToBytesResponseLayer}, + bytes_to_body_request::{BytesToBodyRequest, BytesToBodyRequestLayer}, catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer}, form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer}, json_request::{JsonRequest, JsonRequestLayer}, @@ -69,8 +70,12 @@ impl CorsLayerExt for CorsLayer { } pub trait ServiceExt: Sized { - fn response_body_to_bytes(self) -> BodyToBytes { - BodyToBytes::new(self) + fn request_bytes_to_body(self) -> BytesToBodyRequest { + BytesToBodyRequest::new(self) + } + + fn response_body_to_bytes(self) -> BodyToBytesResponse { + BodyToBytesResponse::new(self) } fn json_response(self) -> JsonResponse { @@ -104,7 +109,8 @@ pub trait ServiceExt: Sized { impl ServiceExt for S where S: Service> {} pub trait ServiceBuilderExt: Sized { - fn response_body_to_bytes(self) -> ServiceBuilder>; + fn request_bytes_to_body(self) -> ServiceBuilder>; + fn response_body_to_bytes(self) -> ServiceBuilder>; fn json_response(self) -> ServiceBuilder, L>>; fn json_request(self) -> ServiceBuilder, L>>; fn form_urlencoded_request(self) -> ServiceBuilder, L>>; @@ -131,8 +137,12 @@ pub trait ServiceBuilderExt: Sized { } impl ServiceBuilderExt for ServiceBuilder { - fn response_body_to_bytes(self) -> ServiceBuilder> { - self.layer(BodyToBytesLayer::default()) + fn request_bytes_to_body(self) -> ServiceBuilder> { + self.layer(BytesToBodyRequestLayer::default()) + } + + fn response_body_to_bytes(self) -> ServiceBuilder> { + self.layer(BodyToBytesResponseLayer::default()) } fn json_response(self) -> ServiceBuilder, L>> { diff --git a/crates/http/src/layers/body_to_bytes.rs b/crates/http/src/layers/body_to_bytes_response.rs similarity index 90% rename from crates/http/src/layers/body_to_bytes.rs rename to crates/http/src/layers/body_to_bytes_response.rs index b2f833d8..648a71ed 100644 --- a/crates/http/src/layers/body_to_bytes.rs +++ b/crates/http/src/layers/body_to_bytes_response.rs @@ -39,17 +39,17 @@ impl Error { } #[derive(Clone)] -pub struct BodyToBytes { +pub struct BodyToBytesResponse { inner: S, } -impl BodyToBytes { +impl BodyToBytesResponse { pub const fn new(inner: S) -> Self { Self { inner } } } -impl Service> for BodyToBytes +impl Service> for BodyToBytesResponse where S: Service, Response = Response>, S::Future: Send + 'static, @@ -85,12 +85,12 @@ where } #[derive(Default, Clone, Copy)] -pub struct BodyToBytesLayer; +pub struct BodyToBytesResponseLayer; -impl Layer for BodyToBytesLayer { - type Service = BodyToBytes; +impl Layer for BodyToBytesResponseLayer { + type Service = BodyToBytesResponse; fn layer(&self, inner: S) -> Self::Service { - BodyToBytes::new(inner) + BodyToBytesResponse::new(inner) } } diff --git a/crates/http/src/layers/bytes_to_body_request.rs b/crates/http/src/layers/bytes_to_body_request.rs new file mode 100644 index 00000000..8c8586c2 --- /dev/null +++ b/crates/http/src/layers/bytes_to_body_request.rs @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use http::Request; +use http_body::Full; +use tower::{Layer, Service}; + +#[derive(Clone)] +pub struct BytesToBodyRequest { + inner: S, +} + +impl BytesToBodyRequest { + pub const fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service> for BytesToBodyRequest +where + S: Service>>, + S::Future: Send + 'static, +{ + type Error = S::Error; + type Response = S::Response; + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let (parts, body) = request.into_parts(); + let body = Full::new(body); + + let request = Request::from_parts(parts, body); + + self.inner.call(request) + } +} + +#[derive(Default, Clone, Copy)] +pub struct BytesToBodyRequestLayer; + +impl Layer for BytesToBodyRequestLayer { + type Service = BytesToBodyRequest; + + fn layer(&self, inner: S) -> Self::Service { + BytesToBodyRequest::new(inner) + } +} diff --git a/crates/http/src/layers/form_urlencoded_request.rs b/crates/http/src/layers/form_urlencoded_request.rs index 134501a2..1d414573 100644 --- a/crates/http/src/layers/form_urlencoded_request.rs +++ b/crates/http/src/layers/form_urlencoded_request.rs @@ -21,7 +21,6 @@ use futures_util::{ }; use headers::{ContentType, HeaderMapExt}; use http::Request; -use http_body::Full; use serde::Serialize; use thiserror::Error; use tower::{Layer, Service}; @@ -65,7 +64,7 @@ impl FormUrlencodedRequest { impl Service> for FormUrlencodedRequest where - S: Service>>, + S: Service>, S::Future: Send + 'static, S::Error: 'static, T: Serialize, @@ -87,7 +86,7 @@ where parts.headers.typed_insert(ContentType::form_url_encoded()); let body = match serde_urlencoded::to_string(&body) { - Ok(body) => Full::new(Bytes::from(body)), + Ok(body) => Bytes::from(body), Err(err) => return std::future::ready(Err(Error::serialize(err))).left_future(), }; diff --git a/crates/http/src/layers/json_request.rs b/crates/http/src/layers/json_request.rs index 5fea858c..52d2fb3f 100644 --- a/crates/http/src/layers/json_request.rs +++ b/crates/http/src/layers/json_request.rs @@ -21,7 +21,6 @@ use futures_util::{ }; use headers::{ContentType, HeaderMapExt}; use http::Request; -use http_body::Full; use serde::Serialize; use thiserror::Error; use tower::{Layer, Service}; @@ -65,7 +64,7 @@ impl JsonRequest { impl Service> for JsonRequest where - S: Service>>, + S: Service>, S::Future: Send + 'static, S::Error: 'static, T: Serialize, @@ -87,7 +86,7 @@ where parts.headers.typed_insert(ContentType::json()); let body = match serde_json::to_vec(&body) { - Ok(body) => Full::new(Bytes::from(body)), + Ok(body) => Bytes::from(body), Err(err) => return std::future::ready(Err(Error::serialize(err))).left_future(), }; diff --git a/crates/http/src/layers/mod.rs b/crates/http/src/layers/mod.rs index 09dae847..a1e08fcc 100644 --- a/crates/http/src/layers/mod.rs +++ b/crates/http/src/layers/mod.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod body_to_bytes; +pub mod body_to_bytes_response; +pub mod bytes_to_body_request; pub mod catch_http_codes; pub mod form_urlencoded_request; pub mod json_request; diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index f0e3a7cb..ce00c50a 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -39,7 +39,8 @@ pub use self::{ }, future_service::FutureService, layers::{ - body_to_bytes::{self, BodyToBytes, BodyToBytesLayer}, + body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, + bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer}, catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer}, client::ClientLayer, form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer}, diff --git a/crates/http/tests/client_layers.rs b/crates/http/tests/client_layers.rs index e68104d1..e0ecde3d 100644 --- a/crates/http/tests/client_layers.rs +++ b/crates/http/tests/client_layers.rs @@ -79,7 +79,10 @@ async fn test_json_request_body() { Ok(res) } - let svc = ServiceBuilder::new().json_request().service_fn(handle); + let svc = ServiceBuilder::new() + .json_request() + .request_bytes_to_body() + .service_fn(handle); let request = Request::new(serde_json::json!({"hello": "world"})); @@ -141,6 +144,7 @@ async fn test_urlencoded_request_body() { let svc = ServiceBuilder::new() .form_urlencoded_request() + .request_bytes_to_body() .service_fn(handle); let request = Request::new(serde_json::json!({"hello": "world"}));