diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index de808df6..8ebe6d25 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -16,8 +16,8 @@ use std::time::Duration; use anyhow::Context; use mas_config::{ - BrandingConfig, DatabaseConfig, DatabaseConnectConfig, EmailConfig, EmailSmtpMode, - EmailTransportConfig, PasswordsConfig, PolicyConfig, TemplatesConfig, + BrandingConfig, DatabaseConfig, EmailConfig, EmailSmtpMode, EmailTransportConfig, + PasswordsConfig, PolicyConfig, TemplatesConfig, }; use mas_email::{MailTransport, Mailer}; use mas_handlers::{passwords::PasswordManager, ActivityTracker}; @@ -151,47 +151,37 @@ pub async fn templates_from_config( fn database_connect_options_from_config( config: &DatabaseConfig, ) -> Result { - let options = match &config.options { - DatabaseConnectConfig::Uri { uri } => uri - .parse() - .context("could not parse database connection string")?, - DatabaseConnectConfig::Options { - host, - port, - socket, - username, - password, - database, - } => { - let mut opts = - PgConnectOptions::new().application_name("matrix-authentication-service"); + let options = if let Some(uri) = config.uri.as_deref() { + uri.parse() + .context("could not parse database connection string")? + } else { + let mut opts = PgConnectOptions::new().application_name("matrix-authentication-service"); - if let Some(host) = host { - opts = opts.host(host); - } - - if let Some(port) = port { - opts = opts.port(*port); - } - - if let Some(socket) = socket { - opts = opts.socket(socket); - } - - if let Some(username) = username { - opts = opts.username(username); - } - - if let Some(password) = password { - opts = opts.password(password); - } - - if let Some(database) = database { - opts = opts.database(database); - } - - opts + if let Some(host) = config.host.as_deref() { + opts = opts.host(host); } + + if let Some(port) = config.port { + opts = opts.port(port); + } + + if let Some(socket) = config.socket.as_deref() { + opts = opts.socket(socket); + } + + if let Some(username) = config.username.as_deref() { + opts = opts.username(username); + } + + if let Some(password) = config.password.as_deref() { + opts = opts.password(password); + } + + if let Some(database) = config.database.as_deref() { + opts = opts.database(database); + } + + opts }; let options = options diff --git a/crates/config/src/schema.rs b/crates/config/src/schema.rs index 5f47ddb6..ea26ad39 100644 --- a/crates/config/src/schema.rs +++ b/crates/config/src/schema.rs @@ -16,36 +16,27 @@ use schemars::{ gen::SchemaGenerator, - schema::{InstanceType, NumberValidation, Schema, SchemaObject}, + schema::{InstanceType, Schema, SchemaObject}, + JsonSchema, }; -/// A network port -pub fn port(_gen: &mut SchemaGenerator) -> Schema { - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::Integer.into()), - number: Some(Box::new(NumberValidation { - minimum: Some(1.0), - maximum: Some(65535.0), - ..NumberValidation::default() - })), - ..SchemaObject::default() - }) +/// A network hostname +pub struct Hostname; + +impl JsonSchema for Hostname { + fn schema_name() -> String { + "Hostname".to_string() + } + + fn json_schema(gen: &mut SchemaGenerator) -> Schema { + hostname(gen) + } } -/// A network hostname -pub fn hostname(_gen: &mut SchemaGenerator) -> Schema { +fn hostname(_gen: &mut SchemaGenerator) -> Schema { Schema::Object(SchemaObject { instance_type: Some(InstanceType::String.into()), format: Some("hostname".to_owned()), ..SchemaObject::default() }) } - -/// An email address -pub fn mailbox(_gen: &mut SchemaGenerator) -> Schema { - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - format: Some("email".to_owned()), - ..SchemaObject::default() - }) -} diff --git a/crates/config/src/sections/database.rs b/crates/config/src/sections/database.rs index c7521161..beb82ae2 100644 --- a/crates/config/src/sections/database.rs +++ b/crates/config/src/sections/database.rs @@ -19,13 +19,14 @@ use camino::Utf8PathBuf; use rand::Rng; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, skip_serializing_none}; +use serde_with::serde_as; use super::ConfigurationSection; use crate::schema; -fn default_connection_string() -> String { - "postgresql://".to_owned() +#[allow(clippy::unnecessary_wraps)] +fn default_connection_string() -> Option { + Some("postgresql://".to_owned()) } fn default_max_connections() -> NonZeroU32 { @@ -49,7 +50,13 @@ fn default_max_lifetime() -> Option { impl Default for DatabaseConfig { fn default() -> Self { Self { - options: ConnectConfig::default(), + uri: default_connection_string(), + host: None, + port: None, + socket: None, + username: None, + password: None, + database: None, max_connections: default_max_connections(), min_connections: Default::default(), connect_timeout: default_connect_timeout(), @@ -59,63 +66,56 @@ impl Default for DatabaseConfig { } } -/// Database connection configuration -#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] -#[serde(untagged)] -pub enum ConnectConfig { - /// Connect via a full URI - Uri { - /// Connection URI - #[schemars(url, default = "default_connection_string")] - uri: String, - }, - /// Connect via a map of options - Options { - /// Name of host to connect to - #[schemars(schema_with = "schema::hostname")] - #[serde(default)] - host: Option, - - /// Port number to connect at the server host - #[schemars(schema_with = "schema::port")] - #[serde(default)] - port: Option, - - /// Directory containing the UNIX socket to connect to - #[serde(default)] - #[schemars(with = "Option")] - socket: Option, - - /// PostgreSQL user name to connect as - #[serde(default)] - username: Option, - - /// Password to be used if the server demands password authentication - #[serde(default)] - password: Option, - - /// The database name - #[serde(default)] - database: Option, - }, -} - -impl Default for ConnectConfig { - fn default() -> Self { - Self::Uri { - uri: default_connection_string(), - } - } -} - /// Database connection configuration #[serde_as] -#[skip_serializing_none] #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct DatabaseConfig { - /// Options related to how to connect to the database - #[serde(default, flatten)] - pub options: ConnectConfig, + /// Connection URI + /// + /// This must not be specified if `host`, `port`, `socket`, `username`, + /// `password`, or `database` are specified. + #[serde(skip_serializing_if = "Option::is_none")] + #[schemars(url, default = "default_connection_string")] + pub uri: Option, + + /// Name of host to connect to + /// + /// This must not be specified if `uri` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + #[schemars(with = "Option::")] + pub host: Option, + + /// Port number to connect at the server host + /// + /// This must not be specified if `uri` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + #[schemars(range(min = 1, max = 65535))] + pub port: Option, + + /// Directory containing the UNIX socket to connect to + /// + /// This must not be specified if `uri` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + #[schemars(with = "Option")] + pub socket: Option, + + /// PostgreSQL user name to connect as + /// + /// This must not be specified if `uri` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option, + + /// Password to be used if the server demands password authentication + /// + /// This must not be specified if `uri` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + pub password: Option, + + /// The database name + /// + /// This must not be specified if `uri` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + pub database: Option, /// Set the maximum number of connections the pool should maintain #[serde(default = "default_max_connections")] @@ -133,13 +133,19 @@ pub struct DatabaseConfig { /// Set a maximum idle duration for individual connections #[schemars(with = "Option")] - #[serde(default = "default_idle_timeout")] + #[serde( + default = "default_idle_timeout", + skip_serializing_if = "Option::is_none" + )] #[serde_as(as = "Option>")] pub idle_timeout: Option, /// Set the maximum lifetime of individual connections #[schemars(with = "u64")] - #[serde(default = "default_max_lifetime")] + #[serde( + default = "default_max_lifetime", + skip_serializing_if = "Option::is_none" + )] #[serde_as(as = "Option>")] pub max_lifetime: Option, } @@ -155,6 +161,31 @@ impl ConfigurationSection for DatabaseConfig { Ok(Self::default()) } + fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> { + let metadata = figment.find_metadata(Self::PATH.unwrap()); + + // Check that the user did not specify both `uri` and the split options at the + // same time + let has_split_options = self.host.is_some() + || self.port.is_some() + || self.socket.is_some() + || self.username.is_some() + || self.password.is_some() + || self.database.is_some(); + + if self.uri.is_some() && has_split_options { + let mut error = figment::error::Error::from( + "uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(), + ); + error.metadata = metadata.cloned(); + error.profile = Some(figment::Profile::Default); + error.path = vec![Self::PATH.unwrap().to_owned(), "uri".to_owned()]; + return Err(error); + } + + Ok(()) + } + fn test() -> Self { Self::default() } @@ -185,10 +216,8 @@ mod tests { .extract_inner::("database")?; assert_eq!( - config.options, - ConnectConfig::Uri { - uri: "postgresql://user:password@host/database".to_string() - } + config.uri.as_deref(), + Some("postgresql://user:password@host/database") ); Ok(()) diff --git a/crates/config/src/sections/email.rs b/crates/config/src/sections/email.rs index 77beb838..af291c94 100644 --- a/crates/config/src/sections/email.rs +++ b/crates/config/src/sections/email.rs @@ -59,7 +59,7 @@ pub enum EmailTransportConfig { mode: EmailSmtpMode, /// Hostname to connect to - #[schemars(schema_with = "crate::schema::hostname")] + #[schemars(with = "crate::schema::Hostname")] hostname: String, /// Port to connect to. Default is 25 for plain, 465 for TLS and 587 for @@ -103,12 +103,12 @@ fn default_sendmail_command() -> String { pub struct EmailConfig { /// Email address to use as From when sending emails #[serde(default = "default_email")] - #[schemars(schema_with = "crate::schema::mailbox")] + #[schemars(email)] pub from: String, /// Email address to use as Reply-To when sending emails #[serde(default = "default_email")] - #[schemars(schema_with = "crate::schema::mailbox")] + #[schemars(email)] pub reply_to: String, /// What backend should be used when sending emails diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index 6032664e..6e439bba 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -34,7 +34,7 @@ mod upstream_oauth2; pub use self::{ branding::BrandingConfig, clients::{ClientAuthMethodConfig, ClientConfig, ClientsConfig}, - database::{ConnectConfig as DatabaseConnectConfig, DatabaseConfig}, + database::DatabaseConfig, email::{EmailConfig, EmailSmtpMode, EmailTransportConfig}, experimental::ExperimentalConfig, http::{ @@ -137,6 +137,24 @@ impl ConfigurationSection for RootConfig { }) } + fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> { + self.clients.validate(figment)?; + self.http.validate(figment)?; + self.database.validate(figment)?; + self.telemetry.validate(figment)?; + self.templates.validate(figment)?; + self.email.validate(figment)?; + self.passwords.validate(figment)?; + self.secrets.validate(figment)?; + self.matrix.validate(figment)?; + self.policy.validate(figment)?; + self.upstream_oauth2.validate(figment)?; + self.branding.validate(figment)?; + self.experimental.validate(figment)?; + + Ok(()) + } + fn test() -> Self { Self { clients: ClientsConfig::test(), @@ -209,6 +227,21 @@ impl ConfigurationSection for AppConfig { }) } + fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> { + self.http.validate(figment)?; + self.database.validate(figment)?; + self.templates.validate(figment)?; + self.email.validate(figment)?; + self.passwords.validate(figment)?; + self.secrets.validate(figment)?; + self.matrix.validate(figment)?; + self.policy.validate(figment)?; + self.branding.validate(figment)?; + self.experimental.validate(figment)?; + + Ok(()) + } + fn test() -> Self { Self { http: HttpConfig::test(), diff --git a/docs/config.schema.json b/docs/config.schema.json index f59b87a6..7c97cc88 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1016,60 +1016,44 @@ "DatabaseConfig": { "description": "Database connection configuration", "type": "object", - "anyOf": [ - { - "description": "Connect via a full URI", - "type": "object", - "properties": { - "uri": { - "description": "Connection URI", - "default": "postgresql://", - "type": "string", - "format": "uri" - } - } - }, - { - "description": "Connect via a map of options", - "type": "object", - "properties": { - "host": { - "description": "Name of host to connect to", - "default": null, - "type": "string", - "format": "hostname" - }, - "port": { - "description": "Port number to connect at the server host", - "default": null, - "type": "integer", - "maximum": 65535.0, - "minimum": 1.0 - }, - "socket": { - "description": "Directory containing the UNIX socket to connect to", - "default": null, - "type": "string" - }, - "username": { - "description": "PostgreSQL user name to connect as", - "default": null, - "type": "string" - }, - "password": { - "description": "Password to be used if the server demands password authentication", - "default": null, - "type": "string" - }, - "database": { - "description": "The database name", - "default": null, - "type": "string" - } - } - } - ], "properties": { + "uri": { + "description": "Connection URI\n\nThis must not be specified if `host`, `port`, `socket`, `username`, `password`, or `database` are specified.", + "default": "postgresql://", + "type": "string", + "format": "uri" + }, + "host": { + "description": "Name of host to connect to\n\nThis must not be specified if `uri` is specified.", + "allOf": [ + { + "$ref": "#/definitions/Hostname" + } + ] + }, + "port": { + "description": "Port number to connect at the server host\n\nThis must not be specified if `uri` is specified.", + "type": "integer", + "format": "uint16", + "maximum": 65535.0, + "minimum": 1.0 + }, + "socket": { + "description": "Directory containing the UNIX socket to connect to\n\nThis must not be specified if `uri` is specified.", + "type": "string" + }, + "username": { + "description": "PostgreSQL user name to connect as\n\nThis must not be specified if `uri` is specified.", + "type": "string" + }, + "password": { + "description": "Password to be used if the server demands password authentication\n\nThis must not be specified if `uri` is specified.", + "type": "string" + }, + "database": { + "description": "The database name\n\nThis must not be specified if `uri` is specified.", + "type": "string" + }, "max_connections": { "description": "Set the maximum number of connections the pool should maintain", "default": 10, @@ -1107,6 +1091,10 @@ } } }, + "Hostname": { + "type": "string", + "format": "hostname" + }, "TelemetryConfig": { "description": "Configuration related to sending monitoring data", "type": "object", @@ -1401,8 +1389,11 @@ }, "hostname": { "description": "Hostname to connect to", - "type": "string", - "format": "hostname" + "allOf": [ + { + "$ref": "#/definitions/Hostname" + } + ] }, "port": { "description": "Port to connect to. Default is 25 for plain, 465 for TLS and 587 for StartTLS",