1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Make the HTTP client factory reuse the underlying client

This avoids duplicating clients, and makes it so that they all share the same connection pool.
This commit is contained in:
Quentin Gliech
2023-09-14 14:22:49 +02:00
parent f29e4adcfa
commit 54071c4969
15 changed files with 146 additions and 77 deletions

View File

@@ -191,8 +191,7 @@ async fn fetch_jwks(
.unwrap();
let mut client = http_client_factory
.client()
.await?
.client("client.fetch_jwks")
.response_body_to_bytes()
.json_response::<PublicJsonWebKeySet>();

View File

@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::body::Full;
use mas_http::{
BodyToBytesResponseLayer, ClientInitError, ClientLayer, ClientService, HttpService,
TracedClient,
make_traced_connector, BodyToBytesResponseLayer, Client, ClientInitError, ClientLayer,
ClientService, HttpService, TracedClient, TracedConnector,
};
use tokio::sync::Semaphore;
use tower::{
util::{MapErrLayer, MapRequestLayer},
BoxError, Layer,
@@ -27,15 +24,16 @@ use tower::{
#[derive(Debug, Clone)]
pub struct HttpClientFactory {
semaphore: Arc<Semaphore>,
traced_connector: TracedConnector,
client_layer: ClientLayer,
}
impl HttpClientFactory {
#[must_use]
pub fn new(concurrency_limit: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(concurrency_limit)),
}
pub async fn new() -> Result<Self, ClientInitError> {
Ok(Self {
traced_connector: make_traced_connector().await?,
client_layer: ClientLayer::new(),
})
}
/// Constructs a new HTTP client
@@ -43,14 +41,16 @@ impl HttpClientFactory {
/// # Errors
///
/// Returns an error if the client failed to initialise
pub async fn client<B>(&self) -> Result<ClientService<TracedClient<B>>, ClientInitError>
pub fn client<B>(&self, category: &'static str) -> ClientService<TracedClient<B>>
where
B: axum::body::HttpBody + Send,
B::Data: Send,
{
let client = mas_http::make_traced_client::<B>().await?;
let layer = ClientLayer::with_semaphore(self.semaphore.clone());
Ok(layer.layer(client))
let client = Client::builder().build(self.traced_connector.clone());
self.client_layer
.clone()
.with_category(category)
.layer(client)
}
/// Constructs a new [`HttpService`], suitable for `mas-oidc-client`
@@ -58,8 +58,8 @@ impl HttpClientFactory {
/// # Errors
///
/// Returns an error if the client failed to initialise
pub async fn http_service(&self) -> Result<HttpService, ClientInitError> {
let client = self.client().await?;
pub fn http_service(&self, category: &'static str) -> HttpService {
let client = self.client(category);
let client = (
MapErrLayer::new(BoxError::from),
MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)),
@@ -67,6 +67,6 @@ impl HttpClientFactory {
)
.layer(client);
Ok(HttpService::new(client))
HttpService::new(client)
}
}