diff --git a/Cargo.lock b/Cargo.lock index 92d3856a..01ef99e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2732,6 +2732,7 @@ dependencies = [ "rustls", "serde_json", "serde_yaml", + "sqlx", "tokio", "tower", "tower-http", @@ -2764,7 +2765,6 @@ dependencies = [ "serde", "serde_json", "serde_with", - "sqlx", "thiserror", "tokio", "tracing", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index e50a0802..d46d5379 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -21,6 +21,7 @@ rand_chacha = "0.3.1" rustls = "0.20.7" serde_json = "1.0.89" serde_yaml = "0.9.14" +sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } tokio = { version = "1.23.0", features = ["full"] } tower = { version = "0.4.13", features = ["full"] } tower-http = { version = "0.3.5", features = ["fs", "compression-full"] } diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index b1e296d3..338fdbf9 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -17,6 +17,8 @@ use clap::Parser; use mas_config::DatabaseConfig; use mas_storage::MIGRATOR; +use crate::util::database_from_config; + #[derive(Parser, Debug)] pub(super) struct Options { #[command(subcommand)] @@ -32,7 +34,7 @@ enum Subcommand { impl Options { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { let config: DatabaseConfig = root.load_config()?; - let pool = config.connect().await?; + let pool = database_from_config(&config).await?; // Run pending migrations MIGRATOR diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 8ad6b3a8..5472b78c 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -28,7 +28,7 @@ use oauth2_types::scope::Scope; use rand::SeedableRng; use tracing::{info, warn}; -use crate::util::password_manager_from_config; +use crate::util::{database_from_config, password_manager_from_config}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -197,7 +197,7 @@ impl Options { let database_config: DatabaseConfig = root.load_config()?; let passwords_config: PasswordsConfig = root.load_config()?; - let pool = database_config.connect().await?; + let pool = database_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; let mut txn = pool.begin().await?; @@ -228,7 +228,7 @@ impl Options { SC::VerifyEmail { username, email } => { let config: DatabaseConfig = root.load_config()?; - let pool = config.connect().await?; + let pool = database_from_config(&config).await?; let mut txn = pool.begin().await?; let user = lookup_user_by_username(&mut txn, username) @@ -247,7 +247,7 @@ impl Options { SC::ImportClients { truncate } => { let config: RootConfig = root.load_config()?; - let pool = config.database.connect().await?; + let pool = database_from_config(&config.database).await?; let encrypter = config.secrets.encrypter(); let mut txn = pool.begin().await?; @@ -306,7 +306,7 @@ impl Options { } => { let config: RootConfig = root.load_config()?; let encrypter = config.secrets.encrypter(); - let pool = config.database.connect().await?; + let pool = database_from_config(&config.database).await?; let url_builder = UrlBuilder::new(config.http.public_base); let mut conn = pool.acquire().await?; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 22eac325..3e10f091 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -28,7 +28,10 @@ use mas_templates::Templates; use tokio::signal::unix::SignalKind; use tracing::{error, info, log::warn}; -use crate::util::{mailer_from_config, password_manager_from_config, policy_factory_from_config}; +use crate::util::{ + database_from_config, mailer_from_config, password_manager_from_config, + policy_factory_from_config, templates_from_config, +}; #[derive(Parser, Debug, Default)] pub(super) struct Options { @@ -105,7 +108,7 @@ impl Options { let config: RootConfig = root.load_config()?; // Connect to the database - let pool = config.database.connect().await?; + let pool = database_from_config(&config.database).await?; if self.migrate { info!("Running pending migrations"); @@ -120,7 +123,6 @@ impl Options { queue.recuring(Duration::from_secs(15), mas_tasks::cleanup_expired(&pool)); queue.start(); - // TODO: task queue, key store, encrypter, url builder, http client // Initialize the key store let key_store = config .secrets @@ -138,9 +140,7 @@ impl Options { let url_builder = UrlBuilder::new(config.http.public_base.clone()); // Load and compile the templates - let templates = Templates::load(config.templates.path.clone(), url_builder.clone()) - .await - .context("could not load templates")?; + let templates = templates_from_config(&config.templates, &url_builder).await?; let mailer = mailer_from_config(&config.email, &templates).await?; mailer.test_connection().await?; diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index ba7c07db..f138eddd 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -12,12 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use anyhow::Context; -use mas_config::{EmailConfig, EmailSmtpMode, EmailTransportConfig, PasswordsConfig, PolicyConfig}; +use mas_config::{ + DatabaseConfig, DatabaseConnectConfig, EmailConfig, EmailSmtpMode, EmailTransportConfig, + PasswordsConfig, PolicyConfig, TemplatesConfig, +}; use mas_email::{MailTransport, Mailer}; use mas_handlers::passwords::PasswordManager; use mas_policy::PolicyFactory; -use mas_templates::Templates; +use mas_router::UrlBuilder; +use mas_templates::{TemplateLoadingError, Templates}; +use sqlx::{ + postgres::{PgConnectOptions, PgPoolOptions}, + ConnectOptions, PgPool, +}; +use tracing::log::LevelFilter; pub async fn password_manager_from_config( config: &PasswordsConfig, @@ -91,3 +102,69 @@ pub async fn policy_factory_from_config( .await .context("failed to load the policy") } + +pub async fn templates_from_config( + config: &TemplatesConfig, + url_builder: &UrlBuilder, +) -> Result { + Templates::load(config.path.clone(), url_builder.clone()).await +} + +pub async fn database_from_config(config: &DatabaseConfig) -> Result { + let mut options = match &config.options { + DatabaseConnectConfig::Uri { uri } => uri + .parse() + .context("could not parse database connection string")?, + DatabaseConnectConfig::Options { + host, + port, + socket, + username, + password, + database, + } => { + let mut opts = + PgConnectOptions::new().application_name("matrix-authentication-service"); + + if let Some(host) = host { + opts = opts.host(host); + } + + if let Some(port) = port { + opts = opts.port(*port); + } + + if let Some(socket) = socket { + opts = opts.socket(socket); + } + + if let Some(username) = username { + opts = opts.username(username); + } + + if let Some(password) = password { + opts = opts.password(password); + } + + if let Some(database) = database { + opts = opts.database(database); + } + + opts + } + }; + + options + .log_statements(LevelFilter::Debug) + .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100)); + + PgPoolOptions::new() + .max_connections(config.max_connections.into()) + .min_connections(config.min_connections) + .acquire_timeout(config.connect_timeout) + .idle_timeout(config.idle_timeout) + .max_lifetime(config.max_lifetime) + .connect_with(options) + .await + .context("could not connect to the database") +} diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml index bf957d8f..46486092 100644 --- a/crates/config/Cargo.toml +++ b/crates/config/Cargo.toml @@ -6,8 +6,8 @@ edition = "2021" license = "Apache-2.0" [dependencies] -tokio = { version = "1.23.0", features = [] } -tracing = { version = "0.1.37", features = ["log"] } +tokio = { version = "1.23.0", features = ["fs", "rt"] } +tracing = { version = "0.1.37" } async-trait = "0.1.59" thiserror = "1.0.37" @@ -23,7 +23,6 @@ url = { version = "2.3.1", features = ["serde"] } serde = { version = "1.0.150", features = ["derive"] } serde_with = { version = "2.1.0", features = ["hex", "chrono"] } serde_json = "1.0.89" -sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } pem-rfc7468 = "0.6.0" rustls-pemfile = "1.0.1" diff --git a/crates/config/src/schema.rs b/crates/config/src/schema.rs index e0861ac8..5f47ddb6 100644 --- a/crates/config/src/schema.rs +++ b/crates/config/src/schema.rs @@ -40,3 +40,12 @@ pub fn hostname(_gen: &mut SchemaGenerator) -> Schema { ..SchemaObject::default() }) } + +/// An email address +pub fn mailbox(_gen: &mut SchemaGenerator) -> Schema { + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + format: Some("email".to_owned()), + ..SchemaObject::default() + }) +} diff --git a/crates/config/src/sections/database.rs b/crates/config/src/sections/database.rs index 772ee885..bc84c1b9 100644 --- a/crates/config/src/sections/database.rs +++ b/crates/config/src/sections/database.rs @@ -14,18 +14,12 @@ use std::{num::NonZeroU32, time::Duration}; -use anyhow::Context; use async_trait::async_trait; use camino::Utf8PathBuf; use rand::Rng; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none}; -use sqlx::{ - postgres::{PgConnectOptions, PgPool, PgPoolOptions}, - ConnectOptions, -}; -use tracing::log::LevelFilter; use super::ConfigurationSection; use crate::schema; @@ -65,14 +59,17 @@ impl Default for DatabaseConfig { } } -#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +/// Database connection configuration +#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] #[serde(untagged)] -enum ConnectConfig { +pub enum ConnectConfig { + /// Connect via a full URI Uri { /// Connection URI #[schemars(url, default = "default_connection_string")] uri: String, }, + /// Connect via a map of options Options { /// Name of host to connect to #[schemars(schema_with = "schema::hostname")] @@ -100,59 +97,9 @@ enum ConnectConfig { /// The database name #[serde(default)] database: Option, - /* TODO - * ssl_mode: PgSslMode, - * ssl_root_cert: Option, */ }, } -impl TryInto for &ConnectConfig { - type Error = sqlx::Error; - - fn try_into(self) -> Result { - match self { - ConnectConfig::Uri { uri } => uri.parse(), - ConnectConfig::Options { - host, - port, - socket, - username, - password, - database, - } => { - let mut opts = - PgConnectOptions::new().application_name("matrix-authentication-service"); - - if let Some(host) = host { - opts = opts.host(host); - } - - if let Some(port) = port { - opts = opts.port(*port); - } - - if let Some(socket) = socket { - opts = opts.socket(socket); - } - - if let Some(username) = username { - opts = opts.username(username); - } - - if let Some(password) = password { - opts = opts.password(password); - } - - if let Some(database) = database { - opts = opts.database(database); - } - - Ok(opts) - } - } - } -} - impl Default for ConnectConfig { fn default() -> Self { Self::Uri { @@ -168,57 +115,33 @@ impl Default for ConnectConfig { pub struct DatabaseConfig { /// Options related to how to connect to the database #[serde(default, flatten)] - options: ConnectConfig, + pub options: ConnectConfig, /// Set the maximum number of connections the pool should maintain #[serde(default = "default_max_connections")] - max_connections: NonZeroU32, + pub max_connections: NonZeroU32, /// Set the minimum number of connections the pool should maintain #[serde(default)] - min_connections: u32, + pub min_connections: u32, /// Set the amount of time to attempt connecting to the database #[schemars(with = "u64")] #[serde(default = "default_connect_timeout")] #[serde_as(as = "serde_with::DurationSeconds")] - connect_timeout: Duration, + pub connect_timeout: Duration, /// Set a maximum idle duration for individual connections #[schemars(with = "Option")] #[serde(default = "default_idle_timeout")] #[serde_as(as = "Option>")] - idle_timeout: Option, + pub idle_timeout: Option, /// Set the maximum lifetime of individual connections #[schemars(with = "u64")] #[serde(default = "default_max_lifetime")] #[serde_as(as = "Option>")] - max_lifetime: Option, -} - -impl DatabaseConfig { - /// Connect to the database - #[tracing::instrument(err, skip_all)] - pub async fn connect(&self) -> anyhow::Result { - let mut options: PgConnectOptions = (&self.options) - .try_into() - .context("invalid database config")?; - - options - .log_statements(LevelFilter::Debug) - .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100)); - - PgPoolOptions::new() - .max_connections(self.max_connections.into()) - .min_connections(self.min_connections) - .acquire_timeout(self.connect_timeout) - .idle_timeout(self.idle_timeout) - .max_lifetime(self.max_lifetime) - .connect_with(options) - .await - .context("could not connect to the database") - } + pub max_lifetime: Option, } #[async_trait] diff --git a/crates/config/src/sections/email.rs b/crates/config/src/sections/email.rs index 9da6ae68..8b74c21d 100644 --- a/crates/config/src/sections/email.rs +++ b/crates/config/src/sections/email.rs @@ -16,31 +16,11 @@ use std::num::NonZeroU16; use async_trait::async_trait; use rand::Rng; -use schemars::{ - gen::SchemaGenerator, - schema::{InstanceType, Schema, SchemaObject}, - JsonSchema, -}; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use super::ConfigurationSection; -fn mailbox_schema(_gen: &mut SchemaGenerator) -> Schema { - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - format: Some("email".to_owned()), - ..SchemaObject::default() - }) -} - -fn hostname_schema(_gen: &mut SchemaGenerator) -> Schema { - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - format: Some("hostname".to_owned()), - ..SchemaObject::default() - }) -} - #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct Credentials { /// Username for use to authenticate when connecting to the SMTP server @@ -77,7 +57,7 @@ pub enum EmailTransportConfig { mode: EmailSmtpMode, /// Hostname to connect to - #[schemars(schema_with = "hostname_schema")] + #[schemars(schema_with = "crate::schema::hostname")] hostname: String, /// Port to connect to. Default is 25 for plain, 465 for TLS and 587 for @@ -120,12 +100,12 @@ fn default_sendmail_command() -> String { pub struct EmailConfig { /// Email address to use as From when sending emails #[serde(default = "default_email")] - #[schemars(schema_with = "mailbox_schema")] + #[schemars(schema_with = "crate::schema::mailbox")] pub from: String, /// Email address to use as Reply-To when sending emails #[serde(default = "default_email")] - #[schemars(schema_with = "mailbox_schema")] + #[schemars(schema_with = "crate::schema::mailbox")] pub reply_to: String, /// What backend should be used when sending emails diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index d2c6280f..017f3e3d 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -32,7 +32,7 @@ mod templates; pub use self::{ clients::{ClientAuthMethodConfig, ClientConfig, ClientsConfig}, csrf::CsrfConfig, - database::DatabaseConfig, + database::{ConnectConfig as DatabaseConnectConfig, DatabaseConfig}, email::{EmailConfig, EmailSmtpMode, EmailTransportConfig}, http::{ BindConfig as HttpBindConfig, HttpConfig, ListenerConfig as HttpListenerConfig, diff --git a/crates/config/src/sections/secrets.rs b/crates/config/src/sections/secrets.rs index b4c5e7d2..3a0f79a4 100644 --- a/crates/config/src/sections/secrets.rs +++ b/crates/config/src/sections/secrets.rs @@ -214,7 +214,7 @@ impl ConfigurationSection<'_> for SecretsConfig { }; Ok(Self { - encryption: rand::random(), + encryption: rng.gen(), keys: vec![rsa_key, ec_p256_key, ec_p384_key, ec_k256_key], }) } diff --git a/docs/config.schema.json b/docs/config.schema.json index d075dbd3..42582e88 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -45,8 +45,8 @@ "email": { "description": "Configuration related to sending emails", "default": { - "from": "Authentication Service ", - "reply_to": "Authentication Service ", + "from": "\"Authentication Service\" ", + "reply_to": "\"Authentication Service\" ", "transport": "blackhole" }, "allOf": [ @@ -436,6 +436,7 @@ "type": "object", "anyOf": [ { + "description": "Connect via a full URI", "type": "object", "properties": { "uri": { @@ -447,6 +448,7 @@ } }, { + "description": "Connect via a map of options", "type": "object", "properties": { "database": { @@ -625,13 +627,13 @@ "properties": { "from": { "description": "Email address to use as From when sending emails", - "default": "Authentication Service ", + "default": "\"Authentication Service\" ", "type": "string", "format": "email" }, "reply_to": { "description": "Email address to use as Reply-To when sending emails", - "default": "Authentication Service ", + "default": "\"Authentication Service\" ", "type": "string", "format": "email" } diff --git a/misc/update.sh b/misc/update.sh index 1953109c..b9ec5a48 100644 --- a/misc/update.sh +++ b/misc/update.sh @@ -9,7 +9,7 @@ GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql" set -x # XXX: we shouldn't have to specify this feature -cargo run -p mas-config --features webpki-roots > "${CONFIG_SCHEMA}" +cargo run -p mas-config > "${CONFIG_SCHEMA}" cargo run -p mas-graphql --features webpki-roots > "${GRAPHQL_SCHEMA}" cd "${BASE_DIR}/frontend"