1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Many improvements to the mas-http crate

- make `mas_http::client` implement Service directly instead of being
   an async function
 - a Get layer that makes a Service<Uri>
 - better error sources in the JSON layer
 - make the client have a proper error type
This commit is contained in:
Quentin Gliech
2022-02-15 08:28:25 +01:00
parent 497a3e006e
commit c5858e6ed5
10 changed files with 260 additions and 53 deletions

12
Cargo.lock generated
View File

@ -1549,9 +1549,9 @@ checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.5.1"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acd94fdbe1d4ff688b67b04eee2e17bd50995534a61539e45adfefb45e5e5503"
checksum = "9100414882e15fb7feccb4897e5f0ff0ff1ca7d1a86a23208ada4d7a18e6c6c4"
[[package]]
name = "httpdate"
@ -1567,9 +1567,9 @@ checksum = "02296996cb8796d7c6e3bc2d9211b7802812d36999a51bb754123ead7d37d026"
[[package]]
name = "hyper"
version = "0.14.16"
version = "0.14.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7ec3e62bdc98a2f0393a5048e4c30ef659440ea6e0e572965103e72bd836f55"
checksum = "043f0e083e9901b6cc658a77d1eb86f4fc650bbb977a4337dd63192826aa85dd"
dependencies = [
"bytes 1.1.0",
"futures-channel",
@ -1580,7 +1580,7 @@ dependencies = [
"http-body",
"httparse",
"httpdate",
"itoa 0.4.8",
"itoa 1.0.1",
"pin-project-lite",
"socket2",
"tokio",
@ -1993,7 +1993,6 @@ dependencies = [
name = "mas-http"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes 1.1.0",
"futures-util",
"http",
@ -2002,7 +2001,6 @@ dependencies = [
"hyper-rustls 0.23.0",
"opentelemetry",
"opentelemetry-http",
"pin-project-lite",
"rustls 0.20.2",
"serde",
"serde_json",

View File

@ -65,7 +65,7 @@ impl Options {
json: false,
url,
} => {
let mut client = mas_http::client("cli-debug-http").await?;
let mut client = mas_http::client("cli-debug-http");
let request = hyper::Request::builder()
.uri(url)
.body(hyper::Body::empty())?;
@ -89,7 +89,7 @@ impl Options {
json: true,
url,
} => {
let mut client = mas_http::client("cli-debug-http").await?.json();
let mut client = mas_http::client("cli-debug-http").json();
let request = hyper::Request::builder()
.uri(url)
.body(hyper::Body::empty())?;

View File

@ -6,7 +6,6 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
anyhow = "1.0.53"
bytes = "1.1.0"
futures-util = "0.3.21"
http = "0.2.6"
@ -15,7 +14,6 @@ hyper = "0.14.16"
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"] }
opentelemetry = "0.17.0"
opentelemetry-http = "0.6.0"
pin-project-lite = "0.2.8"
rustls = "0.20.2"
serde = "1.0.136"
serde_json = "1.0.78"

View File

@ -12,19 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::layers::json::Json;
use crate::layers::{get::Get, json::Json};
pub trait ServiceExt {
fn json<T>(self) -> Json<Self, T>
where
Self: Sized;
pub trait ServiceExt: Sized {
fn json<T>(self) -> Json<Self, T>;
fn get(self) -> Get<Self>;
}
impl<S> ServiceExt for S {
fn json<T>(self) -> Json<Self, T>
where
Self: Sized,
{
fn json<T>(self) -> Json<Self, T> {
Json::new(self)
}
fn get(self) -> Get<Self> {
Get::new(self)
}
}

View File

@ -0,0 +1,77 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A copy of [`tower::util::FutureService`] that also maps the future error to
//! help implementing [`Clone`] on the service
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures_util::ready;
use tower::Service;
#[derive(Clone, Debug)]
pub struct FutureService<F, S> {
state: State<F, S>,
}
impl<F, S> FutureService<F, S> {
#[must_use]
pub fn new(future: F) -> Self {
Self {
state: State::Future(future),
}
}
}
#[derive(Clone, Debug)]
enum State<F, S> {
Future(F),
Service(S),
}
impl<F, S, R, FE, E> Service<R> for FutureService<F, S>
where
F: Future<Output = Result<S, FE>> + Unpin,
S: Service<R, Error = E>,
E: From<FE>,
{
type Response = S::Response;
type Error = E;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
self.state = match &mut self.state {
State::Future(fut) => {
let fut = Pin::new(fut);
let svc = ready!(fut.poll(cx)?);
State::Service(svc)
}
State::Service(svc) => return svc.poll_ready(cx),
};
}
}
fn call(&mut self, req: R) -> Self::Future {
if let State::Service(svc) = &mut self.state {
svc.call(req)
} else {
panic!("FutureService::call was called before FutureService::poll_ready")
}
}
}

View File

@ -32,7 +32,7 @@ use super::trace::OtelTraceLayer;
static MAS_USER_AGENT: HeaderValue =
HeaderValue::from_static("matrix-authentication-service/0.0.1");
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
@ -41,6 +41,7 @@ pub struct ClientLayer<ReqBody> {
}
impl<B> ClientLayer<B> {
#[must_use]
pub fn new(operation: &'static str) -> Self {
Self {
operation,
@ -65,6 +66,13 @@ where
type Service = BoxCloneService<Request<ReqBody>, ClientResponse<ResBody>, BoxError>;
fn layer(&self, inner: S) -> Self::Service {
// Note that most layers here just forward the error type. Two notables
// exceptions are:
// - the TimeoutLayer
// - the DecompressionLayer
// Those layers do type erasure of the error.
// The body is also type-erased because of the DecompressionLayer.
ServiceBuilder::new()
.layer(DecompressionLayer::new())
.map_response(|r: Response<_>| r.map(BoxBody::new))
@ -85,7 +93,7 @@ where
let cx = tracing::Span::current().context();
let mut injector = opentelemetry_http::HeaderInjector(r.headers_mut());
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut injector)
propagator.inject_context(&cx, &mut injector);
});
r

View File

@ -0,0 +1,66 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use http::{Request, Uri};
use tower::{Layer, Service};
pub struct Get<S> {
inner: S,
}
impl<S> Get<S> {
pub const fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S> Service<Uri> for Get<S>
where
S: Service<Request<http_body::Empty<()>>>,
{
type Error = S::Error;
type Response = S::Response;
type Future = S::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Uri) -> Self::Future {
let body = http_body::Empty::new();
let req = Request::builder()
.method("GET")
.uri(req)
.body(body)
.unwrap();
self.inner.call(req)
}
}
#[derive(Default, Clone, Copy)]
pub struct GetLayer;
impl<S> Layer<S> for GetLayer
where
S: Service<Request<http_body::Empty<()>>>,
{
type Service = Get<S>;
fn layer(&self, inner: S) -> Self::Service {
Get::new(inner)
}
}

View File

@ -23,12 +23,20 @@ use tower::{Layer, Service};
#[derive(Debug, Error)]
pub enum Error<Service, Body> {
#[error("service")]
#[error(transparent)]
Service { inner: Service },
#[error("body")]
Body { inner: Body },
#[error("json")]
Json { inner: serde_json::Error },
#[error("failed to fully read the request body")]
Body {
#[source]
inner: Body,
},
#[error("could not parse JSON payload")]
Json {
#[source]
inner: serde_json::Error,
},
}
impl<S, B> Error<S, B> {
@ -75,15 +83,16 @@ where
self.inner.poll_ready(cx).map_err(Error::service)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
req.headers_mut()
fn call(&mut self, mut request: Request<B>) -> Self::Future {
request
.headers_mut()
.insert(ACCEPT, HeaderValue::from_static("application/json"));
let fut = self.inner.call(req);
let fut = self.inner.call(request);
let fut = async {
let res = fut.await.map_err(Error::service)?;
let (parts, body) = res.into_parts();
let response = fut.await.map_err(Error::service)?;
let (parts, body) = response.into_parts();
futures_util::pin_mut!(body);
let bytes = hyper::body::to_bytes(&mut body)

View File

@ -13,6 +13,7 @@
// limitations under the License.
pub(crate) mod client;
pub(crate) mod get;
pub(crate) mod json;
pub(crate) mod server;
pub(crate) mod trace;

View File

@ -12,38 +12,68 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! [`tower`] layers and services to help building HTTP client and servers
#![forbid(unsafe_code)]
#![deny(
clippy::all,
rustdoc::missing_crate_level_docs,
rustdoc::broken_intra_doc_links
)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
use std::sync::Arc;
use bytes::Bytes;
use futures_util::{FutureExt, TryFutureExt};
use http::{Request, Response};
use http_body::Body;
use http_body::{combinators::BoxBody, Body};
use hyper::{client::HttpConnector, Client};
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
use layers::client::ClientResponse;
use tokio::sync::OnceCell;
use thiserror::Error;
use tokio::{sync::OnceCell, task::JoinError};
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
mod ext;
mod future_service;
mod layers;
pub use self::{
ext::ServiceExt as HttpServiceExt,
future_service::FutureService,
layers::{client::ClientLayer, json::JsonResponseLayer, server::ServerLayer},
};
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
/// A wrapper over a boxed error that implements ``std::error::Error``.
/// This is helps converting to ``anyhow::Error`` with the `?` operator
#[derive(Error, Debug)]
pub enum ClientError {
#[error("failed to initialize HTTPS client")]
Init(#[from] ClientInitError),
#[error(transparent)]
Call(#[from] BoxError),
}
#[derive(Error, Debug, Clone)]
pub enum ClientInitError {
#[error("failed to load system certificates")]
CertificateLoad {
#[from]
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
},
}
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
pub async fn client<B, E>(
operation: &'static str,
) -> anyhow::Result<
BoxCloneService<
Request<B>,
Response<impl http_body::Body<Data = bytes::Bytes, Error = anyhow::Error>>,
anyhow::Error,
>,
>
async fn make_base_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector>, B>, ClientInitError>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
{
// TODO: we could probably hook a tracing DNS resolver there
@ -64,7 +94,8 @@ where
})
.await
})
.await?;
.await
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config.clone())
@ -76,15 +107,33 @@ where
// TODO: we should get the remote address here
let client = Client::builder().build(https);
Ok::<_, ClientInitError>(client)
}
#[must_use]
pub fn client<B, E: 'static>(
operation: &'static str,
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError>,
{
let fut = make_base_client()
// Map the error to a ClientError
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
// Wrap it in an Shared (Arc) to be able to Clone it
.shared();
let client: FutureService<_, _> = FutureService::new(fut);
let client = ServiceBuilder::new()
// Convert the errors to anyhow::Error for convenience
.map_err(|e: BoxError| anyhow::anyhow!(e))
// Convert the errors to ClientError to help dealing with them
.map_err(ClientError::from)
.map_response(|r: ClientResponse<hyper::Body>| {
r.map(|body| body.map_err(|e: BoxError| anyhow::anyhow!(e)))
r.map(|body| body.map_err(ClientError::from).boxed())
})
.layer(ClientLayer::new(operation))
.service(client)
.boxed_clone();
.service(client);
Ok(client)
client.boxed_clone()
}