// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::ops::{Bound, RangeBounds}; use futures_util::FutureExt; use http::{Request, Response, StatusCode}; use thiserror::Error; use tower::{Layer, Service}; #[derive(Debug, Error)] pub enum Error { #[error(transparent)] Service { inner: S }, #[error("request failed with status {status_code}")] HttpError { status_code: StatusCode, #[source] inner: E, }, } impl Error { fn service(inner: S) -> Self { Self::Service { inner } } pub fn status_code(&self) -> Option { match self { Self::Service { .. } => None, Self::HttpError { status_code, .. } => Some(*status_code), } } } #[derive(Clone)] pub struct CatchHttpCodes { inner: S, bounds: (Bound, Bound), mapper: M, } impl CatchHttpCodes { pub fn new(inner: S, bounds: B, mapper: M) -> Self where B: RangeBounds, M: Clone, { let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned()); Self { inner, bounds, mapper, } } } impl Service> for CatchHttpCodes where S: Service, Response = Response>, S::Future: Send + 'static, M: Fn(Response) -> E + Send + Clone + 'static, { type Error = Error; type Response = Response; type Future = futures_util::future::Map< S::Future, Box< dyn Fn(Result) -> Result + Send + 'static, >, >; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.inner.poll_ready(cx).map_err(Error::service) } fn call(&mut self, request: Request) -> Self::Future { let fut = self.inner.call(request); let bounds = self.bounds; let mapper = self.mapper.clone(); fut.map(Box::new(move |res: Result| { let response = res.map_err(Error::service)?; let status_code = response.status(); if bounds.contains(&status_code) { let inner = mapper(response); Err(Error::HttpError { status_code, inner }) } else { Ok(response) } })) } } #[derive(Clone)] pub struct CatchHttpCodesLayer { bounds: (Bound, Bound), mapper: M, } impl CatchHttpCodesLayer where M: Clone, { pub fn new(bounds: B, mapper: M) -> Self where B: RangeBounds, { let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned()); Self { bounds, mapper } } pub fn exact(status_code: StatusCode, mapper: M) -> Self { Self::new(status_code..=status_code, mapper) } } impl Layer for CatchHttpCodesLayer where M: Clone, { type Service = CatchHttpCodes; fn layer(&self, inner: S) -> Self::Service { CatchHttpCodes::new(inner, self.bounds, self.mapper.clone()) } }