1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Flatten the database config

This commit is contained in:
Quentin Gliech
2024-03-20 17:49:15 +01:00
parent cba431d20e
commit bf50469da1
6 changed files with 220 additions and 186 deletions

View File

@@ -16,8 +16,8 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use mas_config::{ use mas_config::{
BrandingConfig, DatabaseConfig, DatabaseConnectConfig, EmailConfig, EmailSmtpMode, BrandingConfig, DatabaseConfig, EmailConfig, EmailSmtpMode, EmailTransportConfig,
EmailTransportConfig, PasswordsConfig, PolicyConfig, TemplatesConfig, PasswordsConfig, PolicyConfig, TemplatesConfig,
}; };
use mas_email::{MailTransport, Mailer}; use mas_email::{MailTransport, Mailer};
use mas_handlers::{passwords::PasswordManager, ActivityTracker}; use mas_handlers::{passwords::PasswordManager, ActivityTracker};
@@ -151,47 +151,37 @@ pub async fn templates_from_config(
fn database_connect_options_from_config( fn database_connect_options_from_config(
config: &DatabaseConfig, config: &DatabaseConfig,
) -> Result<PgConnectOptions, anyhow::Error> { ) -> Result<PgConnectOptions, anyhow::Error> {
let options = match &config.options { let options = if let Some(uri) = config.uri.as_deref() {
DatabaseConnectConfig::Uri { uri } => uri uri.parse()
.parse() .context("could not parse database connection string")?
.context("could not parse database connection string")?, } else {
DatabaseConnectConfig::Options { let mut opts = PgConnectOptions::new().application_name("matrix-authentication-service");
host,
port,
socket,
username,
password,
database,
} => {
let mut opts =
PgConnectOptions::new().application_name("matrix-authentication-service");
if let Some(host) = host { if let Some(host) = config.host.as_deref() {
opts = opts.host(host); opts = opts.host(host);
} }
if let Some(port) = port { if let Some(port) = config.port {
opts = opts.port(*port); opts = opts.port(port);
} }
if let Some(socket) = socket { if let Some(socket) = config.socket.as_deref() {
opts = opts.socket(socket); opts = opts.socket(socket);
} }
if let Some(username) = username { if let Some(username) = config.username.as_deref() {
opts = opts.username(username); opts = opts.username(username);
} }
if let Some(password) = password { if let Some(password) = config.password.as_deref() {
opts = opts.password(password); opts = opts.password(password);
} }
if let Some(database) = database { if let Some(database) = config.database.as_deref() {
opts = opts.database(database); opts = opts.database(database);
} }
opts opts
}
}; };
let options = options let options = options

View File

@@ -16,36 +16,27 @@
use schemars::{ use schemars::{
gen::SchemaGenerator, gen::SchemaGenerator,
schema::{InstanceType, NumberValidation, Schema, SchemaObject}, schema::{InstanceType, Schema, SchemaObject},
JsonSchema,
}; };
/// A network port /// A network hostname
pub fn port(_gen: &mut SchemaGenerator) -> Schema { pub struct Hostname;
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::Integer.into()), impl JsonSchema for Hostname {
number: Some(Box::new(NumberValidation { fn schema_name() -> String {
minimum: Some(1.0), "Hostname".to_string()
maximum: Some(65535.0),
..NumberValidation::default()
})),
..SchemaObject::default()
})
} }
/// A network hostname fn json_schema(gen: &mut SchemaGenerator) -> Schema {
pub fn hostname(_gen: &mut SchemaGenerator) -> Schema { hostname(gen)
}
}
fn hostname(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject { Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()), instance_type: Some(InstanceType::String.into()),
format: Some("hostname".to_owned()), format: Some("hostname".to_owned()),
..SchemaObject::default() ..SchemaObject::default()
}) })
} }
/// An email address
pub fn mailbox(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some("email".to_owned()),
..SchemaObject::default()
})
}

View File

@@ -19,13 +19,14 @@ use camino::Utf8PathBuf;
use rand::Rng; use rand::Rng;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none}; use serde_with::serde_as;
use super::ConfigurationSection; use super::ConfigurationSection;
use crate::schema; use crate::schema;
fn default_connection_string() -> String { #[allow(clippy::unnecessary_wraps)]
"postgresql://".to_owned() fn default_connection_string() -> Option<String> {
Some("postgresql://".to_owned())
} }
fn default_max_connections() -> NonZeroU32 { fn default_max_connections() -> NonZeroU32 {
@@ -49,7 +50,13 @@ fn default_max_lifetime() -> Option<Duration> {
impl Default for DatabaseConfig { impl Default for DatabaseConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
options: ConnectConfig::default(), uri: default_connection_string(),
host: None,
port: None,
socket: None,
username: None,
password: None,
database: None,
max_connections: default_max_connections(), max_connections: default_max_connections(),
min_connections: Default::default(), min_connections: Default::default(),
connect_timeout: default_connect_timeout(), connect_timeout: default_connect_timeout(),
@@ -59,63 +66,56 @@ impl Default for DatabaseConfig {
} }
} }
/// Database connection configuration
#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(untagged)]
pub enum ConnectConfig {
/// Connect via a full URI
Uri {
/// Connection URI
#[schemars(url, default = "default_connection_string")]
uri: String,
},
/// Connect via a map of options
Options {
/// Name of host to connect to
#[schemars(schema_with = "schema::hostname")]
#[serde(default)]
host: Option<String>,
/// Port number to connect at the server host
#[schemars(schema_with = "schema::port")]
#[serde(default)]
port: Option<u16>,
/// Directory containing the UNIX socket to connect to
#[serde(default)]
#[schemars(with = "Option<String>")]
socket: Option<Utf8PathBuf>,
/// PostgreSQL user name to connect as
#[serde(default)]
username: Option<String>,
/// Password to be used if the server demands password authentication
#[serde(default)]
password: Option<String>,
/// The database name
#[serde(default)]
database: Option<String>,
},
}
impl Default for ConnectConfig {
fn default() -> Self {
Self::Uri {
uri: default_connection_string(),
}
}
}
/// Database connection configuration /// Database connection configuration
#[serde_as] #[serde_as]
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DatabaseConfig { pub struct DatabaseConfig {
/// Options related to how to connect to the database /// Connection URI
#[serde(default, flatten)] ///
pub options: ConnectConfig, /// This must not be specified if `host`, `port`, `socket`, `username`,
/// `password`, or `database` are specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(url, default = "default_connection_string")]
pub uri: Option<String>,
/// Name of host to connect to
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option::<schema::Hostname>")]
pub host: Option<String>,
/// Port number to connect at the server host
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(range(min = 1, max = 65535))]
pub port: Option<u16>,
/// Directory containing the UNIX socket to connect to
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub socket: Option<Utf8PathBuf>,
/// PostgreSQL user name to connect as
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
/// Password to be used if the server demands password authentication
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
/// The database name
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
/// Set the maximum number of connections the pool should maintain /// Set the maximum number of connections the pool should maintain
#[serde(default = "default_max_connections")] #[serde(default = "default_max_connections")]
@@ -133,13 +133,19 @@ pub struct DatabaseConfig {
/// Set a maximum idle duration for individual connections /// Set a maximum idle duration for individual connections
#[schemars(with = "Option<u64>")] #[schemars(with = "Option<u64>")]
#[serde(default = "default_idle_timeout")] #[serde(
default = "default_idle_timeout",
skip_serializing_if = "Option::is_none"
)]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")] #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
pub idle_timeout: Option<Duration>, pub idle_timeout: Option<Duration>,
/// Set the maximum lifetime of individual connections /// Set the maximum lifetime of individual connections
#[schemars(with = "u64")] #[schemars(with = "u64")]
#[serde(default = "default_max_lifetime")] #[serde(
default = "default_max_lifetime",
skip_serializing_if = "Option::is_none"
)]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")] #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
pub max_lifetime: Option<Duration>, pub max_lifetime: Option<Duration>,
} }
@@ -155,6 +161,31 @@ impl ConfigurationSection for DatabaseConfig {
Ok(Self::default()) Ok(Self::default())
} }
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
// Check that the user did not specify both `uri` and the split options at the
// same time
let has_split_options = self.host.is_some()
|| self.port.is_some()
|| self.socket.is_some()
|| self.username.is_some()
|| self.password.is_some()
|| self.database.is_some();
if self.uri.is_some() && has_split_options {
let mut error = figment::error::Error::from(
"uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
);
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), "uri".to_owned()];
return Err(error);
}
Ok(())
}
fn test() -> Self { fn test() -> Self {
Self::default() Self::default()
} }
@@ -185,10 +216,8 @@ mod tests {
.extract_inner::<DatabaseConfig>("database")?; .extract_inner::<DatabaseConfig>("database")?;
assert_eq!( assert_eq!(
config.options, config.uri.as_deref(),
ConnectConfig::Uri { Some("postgresql://user:password@host/database")
uri: "postgresql://user:password@host/database".to_string()
}
); );
Ok(()) Ok(())

View File

@@ -59,7 +59,7 @@ pub enum EmailTransportConfig {
mode: EmailSmtpMode, mode: EmailSmtpMode,
/// Hostname to connect to /// Hostname to connect to
#[schemars(schema_with = "crate::schema::hostname")] #[schemars(with = "crate::schema::Hostname")]
hostname: String, hostname: String,
/// Port to connect to. Default is 25 for plain, 465 for TLS and 587 for /// Port to connect to. Default is 25 for plain, 465 for TLS and 587 for
@@ -103,12 +103,12 @@ fn default_sendmail_command() -> String {
pub struct EmailConfig { pub struct EmailConfig {
/// Email address to use as From when sending emails /// Email address to use as From when sending emails
#[serde(default = "default_email")] #[serde(default = "default_email")]
#[schemars(schema_with = "crate::schema::mailbox")] #[schemars(email)]
pub from: String, pub from: String,
/// Email address to use as Reply-To when sending emails /// Email address to use as Reply-To when sending emails
#[serde(default = "default_email")] #[serde(default = "default_email")]
#[schemars(schema_with = "crate::schema::mailbox")] #[schemars(email)]
pub reply_to: String, pub reply_to: String,
/// What backend should be used when sending emails /// What backend should be used when sending emails

View File

@@ -34,7 +34,7 @@ mod upstream_oauth2;
pub use self::{ pub use self::{
branding::BrandingConfig, branding::BrandingConfig,
clients::{ClientAuthMethodConfig, ClientConfig, ClientsConfig}, clients::{ClientAuthMethodConfig, ClientConfig, ClientsConfig},
database::{ConnectConfig as DatabaseConnectConfig, DatabaseConfig}, database::DatabaseConfig,
email::{EmailConfig, EmailSmtpMode, EmailTransportConfig}, email::{EmailConfig, EmailSmtpMode, EmailTransportConfig},
experimental::ExperimentalConfig, experimental::ExperimentalConfig,
http::{ http::{
@@ -137,6 +137,24 @@ impl ConfigurationSection for RootConfig {
}) })
} }
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
self.clients.validate(figment)?;
self.http.validate(figment)?;
self.database.validate(figment)?;
self.telemetry.validate(figment)?;
self.templates.validate(figment)?;
self.email.validate(figment)?;
self.passwords.validate(figment)?;
self.secrets.validate(figment)?;
self.matrix.validate(figment)?;
self.policy.validate(figment)?;
self.upstream_oauth2.validate(figment)?;
self.branding.validate(figment)?;
self.experimental.validate(figment)?;
Ok(())
}
fn test() -> Self { fn test() -> Self {
Self { Self {
clients: ClientsConfig::test(), clients: ClientsConfig::test(),
@@ -209,6 +227,21 @@ impl ConfigurationSection for AppConfig {
}) })
} }
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
self.http.validate(figment)?;
self.database.validate(figment)?;
self.templates.validate(figment)?;
self.email.validate(figment)?;
self.passwords.validate(figment)?;
self.secrets.validate(figment)?;
self.matrix.validate(figment)?;
self.policy.validate(figment)?;
self.branding.validate(figment)?;
self.experimental.validate(figment)?;
Ok(())
}
fn test() -> Self { fn test() -> Self {
Self { Self {
http: HttpConfig::test(), http: HttpConfig::test(),

View File

@@ -1016,60 +1016,44 @@
"DatabaseConfig": { "DatabaseConfig": {
"description": "Database connection configuration", "description": "Database connection configuration",
"type": "object", "type": "object",
"anyOf": [
{
"description": "Connect via a full URI",
"type": "object",
"properties": { "properties": {
"uri": { "uri": {
"description": "Connection URI", "description": "Connection URI\n\nThis must not be specified if `host`, `port`, `socket`, `username`, `password`, or `database` are specified.",
"default": "postgresql://", "default": "postgresql://",
"type": "string", "type": "string",
"format": "uri" "format": "uri"
}
}
}, },
{
"description": "Connect via a map of options",
"type": "object",
"properties": {
"host": { "host": {
"description": "Name of host to connect to", "description": "Name of host to connect to\n\nThis must not be specified if `uri` is specified.",
"default": null, "allOf": [
"type": "string", {
"format": "hostname" "$ref": "#/definitions/Hostname"
}
]
}, },
"port": { "port": {
"description": "Port number to connect at the server host", "description": "Port number to connect at the server host\n\nThis must not be specified if `uri` is specified.",
"default": null,
"type": "integer", "type": "integer",
"format": "uint16",
"maximum": 65535.0, "maximum": 65535.0,
"minimum": 1.0 "minimum": 1.0
}, },
"socket": { "socket": {
"description": "Directory containing the UNIX socket to connect to", "description": "Directory containing the UNIX socket to connect to\n\nThis must not be specified if `uri` is specified.",
"default": null,
"type": "string" "type": "string"
}, },
"username": { "username": {
"description": "PostgreSQL user name to connect as", "description": "PostgreSQL user name to connect as\n\nThis must not be specified if `uri` is specified.",
"default": null,
"type": "string" "type": "string"
}, },
"password": { "password": {
"description": "Password to be used if the server demands password authentication", "description": "Password to be used if the server demands password authentication\n\nThis must not be specified if `uri` is specified.",
"default": null,
"type": "string" "type": "string"
}, },
"database": { "database": {
"description": "The database name", "description": "The database name\n\nThis must not be specified if `uri` is specified.",
"default": null,
"type": "string" "type": "string"
} },
}
}
],
"properties": {
"max_connections": { "max_connections": {
"description": "Set the maximum number of connections the pool should maintain", "description": "Set the maximum number of connections the pool should maintain",
"default": 10, "default": 10,
@@ -1107,6 +1091,10 @@
} }
} }
}, },
"Hostname": {
"type": "string",
"format": "hostname"
},
"TelemetryConfig": { "TelemetryConfig": {
"description": "Configuration related to sending monitoring data", "description": "Configuration related to sending monitoring data",
"type": "object", "type": "object",
@@ -1401,8 +1389,11 @@
}, },
"hostname": { "hostname": {
"description": "Hostname to connect to", "description": "Hostname to connect to",
"type": "string", "allOf": [
"format": "hostname" {
"$ref": "#/definitions/Hostname"
}
]
}, },
"port": { "port": {
"description": "Port to connect to. Default is 25 for plain, 465 for TLS and 587 for StartTLS", "description": "Port to connect to. Default is 25 for plain, 465 for TLS and 587 for StartTLS",