// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::{marker::PhantomData, task::Poll}; use futures_util::future::BoxFuture; use http::{header::ACCEPT, HeaderValue, Request, Response}; use http_body::Body; use serde::de::DeserializeOwned; use thiserror::Error; use tower::{Layer, Service}; #[derive(Debug, Error)] pub enum Error { #[error(transparent)] Service { inner: Service }, #[error("failed to fully read the request body")] Body { #[source] inner: Body, }, #[error("could not parse JSON payload")] Json { #[source] inner: serde_json::Error, }, } impl Error { fn service(source: S) -> Self { Self::Service { inner: source } } fn body(source: B) -> Self { Self::Body { inner: source } } fn json(source: serde_json::Error) -> Self { Self::Json { inner: source } } } #[derive(Clone)] pub struct Json { inner: S, _t: PhantomData, } impl Json { pub const fn new(inner: S) -> Self { Self { inner, _t: PhantomData, } } } impl Service> for Json where S: Service, Response = Response>, S::Future: Send + 'static, C: Body + Send + 'static, C::Data: Send + 'static, T: DeserializeOwned, { type Error = Error; type Response = Response; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { self.inner.poll_ready(cx).map_err(Error::service) } fn call(&mut self, mut request: Request) -> Self::Future { request .headers_mut() .insert(ACCEPT, HeaderValue::from_static("application/json")); let fut = self.inner.call(request); let fut = async { let response = fut.await.map_err(Error::service)?; let (parts, body) = response.into_parts(); futures_util::pin_mut!(body); let bytes = hyper::body::to_bytes(&mut body) .await .map_err(Error::body)?; let body = serde_json::from_slice(&bytes).map_err(Error::json)?; let res = Response::from_parts(parts, body); Ok(res) }; Box::pin(fut) } } #[derive(Default, Clone, Copy)] pub struct JsonResponseLayer(PhantomData<(T, ReqBody)>); impl Layer for JsonResponseLayer where S: Service, Response = Response>, T: serde::de::DeserializeOwned, { type Service = Json; fn layer(&self, inner: S) -> Self::Service { Json::new(inner) } }