1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Improve the configuration schema

This commit is contained in:
Quentin Gliech
2022-01-28 11:49:15 +01:00
parent 2db595b9c0
commit 05f0756c13
12 changed files with 272 additions and 45 deletions

View File

@ -14,7 +14,7 @@
use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig};
use schemars::schema_for;
use schemars::gen::SchemaSettings;
use tracing::info;
use super::RootCommand;
@ -52,7 +52,12 @@ impl ConfigCommand {
Ok(())
}
SC::Schema => {
let schema = schema_for!(RootConfig);
let settings = SchemaSettings::draft07().with(|s| {
s.option_nullable = false;
s.option_add_null_type = false;
});
let gen = settings.into_generator();
let schema = gen.into_root_schema_for::<RootConfig>();
serde_yaml::to_writer(std::io::stdout(), &schema)?;

View File

@ -13,20 +13,26 @@
// limitations under the License.
use async_trait::async_trait;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
fn secret_schema(gen: &mut SchemaGenerator) -> Schema {
String::json_schema(gen)
fn example_secret() -> &'static str {
"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff"
}
/// Cookies-related configuration
#[serde_as]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct CookiesConfig {
#[schemars(schema_with = "secret_schema")]
/// Encryption key for secure cookies
#[schemars(
with = "String",
regex(pattern = r"[0-9a-fA-F]{64}"),
example = "example_secret"
)]
#[serde_as(as = "serde_with::hex::Hex")]
pub secret: [u8; 32],
}

View File

@ -14,7 +14,7 @@
use async_trait::async_trait;
use chrono::Duration;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
@ -24,14 +24,12 @@ fn default_ttl() -> Duration {
Duration::hours(1)
}
fn ttl_schema(gen: &mut SchemaGenerator) -> Schema {
u64::json_schema(gen)
}
/// Configuration related to Cross-Site Request Forgery protections
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct CsrfConfig {
#[schemars(schema_with = "ttl_schema")]
/// Time-to-live of a CSRF token in seconds
#[schemars(with = "u64", range(min = 60, max = 86400))]
#[serde(default = "default_ttl")]
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
pub ttl: Duration,

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{path::PathBuf, time::Duration};
use std::{num::NonZeroU32, path::PathBuf, time::Duration};
use anyhow::Context;
use async_trait::async_trait;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none};
use sqlx::{
@ -26,9 +26,14 @@ use sqlx::{
use tracing::log::LevelFilter;
use super::ConfigurationSection;
use crate::schema;
fn default_max_connections() -> u32 {
10
fn default_connection_string() -> String {
"postgresql://".to_string()
}
fn default_max_connections() -> NonZeroU32 {
NonZeroU32::new(10).unwrap()
}
fn default_connect_timeout() -> Duration {
@ -58,31 +63,38 @@ impl Default for DatabaseConfig {
}
}
fn duration_schema(gen: &mut SchemaGenerator) -> Schema {
Option::<u64>::json_schema(gen)
}
fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
u64::json_schema(gen)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(untagged)]
enum ConnectConfig {
Uri {
/// Connection URI
#[schemars(url, default = "default_connection_string")]
uri: String,
},
Options {
/// Name of host to connect to
#[schemars(schema_with = "schema::hostname")]
#[serde(default)]
host: Option<String>,
/// Port number to connect at the server host
#[schemars(schema_with = "schema::port")]
#[serde(default)]
port: Option<u16>,
/// Directory containing the UNIX socket to connect to
#[serde(default)]
socket: Option<PathBuf>,
/// PostgreSQL user name to connect as
#[serde(default)]
username: Option<String>,
/// Password to be used if the server demands password authentication
#[serde(default)]
password: Option<String>,
/// The database name
#[serde(default)]
database: Option<String>,
/* TODO
@ -141,41 +153,49 @@ impl TryInto<PgConnectOptions> for &ConnectConfig {
impl Default for ConnectConfig {
fn default() -> Self {
Self::Uri {
uri: "postgresql://".to_string(),
uri: default_connection_string(),
}
}
}
/// Database connection configuration
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DatabaseConfig {
/// Options related to how to connect to the database
#[serde(default, flatten)]
options: ConnectConfig,
/// Set the maximum number of connections the pool should maintain
#[serde(default = "default_max_connections")]
max_connections: u32,
max_connections: NonZeroU32,
/// Set the minimum number of connections the pool should maintain
#[serde(default)]
min_connections: u32,
#[schemars(schema_with = "duration_schema")]
/// Set the amount of time to attempt connecting to the database
#[schemars(with = "u64")]
#[serde(default = "default_connect_timeout")]
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
connect_timeout: Duration,
#[schemars(schema_with = "optional_duration_schema")]
/// Set a maximum idle duration for individual connections
#[schemars(with = "Option<u64>")]
#[serde(default = "default_idle_timeout")]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
idle_timeout: Option<Duration>,
#[schemars(schema_with = "optional_duration_schema")]
/// Set the maximum lifetime of individual connections
#[schemars(with = "u64")]
#[serde(default = "default_max_lifetime")]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
max_lifetime: Option<Duration>,
}
impl DatabaseConfig {
/// Connect to the database
#[tracing::instrument(err, skip_all)]
pub async fn connect(&self) -> anyhow::Result<PgPool> {
let mut options: PgConnectOptions = (&self.options)
@ -187,7 +207,7 @@ impl DatabaseConfig {
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
PgPoolOptions::new()
.max_connections(self.max_connections)
.max_connections(self.max_connections.into())
.min_connections(self.min_connections)
.connect_timeout(self.connect_timeout)
.idle_timeout(self.idle_timeout)

View File

@ -12,46 +12,90 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::num::NonZeroU16;
use async_trait::async_trait;
use lettre::{message::Mailbox, Address};
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use schemars::{
gen::SchemaGenerator,
schema::{InstanceType, Schema, SchemaObject},
JsonSchema,
};
use serde::{Deserialize, Serialize};
use super::ConfigurationSection;
fn mailbox_schema(gen: &mut SchemaGenerator) -> Schema {
// TODO: proper email schema
String::json_schema(gen)
fn mailbox_schema(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some("email".to_string()),
..SchemaObject::default()
})
}
fn hostname_schema(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some("hostname".to_string()),
..SchemaObject::default()
})
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Credentials {
/// Username for use to authenticate when connecting to the SMTP server
pub username: String,
/// Password for use to authenticate when connecting to the SMTP server
pub password: String,
}
/// Encryption mode to use
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum EmailSmtpMode {
/// Plain text
Plain,
/// StartTLS (starts as plain text then upgrade to TLS)
StartTls,
/// TLS
Tls,
}
/// What backend should be used when sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "transport", rename_all = "snake_case")]
pub enum EmailTransportConfig {
/// Don't send emails anywhere
Blackhole,
/// Send emails via an SMTP relay
Smtp {
/// Connection mode to the relay
mode: EmailSmtpMode,
/// Hostname to connect to
#[schemars(schema_with = "hostname_schema")]
hostname: String,
#[serde(default)]
port: Option<u16>,
/// Port to connect to. Default is 25 for plain, 465 for TLS and 587 for
/// StartTLS
#[serde(default, skip_serializing_if = "Option::is_none")]
port: Option<NonZeroU16>,
/// Set of credentials to use
#[serde(flatten, default)]
credentials: Option<Credentials>,
},
/// Send emails by calling sendmail
Sendmail {
/// Command to execute
#[serde(default = "default_sendmail_command")]
command: String,
},
/// Send emails via the AWS SESv2 API
AwsSes,
}
@ -61,25 +105,38 @@ impl Default for EmailTransportConfig {
}
}
fn default_email() -> Mailbox {
let address = Address::new("root", "localhost").unwrap();
Mailbox::new(Some("Authentication Service".to_string()), address)
}
fn default_sendmail_command() -> String {
"sendmail".to_string()
}
/// Configuration related to sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct EmailConfig {
/// Email address to use as From when sending emails
#[serde(default = "default_email")]
#[schemars(schema_with = "mailbox_schema")]
pub from: Mailbox,
/// Email address to use as Reply-To when sending emails
#[serde(default = "default_email")]
#[schemars(schema_with = "mailbox_schema")]
pub reply_to: Mailbox,
#[serde(flatten)]
/// What backend should be used when sending emails
#[serde(flatten, default)]
pub transport: EmailTransportConfig,
}
impl Default for EmailConfig {
fn default() -> Self {
let address = Address::new("root", "localhost").unwrap();
let mailbox = Mailbox::new(Some("Authentication Service".to_string()), address);
Self {
from: mailbox.clone(),
reply_to: mailbox,
from: default_email(),
reply_to: default_email(),
transport: EmailTransportConfig::Blackhole,
}
}

View File

@ -24,11 +24,34 @@ fn default_http_address() -> String {
"[::]:8080".into()
}
fn http_address_example_1() -> &'static str {
"[::1]:8080"
}
fn http_address_example_2() -> &'static str {
"[::]:8080"
}
fn http_address_example_3() -> &'static str {
"127.0.0.1:8080"
}
fn http_address_example_4() -> &'static str {
"0.0.0.0:8080"
}
/// Configuration related to the web server
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct HttpConfig {
/// IP and port the server should listen to
#[schemars(
example = "http_address_example_1",
example = "http_address_example_2",
example = "http_address_example_3",
example = "http_address_example_4"
)]
#[serde(default = "default_http_address")]
pub address: String,
/// Path from which to serve static files. If not specified, it will serve
/// the static files embedded in the server binary
#[serde(default)]
pub web_root: Option<PathBuf>,
}

View File

@ -20,6 +20,8 @@
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::missing_errors_doc)]
//! Application configuration logic
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -30,6 +32,7 @@ mod database;
mod email;
mod http;
mod oauth2;
pub(crate) mod schema;
mod telemetry;
mod templates;
mod util;
@ -49,27 +52,36 @@ pub use self::{
util::ConfigurationSection,
};
/// Application configuration root
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RootConfig {
/// Configuration related to OAuth 2.0/OIDC operations
pub oauth2: OAuth2Config,
/// Configuration of the HTTP server
#[serde(default)]
pub http: HttpConfig,
/// Database connection configuration
#[serde(default)]
pub database: DatabaseConfig,
/// Configuration related to cookies
pub cookies: CookiesConfig,
/// Configuration related to sending monitoring data
#[serde(default)]
pub telemetry: TelemetryConfig,
/// Configuration related to templates
#[serde(default)]
pub templates: TemplatesConfig,
/// Configuration related to Cross-Site Request Forgery protections
#[serde(default)]
pub csrf: CsrfConfig,
/// Configuration related to sending emails
#[serde(default)]
pub email: EmailConfig,
}

View File

@ -0,0 +1,38 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use schemars::{
gen::SchemaGenerator,
schema::{InstanceType, NumberValidation, Schema, SchemaObject},
};
pub fn port(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::Integer.into()),
number: Some(Box::new(NumberValidation {
minimum: Some(1.0),
maximum: Some(65535.0),
..NumberValidation::default()
})),
..SchemaObject::default()
})
}
pub fn hostname(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some("hostname".to_string()),
..SchemaObject::default()
})
}

View File

@ -22,32 +22,72 @@ use url::Url;
use super::ConfigurationSection;
/// Propagation format for incoming and outgoing requests
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum Propagator {
/// Propagate according to the W3C Trace Context specification
TraceContext,
/// Propagate according to the W3C Baggage specification
Baggage,
/// Propagate trace context with Jaeger compatible headers
Jaeger,
/// Propagate trace context with Zipkin compatible headers (single `b3`
/// header variant)
B3,
/// Propagate trace context with Zipkin compatible headers (multiple
/// `x-b3-*` headers variant)
B3Multi,
}
fn otlp_endpoint_example() -> &'static str {
"https://localhost:4317"
}
fn jaeger_agent_endpoint_example() -> &'static str {
"127.0.0.1:6831"
}
fn zipkin_collector_endpoint_example() -> &'static str {
"http://127.0.0.1:9411/api/v2/spans"
}
/// Exporter to use when exporting traces
#[skip_serializing_none]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "exporter", rename_all = "lowercase")]
pub enum TracingExporterConfig {
/// Don't export traces
None,
/// Export traces to the standard output. Only useful for debugging
Stdout,
/// Export traces to an OpenTelemetry protocol compatible endpoint
Otlp {
/// OTLP compatible endpoint
#[schemars(url, example = "otlp_endpoint_example")]
#[serde(default)]
endpoint: Option<Url>,
},
/// Export traces to a Jaeger agent
Jaeger {
/// Jaeger agent endpoint
#[schemars(example = "jaeger_agent_endpoint_example")]
#[serde(default)]
agent_endpoint: Option<SocketAddr>,
},
/// Export traces to a Zipkin collector
Zipkin {
/// Zipkin collector endpoint
#[schemars(url, example = "zipkin_collector_endpoint_example")]
#[serde(default)]
collector_endpoint: Option<Url>,
},
@ -59,23 +99,34 @@ impl Default for TracingExporterConfig {
}
}
/// Configuration related to exporting traces
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct TracingConfig {
/// Exporter to use when exporting traces
#[serde(default, flatten)]
pub exporter: TracingExporterConfig,
/// List of propagation formats to use for incoming and outgoing requests
pub propagators: Vec<Propagator>,
}
/// Exporter to use when exporting metrics
#[skip_serializing_none]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "exporter", rename_all = "lowercase")]
pub enum MetricsExporterConfig {
/// Don't export metrics
None,
/// Export metrics to stdout. Only useful for debugging
Stdout,
/// Export metrics to an OpenTelemetry protocol compatible endpoint
Otlp {
/// OTLP compatible endpoint
#[schemars(url, example = "otlp_endpoint_example")]
#[serde(default)]
endpoint: Option<url::Url>,
endpoint: Option<Url>,
},
}
@ -85,17 +136,22 @@ impl Default for MetricsExporterConfig {
}
}
/// Configuration related to exporting metrics
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct MetricsConfig {
/// Exporter to use when exporting metrics
#[serde(default, flatten)]
pub exporter: MetricsExporterConfig,
}
/// Configuration related to sending monitoring data
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct TelemetryConfig {
/// Configuration related to exporting traces
#[serde(default)]
pub tracing: TracingConfig,
/// Configuration related to exporting metrics
#[serde(default)]
pub metrics: MetricsConfig,
}

View File

@ -22,6 +22,7 @@ fn default_builtin() -> bool {
true
}
/// Configuration related to templates
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TemplatesConfig {
/// Path to the folder that holds the custom templates

View File

@ -19,4 +19,4 @@ aws-config = "0.5.2"
[dependencies.lettre]
version = "0.10.0-rc.4"
default-features = false
features = ["tokio1-rustls-tls", "hostname", "builder", "tracing", "pool", "smtp-transport"]
features = ["tokio1-rustls-tls", "hostname", "builder", "tracing", "pool", "smtp-transport", "sendmail-transport"]

View File

@ -17,7 +17,10 @@ use std::sync::Arc;
use async_trait::async_trait;
use lettre::{
address::Envelope,
transport::smtp::{authentication::Credentials, AsyncSmtpTransport},
transport::{
sendmail::AsyncSendmailTransport,
smtp::{authentication::Credentials, AsyncSmtpTransport},
},
AsyncTransport, Tokio1Executor,
};
use mas_config::{EmailSmtpMode, EmailTransportConfig};
@ -32,6 +35,7 @@ pub struct Transport {
enum TransportInner {
Blackhole,
Smtp(AsyncSmtpTransport<Tokio1Executor>),
Sendmail(AsyncSendmailTransport<Tokio1Executor>),
AwsSes(aws_ses::Transport),
}
@ -63,11 +67,14 @@ impl Transport {
}
if let Some(port) = port {
t = t.port(*port);
t = t.port((*port).into());
}
TransportInner::Smtp(t.build())
}
EmailTransportConfig::Sendmail { command } => {
TransportInner::Sendmail(AsyncSendmailTransport::new_with_command(command))
}
EmailTransportConfig::AwsSes => {
TransportInner::AwsSes(aws_ses::Transport::from_env().await)
}
@ -84,6 +91,7 @@ impl Transport {
TransportInner::Smtp(t) => {
t.test_connection().await?;
}
&TransportInner::Sendmail(_) => {}
TransportInner::AwsSes(_) => {}
}
@ -113,6 +121,9 @@ impl AsyncTransport for Transport {
TransportInner::Smtp(t) => {
t.send_raw(envelope, email).await?;
}
TransportInner::Sendmail(t) => {
t.send_raw(envelope, email).await?;
}
TransportInner::AwsSes(t) => {
t.send_raw(envelope, email).await?;
}