You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-21 23:00:50 +03:00
Add a layer to catch HTTP error codes
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2432,6 +2432,7 @@ dependencies = [
|
|||||||
name = "mas-http"
|
name = "mas-http"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"axum",
|
"axum",
|
||||||
"bytes 1.2.1",
|
"bytes 1.2.1",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
|||||||
@@ -28,3 +28,9 @@ tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
|||||||
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
|
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors"] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-opentelemetry = "0.17.4"
|
tracing-opentelemetry = "0.17.4"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = "1.0.62"
|
||||||
|
serde = { version = "1.0.142", features = ["derive"] }
|
||||||
|
tokio = { version = "1.20.1", features = ["macros"] }
|
||||||
|
tower = { version = "0.4.13", features = ["util"] }
|
||||||
|
|||||||
@@ -12,13 +12,16 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use http::header::HeaderName;
|
use std::ops::RangeBounds;
|
||||||
|
|
||||||
|
use http::{header::HeaderName, Request, StatusCode};
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use tower::{layer::util::Stack, ServiceBuilder};
|
use tower::{layer::util::Stack, Service, ServiceBuilder};
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
|
|
||||||
use crate::layers::{
|
use crate::layers::{
|
||||||
body_to_bytes::{BodyToBytes, BodyToBytesLayer},
|
body_to_bytes::{BodyToBytes, BodyToBytesLayer},
|
||||||
|
catch_http_codes::{CatchHttpCodes, CatchHttpCodesLayer},
|
||||||
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
|
form_urlencoded_request::{FormUrlencodedRequest, FormUrlencodedRequestLayer},
|
||||||
json_request::{JsonRequest, JsonRequestLayer},
|
json_request::{JsonRequest, JsonRequestLayer},
|
||||||
json_response::{JsonResponse, JsonResponseLayer},
|
json_response::{JsonResponse, JsonResponseLayer},
|
||||||
@@ -65,7 +68,7 @@ impl CorsLayerExt for CorsLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ServiceExt: Sized {
|
pub trait ServiceExt<Body>: Sized {
|
||||||
fn response_body_to_bytes(self) -> BodyToBytes<Self> {
|
fn response_body_to_bytes(self) -> BodyToBytes<Self> {
|
||||||
BodyToBytes::new(self)
|
BodyToBytes::new(self)
|
||||||
}
|
}
|
||||||
@@ -81,19 +84,54 @@ pub trait ServiceExt: Sized {
|
|||||||
fn form_urlencoded_request<T>(self) -> FormUrlencodedRequest<Self, T> {
|
fn form_urlencoded_request<T>(self) -> FormUrlencodedRequest<Self, T> {
|
||||||
FormUrlencodedRequest::new(self)
|
FormUrlencodedRequest::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn catch_http_code<M>(self, status_code: StatusCode, mapper: M) -> CatchHttpCodes<Self, M>
|
||||||
|
where
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
self.catch_http_codes(status_code..=status_code, mapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn catch_http_codes<B, M>(self, bounds: B, mapper: M) -> CatchHttpCodes<Self, M>
|
||||||
|
where
|
||||||
|
B: RangeBounds<StatusCode>,
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
CatchHttpCodes::new(self, bounds, mapper)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> ServiceExt for S {}
|
impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}
|
||||||
|
|
||||||
pub trait ServiceBuilderExt<L>: Sized {
|
pub trait ServiceBuilderExt<L>: Sized {
|
||||||
fn response_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesLayer, L>>;
|
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesLayer, L>>;
|
||||||
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>>;
|
fn json_response<T>(self) -> ServiceBuilder<Stack<JsonResponseLayer<T>, L>>;
|
||||||
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>>;
|
fn json_request<T>(self) -> ServiceBuilder<Stack<JsonRequestLayer<T>, L>>;
|
||||||
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>>;
|
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>>;
|
||||||
|
|
||||||
|
fn catch_http_code<M>(
|
||||||
|
self,
|
||||||
|
status_code: StatusCode,
|
||||||
|
mapper: M,
|
||||||
|
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
||||||
|
where
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
self.catch_http_codes(status_code..=status_code, mapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn catch_http_codes<B, M>(
|
||||||
|
self,
|
||||||
|
bounds: B,
|
||||||
|
mapper: M,
|
||||||
|
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
||||||
|
where
|
||||||
|
B: RangeBounds<StatusCode>,
|
||||||
|
M: Clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
|
impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
|
||||||
fn response_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesLayer, L>> {
|
fn response_body_to_bytes(self) -> ServiceBuilder<Stack<BodyToBytesLayer, L>> {
|
||||||
self.layer(BodyToBytesLayer::default())
|
self.layer(BodyToBytesLayer::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,4 +146,16 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
|
|||||||
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>> {
|
fn form_urlencoded_request<T>(self) -> ServiceBuilder<Stack<FormUrlencodedRequestLayer<T>, L>> {
|
||||||
self.layer(FormUrlencodedRequestLayer::default())
|
self.layer(FormUrlencodedRequestLayer::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn catch_http_codes<B, M>(
|
||||||
|
self,
|
||||||
|
bounds: B,
|
||||||
|
mapper: M,
|
||||||
|
) -> ServiceBuilder<Stack<CatchHttpCodesLayer<M>, L>>
|
||||||
|
where
|
||||||
|
B: RangeBounds<StatusCode>,
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
self.layer(CatchHttpCodesLayer::new(bounds, mapper))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
138
crates/http/src/layers/catch_http_codes.rs
Normal file
138
crates/http/src/layers/catch_http_codes.rs
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::ops::{Bound, RangeBounds};
|
||||||
|
|
||||||
|
use futures_util::FutureExt;
|
||||||
|
use http::{Request, Response, StatusCode};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tower::{Layer, Service};
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum Error<S, E> {
|
||||||
|
#[error(transparent)]
|
||||||
|
Service { inner: S },
|
||||||
|
|
||||||
|
#[error("request failed with status {status_code}")]
|
||||||
|
HttpError {
|
||||||
|
status_code: StatusCode,
|
||||||
|
#[source]
|
||||||
|
inner: E,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, E> Error<S, E> {
|
||||||
|
fn service(inner: S) -> Self {
|
||||||
|
Self::Service { inner }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn status_code(&self) -> Option<StatusCode> {
|
||||||
|
match self {
|
||||||
|
Self::Service { .. } => None,
|
||||||
|
Self::HttpError { status_code, .. } => Some(*status_code),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CatchHttpCodes<S, M> {
|
||||||
|
inner: S,
|
||||||
|
bounds: (Bound<StatusCode>, Bound<StatusCode>),
|
||||||
|
mapper: M,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, M> CatchHttpCodes<S, M> {
|
||||||
|
pub fn new<B>(inner: S, bounds: B, mapper: M) -> Self
|
||||||
|
where
|
||||||
|
B: RangeBounds<StatusCode>,
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
bounds,
|
||||||
|
mapper,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, M, E, ReqBody, ResBody> Service<Request<ReqBody>> for CatchHttpCodes<S, M>
|
||||||
|
where
|
||||||
|
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
|
||||||
|
{
|
||||||
|
type Error = Error<S::Error, E>;
|
||||||
|
type Response = Response<ResBody>;
|
||||||
|
type Future = futures_util::future::Map<
|
||||||
|
S::Future,
|
||||||
|
Box<
|
||||||
|
dyn Fn(Result<S::Response, S::Error>) -> Result<Self::Response, Self::Error>
|
||||||
|
+ Send
|
||||||
|
+ 'static,
|
||||||
|
>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
fn poll_ready(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
self.inner.poll_ready(cx).map_err(Error::service)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
|
||||||
|
let fut = self.inner.call(request);
|
||||||
|
let bounds = self.bounds;
|
||||||
|
let mapper = self.mapper.clone();
|
||||||
|
|
||||||
|
fut.map(Box::new(move |res: Result<S::Response, S::Error>| {
|
||||||
|
let response = res.map_err(Error::service)?;
|
||||||
|
let status_code = response.status();
|
||||||
|
|
||||||
|
if bounds.contains(&status_code) {
|
||||||
|
let inner = mapper(response);
|
||||||
|
Err(Error::HttpError { status_code, inner })
|
||||||
|
} else {
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CatchHttpCodesLayer<M> {
|
||||||
|
bounds: (Bound<StatusCode>, Bound<StatusCode>),
|
||||||
|
mapper: M,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M> CatchHttpCodesLayer<M> {
|
||||||
|
pub fn new<B>(bounds: B, mapper: M) -> Self
|
||||||
|
where
|
||||||
|
B: RangeBounds<StatusCode>,
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
|
||||||
|
Self { bounds, mapper }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, M> Layer<S> for CatchHttpCodesLayer<M>
|
||||||
|
where
|
||||||
|
M: Clone,
|
||||||
|
{
|
||||||
|
type Service = CatchHttpCodes<S, M>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
CatchHttpCodes::new(inner, self.bounds, self.mapper.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,10 +12,12 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
pub(crate) mod body_to_bytes;
|
pub mod body_to_bytes;
|
||||||
pub(crate) mod client;
|
pub mod catch_http_codes;
|
||||||
pub(crate) mod form_urlencoded_request;
|
pub mod form_urlencoded_request;
|
||||||
pub(crate) mod json_request;
|
pub mod json_request;
|
||||||
pub(crate) mod json_response;
|
pub mod json_response;
|
||||||
pub mod otel;
|
pub mod otel;
|
||||||
|
|
||||||
|
pub(crate) mod client;
|
||||||
pub(crate) mod server;
|
pub(crate) mod server;
|
||||||
|
|||||||
@@ -35,24 +35,34 @@ use hyper::{
|
|||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
||||||
use layers::{
|
|
||||||
client::ClientResponse,
|
|
||||||
otel::{TraceDns, TraceLayer},
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::{sync::OnceCell, task::JoinError};
|
use tokio::{sync::OnceCell, task::JoinError};
|
||||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
||||||
|
|
||||||
|
use self::layers::{
|
||||||
|
client::ClientResponse,
|
||||||
|
otel::{TraceDns, TraceLayer},
|
||||||
|
};
|
||||||
|
|
||||||
mod ext;
|
mod ext;
|
||||||
mod future_service;
|
mod future_service;
|
||||||
mod layers;
|
mod layers;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
|
ext::{
|
||||||
|
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
||||||
|
ServiceExt as HttpServiceExt,
|
||||||
|
},
|
||||||
future_service::FutureService,
|
future_service::FutureService,
|
||||||
layers::{
|
layers::{
|
||||||
body_to_bytes::BodyToBytesLayer, client::ClientLayer, json_request::JsonRequestLayer,
|
body_to_bytes::{self, BodyToBytes, BodyToBytesLayer},
|
||||||
json_response::JsonResponseLayer, otel, server::ServerLayer,
|
catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer},
|
||||||
|
client::ClientLayer,
|
||||||
|
form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer},
|
||||||
|
json_request::{self, JsonRequest, JsonRequestLayer},
|
||||||
|
json_response::{self, JsonResponse, JsonResponseLayer},
|
||||||
|
otel,
|
||||||
|
server::ServerLayer,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
150
crates/http/tests/client_layers.rs
Normal file
150
crates/http/tests/client_layers.rs
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::convert::Infallible;
|
||||||
|
|
||||||
|
use anyhow::{bail, Context};
|
||||||
|
use bytes::{Buf, Bytes};
|
||||||
|
use headers::{ContentType, HeaderMapExt};
|
||||||
|
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
|
||||||
|
use mas_http::HttpServiceBuilderExt;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tower::{ServiceBuilder, ServiceExt};
|
||||||
|
|
||||||
|
#[derive(Debug, Error, Deserialize)]
|
||||||
|
#[error("Error code in response: {error}")]
|
||||||
|
struct Error {
|
||||||
|
error: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_http_errors() {
|
||||||
|
async fn handle<B>(_request: Request<B>) -> Result<Response<String>, Infallible> {
|
||||||
|
let mut res = Response::new(r#"{"error": "invalid_request"}"#.to_owned());
|
||||||
|
*res.status_mut() = StatusCode::BAD_REQUEST;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mapper(response: Response<Bytes>) -> Error {
|
||||||
|
serde_json::from_reader(response.into_body().reader()).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
let svc = ServiceBuilder::new()
|
||||||
|
.catch_http_code(StatusCode::BAD_REQUEST, mapper)
|
||||||
|
.response_body_to_bytes()
|
||||||
|
.service_fn(handle);
|
||||||
|
|
||||||
|
let request = Request::new(hyper::Body::empty());
|
||||||
|
|
||||||
|
let res = svc.oneshot(request).await;
|
||||||
|
let err = res.expect_err("the request should fail");
|
||||||
|
assert_eq!(err.status_code(), Some(StatusCode::BAD_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_json_request_body() {
|
||||||
|
async fn handle<B>(request: Request<B>) -> Result<Response<hyper::Body>, anyhow::Error>
|
||||||
|
where
|
||||||
|
B: http_body::Body + Send,
|
||||||
|
B::Error: std::error::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
if request
|
||||||
|
.headers()
|
||||||
|
.typed_get::<ContentType>()
|
||||||
|
.context("Missing Content-Type header")?
|
||||||
|
!= ContentType::json()
|
||||||
|
{
|
||||||
|
bail!("Content-Type header is not application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
let bytes = hyper::body::to_bytes(request.into_body()).await?;
|
||||||
|
if bytes.to_vec() != br#"{"hello":"world"}"#.to_vec() {
|
||||||
|
bail!("Body mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = Response::new(hyper::Body::empty());
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
let svc = ServiceBuilder::new().json_request().service_fn(handle);
|
||||||
|
|
||||||
|
let request = Request::new(serde_json::json!({"hello": "world"}));
|
||||||
|
|
||||||
|
let res = svc.oneshot(request).await;
|
||||||
|
res.expect("the request should succeed");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_json_response_body() {
|
||||||
|
async fn handle<B>(request: Request<B>) -> Result<Response<String>, anyhow::Error> {
|
||||||
|
if request
|
||||||
|
.headers()
|
||||||
|
.get(ACCEPT)
|
||||||
|
.context("Missing Accept header")?
|
||||||
|
!= HeaderValue::from_static("application/json")
|
||||||
|
{
|
||||||
|
bail!("Accept header is not application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = Response::new(r#"{"hello": "world"}"#.to_owned());
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
let svc = ServiceBuilder::new()
|
||||||
|
.json_response()
|
||||||
|
.response_body_to_bytes()
|
||||||
|
.service_fn(handle);
|
||||||
|
|
||||||
|
let request = Request::new(hyper::Body::empty());
|
||||||
|
|
||||||
|
let res = svc.oneshot(request).await;
|
||||||
|
let response = res.expect("the request to succeed");
|
||||||
|
let body: serde_json::Value = response.into_body();
|
||||||
|
assert_eq!(body, serde_json::json!({"hello": "world"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_urlencoded_request_body() {
|
||||||
|
async fn handle<B>(request: Request<B>) -> Result<Response<hyper::Body>, anyhow::Error>
|
||||||
|
where
|
||||||
|
B: http_body::Body + Send,
|
||||||
|
B::Error: std::error::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
if request
|
||||||
|
.headers()
|
||||||
|
.typed_get::<ContentType>()
|
||||||
|
.context("Missing Content-Type header")?
|
||||||
|
!= ContentType::form_url_encoded()
|
||||||
|
{
|
||||||
|
bail!("Content-Type header is not application/x-form-urlencoded")
|
||||||
|
}
|
||||||
|
|
||||||
|
let bytes = hyper::body::to_bytes(request.into_body()).await?;
|
||||||
|
assert_eq!(bytes.to_vec(), br#"hello=world"#.to_vec());
|
||||||
|
|
||||||
|
let res = Response::new(hyper::Body::empty());
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
let svc = ServiceBuilder::new()
|
||||||
|
.form_urlencoded_request()
|
||||||
|
.service_fn(handle);
|
||||||
|
|
||||||
|
let request = Request::new(serde_json::json!({"hello": "world"}));
|
||||||
|
|
||||||
|
let res = svc.oneshot(request).await;
|
||||||
|
res.expect("the request to succeed");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user