You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
@ -17,13 +17,14 @@ use std::collections::HashSet;
|
||||
use clap::Parser;
|
||||
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess, SystemClock,
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::SeedableRng;
|
||||
use sqlx::{postgres::PgAdvisoryLock, Acquire};
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::util::database_from_config;
|
||||
use crate::util::database_connection_from_config;
|
||||
|
||||
fn map_import_preference(
|
||||
config: &mas_config::UpstreamOAuth2ImportPreference,
|
||||
@ -144,8 +145,18 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
|
||||
|
||||
let config: SyncConfig = root.load_config()?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
// Grab a connection to the database
|
||||
let mut conn = database_connection_from_config(&config.database).await?;
|
||||
// Start a transaction
|
||||
let txn = conn.begin().await?;
|
||||
|
||||
// Grab a lock within the transaction
|
||||
tracing::info!("Acquiring config lock");
|
||||
let lock = PgAdvisoryLock::new("MAS config sync");
|
||||
let lock = lock.acquire(txn).await?;
|
||||
|
||||
// Create a repository from the connection with the lock
|
||||
let mut repo = PgRepository::from_conn(lock);
|
||||
|
||||
tracing::info!(
|
||||
prune,
|
||||
@ -284,11 +295,14 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
|
||||
}
|
||||
}
|
||||
|
||||
// Get the lock and release it to commit the transaction
|
||||
let lock = repo.into_inner();
|
||||
let txn = lock.release_now().await?;
|
||||
if dry_run {
|
||||
info!("Dry run, rolling back changes");
|
||||
repo.cancel().await?;
|
||||
txn.rollback().await?;
|
||||
} else {
|
||||
repo.save().await?;
|
||||
txn.commit().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ use mas_config::DatabaseConfig;
|
||||
use mas_storage_pg::MIGRATOR;
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
use crate::util::database_from_config;
|
||||
use crate::util::database_connection_from_config;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
@ -36,11 +36,11 @@ impl Options {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
let _span = info_span!("cli.database.migrate").entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut conn = database_connection_from_config(&config).await?;
|
||||
|
||||
// Run pending migrations
|
||||
MIGRATOR
|
||||
.run(&pool)
|
||||
.run(&mut conn)
|
||||
.instrument(info_span!("db.migrate"))
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
@ -20,14 +20,14 @@ use mas_storage::{
|
||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
Repository, RepositoryAccess, SystemClock,
|
||||
RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::SeedableRng;
|
||||
use sqlx::types::Uuid;
|
||||
use sqlx::{types::Uuid, Acquire};
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::util::{database_from_config, password_manager_from_config};
|
||||
use crate::util::{database_connection_from_config, password_manager_from_config};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
@ -103,10 +103,11 @@ impl Options {
|
||||
let database_config: DatabaseConfig = root.load_config()?;
|
||||
let passwords_config: PasswordsConfig = root.load_config()?;
|
||||
|
||||
let pool = database_from_config(&database_config).await?;
|
||||
let mut conn = database_connection_from_config(&database_config).await?;
|
||||
let password_manager = password_manager_from_config(&passwords_config).await?;
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(&username)
|
||||
@ -122,7 +123,7 @@ impl Options {
|
||||
.await?;
|
||||
|
||||
info!(%user.id, %user.username, "Password changed");
|
||||
repo.save().await?;
|
||||
repo.into_inner().commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -135,9 +136,10 @@ impl Options {
|
||||
)
|
||||
.entered();
|
||||
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let database_config: DatabaseConfig = root.load_config()?;
|
||||
let mut conn = database_connection_from_config(&database_config).await?;
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
@ -152,7 +154,7 @@ impl Options {
|
||||
.context("Email not found")?;
|
||||
let email = repo.user_email().mark_as_verified(&clock, email).await?;
|
||||
|
||||
repo.save().await?;
|
||||
repo.into_inner().commit().await?;
|
||||
info!(?email, "Email marked as verified");
|
||||
|
||||
Ok(())
|
||||
@ -163,9 +165,10 @@ impl Options {
|
||||
admin,
|
||||
device_id,
|
||||
} => {
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let database_config: DatabaseConfig = root.load_config()?;
|
||||
let mut conn = database_connection_from_config(&database_config).await?;
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
@ -191,7 +194,7 @@ impl Options {
|
||||
.add(&mut rng, &clock, &compat_session, token, None)
|
||||
.await?;
|
||||
|
||||
repo.save().await?;
|
||||
repo.into_inner().commit().await?;
|
||||
|
||||
info!(
|
||||
%compat_access_token.id,
|
||||
@ -207,16 +210,16 @@ impl Options {
|
||||
|
||||
SC::ProvisionAllUsers => {
|
||||
let _span = info_span!("cli.manage.provision_all_users").entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let database_config: DatabaseConfig = root.load_config()?;
|
||||
let mut conn = database_connection_from_config(&database_config).await?;
|
||||
let mut txn = conn.begin().await?;
|
||||
|
||||
// TODO: do some pagination here
|
||||
let ids: Vec<Uuid> = sqlx::query_scalar("SELECT user_id FROM users")
|
||||
.fetch_all(&mut *conn)
|
||||
.fetch_all(&mut *txn)
|
||||
.await?;
|
||||
drop(conn);
|
||||
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
|
||||
for id in ids {
|
||||
let id = id.into();
|
||||
@ -225,7 +228,7 @@ impl Options {
|
||||
repo.job().schedule_job(job).await?;
|
||||
}
|
||||
|
||||
repo.save().await?;
|
||||
repo.into_inner().commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -233,10 +236,10 @@ impl Options {
|
||||
SC::KillSessions { username, dry_run } => {
|
||||
let _span =
|
||||
info_span!("cli.manage.kill_sessions", user.username = username).entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let database_config: DatabaseConfig = root.load_config()?;
|
||||
let mut conn = database_connection_from_config(&database_config).await?;
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
@ -251,7 +254,7 @@ impl Options {
|
||||
"#,
|
||||
)
|
||||
.bind(Uuid::from(user.id))
|
||||
.fetch_all(&mut *conn)
|
||||
.fetch_all(&mut **repo)
|
||||
.await?;
|
||||
|
||||
for id in compat_sessions_ids {
|
||||
@ -281,7 +284,7 @@ impl Options {
|
||||
"#,
|
||||
)
|
||||
.bind(Uuid::from(user.id))
|
||||
.fetch_all(&mut *conn)
|
||||
.fetch_all(&mut **repo)
|
||||
.await?;
|
||||
|
||||
for id in oauth2_sessions_ids {
|
||||
@ -316,7 +319,7 @@ impl Options {
|
||||
"#,
|
||||
)
|
||||
.bind(Uuid::from(user.id))
|
||||
.fetch_all(&mut *conn)
|
||||
.fetch_all(&mut **repo)
|
||||
.await?;
|
||||
|
||||
for id in user_sessions_ids {
|
||||
@ -337,11 +340,12 @@ impl Options {
|
||||
.await?;
|
||||
}
|
||||
|
||||
let txn = repo.into_inner();
|
||||
if dry_run {
|
||||
info!("Dry run, not saving");
|
||||
repo.cancel().await?;
|
||||
txn.rollback().await?;
|
||||
} else {
|
||||
repo.save().await?;
|
||||
txn.commit().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -353,8 +357,9 @@ impl Options {
|
||||
} => {
|
||||
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let mut conn = database_connection_from_config(&config).await?;
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
@ -375,7 +380,8 @@ impl Options {
|
||||
.schedule_job(DeactivateUserJob::new(&user, false))
|
||||
.await?;
|
||||
}
|
||||
repo.save().await?;
|
||||
|
||||
repo.into_inner().commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -383,8 +389,9 @@ impl Options {
|
||||
SC::UnlockUser { username } => {
|
||||
let _span = info_span!("cli.manage.lock_user", user.username = username).entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let mut conn = database_connection_from_config(&config).await?;
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
@ -395,7 +402,7 @@ impl Options {
|
||||
info!(%user.id, "Unlocking user");
|
||||
|
||||
repo.user().unlock(user).await?;
|
||||
repo.save().await?;
|
||||
repo.into_inner().commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ use tokio::signal::unix::SignalKind;
|
||||
use tracing::{info, info_span, warn, Instrument};
|
||||
|
||||
use crate::util::{
|
||||
database_from_config, mailer_from_config, password_manager_from_config,
|
||||
database_pool_from_config, mailer_from_config, password_manager_from_config,
|
||||
policy_factory_from_config, register_sighup, templates_from_config,
|
||||
};
|
||||
|
||||
@ -59,7 +59,7 @@ impl Options {
|
||||
|
||||
// Connect to the database
|
||||
info!("Connecting to the database");
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let pool = database_pool_from_config(&config.database).await?;
|
||||
|
||||
if self.migrate {
|
||||
info!("Running pending migrations");
|
||||
|
@ -23,7 +23,7 @@ use rand::{
|
||||
};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use crate::util::{database_from_config, mailer_from_config, templates_from_config};
|
||||
use crate::util::{database_pool_from_config, mailer_from_config, templates_from_config};
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
pub(super) struct Options {}
|
||||
@ -35,7 +35,7 @@ impl Options {
|
||||
|
||||
// Connect to the database
|
||||
info!("Connecting to the database");
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let pool = database_pool_from_config(&config.database).await?;
|
||||
|
||||
let url_builder = UrlBuilder::new(
|
||||
config.http.public_base.clone(),
|
||||
|
@ -26,7 +26,7 @@ use mas_router::UrlBuilder;
|
||||
use mas_templates::{TemplateLoadingError, Templates};
|
||||
use sqlx::{
|
||||
postgres::{PgConnectOptions, PgPoolOptions},
|
||||
ConnectOptions, PgPool,
|
||||
ConnectOptions, PgConnection, PgPool,
|
||||
};
|
||||
use tracing::{error, info, log::LevelFilter};
|
||||
|
||||
@ -120,8 +120,9 @@ pub async fn templates_from_config(
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
pub async fn database_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
|
||||
fn database_connect_options_from_config(
|
||||
config: &DatabaseConfig,
|
||||
) -> Result<PgConnectOptions, anyhow::Error> {
|
||||
let options = match &config.options {
|
||||
DatabaseConnectConfig::Uri { uri } => uri
|
||||
.parse()
|
||||
@ -169,6 +170,13 @@ pub async fn database_from_config(config: &DatabaseConfig) -> Result<PgPool, any
|
||||
.log_statements(LevelFilter::Debug)
|
||||
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
/// Create a database connection pool from the configuration
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
pub async fn database_pool_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
|
||||
let options = database_connect_options_from_config(config)?;
|
||||
PgPoolOptions::new()
|
||||
.max_connections(config.max_connections.into())
|
||||
.min_connections(config.min_connections)
|
||||
@ -180,6 +188,17 @@ pub async fn database_from_config(config: &DatabaseConfig) -> Result<PgPool, any
|
||||
.context("could not connect to the database")
|
||||
}
|
||||
|
||||
/// Create a single database connection from the configuration
|
||||
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
|
||||
pub async fn database_connection_from_config(
|
||||
config: &DatabaseConfig,
|
||||
) -> Result<PgConnection, anyhow::Error> {
|
||||
database_connect_options_from_config(config)?
|
||||
.connect()
|
||||
.await
|
||||
.context("could not connect to the database")
|
||||
}
|
||||
|
||||
/// Reload templates on SIGHUP
|
||||
pub fn register_sighup(templates: &Templates) -> anyhow::Result<()> {
|
||||
#[cfg(unix)]
|
||||
|
Reference in New Issue
Block a user