From d94442f972ece0be4ff1a548a60fa2cb220e9746 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 12 Aug 2022 16:49:54 +0200 Subject: [PATCH] Layer to application/x-www-form-urlencoded bodies --- Cargo.lock | 1 + crates/http/Cargo.toml | 1 + crates/http/src/ext.rs | 10 ++ .../src/layers/form_urlencoded_request.rs | 122 ++++++++++++++++++ crates/http/src/layers/json_request.rs | 7 +- crates/http/src/layers/json_response.rs | 8 +- crates/http/src/layers/mod.rs | 1 + 7 files changed, 142 insertions(+), 8 deletions(-) create mode 100644 crates/http/src/layers/form_urlencoded_request.rs diff --git a/Cargo.lock b/Cargo.lock index 80981493..80c01bd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2447,6 +2447,7 @@ dependencies = [ "rustls 0.20.6", "serde", "serde_json", + "serde_urlencoded", "thiserror", "tokio", "tower", diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 331ef679..b6f51559 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -21,6 +21,7 @@ opentelemetry-semantic-conventions = "0.9.0" rustls = "0.20.6" serde = "1.0.142" serde_json = "1.0.83" +serde_urlencoded = "0.7.1" thiserror = "1.0.32" tokio = { version = "1.20.1", features = ["sync", "parking_lot"] } tower = { version = "0.4.13", features = ["timeout", "limit"] } diff --git a/crates/http/src/ext.rs b/crates/http/src/ext.rs index 7862f53a..49d4f519 100644 --- a/crates/http/src/ext.rs +++ b/crates/http/src/ext.rs @@ -19,6 +19,7 @@ use tower_http::cors::CorsLayer; use crate::layers::{ body_to_bytes::{BodyToBytes, BodyToBytesLayer}, + form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer}, json_request::{JsonRequest, JsonRequestLayer}, json_response::{JsonResponse, JsonResponseLayer}, }; @@ -76,6 +77,10 @@ pub trait ServiceExt: Sized { fn json_request(self) -> JsonRequest { JsonRequest::new(self) } + + fn form_urlencoded_request(self) -> FormUrlencodedRequest { + FormUrlencodedRequest::new(self) + } } impl ServiceExt for S {} @@ -84,6 +89,7 @@ pub trait ServiceBuilderExt: Sized { fn response_to_bytes(self) -> ServiceBuilder>; fn json_response(self) -> ServiceBuilder, L>>; fn json_request(self) -> ServiceBuilder, L>>; + fn form_urlencoded_request(self) -> ServiceBuilder, L>>; } impl ServiceBuilderExt for ServiceBuilder { @@ -98,4 +104,8 @@ impl ServiceBuilderExt for ServiceBuilder { fn json_request(self) -> ServiceBuilder, L>> { self.layer(JsonRequestLayer::default()) } + + fn form_urlencoded_request(self) -> ServiceBuilder, L>> { + self.layer(FormUrlencodedRequestLayer::default()) + } } diff --git a/crates/http/src/layers/form_urlencoded_request.rs b/crates/http/src/layers/form_urlencoded_request.rs new file mode 100644 index 00000000..134501a2 --- /dev/null +++ b/crates/http/src/layers/form_urlencoded_request.rs @@ -0,0 +1,122 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{future::Ready, marker::PhantomData, task::Poll}; + +use bytes::Bytes; +use futures_util::{ + future::{Either, MapErr}, + FutureExt, TryFutureExt, +}; +use headers::{ContentType, HeaderMapExt}; +use http::Request; +use http_body::Full; +use serde::Serialize; +use thiserror::Error; +use tower::{Layer, Service}; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Service { inner: Service }, + + #[error("could not serialize form payload")] + Serialize { + #[source] + inner: serde_urlencoded::ser::Error, + }, +} + +impl Error { + fn service(source: S) -> Self { + Self::Service { inner: source } + } + + fn serialize(source: serde_urlencoded::ser::Error) -> Self { + Self::Serialize { inner: source } + } +} + +#[derive(Clone)] +pub struct FormUrlencodedRequest { + inner: S, + _t: PhantomData, +} + +impl FormUrlencodedRequest { + pub const fn new(inner: S) -> Self { + Self { + inner, + _t: PhantomData, + } + } +} + +impl Service> for FormUrlencodedRequest +where + S: Service>>, + S::Future: Send + 'static, + S::Error: 'static, + T: Serialize, +{ + type Error = Error; + type Response = S::Response; + type Future = Either< + Ready>, + MapErr Self::Error>, + >; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Error::service) + } + + fn call(&mut self, request: Request) -> Self::Future { + let (mut parts, body) = request.into_parts(); + + parts.headers.typed_insert(ContentType::form_url_encoded()); + + let body = match serde_urlencoded::to_string(&body) { + Ok(body) => Full::new(Bytes::from(body)), + Err(err) => return std::future::ready(Err(Error::serialize(err))).left_future(), + }; + + let request = Request::from_parts(parts, body); + + self.inner + .call(request) + .map_err(Error::service as fn(S::Error) -> Self::Error) + .right_future() + } +} + +#[derive(Clone, Copy)] +pub struct FormUrlencodedRequestLayer { + _t: PhantomData, +} + +impl Default for FormUrlencodedRequestLayer { + fn default() -> Self { + Self { + _t: PhantomData::default(), + } + } +} + +impl Layer for FormUrlencodedRequestLayer { + type Service = FormUrlencodedRequest; + + fn layer(&self, inner: S) -> Self::Service { + FormUrlencodedRequest::new(inner) + } +} diff --git a/crates/http/src/layers/json_request.rs b/crates/http/src/layers/json_request.rs index 74f15705..ad8e1cc3 100644 --- a/crates/http/src/layers/json_request.rs +++ b/crates/http/src/layers/json_request.rs @@ -19,7 +19,8 @@ use futures_util::{ future::{Either, MapErr}, FutureExt, TryFutureExt, }; -use http::{header::CONTENT_TYPE, HeaderValue, Request}; +use headers::{ContentType, HeaderMapExt}; +use http::Request; use http_body::Full; use serde::Serialize; use thiserror::Error; @@ -83,9 +84,7 @@ where fn call(&mut self, request: Request) -> Self::Future { let (mut parts, body) = request.into_parts(); - parts - .headers - .insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + parts.headers.typed_insert(ContentType::json()); let body = match serde_json::to_vec(&body) { Ok(body) => Full::new(Bytes::from(body)), diff --git a/crates/http/src/layers/json_response.rs b/crates/http/src/layers/json_response.rs index c0f565c2..e7213527 100644 --- a/crates/http/src/layers/json_response.rs +++ b/crates/http/src/layers/json_response.rs @@ -27,7 +27,7 @@ pub enum Error { Service { inner: Service }, #[error("could not parse JSON payload")] - Json { + Serialize { #[source] inner: serde_json::Error, }, @@ -38,8 +38,8 @@ impl Error { Self::Service { inner: source } } - fn json(source: serde_json::Error) -> Self { - Self::Json { inner: source } + fn serialize(source: serde_json::Error) -> Self { + Self::Serialize { inner: source } } } @@ -85,7 +85,7 @@ where let response = res.map_err(Error::service)?; let (parts, body) = response.into_parts(); - let body = serde_json::from_reader(body.reader()).map_err(Error::json)?; + let body = serde_json::from_reader(body.reader()).map_err(Error::serialize)?; let res = Response::from_parts(parts, body); Ok(res) diff --git a/crates/http/src/layers/mod.rs b/crates/http/src/layers/mod.rs index 8eefbb64..5f853743 100644 --- a/crates/http/src/layers/mod.rs +++ b/crates/http/src/layers/mod.rs @@ -14,6 +14,7 @@ pub(crate) mod body_to_bytes; pub(crate) mod client; +pub(crate) mod form_urlencoded_request; pub(crate) mod json_request; pub(crate) mod json_response; pub mod otel;