1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-23 11:02:35 +03:00

Cleanup HTTP client building

This commit is contained in:
Quentin Gliech
2022-11-23 12:26:58 +01:00
parent 16088fc11c
commit d514a8922c
4 changed files with 71 additions and 72 deletions

View File

@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{marker::PhantomData, time::Duration};
use std::{sync::Arc, time::Duration};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use http::{header::USER_AGENT, HeaderValue};
use tokio::sync::Semaphore;
use tower::{
limit::{ConcurrencyLimit, ConcurrencyLimitLayer},
Layer, Service,
limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
Layer,
};
use tower_http::{
follow_redirect::{FollowRedirect, FollowRedirectLayer},
@@ -26,56 +27,60 @@ use tower_http::{
};
use super::otel::TraceLayer;
use crate::{otel::TraceHttpClient, BoxError};
use crate::otel::{TraceHttpClient, TraceHttpClientLayer};
static MAS_USER_AGENT: HeaderValue =
HeaderValue::from_static("matrix-authentication-service/0.0.1");
pub type ClientService<S> = SetRequestHeader<
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
HeaderValue,
>;
#[derive(Debug, Clone)]
pub struct ClientLayer<ReqBody> {
operation: &'static str,
_t: PhantomData<ReqBody>,
pub struct ClientLayer {
user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
outer_trace_layer: TraceHttpClientLayer,
concurrency_limit_layer: GlobalConcurrencyLimitLayer,
follow_redirect_layer: FollowRedirectLayer,
inner_trace_layer: TraceHttpClientLayer,
timeout_layer: TimeoutLayer,
}
impl<B> ClientLayer<B> {
impl ClientLayer {
#[must_use]
pub fn new(operation: &'static str) -> Self {
let semaphore = Arc::new(Semaphore::new(10));
Self::with_semaphore(operation, semaphore)
}
#[must_use]
pub fn with_semaphore(operation: &'static str, semaphore: Arc<Semaphore>) -> Self {
Self {
operation,
_t: PhantomData,
user_agent_layer: SetRequestHeaderLayer::overriding(
USER_AGENT,
HeaderValue::from_static("matrix-authentication-service/0.0.1"),
),
outer_trace_layer: TraceLayer::http_client(operation),
concurrency_limit_layer: GlobalConcurrencyLimitLayer::with_semaphore(semaphore),
follow_redirect_layer: FollowRedirectLayer::new(),
inner_trace_layer: TraceLayer::inner_http_client(),
timeout_layer: TimeoutLayer::new(Duration::from_secs(10)),
}
}
}
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
impl<S> Layer<S> for ClientLayer
where
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E>
+ Clone
+ Send
+ Sync
+ 'static,
ReqBody: http_body::Body + Default + Send + 'static,
ResBody: http_body::Body + Sync + Send + 'static,
S::Future: Send + 'static,
E: Into<BoxError>,
S: Clone,
{
type Service = SetRequestHeader<
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
HeaderValue,
>;
type Service = ClientService<S>;
fn layer(&self, inner: S) -> Self::Service {
// Note that all layers here just forward the error type.
(
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
// A trace that has the whole operation, with all the redirects, timeouts and rate
// limits in it
TraceLayer::http_client(self.operation),
ConcurrencyLimitLayer::new(10),
FollowRedirectLayer::new(),
// A trace for each "real" http request
TraceLayer::inner_http_client(),
TimeoutLayer::new(Duration::from_secs(10)),
&self.user_agent_layer,
&self.outer_trace_layer,
&self.concurrency_limit_layer,
&self.follow_redirect_layer,
&self.inner_trace_layer,
&self.timeout_layer,
)
.layer(inner)
}