You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Allow splitting database connection options
This commit is contained in:
@ -7,7 +7,7 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.11.0", features = [] }
|
||||
tracing = "0.1.27"
|
||||
tracing = { version = "0.1.27", features = ["log"] }
|
||||
async-trait = "0.1.51"
|
||||
|
||||
thiserror = "1.0.29"
|
||||
|
@ -12,24 +12,21 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
use std::{convert::TryInto, path::PathBuf, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, skip_serializing_none};
|
||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
||||
use sqlx::{
|
||||
postgres::{PgConnectOptions, PgPool, PgPoolOptions},
|
||||
ConnectOptions,
|
||||
};
|
||||
use tracing::log::LevelFilter;
|
||||
|
||||
// FIXME
|
||||
// use sqlx::ConnectOptions
|
||||
// use tracing::log::LevelFilter;
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_uri() -> String {
|
||||
"postgresql://".to_string()
|
||||
}
|
||||
|
||||
fn default_max_connections() -> u32 {
|
||||
10
|
||||
}
|
||||
@ -51,7 +48,7 @@ fn default_max_lifetime() -> Option<Duration> {
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
uri: default_uri(),
|
||||
options: ConnectConfig::default(),
|
||||
max_connections: default_max_connections(),
|
||||
min_connections: Default::default(),
|
||||
connect_timeout: default_connect_timeout(),
|
||||
@ -69,12 +66,92 @@ fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
u64::json_schema(gen)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
enum ConnectConfig {
|
||||
Uri {
|
||||
uri: String,
|
||||
},
|
||||
Options {
|
||||
#[serde(default)]
|
||||
host: Option<String>,
|
||||
#[serde(default)]
|
||||
port: Option<u16>,
|
||||
#[serde(default)]
|
||||
socket: Option<PathBuf>,
|
||||
#[serde(default)]
|
||||
username: Option<String>,
|
||||
#[serde(default)]
|
||||
password: Option<String>,
|
||||
#[serde(default)]
|
||||
database: Option<String>,
|
||||
/* TODO
|
||||
* ssl_mode: PgSslMode,
|
||||
* ssl_root_cert: Option<CertificateInput>, */
|
||||
},
|
||||
}
|
||||
|
||||
impl TryInto<PgConnectOptions> for &ConnectConfig {
|
||||
type Error = sqlx::Error;
|
||||
|
||||
fn try_into(self) -> Result<PgConnectOptions, Self::Error> {
|
||||
match self {
|
||||
ConnectConfig::Uri { uri } => uri.parse(),
|
||||
ConnectConfig::Options {
|
||||
host,
|
||||
port,
|
||||
socket,
|
||||
username,
|
||||
password,
|
||||
database,
|
||||
} => {
|
||||
let mut opts =
|
||||
PgConnectOptions::new().application_name("matrix-authentication-service");
|
||||
|
||||
if let Some(host) = host {
|
||||
opts = opts.host(host);
|
||||
}
|
||||
|
||||
if let Some(port) = port {
|
||||
opts = opts.port(*port);
|
||||
}
|
||||
|
||||
if let Some(socket) = socket {
|
||||
opts = opts.socket(socket);
|
||||
}
|
||||
|
||||
if let Some(username) = username {
|
||||
opts = opts.username(username);
|
||||
}
|
||||
|
||||
if let Some(password) = password {
|
||||
opts = opts.password(password);
|
||||
}
|
||||
|
||||
if let Some(database) = database {
|
||||
opts = opts.database(database);
|
||||
}
|
||||
|
||||
Ok(opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnectConfig {
|
||||
fn default() -> Self {
|
||||
Self::Uri {
|
||||
uri: "postgresql://".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DatabaseConfig {
|
||||
#[serde(default = "default_uri")]
|
||||
uri: String,
|
||||
#[serde(default, flatten)]
|
||||
options: ConnectConfig,
|
||||
|
||||
#[serde(default = "default_max_connections")]
|
||||
max_connections: u32,
|
||||
@ -101,16 +178,13 @@ pub struct DatabaseConfig {
|
||||
impl DatabaseConfig {
|
||||
#[tracing::instrument(err)]
|
||||
pub async fn connect(&self) -> anyhow::Result<PgPool> {
|
||||
let options = self
|
||||
.uri
|
||||
.parse::<PgConnectOptions>()
|
||||
.context("invalid database URL")?
|
||||
.application_name("matrix-authentication-service");
|
||||
let mut options: PgConnectOptions = (&self.options)
|
||||
.try_into()
|
||||
.context("invalid database config")?;
|
||||
|
||||
// FIXME
|
||||
// options
|
||||
// .log_statements(LevelFilter::Debug)
|
||||
// .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
||||
options
|
||||
.log_statements(LevelFilter::Debug)
|
||||
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
||||
|
||||
PgPoolOptions::new()
|
||||
.max_connections(self.max_connections)
|
||||
@ -158,7 +232,12 @@ mod tests {
|
||||
|
||||
let config = DatabaseConfig::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.uri, "postgresql://user:password@host/database");
|
||||
assert_eq!(
|
||||
config.options,
|
||||
ConnectConfig::Uri {
|
||||
uri: "postgresql://user:password@host/database".to_string()
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
Reference in New Issue
Block a user