1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-21 23:00:50 +03:00

Use rustls-platform-verifier for cert validation

This simplifies by removing the mutually exclusive `native-roots` and
`webpki-roots` features with something that is suitable for all
platforms.
This commit is contained in:
Quentin Gliech
2024-03-06 11:23:42 +01:00
parent 58d91f91d2
commit 6eb6209bd8
25 changed files with 173 additions and 258 deletions

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::Infallible;
use hyper::client::{
connect::dns::{GaiResolver, Name},
HttpConnector,
@@ -24,143 +22,21 @@ use mas_tower::{
DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer,
InFlightCounterService, TraceLayer, TraceService,
};
use thiserror::Error;
use tower::Layer;
use tracing::Span;
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features");
#[cfg(all(feature = "webpki-roots", feature = "native-roots"))]
compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive");
#[cfg(feature = "native-roots")]
static NATIVE_TLS_ROOTS: tokio::sync::OnceCell<rustls::RootCertStore> =
tokio::sync::OnceCell::const_new();
#[cfg(feature = "native-roots")]
fn load_tls_roots_blocking() -> Result<rustls::RootCertStore, NativeRootsLoadError> {
let mut roots = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs()?;
for cert in certs {
roots.add(cert)?;
}
if roots.is_empty() {
return Err(NativeRootsLoadError::Empty);
}
Ok(roots)
}
#[cfg(feature = "native-roots")]
async fn tls_roots() -> Result<rustls::RootCertStore, NativeRootsInitError> {
NATIVE_TLS_ROOTS
.get_or_try_init(|| async move {
// Load the TLS config once in a blocking task because loading the system
// certificates can take a long time (~200ms) on macOS
let span = tracing::info_span!("load_tls_roots");
let roots = tokio::task::spawn_blocking(|| {
let _span = span.entered();
load_tls_roots_blocking()
})
.await??;
Ok(roots)
})
.await
.cloned()
}
#[cfg(feature = "webpki-roots")]
#[allow(clippy::unused_async)]
async fn tls_roots() -> Result<rustls::RootCertStore, Infallible> {
let root_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
Ok(root_store)
}
#[cfg(feature = "native-roots")]
#[derive(Error, Debug)]
#[error(transparent)]
pub enum NativeRootsInitError {
RootsLoadError(#[from] NativeRootsLoadError),
JoinError(#[from] tokio::task::JoinError),
}
#[derive(Error, Debug, Clone)]
pub enum ClientInitError {
#[cfg(feature = "native-roots")]
#[error(transparent)]
TlsRootsInit(std::sync::Arc<NativeRootsInitError>),
}
#[cfg(feature = "native-roots")]
impl From<NativeRootsInitError> for ClientInitError {
fn from(inner: NativeRootsInitError) -> Self {
Self::TlsRootsInit(std::sync::Arc::new(inner))
}
}
impl From<Infallible> for ClientInitError {
fn from(e: Infallible) -> Self {
match e {}
}
}
#[cfg(feature = "native-roots")]
#[derive(Error, Debug)]
pub enum NativeRootsLoadError {
#[error("could not load root certificates")]
Io(#[from] std::io::Error),
#[error("invalid root certificate")]
Rustls(#[from] rustls::Error),
#[error("no root certificate loaded")]
Empty,
}
async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(tls_config)
}
pub type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
pub type TracedClient<B> = hyper::Client<TracedConnector, B>;
/// Create a basic Hyper HTTP & HTTPS client without any tracing
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_untraced_client<B>() -> Result<UntracedClient<B>, ClientInitError>
#[must_use]
pub fn make_untraced_client<B>() -> UntracedClient<B>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
{
let https = make_untraced_connector().await?;
Ok(Client::builder().build(https))
}
/// Create a basic Hyper HTTP & HTTPS client which traces DNS requests
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
where
B: http_body::Body + Send,
B::Data: Send,
{
let https = make_traced_connector().await?;
Ok(Client::builder().build(https))
let https = make_untraced_connector();
Client::builder().build(https)
}
pub type TraceResolver<S> =
@@ -169,11 +45,8 @@ pub type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
pub type TracedConnector = HttpsConnector<HttpConnector<TraceResolver<GaiResolver>>>;
/// Create a traced HTTP and HTTPS connector
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_traced_connector() -> Result<TracedConnector, ClientInitError>
#[must_use]
pub fn make_traced_connector() -> TracedConnector
where
{
let in_flight_counter = InFlightCounterLayer::new("dns.resolve.active_requests");
@@ -190,16 +63,16 @@ where
let resolver = (in_flight_counter, duration_recorder, trace_layer).layer(GaiResolver::new());
let tls_config = make_tls_config().await?;
Ok(make_connector(resolver, tls_config))
let tls_config = rustls_platform_verifier::tls_config();
make_connector(resolver, tls_config)
}
async fn make_untraced_connector() -> Result<UntracedConnector, ClientInitError>
fn make_untraced_connector() -> UntracedConnector
where
{
let resolver = GaiResolver::new();
let tls_config = make_tls_config().await?;
Ok(make_connector(resolver, tls_config))
let tls_config = rustls_platform_verifier::tls_config();
make_connector(resolver, tls_config)
}
fn make_connector<R>(

View File

@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::RangeBounds;
use std::{ops::RangeBounds, sync::OnceLock};
use http::{header::HeaderName, Request, StatusCode};
use once_cell::sync::OnceCell;
use tower::Service;
use tower_http::cors::CorsLayer;
@@ -25,7 +24,7 @@ use crate::layers::{
json_request::JsonRequest, json_response::JsonResponse,
};
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
static PROPAGATOR_HEADERS: OnceLock<Vec<HeaderName>> = OnceLock::new();
/// Notify the CORS layer what opentelemetry propagators are being used. This
/// helps whitelisting headers in CORS requests.

View File

@@ -26,8 +26,8 @@ mod service;
#[cfg(feature = "client")]
pub use self::{
client::{
make_traced_client, make_traced_connector, make_untraced_client, Client, ClientInitError,
TracedClient, TracedConnector, UntracedClient, UntracedConnector,
make_traced_connector, make_untraced_client, Client, TracedClient, TracedConnector,
UntracedClient, UntracedConnector,
},
layers::client::{ClientLayer, ClientService},
};