You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Simplify the HTTP client building
Also supports loading the WebPKI roots instead of the native ones for TLS
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -1926,7 +1926,6 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"hyper",
|
"hyper",
|
||||||
"rustls 0.20.6",
|
"rustls 0.20.6",
|
||||||
"rustls-native-certs 0.6.2",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls 0.23.4",
|
"tokio-rustls 0.23.4",
|
||||||
]
|
]
|
||||||
@@ -2459,6 +2458,7 @@ dependencies = [
|
|||||||
"opentelemetry-http",
|
"opentelemetry-http",
|
||||||
"opentelemetry-semantic-conventions",
|
"opentelemetry-semantic-conventions",
|
||||||
"rustls 0.20.6",
|
"rustls 0.20.6",
|
||||||
|
"rustls-native-certs 0.6.2",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
@@ -2468,6 +2468,8 @@ dependencies = [
|
|||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
|
"webpki 0.22.0",
|
||||||
|
"webpki-roots",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@@ -36,3 +36,8 @@ mas-jose = { path = "../jose" }
|
|||||||
mas-keystore = { path = "../keystore" }
|
mas-keystore = { path = "../keystore" }
|
||||||
mas-storage = { path = "../storage" }
|
mas-storage = { path = "../storage" }
|
||||||
mas-templates = { path = "../templates" }
|
mas-templates = { path = "../templates" }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["native-roots"]
|
||||||
|
native-roots = ["mas-http/native-roots"]
|
||||||
|
webpki-roots = ["mas-http/webpki-roots"]
|
||||||
|
@@ -39,7 +39,7 @@ use serde::{de::DeserializeOwned, Deserialize};
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use sqlx::PgExecutor;
|
use sqlx::PgExecutor;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tower::ServiceExt;
|
use tower::{Service, ServiceExt};
|
||||||
|
|
||||||
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
||||||
|
|
||||||
@@ -177,12 +177,12 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result<PublicJsonWebKeySet, BoxErro
|
|||||||
.body(http_body::Empty::new())
|
.body(http_body::Empty::new())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let client = mas_http::client("fetch-jwks")
|
let mut client = mas_http::client("fetch-jwks")
|
||||||
|
.await?
|
||||||
.response_body_to_bytes()
|
.response_body_to_bytes()
|
||||||
.json_response::<PublicJsonWebKeySet>()
|
.json_response::<PublicJsonWebKeySet>();
|
||||||
.map_err(Box::new);
|
|
||||||
|
|
||||||
let response = client.oneshot(request).await?;
|
let response = client.ready().await?.call(request).await?;
|
||||||
|
|
||||||
Ok(response.into_body())
|
Ok(response.into_body())
|
||||||
}
|
}
|
||||||
|
@@ -34,7 +34,7 @@ opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqw
|
|||||||
|
|
||||||
mas-config = { path = "../config" }
|
mas-config = { path = "../config" }
|
||||||
mas-email = { path = "../email" }
|
mas-email = { path = "../email" }
|
||||||
mas-handlers = { path = "../handlers" }
|
mas-handlers = { path = "../handlers", default-features = false }
|
||||||
mas-http = { path = "../http", features = ["axum"] }
|
mas-http = { path = "../http", features = ["axum"] }
|
||||||
mas-policy = { path = "../policy" }
|
mas-policy = { path = "../policy" }
|
||||||
mas-router = { path = "../router" }
|
mas-router = { path = "../router" }
|
||||||
@@ -47,8 +47,16 @@ mas-templates = { path = "../templates" }
|
|||||||
indoc = "1.0.7"
|
indoc = "1.0.7"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["otlp", "jaeger", "zipkin"]
|
default = ["otlp", "jaeger", "zipkin", "native-roots"]
|
||||||
|
|
||||||
|
# Use the native root certificates
|
||||||
|
native-roots = ["mas-http/native-roots", "mas-handlers/native-roots"]
|
||||||
|
# Use the webpki root certificates
|
||||||
|
webpki-roots = ["mas-http/webpki-roots", "mas-handlers/webpki-roots"]
|
||||||
|
|
||||||
|
# Read the builtin static files and templates from the source directory
|
||||||
dev = ["mas-templates/dev", "mas-static-files/dev"]
|
dev = ["mas-templates/dev", "mas-static-files/dev"]
|
||||||
|
|
||||||
# Enable OpenTelemetry OTLP exporter. Requires "protoc"
|
# Enable OpenTelemetry OTLP exporter. Requires "protoc"
|
||||||
otlp = ["opentelemetry-otlp"]
|
otlp = ["opentelemetry-otlp"]
|
||||||
# Enable OpenTelemetry Jaeger exporter and propagator.
|
# Enable OpenTelemetry Jaeger exporter and propagator.
|
||||||
|
@@ -72,7 +72,7 @@ impl Options {
|
|||||||
json: false,
|
json: false,
|
||||||
url,
|
url,
|
||||||
} => {
|
} => {
|
||||||
let mut client = mas_http::client("cli-debug-http");
|
let mut client = mas_http::client("cli-debug-http").await?;
|
||||||
let request = hyper::Request::builder()
|
let request = hyper::Request::builder()
|
||||||
.uri(url)
|
.uri(url)
|
||||||
.body(hyper::Body::empty())?;
|
.body(hyper::Body::empty())?;
|
||||||
@@ -97,6 +97,7 @@ impl Options {
|
|||||||
url,
|
url,
|
||||||
} => {
|
} => {
|
||||||
let mut client = mas_http::client("cli-debug-http")
|
let mut client = mas_http::client("cli-debug-http")
|
||||||
|
.await?
|
||||||
.response_body_to_bytes()
|
.response_body_to_bytes()
|
||||||
.json_response();
|
.json_response();
|
||||||
let request = hyper::Request::builder()
|
let request = hyper::Request::builder()
|
||||||
|
@@ -47,7 +47,7 @@ rand = "0.8.5"
|
|||||||
headers = "0.3.8"
|
headers = "0.3.8"
|
||||||
|
|
||||||
oauth2-types = { path = "../oauth2-types" }
|
oauth2-types = { path = "../oauth2-types" }
|
||||||
mas-axum-utils = { path = "../axum-utils" }
|
mas-axum-utils = { path = "../axum-utils", default-features = false }
|
||||||
mas-data-model = { path = "../data-model" }
|
mas-data-model = { path = "../data-model" }
|
||||||
mas-email = { path = "../email" }
|
mas-email = { path = "../email" }
|
||||||
mas-http = { path = "../http" }
|
mas-http = { path = "../http" }
|
||||||
@@ -61,3 +61,11 @@ mas-templates = { path = "../templates" }
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
indoc = "1.0.7"
|
indoc = "1.0.7"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["native-roots"]
|
||||||
|
|
||||||
|
# Use the native root certificates
|
||||||
|
native-roots = ["mas-axum-utils/native-roots", "mas-http/native-roots"]
|
||||||
|
# Use the webpki root certificates
|
||||||
|
webpki-roots = ["mas-axum-utils/webpki-roots", "mas-http/webpki-roots"]
|
||||||
|
@@ -13,21 +13,24 @@ headers = "0.3.8"
|
|||||||
http = "0.2.8"
|
http = "0.2.8"
|
||||||
http-body = "0.4.5"
|
http-body = "0.4.5"
|
||||||
hyper = "0.14.20"
|
hyper = "0.14.20"
|
||||||
hyper-rustls = { version = "0.23.0", features = ["http1", "http2", "rustls-native-certs"], default-features = false, optional = true }
|
hyper-rustls = { version = "0.23.0", features = ["http1", "http2"], default-features = false, optional = true }
|
||||||
once_cell = "1.15.0"
|
once_cell = "1.15.0"
|
||||||
opentelemetry = "0.17.0"
|
opentelemetry = "0.17.0"
|
||||||
opentelemetry-http = "0.6.0"
|
opentelemetry-http = "0.6.0"
|
||||||
opentelemetry-semantic-conventions = "0.9.0"
|
opentelemetry-semantic-conventions = "0.9.0"
|
||||||
rustls = "0.20.6"
|
rustls = { version = "0.20.6", optional = true }
|
||||||
|
rustls-native-certs = { version = "0.6.2", optional = true }
|
||||||
serde = "1.0.145"
|
serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
serde_urlencoded = "0.7.1"
|
serde_urlencoded = "0.7.1"
|
||||||
thiserror = "1.0.36"
|
thiserror = "1.0.36"
|
||||||
tokio = { version = "1.21.1", optional = true }
|
tokio = { version = "1.21.1", features = ["sync", "parking_lot"], optional = true }
|
||||||
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
||||||
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
tower-http = { version = "0.3.4", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-opentelemetry = "0.17.4"
|
tracing-opentelemetry = "0.17.4"
|
||||||
|
webpki = { version = "0.22.0", optional = true }
|
||||||
|
webpki-roots = { version = "0.22.4", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1.0.65"
|
anyhow = "1.0.65"
|
||||||
@@ -38,4 +41,12 @@ tower = { version = "0.4.13", features = ["util"] }
|
|||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
axum = ["dep:axum"]
|
axum = ["dep:axum"]
|
||||||
client = ["dep:hyper-rustls", "hyper/tcp", "dep:tokio", "tokio?/sync", "tokio?/parking_lot"]
|
native-roots = ["dep:rustls-native-certs"]
|
||||||
|
webpki-roots = ["dep:webpki-roots"]
|
||||||
|
client = [
|
||||||
|
"dep:rustls",
|
||||||
|
"hyper/tcp",
|
||||||
|
"dep:hyper-rustls",
|
||||||
|
"dep:tokio",
|
||||||
|
"dep:webpki",
|
||||||
|
]
|
||||||
|
@@ -12,50 +12,158 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::{convert::Infallible, net::SocketAddr};
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::{FutureExt, TryFutureExt};
|
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
use http_body::{combinators::BoxBody, Body};
|
use http_body::{combinators::BoxBody, Body};
|
||||||
use hyper::{
|
use hyper::{
|
||||||
client::{connect::dns::GaiResolver, HttpConnector},
|
client::{
|
||||||
|
connect::dns::{GaiResolver, Name},
|
||||||
|
HttpConnector,
|
||||||
|
},
|
||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder};
|
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::{sync::OnceCell, task::JoinError};
|
use tower::{util::BoxCloneService, Service, ServiceBuilder, ServiceExt};
|
||||||
use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
layers::{
|
layers::{
|
||||||
client::{ClientLayer, ClientResponse},
|
client::{ClientLayer, ClientResponse},
|
||||||
otel::{TraceDns, TraceLayer},
|
otel::{TraceDns, TraceLayer},
|
||||||
},
|
},
|
||||||
BoxError, FutureService,
|
BoxError,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
|
||||||
|
compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features");
|
||||||
|
|
||||||
|
#[cfg(all(feature = "webpki-roots", feature = "native-roots"))]
|
||||||
|
compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive");
|
||||||
|
|
||||||
|
#[cfg(feature = "native-roots")]
|
||||||
|
static NATIVE_TLS_ROOTS: tokio::sync::OnceCell<rustls::RootCertStore> =
|
||||||
|
tokio::sync::OnceCell::const_new();
|
||||||
|
|
||||||
|
#[cfg(feature = "native-roots")]
|
||||||
|
fn load_tls_roots_blocking() -> Result<rustls::RootCertStore, NativeRootsLoadError> {
|
||||||
|
let mut roots = rustls::RootCertStore::empty();
|
||||||
|
let certs = rustls_native_certs::load_native_certs()?;
|
||||||
|
for cert in certs {
|
||||||
|
let cert = rustls::Certificate(cert.0);
|
||||||
|
roots.add(&cert)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if roots.is_empty() {
|
||||||
|
return Err(NativeRootsLoadError::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(roots)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "native-roots")]
|
||||||
|
async fn tls_roots() -> Result<rustls::RootCertStore, NativeRootsInitError> {
|
||||||
|
NATIVE_TLS_ROOTS
|
||||||
|
.get_or_try_init(|| async move {
|
||||||
|
// Load the TLS config once in a blocking task because loading the system
|
||||||
|
// certificates can take a long time (~200ms) on macOS
|
||||||
|
let span = tracing::info_span!("load_tls_roots");
|
||||||
|
let roots = tokio::task::spawn_blocking(|| {
|
||||||
|
let _span = span.entered();
|
||||||
|
load_tls_roots_blocking()
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
Ok(roots)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "webpki-roots")]
|
||||||
|
async fn tls_roots() -> Result<rustls::RootCertStore, Infallible> {
|
||||||
|
let mut roots = rustls::RootCertStore::empty();
|
||||||
|
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||||
|
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
Ok(roots)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "native-roots")]
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
#[error(transparent)]
|
||||||
|
pub enum NativeRootsInitError {
|
||||||
|
RootsLoadError(#[from] NativeRootsLoadError),
|
||||||
|
|
||||||
|
JoinError(#[from] tokio::task::JoinError),
|
||||||
|
}
|
||||||
|
|
||||||
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
/// A wrapper over a boxed error that implements ``std::error::Error``.
|
||||||
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
/// This is helps converting to ``anyhow::Error`` with the `?` operator
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ClientError {
|
#[error(transparent)]
|
||||||
#[error("failed to initialize HTTPS client")]
|
pub struct ClientError {
|
||||||
Init(#[from] ClientInitError),
|
#[from]
|
||||||
|
inner: BoxError,
|
||||||
#[error(transparent)]
|
|
||||||
Call(#[from] BoxError),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
pub enum ClientInitError {
|
pub enum ClientInitError {
|
||||||
#[error("failed to load system certificates")]
|
#[cfg(feature = "native-roots")]
|
||||||
CertificateLoad {
|
#[error(transparent)]
|
||||||
#[from]
|
TlsRootsInit(std::sync::Arc<NativeRootsInitError>),
|
||||||
inner: Arc<JoinError>, // That error is in an Arc to have the error implement Clone
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static TLS_CONFIG: OnceCell<rustls::ClientConfig> = OnceCell::const_new();
|
#[cfg(feature = "native-roots")]
|
||||||
|
impl From<NativeRootsInitError> for ClientInitError {
|
||||||
|
fn from(inner: NativeRootsInitError) -> Self {
|
||||||
|
Self::TlsRootsInit(std::sync::Arc::new(inner))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Infallible> for ClientInitError {
|
||||||
|
fn from(_: Infallible) -> Self {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "native-roots")]
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum NativeRootsLoadError {
|
||||||
|
#[error("could not load root certificates")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("invalid root certificate")]
|
||||||
|
Webpki(#[from] webpki::Error),
|
||||||
|
|
||||||
|
#[error("no root certificate loaded")]
|
||||||
|
Empty,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if it failed to load the TLS certificates
|
||||||
|
pub async fn make_untraced_client<B, E>(
|
||||||
|
) -> Result<hyper::Client<HttpsConnector<HttpConnector<GaiResolver>>, B>, ClientInitError>
|
||||||
|
where
|
||||||
|
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||||
|
E: Into<BoxError>,
|
||||||
|
{
|
||||||
|
let resolver = GaiResolver::new();
|
||||||
|
let roots = tls_roots().await?;
|
||||||
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(roots)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
Ok(make_client(resolver, tls_config))
|
||||||
|
}
|
||||||
|
|
||||||
async fn make_base_client<B, E>(
|
async fn make_base_client<B, E>(
|
||||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||||
@@ -68,54 +176,57 @@ where
|
|||||||
.layer(TraceLayer::dns())
|
.layer(TraceLayer::dns())
|
||||||
.service(GaiResolver::new());
|
.service(GaiResolver::new());
|
||||||
|
|
||||||
|
let roots = tls_roots().await?;
|
||||||
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(roots)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
Ok(make_client(resolver, tls_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_client<R, B, E>(
|
||||||
|
resolver: R,
|
||||||
|
tls_config: rustls::ClientConfig,
|
||||||
|
) -> hyper::Client<HttpsConnector<HttpConnector<R>>, B>
|
||||||
|
where
|
||||||
|
R: Service<Name> + Send + Sync + Clone + 'static,
|
||||||
|
R::Error: std::error::Error + Send + Sync,
|
||||||
|
R::Future: Send,
|
||||||
|
R::Response: Iterator<Item = SocketAddr>,
|
||||||
|
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||||
|
E: Into<BoxError>,
|
||||||
|
{
|
||||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||||
http.enforce_http(false);
|
http.enforce_http(false);
|
||||||
|
|
||||||
let tls_config = TLS_CONFIG
|
|
||||||
.get_or_try_init(|| async move {
|
|
||||||
// Load the TLS config once in a blocking task because loading the system
|
|
||||||
// certificates can take a long time (~200ms) on macOS
|
|
||||||
let span = tracing::info_span!("load_certificates");
|
|
||||||
tokio::task::spawn_blocking(|| {
|
|
||||||
let _span = span.entered();
|
|
||||||
rustls::ClientConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_native_roots()
|
|
||||||
.with_no_client_auth()
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| ClientInitError::from(Arc::new(e)))?;
|
|
||||||
|
|
||||||
let https = HttpsConnectorBuilder::new()
|
let https = HttpsConnectorBuilder::new()
|
||||||
.with_tls_config(tls_config.clone())
|
.with_tls_config(tls_config)
|
||||||
.https_or_http()
|
.https_or_http()
|
||||||
.enable_http1()
|
.enable_http1()
|
||||||
.enable_http2()
|
.enable_http2()
|
||||||
.wrap_connector(http);
|
.wrap_connector(http);
|
||||||
|
|
||||||
// TODO: we should get the remote address here
|
Client::builder().build(https)
|
||||||
let client = Client::builder().build(https);
|
|
||||||
|
|
||||||
Ok::<_, ClientInitError>(client)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
/// Create a traced HTTP client, with a default timeout, which follows redirects
|
||||||
pub fn client<B, E>(
|
/// and handles compression
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if it failed to initialize
|
||||||
|
pub async fn client<B, E>(
|
||||||
operation: &'static str,
|
operation: &'static str,
|
||||||
) -> BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>
|
) -> Result<
|
||||||
|
BoxCloneService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>,
|
||||||
|
ClientInitError,
|
||||||
|
>
|
||||||
where
|
where
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||||
E: Into<BoxError> + 'static,
|
E: Into<BoxError> + 'static,
|
||||||
{
|
{
|
||||||
let fut = make_base_client()
|
let client = make_base_client().await?;
|
||||||
// Map the error to a ClientError
|
|
||||||
.map_ok(|s| s.map_err(|e| ClientError::from(BoxError::from(e))))
|
|
||||||
// Wrap it in an Shared (Arc) to be able to Clone it
|
|
||||||
.shared();
|
|
||||||
|
|
||||||
let client: FutureService<_, _> = FutureService::new(fut);
|
|
||||||
|
|
||||||
let client = ServiceBuilder::new()
|
let client = ServiceBuilder::new()
|
||||||
// Convert the errors to ClientError to help dealing with them
|
// Convert the errors to ClientError to help dealing with them
|
||||||
@@ -124,7 +235,8 @@ where
|
|||||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
r.map(|body| body.map_err(ClientError::from).boxed())
|
||||||
})
|
})
|
||||||
.layer(ClientLayer::new(operation))
|
.layer(ClientLayer::new(operation))
|
||||||
.service(client);
|
.service(client)
|
||||||
|
.boxed_clone();
|
||||||
|
|
||||||
client.boxed_clone()
|
Ok(client)
|
||||||
}
|
}
|
||||||
|
@@ -1,77 +0,0 @@
|
|||||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
//! A copy of [`tower::util::FutureService`] that also maps the future error to
|
|
||||||
//! help implementing [`Clone`] on the service
|
|
||||||
|
|
||||||
use std::{
|
|
||||||
future::Future,
|
|
||||||
pin::Pin,
|
|
||||||
task::{Context, Poll},
|
|
||||||
};
|
|
||||||
|
|
||||||
use futures_util::ready;
|
|
||||||
use tower::Service;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct FutureService<F, S> {
|
|
||||||
state: State<F, S>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F, S> FutureService<F, S> {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(future: F) -> Self {
|
|
||||||
Self {
|
|
||||||
state: State::Future(future),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
enum State<F, S> {
|
|
||||||
Future(F),
|
|
||||||
Service(S),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F, S, R, FE, E> Service<R> for FutureService<F, S>
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<S, FE>> + Unpin,
|
|
||||||
S: Service<R, Error = E>,
|
|
||||||
E: From<FE>,
|
|
||||||
{
|
|
||||||
type Response = S::Response;
|
|
||||||
type Error = E;
|
|
||||||
type Future = S::Future;
|
|
||||||
|
|
||||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
loop {
|
|
||||||
self.state = match &mut self.state {
|
|
||||||
State::Future(fut) => {
|
|
||||||
let fut = Pin::new(fut);
|
|
||||||
let svc = ready!(fut.poll(cx)?);
|
|
||||||
State::Service(svc)
|
|
||||||
}
|
|
||||||
State::Service(svc) => return svc.poll_ready(cx),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, req: R) -> Self::Future {
|
|
||||||
if let State::Service(svc) = &mut self.state {
|
|
||||||
svc.call(req)
|
|
||||||
} else {
|
|
||||||
panic!("FutureService::call was called before FutureService::poll_ready")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -27,17 +27,15 @@
|
|||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
mod client;
|
mod client;
|
||||||
mod ext;
|
mod ext;
|
||||||
mod future_service;
|
|
||||||
mod layers;
|
mod layers;
|
||||||
|
|
||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
pub use self::client::client;
|
pub use self::client::{client, make_untraced_client};
|
||||||
pub use self::{
|
pub use self::{
|
||||||
ext::{
|
ext::{
|
||||||
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
set_propagator, CorsLayerExt, ServiceBuilderExt as HttpServiceBuilderExt,
|
||||||
ServiceExt as HttpServiceExt,
|
ServiceExt as HttpServiceExt,
|
||||||
},
|
},
|
||||||
future_service::FutureService,
|
|
||||||
layers::{
|
layers::{
|
||||||
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
|
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
|
||||||
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
|
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
|
||||||
|
Reference in New Issue
Block a user