diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index 9572bf1b..212e4262 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -17,7 +17,7 @@ use std::process::ExitCode; use anyhow::Context; use clap::Parser; use figment::Figment; -use mas_config::{ConfigurationSection, DatabaseConfig}; +use mas_config::{ConfigurationSectionExt, DatabaseConfig}; use mas_storage_pg::MIGRATOR; use tracing::{info_span, Instrument}; @@ -38,7 +38,7 @@ enum Subcommand { impl Options { pub async fn run(self, figment: &Figment) -> anyhow::Result { let _span = info_span!("cli.database.migrate").entered(); - let config = DatabaseConfig::extract(figment)?; + let config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&config).await?; // Run pending migrations diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 25308695..ea86dd04 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -18,7 +18,7 @@ use clap::Parser; use figment::Figment; use http_body_util::BodyExt; use hyper::{Response, Uri}; -use mas_config::{ConfigurationSection, PolicyConfig}; +use mas_config::{ConfigurationSectionExt, PolicyConfig}; use mas_handlers::HttpClientFactory; use mas_http::HttpServiceExt; use tokio::io::AsyncWriteExt; @@ -124,7 +124,7 @@ impl Options { SC::Policy => { let _span = info_span!("cli.debug.policy").entered(); - let config = PolicyConfig::extract(figment)?; + let config = PolicyConfig::extract_or_default(figment)?; info!("Loading and compiling the policy module"); let policy_factory = policy_factory_from_config(&config).await?; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index c76ec7bc..42b74a1c 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -19,7 +19,9 @@ use clap::{ArgAction, CommandFactory, Parser}; use console::{pad_str, style, Alignment, Style, Term}; use dialoguer::{theme::ColorfulTheme, Confirm, FuzzySelect, Input, Password}; use figment::Figment; -use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig}; +use mas_config::{ + ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, PasswordsConfig, +}; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_email::Address; use mas_handlers::HttpClientFactory; @@ -192,8 +194,8 @@ impl Options { let _span = info_span!("cli.manage.set_password", user.username = %username).entered(); - let database_config = DatabaseConfig::extract(figment)?; - let passwords_config = PasswordsConfig::extract(figment)?; + let database_config = DatabaseConfig::extract_or_default(figment)?; + let passwords_config = PasswordsConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; @@ -233,7 +235,7 @@ impl Options { ) .entered(); - let database_config = DatabaseConfig::extract(figment)?; + let database_config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -267,7 +269,7 @@ impl Options { admin, device_id, } => { - let database_config = DatabaseConfig::extract(figment)?; + let database_config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -312,7 +314,7 @@ impl Options { SC::ProvisionAllUsers => { let _span = info_span!("cli.manage.provision_all_users").entered(); - let database_config = DatabaseConfig::extract(figment)?; + let database_config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&database_config).await?; let mut txn = conn.begin().await?; @@ -338,7 +340,7 @@ impl Options { SC::KillSessions { username, dry_run } => { let _span = info_span!("cli.manage.kill_sessions", user.username = username).entered(); - let database_config = DatabaseConfig::extract(figment)?; + let database_config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -408,7 +410,7 @@ impl Options { deactivate, } => { let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); - let config = DatabaseConfig::extract(figment)?; + let config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -440,7 +442,7 @@ impl Options { SC::UnlockUser { username } => { let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); - let config = DatabaseConfig::extract(figment)?; + let config = DatabaseConfig::extract_or_default(figment)?; let mut conn = database_connection_from_config(&config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -473,8 +475,8 @@ impl Options { ignore_password_complexity, } => { let http_client_factory = HttpClientFactory::new(); - let password_config = PasswordsConfig::extract(figment)?; - let database_config = DatabaseConfig::extract(figment)?; + let password_config = PasswordsConfig::extract_or_default(figment)?; + let database_config = DatabaseConfig::extract_or_default(figment)?; let matrix_config = MatrixConfig::extract(figment)?; let password_manager = password_manager_from_config(&password_config).await?; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 0eba7a72..aaf6a98e 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -18,7 +18,9 @@ use anyhow::Context; use clap::Parser; use figment::Figment; use itertools::Itertools; -use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config}; +use mas_config::{ + AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config, +}; use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; @@ -103,8 +105,8 @@ impl Options { } else { // Sync the configuration with the database let mut conn = pool.acquire().await?; - let clients_config = ClientsConfig::extract(figment)?; - let upstream_oauth2_config = UpstreamOAuth2Config::extract(figment)?; + let clients_config = ClientsConfig::extract_or_default(figment)?; + let upstream_oauth2_config = UpstreamOAuth2Config::extract_or_default(figment)?; crate::sync::config_sync( upstream_oauth2_config, diff --git a/crates/cli/src/commands/templates.rs b/crates/cli/src/commands/templates.rs index 5597356a..9b104310 100644 --- a/crates/cli/src/commands/templates.rs +++ b/crates/cli/src/commands/templates.rs @@ -17,8 +17,8 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; use mas_config::{ - AccountConfig, BrandingConfig, CaptchaConfig, ConfigurationSection, ExperimentalConfig, - MatrixConfig, PasswordsConfig, TemplatesConfig, + AccountConfig, BrandingConfig, CaptchaConfig, ConfigurationSection, ConfigurationSectionExt, + ExperimentalConfig, MatrixConfig, PasswordsConfig, TemplatesConfig, }; use mas_storage::{Clock, SystemClock}; use rand::SeedableRng; @@ -45,13 +45,13 @@ impl Options { SC::Check => { let _span = info_span!("cli.templates.check").entered(); - let template_config = TemplatesConfig::extract(figment)?; - let branding_config = BrandingConfig::extract(figment)?; + let template_config = TemplatesConfig::extract_or_default(figment)?; + let branding_config = BrandingConfig::extract_or_default(figment)?; let matrix_config = MatrixConfig::extract(figment)?; - let experimental_config = ExperimentalConfig::extract(figment)?; - let password_config = PasswordsConfig::extract(figment)?; - let account_config = AccountConfig::extract(figment)?; - let captcha_config = CaptchaConfig::extract(figment)?; + let experimental_config = ExperimentalConfig::extract_or_default(figment)?; + let password_config = PasswordsConfig::extract_or_default(figment)?; + let account_config = AccountConfig::extract_or_default(figment)?; + let captcha_config = CaptchaConfig::extract_or_default(figment)?; let clock = SystemClock::default(); // XXX: we should disallow SeedableRng::from_entropy diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs index 26b45ea2..dbcd6cf5 100644 --- a/crates/config/src/lib.rs +++ b/crates/config/src/lib.rs @@ -26,4 +26,7 @@ pub(crate) mod schema; mod sections; pub(crate) mod util; -pub use self::{sections::*, util::ConfigurationSection}; +pub use self::{ + sections::*, + util::{ConfigurationSection, ConfigurationSectionExt}, +}; diff --git a/crates/config/src/util.rs b/crates/config/src/util.rs index 0603ec43..3608d6b2 100644 --- a/crates/config/src/util.rs +++ b/crates/config/src/util.rs @@ -46,3 +46,32 @@ pub trait ConfigurationSection: Sized + DeserializeOwned { Ok(this) } } + +/// Extension trait for [`ConfigurationSection`] to allow extracting the +/// configuration section from a [`Figment`] or return the default value if the +/// section is not present. +pub trait ConfigurationSectionExt: ConfigurationSection + Default { + /// Extract the configuration section from the given [`Figment`], or return + /// the default value if the section is not present. + /// + /// # Errors + /// + /// Returns an error if the configuration section is invalid. + fn extract_or_default(figment: &Figment) -> Result { + let this: Self = if let Some(path) = Self::PATH { + // If the configuration section is not present, we return the default value + if !figment.contains(path) { + return Ok(Self::default()); + } + + figment.extract_inner(path)? + } else { + figment.extract()? + }; + + this.validate(figment)?; + Ok(this) + } +} + +impl ConfigurationSectionExt for T {}