You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Many improvements to the mas-http crate
- make `mas_http::client` implement Service directly instead of being an async function - a Get layer that makes a Service<Uri> - better error sources in the JSON layer - make the client have a proper error type
This commit is contained in:
@@ -6,7 +6,6 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.53"
|
||||
bytes = "1.1.0"
|
||||
futures-util = "0.3.21"
|
||||
http = "0.2.6"
|
||||
@@ -15,7 +14,6 @@ hyper = "0.14.16"
|
||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
|
||||
opentelemetry = "0.17.0"
|
||||
opentelemetry-http = "0.6.0"
|
||||
pin-project-lite = "0.2.8"
|
||||
rustls = "0.20.2"
|
||||
serde = "1.0.136"
|
||||
serde_json = "1.0.78"
|
||||
|
@@ -12,19 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::layers::json::Json;
|
||||
use crate::layers::{get::Get, json::Json};
|
||||
|
||||
pub trait ServiceExt {
|
||||
fn json<T>(self) -> Json<Self, T>
|
||||
where
|
||||
Self: Sized;
|
||||
pub trait ServiceExt: Sized {
|
||||
fn json<T>(self) -> Json<Self, T>;
|
||||
|
||||
fn get(self) -> Get<Self>;
|
||||
}
|
||||
|
||||
impl<S> ServiceExt for S {
|
||||
fn json<T>(self) -> Json<Self, T>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn json<T>(self) -> Json<Self, T> {
|
||||
Json::new(self)
|
||||
}
|
||||
|
||||
fn get(self) -> Get<Self> {
|
||||
Get::new(self)
|
||||
}
|
||||
}
|
||||
|
77
crates/http/src/future_service.rs
Normal file
77
crates/http/src/future_service.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! A copy of [`tower::util::FutureService`] that also maps the future error to
|
||||
//! help implementing [`Clone`] on the service
|
||||
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::ready;
|
||||
use tower::Service;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct FutureService<F, S> {
|
||||
state: State<F, S>,
|
||||
}
|
||||
|
||||
impl<F, S> FutureService<F, S> {
|
||||
#[must_use]
|
||||
pub fn new(future: F) -> Self {
|
||||
Self {
|
||||
state: State::Future(future),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum State<F, S> {
|
||||
Future(F),
|
||||
Service(S),
|
||||
}
|
||||
|
||||
impl<F, S, R, FE, E> Service<R> for FutureService<F, S>
|
||||
where
|
||||
F: Future<Output = Result<S, FE>> + Unpin,
|
||||
S: Service<R, Error = E>,
|
||||
E: From<FE>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = E;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
loop {
|
||||
self.state = match &mut self.state {
|
||||
State::Future(fut) => {
|
||||
let fut = Pin::new(fut);
|
||||
let svc = ready!(fut.poll(cx)?);
|
||||
State::Service(svc)
|
||||
}
|
||||
State::Service(svc) => return svc.poll_ready(cx),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn call(&mut self, req: R) -> Self::Future {
|
||||
if let State::Service(svc) = &mut self.state {
|
||||
svc.call(req)
|
||||
} else {
|
||||
panic!("FutureService::call was called before FutureService::poll_ready")
|
||||
}
|
||||
}
|
||||
}
|
@@ -32,7 +32,7 @@ use super::trace::OtelTraceLayer;
|
||||
static MAS_USER_AGENT: HeaderValue =
|
||||
HeaderValue::from_static("matrix-authentication-service/0.0.1");
|
||||
|
||||
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientLayer<ReqBody> {
|
||||
@@ -41,6 +41,7 @@ pub struct ClientLayer<ReqBody> {
|
||||
}
|
||||
|
||||
impl<B> ClientLayer<B> {
|
||||
#[must_use]
|
||||
pub fn new(operation: &'static str) -> Self {
|
||||
Self {
|
||||
operation,
|
||||
@@ -65,6 +66,13 @@ where
|
||||
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
// Note that most layers here just forward the error type. Two notables
|
||||
// exceptions are:
|
||||
// - the TimeoutLayer
|
||||
// - the DecompressionLayer
|
||||
// Those layers do type erasure of the error.
|
||||
// The body is also type-erased because of the DecompressionLayer.
|
||||
|
||||
ServiceBuilder::new()
|
||||
.layer(DecompressionLayer::new())
|
||||
.map_response(|r: Response<_>| r.map(BoxBody::new))
|
||||
@@ -85,7 +93,7 @@ where
|
||||
let cx = tracing::Span::current().context();
|
||||
let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut());
|
||||
opentelemetry::global::get_text_map_propagator(|propagator| {
|
||||
propagator.inject_context(&cx, &mut injector)
|
||||
propagator.inject_context(&cx, &mut injector);
|
||||
});
|
||||
|
||||
r
|
||||
|
66
crates/http/src/layers/get.rs
Normal file
66
crates/http/src/layers/get.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use http::{Request, Uri};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
pub struct Get<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> Get<S> {
|
||||
pub const fn new(inner: S) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Service<Uri> for Get<S>
|
||||
where
|
||||
S: Service<Request<http_body::Empty<()>>>,
|
||||
{
|
||||
type Error = S::Error;
|
||||
type Response = S::Response;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Uri) -> Self::Future {
|
||||
let body = http_body::Empty::new();
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri(req)
|
||||
.body(body)
|
||||
.unwrap();
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy)]
|
||||
pub struct GetLayer;
|
||||
|
||||
impl<S> Layer<S> for GetLayer
|
||||
where
|
||||
S: Service<Request<http_body::Empty<()>>>,
|
||||
{
|
||||
type Service = Get<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
Get::new(inner)
|
||||
}
|
||||
}
|
@@ -23,12 +23,20 @@ use tower::{Layer, Service};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error<Service, Body> {
|
||||
#[error("service")]
|
||||
#[error(transparent)]
|
||||
Service { inner: Service },
|
||||
#[error("body")]
|
||||
Body { inner: Body },
|
||||
#[error("json")]
|
||||
Json { inner: serde_json::Error },
|
||||
|
||||
#[error("failed to fully read the request body")]
|
||||
Body {
|
||||
#[source]
|
||||
inner: Body,
|
||||
},
|
||||
|
||||
#[error("could not parse JSON payload")]
|
||||
Json {
|
||||
#[source]
|
||||
inner: serde_json::Error,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S, B> Error<S, B> {
|
||||
@@ -75,15 +83,16 @@ where
|
||||
self.inner.poll_ready(cx).map_err(Error::service)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
req.headers_mut()
|
||||
fn call(&mut self, mut request: Request<B>) -> Self::Future {
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(ACCEPT, HeaderValue::from_static("application/json"));
|
||||
|
||||
let fut = self.inner.call(req);
|
||||
let fut = self.inner.call(request);
|
||||
|
||||
let fut = async {
|
||||
let res = fut.await.map_err(Error::service)?;
|
||||
let (parts, body) = res.into_parts();
|
||||
let response = fut.await.map_err(Error::service)?;
|
||||
let (parts, body) = response.into_parts();
|
||||
|
||||
futures_util::pin_mut!(body);
|
||||
let bytes = hyper::body::to_bytes(&mut body)
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod get;
|
||||
pub(crate) mod json;
|
||||
pub(crate) mod server;
|
||||
pub(crate) mod trace;
|
||||
|
@@ -12,38 +12,68 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! [`tower`] layers and services to help building HTTP client and servers
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
rustdoc::missing_crate_level_docs,
|
||||
rustdoc::broken_intra_doc_links
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::{FutureExt, TryFutureExt};
|
||||
use http::{Request, Response};
|
||||
use http_body::Body;
|
||||
use http_body::{combinators::BoxBody, Body};
|
||||
use hyper::{client::HttpConnector, Client};
|
||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
|
||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
||||
use layers::client::ClientResponse;
|
||||
use tokio::sync::OnceCell;
|
||||
use thiserror::Error;
|
||||
use tokio::{sync::OnceCell, task::JoinError};
|
||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
||||
|
||||
mod ext;
|
||||
mod future_service;
|
||||
mod layers;
|
||||
|
||||
pub use self::{
|
||||
ext::ServiceExt as HttpServiceExt,
|
||||
future_service::FutureService,
|
||||
layers::{client::ClientLayer, json::JsonResponseLayer, server::ServerLayer},
|
||||
};
|
||||
|
||||
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
||||
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("failed to initialize HTTPS client")]
|
||||
Init(#[from] ClientInitError),
|
||||
|
||||
#[error(transparent)]
|
||||
Call(#[from] BoxError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientInitError {
|
||||
#[error("failed to load system certificates")]
|
||||
CertificateLoad {
|
||||
#[from]
|
||||
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
||||
},
|
||||
}
|
||||
|
||||
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
||||
|
||||
pub async fn client<B, E>(
|
||||
operation: &'static str,
|
||||
) -> anyhow::Result<
|
||||
BoxCloneService<
|
||||
Request<B>,
|
||||
Response<impl http_body::Body<Data = bytes::Bytes, Error = anyhow::Error>>,
|
||||
anyhow::Error,
|
||||
>,
|
||||
>
|
||||
async fn make_base_client<B, E>(
|
||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector>, B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
// TODO: we could probably hook a tracing DNS resolver there
|
||||
@@ -64,7 +94,8 @@ where
|
||||
})
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
||||
|
||||
let https = HttpsConnectorBuilder::new()
|
||||
.with_tls_config(tls_config.clone())
|
||||
@@ -76,15 +107,33 @@ where
|
||||
// TODO: we should get the remote address here
|
||||
let client = Client::builder().build(https);
|
||||
|
||||
Ok::<_, ClientInitError>(client)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn client<B, E: 'static>(
|
||||
operation: &'static str,
|
||||
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
||||
where
|
||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
let fut = make_base_client()
|
||||
// Map the error to a ClientError
|
||||
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
||||
// Wrap it in an Shared (Arc) to be able to Clone it
|
||||
.shared();
|
||||
|
||||
let client: FutureService<_, _> = FutureService::new(fut);
|
||||
|
||||
let client = ServiceBuilder::new()
|
||||
// Convert the errors to anyhow::Error for convenience
|
||||
.map_err(|e: BoxError| anyhow::anyhow!(e))
|
||||
// Convert the errors to ClientError to help dealing with them
|
||||
.map_err(ClientError::from)
|
||||
.map_response(|r: ClientResponse<hyper::Body>| {
|
||||
r.map(|body| body.map_err(|e: BoxError| anyhow::anyhow!(e)))
|
||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||
})
|
||||
.layer(ClientLayer::new(operation))
|
||||
.service(client)
|
||||
.boxed_clone();
|
||||
.service(client);
|
||||
|
||||
Ok(client)
|
||||
client.boxed_clone()
|
||||
}
|
||||
|
Reference in New Issue
Block a user