1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Remove the config dependency from the email, templates & handlers crates

This commit is contained in:
Quentin Gliech
2022-09-02 15:15:36 +02:00
parent e58dd6d33d
commit cc6c6e8bdb
14 changed files with 148 additions and 90 deletions

4
Cargo.lock generated
View File

@ -2420,6 +2420,7 @@ dependencies = [
"figment", "figment",
"indoc", "indoc",
"lettre", "lettre",
"mas-email",
"mas-iana", "mas-iana",
"mas-jose", "mas-jose",
"mas-keystore", "mas-keystore",
@ -2461,7 +2462,6 @@ dependencies = [
"aws-sdk-sesv2", "aws-sdk-sesv2",
"aws-types", "aws-types",
"lettre", "lettre",
"mas-config",
"mas-templates", "mas-templates",
"tokio", "tokio",
"tracing", "tracing",
@ -2483,7 +2483,6 @@ dependencies = [
"indoc", "indoc",
"lettre", "lettre",
"mas-axum-utils", "mas-axum-utils",
"mas-config",
"mas-data-model", "mas-data-model",
"mas-email", "mas-email",
"mas-http", "mas-http",
@ -2707,7 +2706,6 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"chrono", "chrono",
"mas-config",
"mas-data-model", "mas-data-model",
"mas-router", "mas-router",
"oauth2-types", "oauth2-types",

View File

@ -23,7 +23,8 @@ use clap::Parser;
use futures::stream::{StreamExt, TryStreamExt}; use futures::stream::{StreamExt, TryStreamExt};
use hyper::Server; use hyper::Server;
use mas_config::RootConfig; use mas_config::RootConfig;
use mas_email::{MailTransport, Mailer}; use mas_email::Mailer;
use mas_handlers::MatrixHomeserver;
use mas_http::ServerLayer; use mas_http::ServerLayer;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
@ -148,7 +149,7 @@ impl Options {
let listener = TcpListener::bind(addr).context("could not bind address")?; let listener = TcpListener::bind(addr).context("could not bind address")?;
// Connect to the mail server // Connect to the mail server
let mail_transport = MailTransport::from_config(&config.email.transport).await?; let mail_transport = config.email.transport.to_transport().await?;
mail_transport.test_connection().await?; mail_transport.test_connection().await?;
// Connect to the database // Connect to the database
@ -203,7 +204,7 @@ impl Options {
let policy_factory = Arc::new(policy_factory); let policy_factory = Arc::new(policy_factory);
// Load and compile the templates // Load and compile the templates
let templates = Templates::load_from_config(&config.templates) let templates = Templates::load(config.templates.path.clone(), config.templates.builtin)
.await .await
.context("could not load templates")?; .context("could not load templates")?;
@ -218,7 +219,7 @@ impl Options {
let static_files = mas_static_files::service(&config.http.web_root); let static_files = mas_static_files::service(&config.http.web_root);
let matrix_config = config.matrix.clone(); let homeserver = MatrixHomeserver::new(config.matrix.homeserver.clone());
// Explicitely the config to properly zeroize secret keys // Explicitely the config to properly zeroize secret keys
drop(config); drop(config);
@ -242,7 +243,7 @@ impl Options {
&encrypter, &encrypter,
&mailer, &mailer,
&url_builder, &url_builder,
&matrix_config, &homeserver,
&policy_factory, &policy_factory,
) )
.fallback(static_files) .fallback(static_files)

View File

@ -62,7 +62,7 @@ impl Options {
path: Some(path.to_string()), path: Some(path.to_string()),
builtin: !skip_builtin, builtin: !skip_builtin,
}; };
let templates = Templates::load_from_config(&config).await?; let templates = Templates::load(config.path.clone(), config.builtin).await?;
templates.check_render().await?; templates.check_render().await?;
Ok(()) Ok(())

View File

@ -32,3 +32,4 @@ indoc = "1.0.7"
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-keystore = { path = "../keystore" } mas-keystore = { path = "../keystore" }
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
mas-email = { path = "../email" }

View File

@ -14,8 +14,10 @@
use std::num::NonZeroU16; use std::num::NonZeroU16;
use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use lettre::{message::Mailbox, Address}; use lettre::{message::Mailbox, Address};
use mas_email::MailTransport;
use schemars::{ use schemars::{
gen::SchemaGenerator, gen::SchemaGenerator,
schema::{InstanceType, Schema, SchemaObject}, schema::{InstanceType, Schema, SchemaObject},
@ -51,7 +53,7 @@ pub struct Credentials {
} }
/// Encryption mode to use /// Encryption mode to use
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum EmailSmtpMode { pub enum EmailSmtpMode {
/// Plain text /// Plain text
@ -62,6 +64,16 @@ pub enum EmailSmtpMode {
Tls, Tls,
} }
impl From<&EmailSmtpMode> for mas_email::SmtpMode {
fn from(value: &EmailSmtpMode) -> Self {
match value {
EmailSmtpMode::Plain => Self::Plain,
EmailSmtpMode::StartTls => Self::StartTls,
EmailSmtpMode::Tls => Self::Tls,
}
}
}
/// What backend should be used when sending emails /// What backend should be used when sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "transport", rename_all = "snake_case")] #[serde(tag = "transport", rename_all = "snake_case")]
@ -156,3 +168,30 @@ impl ConfigurationSection<'_> for EmailConfig {
Self::default() Self::default()
} }
} }
impl EmailTransportConfig {
/// Create a [`lettre::Transport`] out of this config
///
/// # Errors
///
/// Returns an error if the transport could not be created
pub async fn to_transport(&self) -> Result<MailTransport, anyhow::Error> {
match self {
Self::Blackhole => Ok(MailTransport::blackhole()),
Self::Smtp {
mode,
hostname,
credentials,
port,
} => {
let credentials = credentials
.clone()
.map(|c| mas_email::SmtpCredentials::new(c.username, c.password));
MailTransport::smtp(mode.into(), hostname, port.as_ref().copied(), credentials)
.context("failed to build SMTP transport")
}
EmailTransportConfig::Sendmail { command } => Ok(MailTransport::sendmail(command)),
EmailTransportConfig::AwsSes => Ok(MailTransport::aws_ses().await),
}
}
}

View File

@ -15,7 +15,6 @@ aws-config = "0.47.0"
aws-types = "0.47.0" aws-types = "0.47.0"
mas-templates = { path = "../templates" } mas-templates = { path = "../templates" }
mas-config = { path = "../config" }
[dependencies.lettre] [dependencies.lettre]
version = "0.10.1" version = "0.10.1"

View File

@ -26,7 +26,9 @@
mod mailer; mod mailer;
mod transport; mod transport;
pub use lettre::transport::smtp::authentication::Credentials as SmtpCredentials;
pub use self::{ pub use self::{
mailer::Mailer, mailer::Mailer,
transport::{aws_ses::Transport as AwsSesTransport, Transport as MailTransport}, transport::{aws_ses::Transport as AwsSesTransport, SmtpMode, Transport as MailTransport},
}; };

View File

@ -14,7 +14,7 @@
//! Email transport backends //! Email transport backends
use std::sync::Arc; use std::{ffi::OsString, num::NonZeroU16, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use lettre::{ use lettre::{
@ -25,10 +25,20 @@ use lettre::{
}, },
AsyncTransport, Tokio1Executor, AsyncTransport, Tokio1Executor,
}; };
use mas_config::{EmailSmtpMode, EmailTransportConfig};
pub mod aws_ses; pub mod aws_ses;
/// Encryption mode to use
#[derive(Debug, Clone, Copy)]
pub enum SmtpMode {
/// Plain text
Plain,
/// StartTLS (starts as plain text then upgrade to TLS)
StartTls,
/// TLS
Tls,
}
/// A wrapper around many [`AsyncTransport`]s /// A wrapper around many [`AsyncTransport`]s
#[derive(Default, Clone)] #[derive(Default, Clone)]
pub struct Transport { pub struct Transport {
@ -43,52 +53,56 @@ enum TransportInner {
} }
impl Transport { impl Transport {
/// Construct a transport from a user configration fn new(inner: TransportInner) -> Self {
let inner = Arc::new(inner);
Self { inner }
}
/// Construct a blackhole transport
#[must_use]
pub fn blackhole() -> Self {
Self::new(TransportInner::Blackhole)
}
/// Construct a SMTP transport
/// ///
/// # Errors /// # Errors
/// ///
/// Will return `Err` on invalid confiuration /// Returns an error if the underlying SMTP transport could not be built
pub async fn from_config(config: &EmailTransportConfig) -> Result<Self, anyhow::Error> { pub fn smtp(
let inner = match config { mode: SmtpMode,
EmailTransportConfig::Blackhole => TransportInner::Blackhole, hostname: &str,
EmailTransportConfig::Smtp { port: Option<NonZeroU16>,
mode, credentials: Option<Credentials>,
hostname, ) -> Result<Self, lettre::transport::smtp::Error> {
credentials, let mut t = match mode {
port, SmtpMode::Plain => AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(hostname),
} => { SmtpMode::StartTls => AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(hostname)?,
let mut t = match mode { SmtpMode::Tls => AsyncSmtpTransport::<Tokio1Executor>::relay(hostname)?,
EmailSmtpMode::Plain => {
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(hostname)
}
EmailSmtpMode::StartTls => {
AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(hostname)?
}
EmailSmtpMode::Tls => AsyncSmtpTransport::<Tokio1Executor>::relay(hostname)?,
};
if let Some(credentials) = credentials {
t = t.credentials(Credentials::new(
credentials.username.clone(),
credentials.password.clone(),
));
}
if let Some(port) = port {
t = t.port((*port).into());
}
TransportInner::Smtp(t.build())
}
EmailTransportConfig::Sendmail { command } => {
TransportInner::Sendmail(AsyncSendmailTransport::new_with_command(command))
}
EmailTransportConfig::AwsSes => {
TransportInner::AwsSes(aws_ses::Transport::from_env().await)
}
}; };
let inner = Arc::new(inner);
Ok(Self { inner }) if let Some(credentials) = credentials {
t = t.credentials(credentials);
}
if let Some(port) = port {
t = t.port(port.into());
}
Ok(Self::new(TransportInner::Smtp(t.build())))
}
/// Construct a Sendmail transport
#[must_use]
pub fn sendmail(command: impl Into<OsString>) -> Self {
Self::new(TransportInner::Sendmail(
AsyncSendmailTransport::new_with_command(command),
))
}
/// Construct a AWS SES transport
pub async fn aws_ses() -> Self {
Self::new(TransportInner::AwsSes(aws_ses::Transport::from_env().await))
} }
} }

View File

@ -52,7 +52,6 @@ headers = "0.3.7"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-axum-utils = { path = "../axum-utils" } mas-axum-utils = { path = "../axum-utils" }
mas-config = { path = "../config" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-email = { path = "../email" } mas-email = { path = "../email" }
mas-http = { path = "../http" } mas-http = { path = "../http" }

View File

@ -15,7 +15,6 @@
use axum::{response::IntoResponse, Extension, Json}; use axum::{response::IntoResponse, Extension, Json};
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
use hyper::StatusCode; use hyper::StatusCode;
use mas_config::MatrixConfig;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
use mas_storage::{ use mas_storage::{
compat::{ compat::{
@ -31,7 +30,7 @@ use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error; use thiserror::Error;
use super::MatrixError; use super::{MatrixError, MatrixHomeserver};
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
@ -199,7 +198,7 @@ impl IntoResponse for RouteError {
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
pub(crate) async fn post( pub(crate) async fn post(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Extension(config): Extension<MatrixConfig>, Extension(homeserver): Extension<MatrixHomeserver>,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
@ -216,7 +215,7 @@ pub(crate) async fn post(
} }
}; };
let user_id = format!("@{}:{}", session.user.username, config.homeserver); let user_id = format!("@{}:{}", session.user.username, homeserver);
// If the client asked for a refreshable token, make it expire // If the client asked for a refreshable token, make it expire
let expires_in = if input.refresh_token { let expires_in = if input.refresh_token {

View File

@ -22,6 +22,22 @@ pub(crate) mod login_sso_redirect;
pub(crate) mod logout; pub(crate) mod logout;
pub(crate) mod refresh; pub(crate) mod refresh;
#[derive(Debug, Clone)]
pub struct MatrixHomeserver(String);
impl MatrixHomeserver {
#[must_use]
pub const fn new(hs: String) -> Self {
Self(hs)
}
}
impl std::fmt::Display for MatrixHomeserver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct MatrixError { struct MatrixError {
errcode: &'static str, errcode: &'static str,

View File

@ -30,7 +30,6 @@ use axum::{
}; };
use headers::HeaderName; use headers::HeaderName;
use hyper::header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_TYPE}; use hyper::header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_TYPE};
use mas_config::MatrixConfig;
use mas_email::Mailer; use mas_email::Mailer;
use mas_http::CorsLayerExt; use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore}; use mas_keystore::{Encrypter, Keystore};
@ -46,6 +45,8 @@ mod health;
mod oauth2; mod oauth2;
mod views; mod views;
pub use compat::MatrixHomeserver;
#[must_use] #[must_use]
#[allow( #[allow(
clippy::too_many_lines, clippy::too_many_lines,
@ -60,7 +61,7 @@ pub fn router<B>(
encrypter: &Encrypter, encrypter: &Encrypter,
mailer: &Mailer, mailer: &Mailer,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
matrix_config: &MatrixConfig, homeserver: &MatrixHomeserver,
policy_factory: &Arc<PolicyFactory>, policy_factory: &Arc<PolicyFactory>,
) -> Router<B> ) -> Router<B>
where where
@ -239,33 +240,28 @@ where
.layer(Extension(encrypter.clone())) .layer(Extension(encrypter.clone()))
.layer(Extension(url_builder.clone())) .layer(Extension(url_builder.clone()))
.layer(Extension(mailer.clone())) .layer(Extension(mailer.clone()))
.layer(Extension(matrix_config.clone())) .layer(Extension(homeserver.clone()))
.layer(Extension(policy_factory.clone())) .layer(Extension(policy_factory.clone()))
} }
#[cfg(test)] #[cfg(test)]
async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> { async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
use mas_config::TemplatesConfig;
use mas_email::MailTransport; use mas_email::MailTransport;
let templates_config = TemplatesConfig::default(); let templates = Templates::load(None, true).await?;
let templates = Templates::load_from_config(&templates_config).await?;
// TODO: add test keys to the store // TODO: add test keys to the store
let key_store = Keystore::default(); let key_store = Keystore::default();
let encrypter = Encrypter::new(&[0x42; 32]); let encrypter = Encrypter::new(&[0x42; 32]);
let transport = MailTransport::default(); let transport = MailTransport::blackhole();
let mailbox = "server@example.com".parse()?; let mailbox = "server@example.com".parse()?;
let mailer = Mailer::new(&templates, &transport, &mailbox, &mailbox); let mailer = Mailer::new(&templates, &transport, &mailbox, &mailbox);
let url_builder = UrlBuilder::new("https://example.com/".parse()?); let url_builder = UrlBuilder::new("https://example.com/".parse()?);
let matrix_config = MatrixConfig { let homeserver = MatrixHomeserver::new("example.com".to_owned());
homeserver: "example.com".to_owned(),
};
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?; let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
let policy_factory = Arc::new(policy_factory); let policy_factory = Arc::new(policy_factory);
@ -276,7 +272,7 @@ async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
&encrypter, &encrypter,
&mailer, &mailer,
&url_builder, &url_builder,
&matrix_config, &homeserver,
&policy_factory, &policy_factory,
)) ))
} }

View File

@ -25,5 +25,4 @@ url = "2.2.2"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-config = { path = "../config" }
mas-router = { path = "../router" } mas-router = { path = "../router" }

View File

@ -33,7 +33,6 @@ use std::{
}; };
use anyhow::{bail, Context as _}; use anyhow::{bail, Context as _};
use mas_config::TemplatesConfig;
use mas_data_model::StorageBackend; use mas_data_model::StorageBackend;
use serde::Serialize; use serde::Serialize;
use tera::{Context, Error as TeraError, Tera}; use tera::{Context, Error as TeraError, Tera};
@ -63,7 +62,8 @@ pub use self::{
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Templates { pub struct Templates {
tera: Arc<RwLock<Tera>>, tera: Arc<RwLock<Tera>>,
config: TemplatesConfig, path: Option<String>,
builtin: bool,
} }
/// There was an issue while loading the templates /// There was an issue while loading the templates
@ -90,7 +90,7 @@ pub enum TemplateLoadingError {
impl Templates { impl Templates {
/// List directories to watch /// List directories to watch
pub async fn watch_roots(&self) -> Vec<PathBuf> { pub async fn watch_roots(&self) -> Vec<PathBuf> {
Self::roots(self.config.path.as_deref(), self.config.builtin) Self::roots(self.path.as_deref(), self.builtin)
.await .await
.into_iter() .into_iter()
.filter_map(Result::ok) .filter_map(Result::ok)
@ -133,17 +133,17 @@ impl Templates {
Ok(tera) Ok(tera)
} }
/// Load the templates from [the config][`TemplatesConfig`] /// Load the templates from the given config
pub async fn load_from_config(config: &TemplatesConfig) -> Result<Self, TemplateLoadingError> { pub async fn load(path: Option<String>, builtin: bool) -> Result<Self, TemplateLoadingError> {
let tera = Self::load(config.path.as_deref(), config.builtin).await?; let tera = Self::load_(path.as_deref(), builtin).await?;
Ok(Self { Ok(Self {
tera: Arc::new(RwLock::new(tera)), tera: Arc::new(RwLock::new(tera)),
config: config.clone(), path,
builtin,
}) })
} }
async fn load(path: Option<&str>, builtin: bool) -> Result<Tera, TemplateLoadingError> { async fn load_(path: Option<&str>, builtin: bool) -> Result<Tera, TemplateLoadingError> {
let mut teras = Vec::new(); let mut teras = Vec::new();
let roots = Self::roots(path, builtin).await; let roots = Self::roots(path, builtin).await;
@ -202,7 +202,7 @@ impl Templates {
/// Reload the templates on disk /// Reload the templates on disk
pub async fn reload(&self) -> anyhow::Result<()> { pub async fn reload(&self) -> anyhow::Result<()> {
// Prepare the new Tera instance // Prepare the new Tera instance
let new_tera = Self::load(self.config.path.as_deref(), self.config.builtin).await?; let new_tera = Self::load_(self.path.as_deref(), self.builtin).await?;
// Swap it // Swap it
*self.tera.write().await = new_tera; *self.tera.write().await = new_tera;
@ -378,12 +378,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn check_builtin_templates() { async fn check_builtin_templates() {
let config = TemplatesConfig { let templates = Templates::load(None, true).await.unwrap();
path: None,
builtin: true,
};
let templates = Templates::load_from_config(&config).await.unwrap();
templates.check_render().await.unwrap(); templates.check_render().await.unwrap();
} }
} }