You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Allow splitting database connection options
This commit is contained in:
@ -7,7 +7,7 @@ license = "Apache-2.0"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { version = "1.11.0", features = [] }
|
tokio = { version = "1.11.0", features = [] }
|
||||||
tracing = "0.1.27"
|
tracing = { version = "0.1.27", features = ["log"] }
|
||||||
async-trait = "0.1.51"
|
async-trait = "0.1.51"
|
||||||
|
|
||||||
thiserror = "1.0.29"
|
thiserror = "1.0.29"
|
||||||
|
@ -12,24 +12,21 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::{convert::TryInto, path::PathBuf, time::Duration};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_with::{serde_as, skip_serializing_none};
|
use serde_with::{serde_as, skip_serializing_none};
|
||||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
use sqlx::{
|
||||||
|
postgres::{PgConnectOptions, PgPool, PgPoolOptions},
|
||||||
|
ConnectOptions,
|
||||||
|
};
|
||||||
|
use tracing::log::LevelFilter;
|
||||||
|
|
||||||
// FIXME
|
|
||||||
// use sqlx::ConnectOptions
|
|
||||||
// use tracing::log::LevelFilter;
|
|
||||||
use super::ConfigurationSection;
|
use super::ConfigurationSection;
|
||||||
|
|
||||||
fn default_uri() -> String {
|
|
||||||
"postgresql://".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_connections() -> u32 {
|
fn default_max_connections() -> u32 {
|
||||||
10
|
10
|
||||||
}
|
}
|
||||||
@ -51,7 +48,7 @@ fn default_max_lifetime() -> Option<Duration> {
|
|||||||
impl Default for DatabaseConfig {
|
impl Default for DatabaseConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
uri: default_uri(),
|
options: ConnectConfig::default(),
|
||||||
max_connections: default_max_connections(),
|
max_connections: default_max_connections(),
|
||||||
min_connections: Default::default(),
|
min_connections: Default::default(),
|
||||||
connect_timeout: default_connect_timeout(),
|
connect_timeout: default_connect_timeout(),
|
||||||
@ -69,12 +66,92 @@ fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
|
|||||||
u64::json_schema(gen)
|
u64::json_schema(gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum ConnectConfig {
|
||||||
|
Uri {
|
||||||
|
uri: String,
|
||||||
|
},
|
||||||
|
Options {
|
||||||
|
#[serde(default)]
|
||||||
|
host: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
port: Option<u16>,
|
||||||
|
#[serde(default)]
|
||||||
|
socket: Option<PathBuf>,
|
||||||
|
#[serde(default)]
|
||||||
|
username: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
password: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
database: Option<String>,
|
||||||
|
/* TODO
|
||||||
|
* ssl_mode: PgSslMode,
|
||||||
|
* ssl_root_cert: Option<CertificateInput>, */
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryInto<PgConnectOptions> for &ConnectConfig {
|
||||||
|
type Error = sqlx::Error;
|
||||||
|
|
||||||
|
fn try_into(self) -> Result<PgConnectOptions, Self::Error> {
|
||||||
|
match self {
|
||||||
|
ConnectConfig::Uri { uri } => uri.parse(),
|
||||||
|
ConnectConfig::Options {
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
socket,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
database,
|
||||||
|
} => {
|
||||||
|
let mut opts =
|
||||||
|
PgConnectOptions::new().application_name("matrix-authentication-service");
|
||||||
|
|
||||||
|
if let Some(host) = host {
|
||||||
|
opts = opts.host(host);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(port) = port {
|
||||||
|
opts = opts.port(*port);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(socket) = socket {
|
||||||
|
opts = opts.socket(socket);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(username) = username {
|
||||||
|
opts = opts.username(username);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(password) = password {
|
||||||
|
opts = opts.password(password);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(database) = database {
|
||||||
|
opts = opts.database(database);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConnectConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Uri {
|
||||||
|
uri: "postgresql://".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[serde_as]
|
#[serde_as]
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
pub struct DatabaseConfig {
|
pub struct DatabaseConfig {
|
||||||
#[serde(default = "default_uri")]
|
#[serde(default, flatten)]
|
||||||
uri: String,
|
options: ConnectConfig,
|
||||||
|
|
||||||
#[serde(default = "default_max_connections")]
|
#[serde(default = "default_max_connections")]
|
||||||
max_connections: u32,
|
max_connections: u32,
|
||||||
@ -101,16 +178,13 @@ pub struct DatabaseConfig {
|
|||||||
impl DatabaseConfig {
|
impl DatabaseConfig {
|
||||||
#[tracing::instrument(err)]
|
#[tracing::instrument(err)]
|
||||||
pub async fn connect(&self) -> anyhow::Result<PgPool> {
|
pub async fn connect(&self) -> anyhow::Result<PgPool> {
|
||||||
let options = self
|
let mut options: PgConnectOptions = (&self.options)
|
||||||
.uri
|
.try_into()
|
||||||
.parse::<PgConnectOptions>()
|
.context("invalid database config")?;
|
||||||
.context("invalid database URL")?
|
|
||||||
.application_name("matrix-authentication-service");
|
|
||||||
|
|
||||||
// FIXME
|
options
|
||||||
// options
|
.log_statements(LevelFilter::Debug)
|
||||||
// .log_statements(LevelFilter::Debug)
|
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
||||||
// .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
|
||||||
|
|
||||||
PgPoolOptions::new()
|
PgPoolOptions::new()
|
||||||
.max_connections(self.max_connections)
|
.max_connections(self.max_connections)
|
||||||
@ -158,7 +232,12 @@ mod tests {
|
|||||||
|
|
||||||
let config = DatabaseConfig::load_from_file("config.yaml")?;
|
let config = DatabaseConfig::load_from_file("config.yaml")?;
|
||||||
|
|
||||||
assert_eq!(config.uri, "postgresql://user:password@host/database");
|
assert_eq!(
|
||||||
|
config.options,
|
||||||
|
ConnectConfig::Uri {
|
||||||
|
uri: "postgresql://user:password@host/database".to_string()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
});
|
});
|
||||||
|
Reference in New Issue
Block a user