You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Trace AWS operations & share TLS connector with mas-http
This commit is contained in:
9
Cargo.lock
generated
9
Cargo.lock
generated
@ -390,8 +390,6 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-rustls",
|
|
||||||
"lazy_static",
|
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
@ -415,8 +413,6 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"tokio",
|
|
||||||
"tokio-util 0.7.4",
|
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1980,9 +1976,7 @@ checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"hyper",
|
"hyper",
|
||||||
"log",
|
|
||||||
"rustls",
|
"rustls",
|
||||||
"rustls-native-certs",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
]
|
]
|
||||||
@ -2486,7 +2480,10 @@ dependencies = [
|
|||||||
"async-trait",
|
"async-trait",
|
||||||
"aws-config",
|
"aws-config",
|
||||||
"aws-sdk-sesv2",
|
"aws-sdk-sesv2",
|
||||||
|
"aws-smithy-async",
|
||||||
|
"aws-smithy-client",
|
||||||
"lettre",
|
"lettre",
|
||||||
|
"mas-http",
|
||||||
"mas-templates",
|
"mas-templates",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -195,7 +195,7 @@ impl EmailTransportConfig {
|
|||||||
.context("failed to build SMTP transport")
|
.context("failed to build SMTP transport")
|
||||||
}
|
}
|
||||||
EmailTransportConfig::Sendmail { command } => Ok(MailTransport::sendmail(command)),
|
EmailTransportConfig::Sendmail { command } => Ok(MailTransport::sendmail(command)),
|
||||||
EmailTransportConfig::AwsSes => Ok(MailTransport::aws_ses().await),
|
EmailTransportConfig::AwsSes => Ok(MailTransport::aws_ses().await?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,10 +10,14 @@ anyhow = "1.0.66"
|
|||||||
async-trait = "0.1.58"
|
async-trait = "0.1.58"
|
||||||
tokio = { version = "1.21.2", features = ["macros"] }
|
tokio = { version = "1.21.2", features = ["macros"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
aws-sdk-sesv2 = "0.21.0"
|
|
||||||
aws-config = "0.51.0"
|
aws-sdk-sesv2 = { version = "0.21.0", default-features = false }
|
||||||
|
aws-config = { version = "0.51.0", default-features = false }
|
||||||
|
aws-smithy-client = { version = "0.51.0", default-features = false, features = ["client-hyper"] }
|
||||||
|
aws-smithy-async = { version = "0.51.0", default-features = false, features = ["rt-tokio"] }
|
||||||
|
|
||||||
mas-templates = { path = "../templates" }
|
mas-templates = { path = "../templates" }
|
||||||
|
mas-http = { path = "../http" }
|
||||||
|
|
||||||
[dependencies.lettre]
|
[dependencies.lettre]
|
||||||
version = "0.10.1"
|
version = "0.10.1"
|
||||||
|
@ -12,13 +12,20 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use aws_config::provider_config::ProviderConfig;
|
||||||
use aws_sdk_sesv2::{
|
use aws_sdk_sesv2::{
|
||||||
|
middleware::DefaultMiddleware,
|
||||||
model::{EmailContent, RawMessage},
|
model::{EmailContent, RawMessage},
|
||||||
types::Blob,
|
types::Blob,
|
||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
|
use aws_smithy_async::rt::sleep::TokioSleep;
|
||||||
|
use aws_smithy_client::erase::{DynConnector, DynMiddleware};
|
||||||
use lettre::{address::Envelope, AsyncTransport};
|
use lettre::{address::Envelope, AsyncTransport};
|
||||||
|
use mas_http::{otel::TraceLayer, ClientInitError};
|
||||||
|
|
||||||
/// An asynchronous email transport that sends email via the AWS Simple Email
|
/// An asynchronous email transport that sends email via the AWS Simple Email
|
||||||
/// Service v2 API
|
/// Service v2 API
|
||||||
@ -28,17 +35,47 @@ pub struct Transport {
|
|||||||
|
|
||||||
impl Transport {
|
impl Transport {
|
||||||
/// Construct a [`Transport`] from the environment
|
/// Construct a [`Transport`] from the environment
|
||||||
pub async fn from_env() -> Self {
|
///
|
||||||
let config = aws_config::from_env().load().await;
|
/// # Errors
|
||||||
let config = aws_sdk_sesv2::Config::from(&config);
|
///
|
||||||
Self::new(config)
|
/// Returns an error if the HTTP client failed to initialize
|
||||||
}
|
pub async fn from_env() -> Result<Self, ClientInitError> {
|
||||||
|
let sleep = Arc::new(TokioSleep::new());
|
||||||
|
|
||||||
/// Constructs a [`Transport`] from a given AWS SES SDK config
|
// Create the TCP connector from mas-http. This way we share the root
|
||||||
#[must_use]
|
// certificate loader with it
|
||||||
pub fn new(config: aws_sdk_sesv2::Config) -> Self {
|
let http_connector = mas_http::make_traced_connector()
|
||||||
let client = Client::from_conf(config);
|
.await
|
||||||
Self { client }
|
.expect("failed to create HTTPS connector");
|
||||||
|
|
||||||
|
let http_connector = aws_smithy_client::hyper_ext::Adapter::builder()
|
||||||
|
.sleep_impl(sleep.clone())
|
||||||
|
.build(http_connector);
|
||||||
|
|
||||||
|
let http_connector = DynConnector::new(http_connector);
|
||||||
|
|
||||||
|
// Middleware to add tracing to AWS SDK operations
|
||||||
|
let middleware = DynMiddleware::new((
|
||||||
|
TraceLayer::with_namespace("aws_sdk")
|
||||||
|
.make_span_builder(mas_http::otel::DefaultMakeSpanBuilder::new("aws_sdk"))
|
||||||
|
.on_error(mas_http::otel::DebugOnError),
|
||||||
|
DefaultMiddleware::default(),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Use that connector for discovering the config
|
||||||
|
let config = ProviderConfig::default().with_http_connector(http_connector.clone());
|
||||||
|
let config = aws_config::from_env().configure(config).load().await;
|
||||||
|
let config = aws_sdk_sesv2::Config::from(&config);
|
||||||
|
|
||||||
|
// As well as for the client itself
|
||||||
|
let client = aws_smithy_client::Client::builder()
|
||||||
|
.sleep_impl(sleep)
|
||||||
|
.connector(http_connector)
|
||||||
|
.middleware(middleware)
|
||||||
|
.build_dyn();
|
||||||
|
|
||||||
|
let client = Client::with_config(client, config);
|
||||||
|
Ok(Self { client })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ use lettre::{
|
|||||||
},
|
},
|
||||||
AsyncTransport, Tokio1Executor,
|
AsyncTransport, Tokio1Executor,
|
||||||
};
|
};
|
||||||
|
use mas_http::ClientInitError;
|
||||||
|
|
||||||
pub mod aws_ses;
|
pub mod aws_ses;
|
||||||
|
|
||||||
@ -101,8 +102,13 @@ impl Transport {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Construct a AWS SES transport
|
/// Construct a AWS SES transport
|
||||||
pub async fn aws_ses() -> Self {
|
///
|
||||||
Self::new(TransportInner::AwsSes(aws_ses::Transport::from_env().await))
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the HTTP client failed to initialize
|
||||||
|
pub async fn aws_ses() -> Result<Self, ClientInitError> {
|
||||||
|
let transport = aws_ses::Transport::from_env().await?;
|
||||||
|
Ok(Self::new(TransportInner::AwsSes(transport)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,6 +148,16 @@ pub enum NativeRootsLoadError {
|
|||||||
Empty,
|
Empty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn make_tls_config() -> Result<rustls::ClientConfig, ClientInitError> {
|
||||||
|
let roots = tls_roots().await?;
|
||||||
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(roots)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
Ok(tls_config)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
/// Create a basic Hyper HTTP & HTTPS client without any tracing
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
@ -159,57 +169,63 @@ where
|
|||||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||||
E: Into<BoxError>,
|
E: Into<BoxError>,
|
||||||
{
|
{
|
||||||
let resolver = GaiResolver::new();
|
let https = make_untraced_connector().await?;
|
||||||
let roots = tls_roots().await?;
|
Ok(Client::builder().build(https))
|
||||||
let tls_config = rustls::ClientConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_root_certificates(roots)
|
|
||||||
.with_no_client_auth();
|
|
||||||
|
|
||||||
Ok(make_client(resolver, tls_config))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn make_base_client<B, E>(
|
async fn make_traced_client<B, E>(
|
||||||
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
) -> Result<hyper::Client<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, B>, ClientInitError>
|
||||||
where
|
where
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
||||||
E: Into<BoxError>,
|
E: Into<BoxError>,
|
||||||
{
|
{
|
||||||
// Trace DNS requests
|
let https = make_traced_connector().await?;
|
||||||
let resolver = TraceLayer::dns().layer(GaiResolver::new());
|
Ok(Client::builder().build(https))
|
||||||
|
|
||||||
let roots = tls_roots().await?;
|
|
||||||
let tls_config = rustls::ClientConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_root_certificates(roots)
|
|
||||||
.with_no_client_auth();
|
|
||||||
|
|
||||||
Ok(make_client(resolver, tls_config))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_client<R, B, E>(
|
/// Create a traced HTTP and HTTPS connector
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if it failed to load the TLS certificates
|
||||||
|
pub async fn make_traced_connector(
|
||||||
|
) -> Result<HttpsConnector<HttpConnector<TraceDns<GaiResolver>>>, ClientInitError>
|
||||||
|
where
|
||||||
|
{
|
||||||
|
// Trace DNS requests
|
||||||
|
let resolver = TraceLayer::dns().layer(GaiResolver::new());
|
||||||
|
let tls_config = make_tls_config().await?;
|
||||||
|
Ok(make_connector(resolver, tls_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn make_untraced_connector(
|
||||||
|
) -> Result<HttpsConnector<HttpConnector<GaiResolver>>, ClientInitError>
|
||||||
|
where
|
||||||
|
{
|
||||||
|
let resolver = GaiResolver::new();
|
||||||
|
let tls_config = make_tls_config().await?;
|
||||||
|
Ok(make_connector(resolver, tls_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_connector<R>(
|
||||||
resolver: R,
|
resolver: R,
|
||||||
tls_config: rustls::ClientConfig,
|
tls_config: rustls::ClientConfig,
|
||||||
) -> hyper::Client<HttpsConnector<HttpConnector<R>>, B>
|
) -> HttpsConnector<HttpConnector<R>>
|
||||||
where
|
where
|
||||||
R: Service<Name> + Send + Sync + Clone + 'static,
|
R: Service<Name> + Send + Sync + Clone + 'static,
|
||||||
R::Error: std::error::Error + Send + Sync,
|
R::Error: std::error::Error + Send + Sync,
|
||||||
R::Future: Send,
|
R::Future: Send,
|
||||||
R::Response: Iterator<Item = SocketAddr>,
|
R::Response: Iterator<Item = SocketAddr>,
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Send + 'static,
|
|
||||||
E: Into<BoxError>,
|
|
||||||
{
|
{
|
||||||
let mut http = HttpConnector::new_with_resolver(resolver);
|
let mut http = HttpConnector::new_with_resolver(resolver);
|
||||||
http.enforce_http(false);
|
http.enforce_http(false);
|
||||||
|
|
||||||
let https = HttpsConnectorBuilder::new()
|
HttpsConnectorBuilder::new()
|
||||||
.with_tls_config(tls_config)
|
.with_tls_config(tls_config)
|
||||||
.https_or_http()
|
.https_or_http()
|
||||||
.enable_http1()
|
.enable_http1()
|
||||||
.enable_http2()
|
.enable_http2()
|
||||||
.wrap_connector(http);
|
.wrap_connector(http)
|
||||||
|
|
||||||
Client::builder().build(https)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a traced HTTP client, with a default timeout, which follows redirects
|
/// Create a traced HTTP client, with a default timeout, which follows redirects
|
||||||
@ -228,7 +244,7 @@ where
|
|||||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||||
E: Into<BoxError> + 'static,
|
E: Into<BoxError> + 'static,
|
||||||
{
|
{
|
||||||
let client = make_base_client().await?;
|
let client = make_traced_client().await?;
|
||||||
|
|
||||||
let layer = (
|
let layer = (
|
||||||
// Convert the errors to ClientError to help dealing with them
|
// Convert the errors to ClientError to help dealing with them
|
||||||
|
@ -31,3 +31,16 @@ where
|
|||||||
span.add_event("exception".to_owned(), attributes);
|
span.add_event("exception".to_owned(), attributes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct DebugOnError;
|
||||||
|
|
||||||
|
impl<E> OnError<E> for DebugOnError
|
||||||
|
where
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
{
|
||||||
|
fn on_error(&self, span: &SpanRef<'_>, _metrics_labels: &mut Vec<KeyValue>, err: &E) {
|
||||||
|
let attributes = vec![EXCEPTION_MESSAGE.string(format!("{err:?}"))];
|
||||||
|
span.add_event("exception".to_owned(), attributes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -30,7 +30,7 @@ mod ext;
|
|||||||
mod layers;
|
mod layers;
|
||||||
|
|
||||||
#[cfg(feature = "client")]
|
#[cfg(feature = "client")]
|
||||||
pub use self::client::{client, make_untraced_client};
|
pub use self::client::{client, make_traced_connector, make_untraced_client, ClientInitError};
|
||||||
pub use self::{
|
pub use self::{
|
||||||
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
|
ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt},
|
||||||
layers::{
|
layers::{
|
||||||
|
@ -67,6 +67,7 @@ pub struct PolicyFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PolicyFactory {
|
impl PolicyFactory {
|
||||||
|
#[tracing::instrument(skip(source), err(Display))]
|
||||||
pub async fn load(
|
pub async fn load(
|
||||||
mut source: impl AsyncRead + std::marker::Unpin,
|
mut source: impl AsyncRead + std::marker::Unpin,
|
||||||
data: serde_json::Value,
|
data: serde_json::Value,
|
||||||
@ -125,6 +126,7 @@ impl PolicyFactory {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), err)]
|
||||||
pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> {
|
pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> {
|
||||||
let mut store = Store::new(&self.engine, ());
|
let mut store = Store::new(&self.engine, ());
|
||||||
let runtime = Runtime::new(&mut store, &self.module).await?;
|
let runtime = Runtime::new(&mut store, &self.module).await?;
|
||||||
|
Reference in New Issue
Block a user