1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Allow splitting database connection options

This commit is contained in:
Quentin Gliech
2021-09-17 12:03:00 +02:00
parent 789ace84fd
commit bd441ceef7
2 changed files with 102 additions and 23 deletions

View File

@ -7,7 +7,7 @@ license = "Apache-2.0"
[dependencies]
tokio = { version = "1.11.0", features = [] }
tracing = "0.1.27"
tracing = { version = "0.1.27", features = ["log"] }
async-trait = "0.1.51"
thiserror = "1.0.29"

View File

@ -12,24 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use std::{convert::TryInto, path::PathBuf, time::Duration};
use anyhow::Context;
use async_trait::async_trait;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none};
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
use sqlx::{
postgres::{PgConnectOptions, PgPool, PgPoolOptions},
ConnectOptions,
};
use tracing::log::LevelFilter;
// FIXME
// use sqlx::ConnectOptions
// use tracing::log::LevelFilter;
use super::ConfigurationSection;
fn default_uri() -> String {
"postgresql://".to_string()
}
fn default_max_connections() -> u32 {
10
}
@ -51,7 +48,7 @@ fn default_max_lifetime() -> Option<Duration> {
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
uri: default_uri(),
options: ConnectConfig::default(),
max_connections: default_max_connections(),
min_connections: Default::default(),
connect_timeout: default_connect_timeout(),
@ -69,12 +66,92 @@ fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
u64::json_schema(gen)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(untagged)]
enum ConnectConfig {
Uri {
uri: String,
},
Options {
#[serde(default)]
host: Option<String>,
#[serde(default)]
port: Option<u16>,
#[serde(default)]
socket: Option<PathBuf>,
#[serde(default)]
username: Option<String>,
#[serde(default)]
password: Option<String>,
#[serde(default)]
database: Option<String>,
/* TODO
* ssl_mode: PgSslMode,
* ssl_root_cert: Option<CertificateInput>, */
},
}
impl TryInto<PgConnectOptions> for &ConnectConfig {
type Error = sqlx::Error;
fn try_into(self) -> Result<PgConnectOptions, Self::Error> {
match self {
ConnectConfig::Uri { uri } => uri.parse(),
ConnectConfig::Options {
host,
port,
socket,
username,
password,
database,
} => {
let mut opts =
PgConnectOptions::new().application_name("matrix-authentication-service");
if let Some(host) = host {
opts = opts.host(host);
}
if let Some(port) = port {
opts = opts.port(*port);
}
if let Some(socket) = socket {
opts = opts.socket(socket);
}
if let Some(username) = username {
opts = opts.username(username);
}
if let Some(password) = password {
opts = opts.password(password);
}
if let Some(database) = database {
opts = opts.database(database);
}
Ok(opts)
}
}
}
}
impl Default for ConnectConfig {
fn default() -> Self {
Self::Uri {
uri: "postgresql://".to_string(),
}
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DatabaseConfig {
#[serde(default = "default_uri")]
uri: String,
#[serde(default, flatten)]
options: ConnectConfig,
#[serde(default = "default_max_connections")]
max_connections: u32,
@ -101,16 +178,13 @@ pub struct DatabaseConfig {
impl DatabaseConfig {
#[tracing::instrument(err)]
pub async fn connect(&self) -> anyhow::Result<PgPool> {
let options = self
.uri
.parse::<PgConnectOptions>()
.context("invalid database URL")?
.application_name("matrix-authentication-service");
let mut options: PgConnectOptions = (&self.options)
.try_into()
.context("invalid database config")?;
// FIXME
// options
// .log_statements(LevelFilter::Debug)
// .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
options
.log_statements(LevelFilter::Debug)
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
PgPoolOptions::new()
.max_connections(self.max_connections)
@ -158,7 +232,12 @@ mod tests {
let config = DatabaseConfig::load_from_file("config.yaml")?;
assert_eq!(config.uri, "postgresql://user:password@host/database");
assert_eq!(
config.options,
ConnectConfig::Uri {
uri: "postgresql://user:password@host/database".to_string()
}
);
Ok(())
});