1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Trace AWS operations & share TLS connector with mas-http

This commit is contained in:
Quentin Gliech
2022-11-03 17:45:49 +01:00
parent b5fd54bbf4
commit a414936484
9 changed files with 125 additions and 50 deletions

9
Cargo.lock generated
View File

@ -390,8 +390,6 @@ dependencies = [
"http", "http",
"http-body", "http-body",
"hyper", "hyper",
"hyper-rustls",
"lazy_static",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
"tower", "tower",
@ -415,8 +413,6 @@ dependencies = [
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"tokio",
"tokio-util 0.7.4",
"tracing", "tracing",
] ]
@ -1980,9 +1976,7 @@ checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac"
dependencies = [ dependencies = [
"http", "http",
"hyper", "hyper",
"log",
"rustls", "rustls",
"rustls-native-certs",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
] ]
@ -2486,7 +2480,10 @@ dependencies = [
"async-trait", "async-trait",
"aws-config", "aws-config",
"aws-sdk-sesv2", "aws-sdk-sesv2",
"aws-smithy-async",
"aws-smithy-client",
"lettre", "lettre",
"mas-http",
"mas-templates", "mas-templates",
"tokio", "tokio",
"tracing", "tracing",

View File

@ -195,7 +195,7 @@ impl EmailTransportConfig {
.context("failed to build SMTP transport") .context("failed to build SMTP transport")
} }
EmailTransportConfig::Sendmail { command } => Ok(MailTransport::sendmail(command)), EmailTransportConfig::Sendmail { command } => Ok(MailTransport::sendmail(command)),
EmailTransportConfig::AwsSes => Ok(MailTransport::aws_ses().await), EmailTransportConfig::AwsSes => Ok(MailTransport::aws_ses().await?),
} }
} }
} }

View File

@ -10,10 +10,14 @@ anyhow = "1.0.66"
async-trait = "0.1.58" async-trait = "0.1.58"
tokio = { version = "1.21.2", features = ["macros"] } tokio = { version = "1.21.2", features = ["macros"] }
tracing = "0.1.37" tracing = "0.1.37"
aws-sdk-sesv2 = "0.21.0"
aws-config = "0.51.0" aws-sdk-sesv2 = { version = "0.21.0", default-features = false }
aws-config = { version = "0.51.0", default-features = false }
aws-smithy-client = { version = "0.51.0", default-features = false, features = ["client-hyper"] }
aws-smithy-async = { version = "0.51.0", default-features = false, features = ["rt-tokio"] }
mas-templates = { path = "../templates" } mas-templates = { path = "../templates" }
mas-http = { path = "../http" }
[dependencies.lettre] [dependencies.lettre]
version = "0.10.1" version = "0.10.1"

View File

@ -12,13 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use aws_config::provider_config::ProviderConfig;
use aws_sdk_sesv2::{ use aws_sdk_sesv2::{
middleware::DefaultMiddleware,
model::{EmailContent, RawMessage}, model::{EmailContent, RawMessage},
types::Blob, types::Blob,
Client, Client,
}; };
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_client::erase::{DynConnector, DynMiddleware};
use lettre::{address::Envelope, AsyncTransport}; use lettre::{address::Envelope, AsyncTransport};
use mas_http::{otel::TraceLayer, ClientInitError};
/// An asynchronous email transport that sends email via the AWS Simple Email /// An asynchronous email transport that sends email via the AWS Simple Email
/// Service v2 API /// Service v2 API
@ -28,17 +35,47 @@ pub struct Transport {
impl Transport { impl Transport {
/// Construct a [`Transport`] from the environment /// Construct a [`Transport`] from the environment
pub async fn from_env() -> Self { ///
let config = aws_config::from_env().load().await; /// # Errors
let config = aws_sdk_sesv2::Config::from(&config); ///
Self::new(config) /// Returns an error if the HTTP client failed to initialize
} pub async fn from_env() -> Result<Self, ClientInitError> {
let sleep = Arc::new(TokioSleep::new());
/// Constructs a [`Transport`] from a given AWS SES SDK config // Create the TCP connector from mas-http. This way we share the root
#[must_use] // certificate loader with it
pub fn new(config: aws_sdk_sesv2::Config) -> Self { let http_connector = mas_http::make_traced_connector()
let client = Client::from_conf(config); .await
Self { client } .expect("failed to create HTTPS connector");
let http_connector = aws_smithy_client::hyper_ext::Adapter::builder()
.sleep_impl(sleep.clone())
.build(http_connector);
let http_connector = DynConnector::new(http_connector);
// Middleware to add tracing to AWS SDK operations
let middleware = DynMiddleware::new((
TraceLayer::with_namespace("aws_sdk")
.make_span_builder(mas_http::otel::DefaultMakeSpanBuilder::new("aws_sdk"))
.on_error(mas_http::otel::DebugOnError),
DefaultMiddleware::default(),
));
// Use that connector for discovering the config
let config = ProviderConfig::default().with_http_connector(http_connector.clone());
let config = aws_config::from_env().configure(config).load().await;
let config = aws_sdk_sesv2::Config::from(&config);
// As well as for the client itself
let client = aws_smithy_client::Client::builder()
.sleep_impl(sleep)
.connector(http_connector)
.middleware(middleware)
.build_dyn();
let client = Client::with_config(client, config);
Ok(Self { client })
} }
} }

View File

@ -25,6 +25,7 @@ use lettre::{
}, },
AsyncTransport, Tokio1Executor, AsyncTransport, Tokio1Executor,
}; };
use mas_http::ClientInitError;
pub mod aws_ses; pub mod aws_ses;
@ -101,8 +102,13 @@ impl Transport {
} }
/// Construct a AWS SES transport /// Construct a AWS SES transport
pub async fn aws_ses() -> Self { ///
Self::new(TransportInner::AwsSes(aws_ses::Transport::from_env().await)) /// # Errors
///
/// Returns an error if the HTTP client failed to initialize
pub async fn aws_ses() -> Result<Self, ClientInitError> {
let transport = aws_ses::Transport::from_env().await?;
Ok(Self::new(TransportInner::AwsSes(transport)))
} }
} }

View File

@ -148,6 +148,16 @@ pub enum NativeRootsLoadError {
Empty, Empty,
} }
async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(tls_config)
}
/// Create a basic Hyper HTTP & HTTPS client without any tracing /// Create a basic Hyper HTTP & HTTPS client without any tracing
/// ///
/// # Errors /// # Errors
@ -159,57 +169,63 @@ where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static, B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>, E: Into<BoxError>,
{ {
let resolver = GaiResolver::new(); let https = make_untraced_connector().await?;
let roots = tls_roots().await?; Ok(Client::builder().build(https))
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(make_client(resolver, tls_config))
} }
async fn make_base_client<B, E>( async fn make_traced_client<B, E>(
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError> ) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
where where
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static, B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>, E: Into<BoxError>,
{ {
// Trace DNS requests let https = make_traced_connector().await?;
let resolver = TraceLayer::dns().layer(GaiResolver::new()); Ok(Client::builder().build(https))
let roots = tls_roots().await?;
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(make_client(resolver, tls_config))
} }
fn make_client<R, B, E>( /// Create a traced HTTP and HTTPS connector
///
/// # Errors
///
/// Returns an error if it failed to load the TLS certificates
pub async fn make_traced_connector(
) -> Result<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, ClientInitError>
where
{
// Trace DNS requests
let resolver = TraceLayer::dns().layer(GaiResolver::new());
let tls_config = make_tls_config().await?;
Ok(make_connector(resolver, tls_config))
}
async fn make_untraced_connector(
) -> Result<HttpsConnector<HttpConnector<GaiResolver>>, ClientInitError>
where
{
let resolver = GaiResolver::new();
let tls_config = make_tls_config().await?;
Ok(make_connector(resolver, tls_config))
}
fn make_connector<R>(
resolver: R, resolver: R,
tls_config: rustls::ClientConfig, tls_config: rustls::ClientConfig,
) -> hyper::Client<HttpsConnector<HttpConnector<R>>, B> ) -> HttpsConnector<HttpConnector<R>>
where where
R: Service<Name> + Send + Sync + Clone + 'static, R: Service<Name> + Send + Sync + Clone + 'static,
R::Error: std::error::Error + Send + Sync, R::Error: std::error::Error + Send + Sync,
R::Future: Send, R::Future: Send,
R::Response: Iterator<Item = SocketAddr>, R::Response: Iterator<Item = SocketAddr>,
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
E: Into<BoxError>,
{ {
let mut http = HttpConnector::new_with_resolver(resolver); let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false); http.enforce_http(false);
let https = HttpsConnectorBuilder::new() HttpsConnectorBuilder::new()
.with_tls_config(tls_config) .with_tls_config(tls_config)
.https_or_http() .https_or_http()
.enable_http1() .enable_http1()
.enable_http2() .enable_http2()
.wrap_connector(http); .wrap_connector(http)
Client::builder().build(https)
} }
/// Create a traced HTTP client, with a default timeout, which follows redirects /// Create a traced HTTP client, with a default timeout, which follows redirects
@ -228,7 +244,7 @@ where
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static, B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
E: Into<BoxError> + 'static, E: Into<BoxError> + 'static,
{ {
let client = make_base_client().await?; let client = make_traced_client().await?;
let layer = ( let layer = (
// Convert the errors to ClientError to help dealing with them // Convert the errors to ClientError to help dealing with them

View File

@ -31,3 +31,16 @@ where
span.add_event("exception".to_owned(), attributes); span.add_event("exception".to_owned(), attributes);
} }
} }
#[derive(Debug, Clone, Copy, Default)]
pub struct DebugOnError;
impl<E> OnError<E> for DebugOnError
where
E: std::fmt::Debug,
{
fn on_error(&self, span: &SpanRef<'_>, _metrics_labels: &mut Vec<KeyValue>, err: &E) {
let attributes = vec![EXCEPTION_MESSAGE.string(format!("{err:?}"))];
span.add_event("exception".to_owned(), attributes);
}
}

View File

@ -30,7 +30,7 @@ mod ext;
mod layers; mod layers;
#[cfg(feature = "client")] #[cfg(feature = "client")]
pub use self::client::{client, make_untraced_client}; pub use self::client::{client, make_traced_connector, make_untraced_client, ClientInitError};
pub use self::{ pub use self::{
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
layers::{ layers::{

View File

@ -67,6 +67,7 @@ pub struct PolicyFactory {
} }
impl PolicyFactory { impl PolicyFactory {
#[tracing::instrument(skip(source), err(Display))]
pub async fn load( pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin, mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value, data: serde_json::Value,
@ -125,6 +126,7 @@ impl PolicyFactory {
.await .await
} }
#[tracing::instrument(skip(self), err)]
pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> { pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> {
let mut store = Store::new(&self.engine, ()); let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module).await?; let runtime = Runtime::new(&mut store, &self.module).await?;