diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 6710dd35..706f21ee 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -41,6 +41,8 @@ use sqlx::PgExecutor; use thiserror::Error; use tower::{Service, ServiceExt}; +use crate::http_client_factory::HttpClientFactory; + static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; #[derive(Deserialize)] @@ -91,6 +93,7 @@ impl Credentials { #[tracing::instrument(skip_all, err)] pub async fn verify( &self, + http_client_factory: &HttpClientFactory, encrypter: &Encrypter, method: &OAuthClientAuthenticationMethod, client: &Client, @@ -132,7 +135,7 @@ impl Credentials { .as_ref() .ok_or(CredentialsVerificationError::InvalidClientConfig)?; - let jwks = fetch_jwks(jwks) + let jwks = fetch_jwks(http_client_factory, jwks) .await .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?; @@ -166,7 +169,10 @@ impl Credentials { } } -async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result { +async fn fetch_jwks( + http_client_factory: &HttpClientFactory, + jwks: &JwksOrJwksUri, +) -> Result { let uri = match jwks { JwksOrJwksUri::Jwks(j) => return Ok(j.clone()), JwksOrJwksUri::JwksUri(u) => u, @@ -177,7 +183,8 @@ async fn fetch_jwks(jwks: &JwksOrJwksUri) -> Result(); diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs new file mode 100644 index 00000000..25f5b155 --- /dev/null +++ b/crates/axum-utils/src/http_client_factory.rs @@ -0,0 +1,78 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use axum::body::Full; +use mas_http::{ + BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService, + TracedClient, +}; +use tokio::sync::Semaphore; +use tower::{ + util::{MapErrLayer, MapRequestLayer}, + BoxError, Layer, +}; + +#[derive(Debug, Clone)] +pub struct HttpClientFactory { + semaphore: Arc, +} + +impl HttpClientFactory { + #[must_use] + pub fn new(concurrency_limit: usize) -> Self { + Self { + semaphore: Arc::new(Semaphore::new(concurrency_limit)), + } + } + + /// Constructs a new HTTP client + /// + /// # Errors + /// + /// Returns an error if the client failed to initialise + pub async fn client( + &self, + operation: &'static str, + ) -> Result>, ClientInitError> + where + B: axum::body::HttpBody + Send + Sync + 'static, + B::Data: Send, + { + let client = mas_http::make_traced_client::().await?; + let layer = ClientLayer::with_semaphore(operation, self.semaphore.clone()); + Ok(layer.layer(client)) + } + + /// Constructs a new [`HttpService`], suitable for [`mas_oidc_client`] + /// + /// # Errors + /// + /// Returns an error if the client failed to initialise + pub async fn http_service( + &self, + operation: &'static str, + ) -> Result { + let client = self.client(operation).await?; + let client = ( + MapErrLayer::new(BoxError::from), + MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), + BodyToBytesResponseLayer::default(), + ) + .layer(client); + + Ok(HttpService::new(client)) + } +} diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs index 4897e0ba..d03fe4db 100644 --- a/crates/axum-utils/src/lib.rs +++ b/crates/axum-utils/src/lib.rs @@ -26,6 +26,7 @@ pub mod client_authorization; pub mod cookies; pub mod csrf; pub mod fancy_error; +pub mod http_client_factory; pub mod jwt; pub mod session; pub mod user_authorization; diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 00007d86..f15123d1 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -16,6 +16,7 @@ use anyhow::Context; use clap::Parser; use hyper::{Response, Uri}; use mas_config::PolicyConfig; +use mas_handlers::HttpClientFactory; use mas_http::HttpServiceExt; use mas_policy::PolicyFactory; use tokio::io::AsyncWriteExt; @@ -66,13 +67,14 @@ impl Options { #[tracing::instrument(skip_all)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; + let http_client_factory = HttpClientFactory::new(10); match &self.subcommand { SC::Http { show_headers, json: false, url, } => { - let mut client = mas_http::client("cli-debug-http").await?; + let mut client = http_client_factory.client("cli-debug-http").await?; let request = hyper::Request::builder() .uri(url) .body(hyper::Body::empty())?; @@ -96,7 +98,8 @@ impl Options { json: true, url, } => { - let mut client = mas_http::client("cli-debug-http") + let mut client = http_client_factory + .client("cli-debug-http") .await? .response_body_to_bytes() .json_response(); diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index bc042de5..cf4dd5fc 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -20,7 +20,7 @@ use futures_util::stream::{StreamExt, TryStreamExt}; use itertools::Itertools; use mas_config::RootConfig; use mas_email::Mailer; -use mas_handlers::{AppState, MatrixHomeserver}; +use mas_handlers::{AppState, HttpClientFactory, MatrixHomeserver}; use mas_http::ServerLayer; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_policy::PolicyFactory; @@ -187,6 +187,9 @@ impl Options { let graphql_schema = mas_handlers::graphql_schema(&pool); + // Maximum 50 outgoing HTTP requests at a time + let http_client_factory = HttpClientFactory::new(50); + let state = AppState { pool, templates, @@ -197,6 +200,7 @@ impl Options { homeserver, policy_factory, graphql_schema, + http_client_factory, }; let mut fd_manager = listenfd::ListenFd::from_env(); diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index fb8a149d..daea4ae0 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use axum::extract::FromRef; +use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_email::Mailer; use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; @@ -35,6 +36,7 @@ pub struct AppState { pub homeserver: MatrixHomeserver, pub policy_factory: Arc, pub graphql_schema: mas_graphql::Schema, + pub http_client_factory: HttpClientFactory, } impl FromRef for PgPool { @@ -90,3 +92,8 @@ impl FromRef for Arc { input.policy_factory.clone() } } +impl FromRef for HttpClientFactory { + fn from_ref(input: &AppState) -> Self { + input.http_client_factory.clone() + } +} diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 974d849b..36d993bb 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -55,9 +55,9 @@ mod oauth2; mod upstream_oauth2; mod views; -pub use compat::MatrixHomeserver; +pub use mas_axum_utils::http_client_factory::HttpClientFactory; -pub use self::{app_state::AppState, graphql::schema as graphql_schema}; +pub use self::{app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema}; #[must_use] pub fn healthcheck_router() -> Router @@ -138,6 +138,7 @@ where Arc: FromRef, PgPool: FromRef, Encrypter: FromRef, + HttpClientFactory: FromRef, { // All those routes are API-like, with a common CORS layer Router::new() @@ -235,6 +236,7 @@ where Templates: FromRef, Mailer: FromRef, Keystore: FromRef, + HttpClientFactory: FromRef, { Router::new() .route( @@ -363,6 +365,8 @@ async fn test_state(pool: PgPool) -> Result { let graphql_schema = graphql_schema(&pool); + let http_client_factory = HttpClientFactory::new(10); + Ok(AppState { pool, templates, @@ -373,6 +377,7 @@ async fn test_state(pool: PgPool) -> Result { homeserver, policy_factory, graphql_schema, + http_client_factory, }) } diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 75f898a3..2a122ec6 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -14,7 +14,10 @@ use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; -use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError}; +use mas_axum_utils::{ + client_authorization::{ClientAuthorization, CredentialsVerificationError}, + http_client_factory::HttpClientFactory, +}; use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; @@ -155,6 +158,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse { #[allow(clippy::too_many_lines)] pub(crate) async fn post( + State(http_client_factory): State, State(pool): State, State(encrypter): State, client_authorization: ClientAuthorization, @@ -173,7 +177,7 @@ pub(crate) async fn post( client_authorization .credentials - .verify(&encrypter, method, &client) + .verify(&http_client_factory, &encrypter, method, &client) .await?; let form = if let Some(form) = client_authorization.form { diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index bf37121d..725e0f04 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -19,7 +19,10 @@ use axum::{extract::State, response::IntoResponse, Json}; use chrono::{DateTime, Duration, Utc}; use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; use hyper::StatusCode; -use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError}; +use mas_axum_utils::{ + client_authorization::{ClientAuthorization, CredentialsVerificationError}, + http_client_factory::HttpClientFactory, +}; use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ @@ -191,6 +194,7 @@ impl From for RouteError { #[tracing::instrument(skip_all, err)] pub(crate) async fn post( + State(http_client_factory): State, State(key_store): State, State(url_builder): State, State(pool): State, @@ -208,7 +212,7 @@ pub(crate) async fn post( client_authorization .credentials - .verify(&encrypter, method, &client) + .verify(&http_client_factory, &encrypter, method, &client) .await?; let form = client_authorization.form.ok_or(RouteError::BadRequest)?; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index f473582b..c6b94049 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -18,6 +18,7 @@ use axum::{ }; use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use hyper::StatusCode; +use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_http::ClientInitError; use mas_keystore::Encrypter; use mas_oidc_client::{ @@ -30,8 +31,6 @@ use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; -use super::http_service; - #[derive(Debug, Error)] pub(crate) enum RouteError { #[error("Provider not found")] @@ -89,6 +88,7 @@ impl IntoResponse for RouteError { } pub(crate) async fn get( + State(http_client_factory): State, State(pool): State, State(url_builder): State, cookie_jar: PrivateCookieJar, @@ -103,7 +103,9 @@ pub(crate) async fn get( .to_option()? .ok_or(RouteError::ProviderNotFound)?; - let http_service = http_service("upstream-discover").await?; + let http_service = http_client_factory + .http_service("upstream-discover") + .await?; // First, discover the provider let metadata = diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 11e6b2e9..632a150b 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -19,6 +19,7 @@ use axum::{ }; use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; +use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_http::ClientInitError; use mas_keystore::{Encrypter, Keystore}; use mas_oidc_client::{ @@ -33,7 +34,7 @@ use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; -use super::{client_credentials_for_provider, http_service, ProviderCredentialsError}; +use super::{client_credentials_for_provider, ProviderCredentialsError}; #[derive(Deserialize)] pub struct QueryParams { @@ -144,8 +145,9 @@ impl IntoResponse for RouteError { } } -#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] pub(crate) async fn get( + State(http_client_factory): State, State(pool): State, State(url_builder): State, State(encrypter): State, @@ -195,13 +197,19 @@ pub(crate) async fn get( CodeOrError::Code { code } => code, }; - let http_service = http_service("upstream-code-exchange").await?; + let http_service = http_client_factory + .http_service("upstream-discover") + .await?; // XXX: we shouldn't discover on-the-fly // Discover the provider let metadata = mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?; + let http_service = http_client_factory + .http_service("upstream-fetch-jwks") + .await?; + // Fetch the JWKS let jwks = mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?; @@ -231,6 +239,10 @@ pub(crate) async fn get( client_id: &provider.client_id, }; + let http_service = http_client_factory + .http_service("upstream-exchange-code") + .await?; + let (response, _id_token) = mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( &http_service, diff --git a/crates/handlers/src/upstream_oauth2/mod.rs b/crates/handlers/src/upstream_oauth2/mod.rs index 9147f6ee..4cb889bc 100644 --- a/crates/handlers/src/upstream_oauth2/mod.rs +++ b/crates/handlers/src/upstream_oauth2/mod.rs @@ -13,17 +13,11 @@ // limitations under the License. use anyhow::Context; -use axum::body::Full; use mas_data_model::UpstreamOAuthProvider; -use mas_http::{BodyToBytesResponseLayer, ClientInitError, ClientLayer, HttpService}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_keystore::{Encrypter, Keystore}; use mas_oidc_client::types::client_credentials::{ClientCredentials, JwtSigningMethod}; use thiserror::Error; -use tower::{ - util::{MapErrLayer, MapRequestLayer}, - BoxError, Layer, -}; use url::Url; pub(crate) mod authorize; @@ -101,15 +95,3 @@ fn client_credentials_for_provider( Ok(client_credentials) } - -async fn http_service(operation: &'static str) -> Result { - let client = ( - MapErrLayer::new(BoxError::from), - MapRequestLayer::new(|req: hyper::Request<_>| req.map(Full::new)), - BodyToBytesResponseLayer::default(), - ClientLayer::new(operation), - ) - .layer(mas_http::make_untraced_client().await?); - - Ok(HttpService::new(client)) -} diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs index c20dbb21..19d9d726 100644 --- a/crates/http/src/client.rs +++ b/crates/http/src/client.rs @@ -22,13 +22,7 @@ use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use thiserror::Error; use tower::Layer; -use crate::{ - layers::{ - client::{ClientLayer, ClientService}, - otel::{TraceDns, TraceLayer}, - }, - BoxError, -}; +use crate::layers::otel::{TraceDns, TraceLayer}; #[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))] compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features"); @@ -97,15 +91,6 @@ pub enum NativeRootsInitError { JoinError(#[from] tokio::task::JoinError), } -/// A wrapper over a boxed error that implements ``std::error::Error``. -/// This is helps converting to ``anyhow::Error`` with the `?` operator -#[derive(Error, Debug)] -#[error(transparent)] -pub struct ClientError { - #[from] - inner: BoxError, -} - #[derive(Error, Debug, Clone)] pub enum ClientInitError { #[cfg(feature = "native-roots")] @@ -121,8 +106,8 @@ impl From for ClientInitError { } impl From for ClientInitError { - fn from(_: Infallible) -> Self { - unreachable!() + fn from(e: Infallible) -> Self { + match e {} } } @@ -149,8 +134,8 @@ async fn make_tls_config() -> Result { Ok(tls_config) } -type UntracedClient = hyper::Client; -type TracedClient = hyper::Client; +pub type UntracedClient = hyper::Client; +pub type TracedClient = hyper::Client; /// Create a basic Hyper HTTP & HTTPS client without any tracing /// @@ -166,7 +151,12 @@ where Ok(Client::builder().build(https)) } -async fn make_traced_client() -> Result, ClientInitError> +/// Create a basic Hyper HTTP & HTTPS client which traces DNS requests +/// +/// # Errors +/// +/// Returns an error if it failed to load the TLS certificates +pub async fn make_traced_client() -> Result, ClientInitError> where B: http_body::Body + Send + 'static, B::Data: Send, @@ -175,8 +165,8 @@ where Ok(Client::builder().build(https)) } -type UntracedConnector = HttpsConnector>; -type TracedConnector = HttpsConnector>>; +pub type UntracedConnector = HttpsConnector>; +pub type TracedConnector = HttpsConnector>>; /// Create a traced HTTP and HTTPS connector /// @@ -214,23 +204,3 @@ fn make_connector( .enable_http2() .wrap_connector(http) } - -/// Create a traced HTTP client, with a default timeout, which follows redirects -/// -/// # Errors -/// -/// Returns an error if it failed to initialize -pub async fn client( - operation: &'static str, -) -> Result>, ClientInitError> -where - B: http_body::Body + Default + Send + 'static, - B::Data: Send, - B::Error: Into, -{ - let client = make_traced_client().await?; - - let client = ClientLayer::new(operation).layer(client); - - Ok(client) -} diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 74165ced..d343f516 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -31,14 +31,17 @@ mod layers; mod service; #[cfg(feature = "client")] -pub use self::client::{client, make_traced_connector, make_untraced_client, ClientInitError}; +pub use self::client::{ + make_traced_client, make_traced_connector, make_untraced_client, ClientInitError, TracedClient, + TracedConnector, UntracedClient, UntracedConnector, +}; pub use self::{ ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, layers::{ body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer}, catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer}, - client::ClientLayer, + client::{ClientLayer, ClientService}, form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer}, json_request::{self, JsonRequest, JsonRequestLayer}, json_response::{self, JsonResponse, JsonResponseLayer}, @@ -48,6 +51,4 @@ pub use self::{ service::{BoxCloneSyncService, HttpService}, }; -pub(crate) type BoxError = Box; - pub type EmptyBody = http_body::Empty;