You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Use rustls-platform-verifier for cert validation
This simplifies by removing the mutually exclusive `native-roots` and `webpki-roots` features with something that is suitable for all platforms.
This commit is contained in:
@ -12,44 +12,38 @@ repository.workspace = true
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.6.20", optional = true }
|
||||
bytes = "1.5.0"
|
||||
futures-util = "0.3.30"
|
||||
headers = "0.3.9"
|
||||
http.workspace = true
|
||||
http-body = "0.4.5"
|
||||
hyper = "0.14.27"
|
||||
hyper-rustls = { version = "0.25.0", features = ["http1", "http2"], default-features = false, optional = true }
|
||||
once_cell = "1.19.0"
|
||||
hyper-rustls = { workspace = true, optional = true }
|
||||
opentelemetry.workspace = true
|
||||
rustls = { version = "0.22.2", optional = true }
|
||||
rustls-native-certs = { version = "0.7.0", optional = true }
|
||||
rustls = { workspace = true, optional = true }
|
||||
rustls-platform-verifier = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_urlencoded = "0.7.1"
|
||||
thiserror.workspace = true
|
||||
tokio = { version = "1.35.1", features = ["sync", "parking_lot"], optional = true }
|
||||
tower = { version = "0.4.13", features = ["util"] }
|
||||
tower.workspace = true
|
||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||
tracing.workspace = true
|
||||
tracing-opentelemetry.workspace = true
|
||||
webpki-roots = { version = "0.26.0", optional = true }
|
||||
|
||||
mas-tower.workspace = true
|
||||
mas-tower = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow.workspace = true
|
||||
tokio = { version = "1.35.1", features = ["macros", "rt"] }
|
||||
|
||||
[features]
|
||||
axum = ["dep:axum"]
|
||||
native-roots = ["dep:rustls-native-certs"]
|
||||
webpki-roots = ["dep:webpki-roots"]
|
||||
client = [
|
||||
"dep:mas-tower",
|
||||
"dep:rustls",
|
||||
"hyper/tcp",
|
||||
"dep:hyper-rustls",
|
||||
"dep:tokio",
|
||||
"dep:rustls-platform-verifier",
|
||||
"tower/limit",
|
||||
"tower-http/timeout",
|
||||
"tower-http/follow-redirect",
|
||||
|
@ -12,8 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use hyper::client::{
|
||||
connect::dns::{GaiResolver, Name},
|
||||
HttpConnector,
|
||||
@ -24,143 +22,21 @@ use mas_tower::{
|
||||
DurationRecorderLayer, DurationRecorderService, FnWrapper, InFlightCounterLayer,
|
||||
InFlightCounterService, TraceLayer, TraceService,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tower::Layer;
|
||||
use tracing::Span;
|
||||
|
||||
#[cfg(all(not(feature = "webpki-roots"), not(feature = "native-roots")))]
|
||||
compile_error!("enabling the 'client' feature requires also enabling the 'webpki-roots' or the 'native-roots' features");
|
||||
|
||||
#[cfg(all(feature = "webpki-roots", feature = "native-roots"))]
|
||||
compile_error!("'webpki-roots' and 'native-roots' features are mutually exclusive");
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
static NATIVE_TLS_ROOTS: tokio::sync::OnceCell<rustls::RootCertStore> =
|
||||
tokio::sync::OnceCell::const_new();
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
fn load_tls_roots_blocking() -> Result<rustls::RootCertStore, NativeRootsLoadError> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
let certs = rustls_native_certs::load_native_certs()?;
|
||||
for cert in certs {
|
||||
roots.add(cert)?;
|
||||
}
|
||||
|
||||
if roots.is_empty() {
|
||||
return Err(NativeRootsLoadError::Empty);
|
||||
}
|
||||
|
||||
Ok(roots)
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
async fn tls_roots() -> Result<rustls::RootCertStore, NativeRootsInitError> {
|
||||
NATIVE_TLS_ROOTS
|
||||
.get_or_try_init(|| async move {
|
||||
// Load the TLS config once in a blocking task because loading the system
|
||||
// certificates can take a long time (~200ms) on macOS
|
||||
let span = tracing::info_span!("load_tls_roots");
|
||||
let roots = tokio::task::spawn_blocking(|| {
|
||||
let _span = span.entered();
|
||||
load_tls_roots_blocking()
|
||||
})
|
||||
.await??;
|
||||
Ok(roots)
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(feature = "webpki-roots")]
|
||||
#[allow(clippy::unused_async)]
|
||||
async fn tls_roots() -> Result<rustls::RootCertStore, Infallible> {
|
||||
let root_store = rustls::RootCertStore {
|
||||
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
|
||||
};
|
||||
|
||||
Ok(root_store)
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
#[derive(Error, Debug)]
|
||||
#[error(transparent)]
|
||||
pub enum NativeRootsInitError {
|
||||
RootsLoadError(#[from] NativeRootsLoadError),
|
||||
|
||||
JoinError(#[from] tokio::task::JoinError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientInitError {
|
||||
#[cfg(feature = "native-roots")]
|
||||
#[error(transparent)]
|
||||
TlsRootsInit(std::sync::Arc<NativeRootsInitError>),
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
impl From<NativeRootsInitError> for ClientInitError {
|
||||
fn from(inner: NativeRootsInitError) -> Self {
|
||||
Self::TlsRootsInit(std::sync::Arc::new(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for ClientInitError {
|
||||
fn from(e: Infallible) -> Self {
|
||||
match e {}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "native-roots")]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum NativeRootsLoadError {
|
||||
#[error("could not load root certificates")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("invalid root certificate")]
|
||||
Rustls(#[from] rustls::Error),
|
||||
|
||||
#[error("no root certificate loaded")]
|
||||
Empty,
|
||||
}
|
||||
|
||||
async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
|
||||
let roots = tls_roots().await?;
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
|
||||
Ok(tls_config)
|
||||
}
|
||||
|
||||
pub type UntracedClient<B> = hyper::Client<UntracedConnector, B>;
|
||||
pub type TracedClient<B> = hyper::Client<TracedConnector, B>;
|
||||
|
||||
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to load the TLS certificates
|
||||
pub async fn make_untraced_client<B>() -> Result<UntracedClient<B>, ClientInitError>
|
||||
#[must_use]
|
||||
pub fn make_untraced_client<B>() -> UntracedClient<B>
|
||||
where
|
||||
B: http_body::Body + Send + 'static,
|
||||
B::Data: Send,
|
||||
{
|
||||
let https = make_untraced_connector().await?;
|
||||
Ok(Client::builder().build(https))
|
||||
}
|
||||
|
||||
/// Create a basic Hyper HTTP & HTTPS client which traces DNS requests
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to load the TLS certificates
|
||||
pub async fn make_traced_client<B>() -> Result<TracedClient<B>, ClientInitError>
|
||||
where
|
||||
B: http_body::Body + Send,
|
||||
B::Data: Send,
|
||||
{
|
||||
let https = make_traced_connector().await?;
|
||||
Ok(Client::builder().build(https))
|
||||
let https = make_untraced_connector();
|
||||
Client::builder().build(https)
|
||||
}
|
||||
|
||||
pub type TraceResolver<S> =
|
||||
@ -169,11 +45,8 @@ pub type UntracedConnector = HttpsConnector<HttpConnector<GaiResolver>>;
|
||||
pub type TracedConnector = HttpsConnector<HttpConnector<TraceResolver<GaiResolver>>>;
|
||||
|
||||
/// Create a traced HTTP and HTTPS connector
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if it failed to load the TLS certificates
|
||||
pub async fn make_traced_connector() -> Result<TracedConnector, ClientInitError>
|
||||
#[must_use]
|
||||
pub fn make_traced_connector() -> TracedConnector
|
||||
where
|
||||
{
|
||||
let in_flight_counter = InFlightCounterLayer::new("dns.resolve.active_requests");
|
||||
@ -190,16 +63,16 @@ where
|
||||
|
||||
let resolver = (in_flight_counter, duration_recorder, trace_layer).layer(GaiResolver::new());
|
||||
|
||||
let tls_config = make_tls_config().await?;
|
||||
Ok(make_connector(resolver, tls_config))
|
||||
let tls_config = rustls_platform_verifier::tls_config();
|
||||
make_connector(resolver, tls_config)
|
||||
}
|
||||
|
||||
async fn make_untraced_connector() -> Result<UntracedConnector, ClientInitError>
|
||||
fn make_untraced_connector() -> UntracedConnector
|
||||
where
|
||||
{
|
||||
let resolver = GaiResolver::new();
|
||||
let tls_config = make_tls_config().await?;
|
||||
Ok(make_connector(resolver, tls_config))
|
||||
let tls_config = rustls_platform_verifier::tls_config();
|
||||
make_connector(resolver, tls_config)
|
||||
}
|
||||
|
||||
fn make_connector<R>(
|
||||
|
@ -12,10 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::RangeBounds;
|
||||
use std::{ops::RangeBounds, sync::OnceLock};
|
||||
|
||||
use http::{header::HeaderName, Request, StatusCode};
|
||||
use once_cell::sync::OnceCell;
|
||||
use tower::Service;
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
@ -25,7 +24,7 @@ use crate::layers::{
|
||||
json_request::JsonRequest, json_response::JsonResponse,
|
||||
};
|
||||
|
||||
static PROPAGATOR_HEADERS: OnceCell<Vec<HeaderName>> = OnceCell::new();
|
||||
static PROPAGATOR_HEADERS: OnceLock<Vec<HeaderName>> = OnceLock::new();
|
||||
|
||||
/// Notify the CORS layer what opentelemetry propagators are being used. This
|
||||
/// helps whitelisting headers in CORS requests.
|
||||
|
@ -26,8 +26,8 @@ mod service;
|
||||
#[cfg(feature = "client")]
|
||||
pub use self::{
|
||||
client::{
|
||||
make_traced_client, make_traced_connector, make_untraced_client, Client, ClientInitError,
|
||||
TracedClient, TracedConnector, UntracedClient, UntracedConnector,
|
||||
make_traced_connector, make_untraced_client, Client, TracedClient, TracedConnector,
|
||||
UntracedClient, UntracedConnector,
|
||||
},
|
||||
layers::client::{ClientLayer, ClientService},
|
||||
};
|
||||
|
Reference in New Issue
Block a user