1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Add a global HTTP client factory

This commit is contained in:
Quentin Gliech
2022-11-23 13:18:48 +01:00
parent d514a8922c
commit 4227fa7a83
14 changed files with 163 additions and 83 deletions

View File

@ -41,6 +41,8 @@ use sqlx::PgExecutor;
use thiserror::Error; use thiserror::Error;
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
use crate::http_client_factory::HttpClientFactory;
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
#[derive(Deserialize)] #[derive(Deserialize)]
@ -91,6 +93,7 @@ impl Credentials {
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
pub async fn verify<S: StorageBackend>( pub async fn verify<S: StorageBackend>(
&self, &self,
http_client_factory: &HttpClientFactory,
encrypter: &Encrypter, encrypter: &Encrypter,
method: &OAuthClientAuthenticationMethod, method: &OAuthClientAuthenticationMethod,
client: &Client<S>, client: &Client<S>,
@ -132,7 +135,7 @@ impl Credentials {
.as_ref() .as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?; .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let jwks = fetch_jwks(jwks) let jwks = fetch_jwks(http_client_factory, jwks)
.await .await
.map_err(|_| CredentialsVerificationError::JwksFetchFailed)?; .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
@ -166,7 +169,10 @@ impl Credentials {
} }
} }
async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result<PublicJsonWebKeySet, BoxError> { async fn fetch_jwks(
http_client_factory: &HttpClientFactory,
jwks: &JwksOrJwksUri,
) -> Result<PublicJsonWebKeySet, BoxError> {
let uri = match jwks { let uri = match jwks {
JwksOrJwksUri::Jwks(j) => return Ok(j.clone()), JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
JwksOrJwksUri::JwksUri(u) => u, JwksOrJwksUri::JwksUri(u) => u,
@ -177,7 +183,8 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result<PublicJsonWebKeySet, BoxErro
.body(mas_http::EmptyBody::new()) .body(mas_http::EmptyBody::new())
.unwrap(); .unwrap();
let mut client = mas_http::client("fetch-jwks") let mut client = http_client_factory
.client("fetch-jwks")
.await? .await?
.response_body_to_bytes() .response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>(); .json_response::<PublicJsonWebKeySet>();

View File

@ -0,0 +1,78 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::body::Full;
use mas_http::{
BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService,
TracedClient,
};
use tokio::sync::Semaphore;
use tower::{
util::{MapErrLayer, MapRequestLayer},
BoxError, Layer,
};
#[derive(Debug, Clone)]
pub struct HttpClientFactory {
semaphore: Arc<Semaphore>,
}
impl HttpClientFactory {
#[must_use]
pub fn new(concurrency_limit: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(concurrency_limit)),
}
}
/// Constructs a new HTTP client
///
/// # Errors
///
/// Returns an error if the client failed to initialise
pub async fn client<B>(
&self,
operation: &'static str,
) -> Result<ClientService<TracedClient<B>>, ClientInitError>
where
B: axum::body::HttpBody + Send + Sync + 'static,
B::Data: Send,
{
let client = mas_http::make_traced_client::<B>().await?;
let layer = ClientLayer::with_semaphore(operation, self.semaphore.clone());
Ok(layer.layer(client))
}
/// Constructs a new [`HttpService`], suitable for [`mas_oidc_client`]
///
/// # Errors
///
/// Returns an error if the client failed to initialise
pub async fn http_service(
&self,
operation: &'static str,
) -> Result<HttpService, ClientInitError> {
let client = self.client(operation).await?;
let client = (
MapErrLayer::new(BoxError::from),
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
BodyToBytesResponseLayer::default(),
)
.layer(client);
Ok(HttpService::new(client))
}
}

View File

@ -26,6 +26,7 @@ pub mod client_authorization;
pub mod cookies; pub mod cookies;
pub mod csrf; pub mod csrf;
pub mod fancy_error; pub mod fancy_error;
pub mod http_client_factory;
pub mod jwt; pub mod jwt;
pub mod session; pub mod session;
pub mod user_authorization; pub mod user_authorization;

View File

@ -16,6 +16,7 @@ use anyhow::Context;
use clap::Parser; use clap::Parser;
use hyper::{Response, Uri}; use hyper::{Response, Uri};
use mas_config::PolicyConfig; use mas_config::PolicyConfig;
use mas_handlers::HttpClientFactory;
use mas_http::HttpServiceExt; use mas_http::HttpServiceExt;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
@ -66,13 +67,14 @@ impl Options {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC; use Subcommand as SC;
let http_client_factory = HttpClientFactory::new(10);
match &self.subcommand { match &self.subcommand {
SC::Http { SC::Http {
show_headers, show_headers,
json: false, json: false,
url, url,
} => { } => {
let mut client = mas_http::client("cli-debug-http").await?; let mut client = http_client_factory.client("cli-debug-http").await?;
let request = hyper::Request::builder() let request = hyper::Request::builder()
.uri(url) .uri(url)
.body(hyper::Body::empty())?; .body(hyper::Body::empty())?;
@ -96,7 +98,8 @@ impl Options {
json: true, json: true,
url, url,
} => { } => {
let mut client = mas_http::client("cli-debug-http") let mut client = http_client_factory
.client("cli-debug-http")
.await? .await?
.response_body_to_bytes() .response_body_to_bytes()
.json_response(); .json_response();

View File

@ -20,7 +20,7 @@ use futures_util::stream::{StreamExt, TryStreamExt};
use itertools::Itertools; use itertools::Itertools;
use mas_config::RootConfig; use mas_config::RootConfig;
use mas_email::Mailer; use mas_email::Mailer;
use mas_handlers::{AppState, MatrixHomeserver}; use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver};
use mas_http::ServerLayer; use mas_http::ServerLayer;
use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
@ -187,6 +187,9 @@ impl Options {
let graphql_schema = mas_handlers::graphql_schema(&pool); let graphql_schema = mas_handlers::graphql_schema(&pool);
// Maximum 50 outgoing HTTP requests at a time
let http_client_factory = HttpClientFactory::new(50);
let state = AppState { let state = AppState {
pool, pool,
templates, templates,
@ -197,6 +200,7 @@ impl Options {
homeserver, homeserver,
policy_factory, policy_factory,
graphql_schema, graphql_schema,
http_client_factory,
}; };
let mut fd_manager = listenfd::ListenFd::from_env(); let mut fd_manager = listenfd::ListenFd::from_env();

View File

@ -15,6 +15,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::FromRef; use axum::extract::FromRef;
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_email::Mailer; use mas_email::Mailer;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
@ -35,6 +36,7 @@ pub struct AppState {
pub homeserver: MatrixHomeserver, pub homeserver: MatrixHomeserver,
pub policy_factory: Arc<PolicyFactory>, pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: mas_graphql::Schema, pub graphql_schema: mas_graphql::Schema,
pub http_client_factory: HttpClientFactory,
} }
impl FromRef<AppState> for PgPool { impl FromRef<AppState> for PgPool {
@ -90,3 +92,8 @@ impl FromRef<AppState> for Arc<PolicyFactory> {
input.policy_factory.clone() input.policy_factory.clone()
} }
} }
impl FromRef<AppState> for HttpClientFactory {
fn from_ref(input: &AppState) -> Self {
input.http_client_factory.clone()
}
}

View File

@ -55,9 +55,9 @@ mod oauth2;
mod upstream_oauth2; mod upstream_oauth2;
mod views; mod views;
pub use compat::MatrixHomeserver; pub use mas_axum_utils::http_client_factory::HttpClientFactory;
pub use self::{app_state::AppState, graphql::schema as graphql_schema}; pub use self::{app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema};
#[must_use] #[must_use]
pub fn healthcheck_router<S, B>() -> Router<S, B> pub fn healthcheck_router<S, B>() -> Router<S, B>
@ -138,6 +138,7 @@ where
Arc<PolicyFactory>: FromRef<S>, Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>, PgPool: FromRef<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
{ {
// All those routes are API-like, with a common CORS layer // All those routes are API-like, with a common CORS layer
Router::new() Router::new()
@ -235,6 +236,7 @@ where
Templates: FromRef<S>, Templates: FromRef<S>,
Mailer: FromRef<S>, Mailer: FromRef<S>,
Keystore: FromRef<S>, Keystore: FromRef<S>,
HttpClientFactory: FromRef<S>,
{ {
Router::new() Router::new()
.route( .route(
@ -363,6 +365,8 @@ async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
let graphql_schema = graphql_schema(&pool); let graphql_schema = graphql_schema(&pool);
let http_client_factory = HttpClientFactory::new(10);
Ok(AppState { Ok(AppState {
pool, pool,
templates, templates,
@ -373,6 +377,7 @@ async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
homeserver, homeserver,
policy_factory, policy_factory,
graphql_schema, graphql_schema,
http_client_factory,
}) })
} }

View File

@ -14,7 +14,10 @@
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError}; use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
http_client_factory::HttpClientFactory,
};
use mas_data_model::{TokenFormatError, TokenType}; use mas_data_model::{TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
@ -155,6 +158,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) async fn post( pub(crate) async fn post(
State(http_client_factory): State<HttpClientFactory>,
State(pool): State<PgPool>, State(pool): State<PgPool>,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<IntrospectionRequest>, client_authorization: ClientAuthorization<IntrospectionRequest>,
@ -173,7 +177,7 @@ pub(crate) async fn post(
client_authorization client_authorization
.credentials .credentials
.verify(&encrypter, method, &client) .verify(&http_client_factory, &encrypter, method, &client)
.await?; .await?;
let form = if let Some(form) = client_authorization.form { let form = if let Some(form) = client_authorization.form {

View File

@ -19,7 +19,10 @@ use axum::{extract::State, response::IntoResponse, Json};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError}; use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
http_client_factory::HttpClientFactory,
};
use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{ use mas_jose::{
@ -191,6 +194,7 @@ impl From<JwtSignatureError> for RouteError {
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
pub(crate) async fn post( pub(crate) async fn post(
State(http_client_factory): State<HttpClientFactory>,
State(key_store): State<Keystore>, State(key_store): State<Keystore>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(pool): State<PgPool>, State(pool): State<PgPool>,
@ -208,7 +212,7 @@ pub(crate) async fn post(
client_authorization client_authorization
.credentials .credentials
.verify(&encrypter, method, &client) .verify(&http_client_factory, &encrypter, method, &client)
.await?; .await?;
let form = client_authorization.form.ok_or(RouteError::BadRequest)?; let form = client_authorization.form.ok_or(RouteError::BadRequest)?;

View File

@ -18,6 +18,7 @@ use axum::{
}; };
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_http::ClientInitError; use mas_http::ClientInitError;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_oidc_client::{ use mas_oidc_client::{
@ -30,8 +31,6 @@ use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::http_service;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub(crate) enum RouteError { pub(crate) enum RouteError {
#[error("Provider not found")] #[error("Provider not found")]
@ -89,6 +88,7 @@ impl IntoResponse for RouteError {
} }
pub(crate) async fn get( pub(crate) async fn get(
State(http_client_factory): State<HttpClientFactory>,
State(pool): State<PgPool>, State(pool): State<PgPool>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
@ -103,7 +103,9 @@ pub(crate) async fn get(
.to_option()? .to_option()?
.ok_or(RouteError::ProviderNotFound)?; .ok_or(RouteError::ProviderNotFound)?;
let http_service = http_service("upstream-discover").await?; let http_service = http_client_factory
.http_service("upstream-discover")
.await?;
// First, discover the provider // First, discover the provider
let metadata = let metadata =

View File

@ -19,6 +19,7 @@ use axum::{
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_http::ClientInitError; use mas_http::ClientInitError;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::{ use mas_oidc_client::{
@ -33,7 +34,7 @@ use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::{client_credentials_for_provider, http_service, ProviderCredentialsError}; use super::{client_credentials_for_provider, ProviderCredentialsError};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct QueryParams { pub struct QueryParams {
@ -144,8 +145,9 @@ impl IntoResponse for RouteError {
} }
} }
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub(crate) async fn get( pub(crate) async fn get(
State(http_client_factory): State<HttpClientFactory>,
State(pool): State<PgPool>, State(pool): State<PgPool>,
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
@ -195,13 +197,19 @@ pub(crate) async fn get(
CodeOrError::Code { code } => code, CodeOrError::Code { code } => code,
}; };
let http_service = http_service("upstream-code-exchange").await?; let http_service = http_client_factory
.http_service("upstream-discover")
.await?;
// XXX: we shouldn't discover on-the-fly // XXX: we shouldn't discover on-the-fly
// Discover the provider // Discover the provider
let metadata = let metadata =
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?; mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
let http_service = http_client_factory
.http_service("upstream-fetch-jwks")
.await?;
// Fetch the JWKS // Fetch the JWKS
let jwks = let jwks =
mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?; mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?;
@ -231,6 +239,10 @@ pub(crate) async fn get(
client_id: &provider.client_id, client_id: &provider.client_id,
}; };
let http_service = http_client_factory
.http_service("upstream-exchange-code")
.await?;
let (response, _id_token) = let (response, _id_token) =
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&http_service, &http_service,

View File

@ -13,17 +13,11 @@
// limitations under the License. // limitations under the License.
use anyhow::Context; use anyhow::Context;
use axum::body::Full;
use mas_data_model::UpstreamOAuthProvider; use mas_data_model::UpstreamOAuthProvider;
use mas_http::{BodyToBytesResponseLayer, ClientInitError, ClientLayer, HttpService};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::types::client_credentials::{ClientCredentials, JwtSigningMethod}; use mas_oidc_client::types::client_credentials::{ClientCredentials, JwtSigningMethod};
use thiserror::Error; use thiserror::Error;
use tower::{
util::{MapErrLayer, MapRequestLayer},
BoxError, Layer,
};
use url::Url; use url::Url;
pub(crate) mod authorize; pub(crate) mod authorize;
@ -101,15 +95,3 @@ fn client_credentials_for_provider(
Ok(client_credentials) Ok(client_credentials)
} }
async fn http_service(operation: &'static str) -> Result<HttpService, ClientInitError> {
let client = (
MapErrLayer::new(BoxError::from),
MapRequestLayer::new(|req: hyper::Request<_>| req.map(Full::new)),
BodyToBytesResponseLayer::default(),
ClientLayer::new(operation),
)
.layer(mas_http::make_untraced_client().await?);
Ok(HttpService::new(client))
}

View File

@ -22,13 +22,7 @@ use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use thiserror::Error; use thiserror::Error;
use tower::Layer; use tower::Layer;
use crate::{ use crate::layers::otel::{TraceDns, TraceLayer};
layers::{
client::{ClientLayer, ClientService},
otel::{TraceDns, TraceLayer},
},
BoxError,
};
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))] #[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features"); compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features");
@ -97,15 +91,6 @@ pub enum NativeRootsInitError {
JoinError(#[from] tokio::task::JoinError), JoinError(#[from] tokio::task::JoinError),
} }
/// A wrapper over a boxed error that implements ``std::error::Error``.
/// This is helps converting to ``anyhow::Error`` with the `?` operator
#[derive(Error, Debug)]
#[error(transparent)]
pub struct ClientError {
#[from]
inner: BoxError,
}
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientInitError { pub enum ClientInitError {
#[cfg(feature = "native-roots")] #[cfg(feature = "native-roots")]
@ -121,8 +106,8 @@ impl From<NativeRootsInitError> for ClientInitError {
} }
impl From<Infallible> for ClientInitError { impl From<Infallible> for ClientInitError {
fn from(_: Infallible) -> Self { fn from(e: Infallible) -> Self {
unreachable!() match e {}
} }
} }
@ -149,8 +134,8 @@ async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
Ok(tls_config) Ok(tls_config)
} }
type UntracedClient<B> = hyper::Client<UntracedConnector, B>; pub type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
type TracedClient<B> = hyper::Client<TracedConnector, B>; pub type TracedClient<B> = hyper::Client<TracedConnector, B>;
/// Create a basic Hyper HTTP & HTTPS client without any tracing /// Create a basic Hyper HTTP & HTTPS client without any tracing
/// ///
@ -166,7 +151,12 @@ where
Ok(Client::builder().build(https)) Ok(Client::builder().build(https))
} }
async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError> /// Create a basic Hyper HTTP & HTTPS client which traces DNS requests
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
where where
B: http_body::Body + Send + 'static, B: http_body::Body + Send + 'static,
B::Data: Send, B::Data: Send,
@ -175,8 +165,8 @@ where
Ok(Client::builder().build(https)) Ok(Client::builder().build(https))
} }
type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>; pub type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
type TracedConnector = HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>; pub type TracedConnector = HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>;
/// Create a traced HTTP and HTTPS connector /// Create a traced HTTP and HTTPS connector
/// ///
@ -214,23 +204,3 @@ fn make_connector<R>(
.enable_http2() .enable_http2()
.wrap_connector(http) .wrap_connector(http)
} }
/// Create a traced HTTP client, with a default timeout, which follows redirects
///
/// # Errors
///
/// Returns an error if it failed to initialize
pub async fn client<B>(
operation: &'static str,
) -> Result<ClientService<TracedClient<B>>, ClientInitError>
where
B: http_body::Body + Default + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
let client = make_traced_client().await?;
let client = ClientLayer::new(operation).layer(client);
Ok(client)
}

View File

@ -31,14 +31,17 @@ mod layers;
mod service; mod service;
#[cfg(feature = "client")] #[cfg(feature = "client")]
pub use self::client::{client, make_traced_connector, make_untraced_client, ClientInitError}; pub use self::client::{
make_traced_client, make_traced_connector, make_untraced_client, ClientInitError, TracedClient,
TracedConnector, UntracedClient, UntracedConnector,
};
pub use self::{ pub use self::{
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
layers::{ layers::{
body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer},
bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer}, bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer},
catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer}, catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer},
client::ClientLayer, client::{ClientLayer, ClientService},
form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer}, form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer},
json_request::{self, JsonRequest, JsonRequestLayer}, json_request::{self, JsonRequest, JsonRequestLayer},
json_response::{self, JsonResponse, JsonResponseLayer}, json_response::{self, JsonResponse, JsonResponseLayer},
@ -48,6 +51,4 @@ pub use self::{
service::{BoxCloneSyncService, HttpService}, service::{BoxCloneSyncService, HttpService},
}; };
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type EmptyBody = http_body::Empty<bytes::Bytes>; pub type EmptyBody = http_body::Empty<bytes::Bytes>;