1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-06 05:42:30 +03:00
Files
authentication-service/crates/http/tests/client_layers.rs
2022-10-17 16:48:12 +02:00

156 lines
4.9 KiB
Rust

// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::Infallible;
use anyhow::{bail, Context};
use bytes::{Buf, Bytes};
use headers::{ContentType, HeaderMapExt};
use http::{header::ACCEPT, HeaderValue, Request, Response, StatusCode};
use mas_http::{
BodyToBytesResponseLayer, BytesToBodyRequestLayer, CatchHttpCodesLayer,
FormUrlencodedRequestLayer, JsonRequestLayer, JsonResponseLayer,
};
use serde::Deserialize;
use thiserror::Error;
use tower::{service_fn, Layer, ServiceExt};
#[derive(Debug, Error, Deserialize)]
#[error("Error code in response: {error}")]
struct Error {
error: String,
}
#[tokio::test]
async fn test_http_errors() {
async fn handle<B>(_request: Request<B>) -> Result<Response<String>, Infallible> {
let mut res = Response::new(r#"{"error": "invalid_request"}"#.to_owned());
*res.status_mut() = StatusCode::BAD_REQUEST;
Ok(res)
}
fn mapper(response: Response<Bytes>) -> Error {
serde_json::from_reader(response.into_body().reader()).unwrap()
}
let layer = (
CatchHttpCodesLayer::exact(StatusCode::BAD_REQUEST, mapper),
BodyToBytesResponseLayer,
);
let svc = layer.layer(service_fn(handle));
let request = Request::new(hyper::Body::empty());
let res = svc.oneshot(request).await;
let err = res.expect_err("the request should fail");
assert_eq!(err.status_code(), Some(StatusCode::BAD_REQUEST));
}
#[tokio::test]
async fn test_json_request_body() {
async fn handle<B>(request: Request<B>) -> Result<Response<hyper::Body>, anyhow::Error>
where
B: http_body::Body + Send,
B::Error: std::error::Error + Send + Sync + 'static,
{
if request
.headers()
.typed_get::<ContentType>()
.context("Missing Content-Type header")?
!= ContentType::json()
{
bail!("Content-Type header is not application/json")
}
let bytes = hyper::body::to_bytes(request.into_body()).await?;
if bytes.to_vec() != br#"{"hello":"world"}"#.to_vec() {
bail!("Body mismatch")
}
let res = Response::new(hyper::Body::empty());
Ok(res)
}
let layer = (JsonRequestLayer::default(), BytesToBodyRequestLayer);
let svc = layer.layer(service_fn(handle));
let request = Request::new(serde_json::json!({"hello": "world"}));
let res = svc.oneshot(request).await;
res.expect("the request should succeed");
}
#[tokio::test]
async fn test_json_response_body() {
async fn handle<B>(request: Request<B>) -> Result<Response<String>, anyhow::Error> {
if request
.headers()
.get(ACCEPT)
.context("Missing Accept header")?
!= HeaderValue::from_static("application/json")
{
bail!("Accept header is not application/json")
}
let res = Response::new(r#"{"hello": "world"}"#.to_owned());
Ok(res)
}
let layer = (JsonResponseLayer::default(), BodyToBytesResponseLayer);
let svc = layer.layer(service_fn(handle));
let request = Request::new(hyper::Body::empty());
let res = svc.oneshot(request).await;
let response = res.expect("the request to succeed");
let body: serde_json::Value = response.into_body();
assert_eq!(body, serde_json::json!({"hello": "world"}));
}
#[tokio::test]
async fn test_urlencoded_request_body() {
async fn handle<B>(request: Request<B>) -> Result<Response<hyper::Body>, anyhow::Error>
where
B: http_body::Body + Send,
B::Error: std::error::Error + Send + Sync + 'static,
{
if request
.headers()
.typed_get::<ContentType>()
.context("Missing Content-Type header")?
!= ContentType::form_url_encoded()
{
bail!("Content-Type header is not application/x-form-urlencoded")
}
let bytes = hyper::body::to_bytes(request.into_body()).await?;
assert_eq!(bytes.to_vec(), br#"hello=world"#.to_vec());
let res = Response::new(hyper::Body::empty());
Ok(res)
}
let layer = (
FormUrlencodedRequestLayer::default(),
BytesToBodyRequestLayer,
);
let svc = layer.layer(service_fn(handle));
let request = Request::new(serde_json::json!({"hello": "world"}));
let res = svc.oneshot(request).await;
res.expect("the request to succeed");
}