From 9289922dfb73e7fe185c3378ed321388ddfcf655 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 25 Aug 2023 15:39:31 +0200 Subject: [PATCH] Grab a database lock when syncing the config Fixes #1475 --- crates/cli/src/commands/config.rs | 26 ++++++-- crates/cli/src/commands/database.rs | 6 +- crates/cli/src/commands/manage.rs | 79 +++++++++++++----------- crates/cli/src/commands/server.rs | 4 +- crates/cli/src/commands/worker.rs | 4 +- crates/cli/src/util.rs | 25 +++++++- crates/storage-pg/src/repository.rs | 94 +++++++++++++++++++++-------- 7 files changed, 162 insertions(+), 76 deletions(-) diff --git a/crates/cli/src/commands/config.rs b/crates/cli/src/commands/config.rs index 3dfab39d..dea8b6de 100644 --- a/crates/cli/src/commands/config.rs +++ b/crates/cli/src/commands/config.rs @@ -17,13 +17,14 @@ use std::collections::HashSet; use clap::Parser; use mas_config::{ConfigurationSection, RootConfig, SyncConfig}; use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess, SystemClock, + upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; use rand::SeedableRng; +use sqlx::{postgres::PgAdvisoryLock, Acquire}; use tracing::{info, info_span, warn}; -use crate::util::database_from_config; +use crate::util::database_connection_from_config; fn map_import_preference( config: &mas_config::UpstreamOAuth2ImportPreference, @@ -144,8 +145,18 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu let config: SyncConfig = root.load_config()?; let encrypter = config.secrets.encrypter(); - let pool = database_from_config(&config.database).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + // Grab a connection to the database + let mut conn = database_connection_from_config(&config.database).await?; + // Start a transaction + let txn = conn.begin().await?; + + // Grab a lock within the transaction + tracing::info!("Acquiring config lock"); + let lock = PgAdvisoryLock::new("MAS config sync"); + let lock = lock.acquire(txn).await?; + + // Create a repository from the connection with the lock + let mut repo = PgRepository::from_conn(lock); tracing::info!( prune, @@ -284,11 +295,14 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu } } + // Get the lock and release it to commit the transaction + let lock = repo.into_inner(); + let txn = lock.release_now().await?; if dry_run { info!("Dry run, rolling back changes"); - repo.cancel().await?; + txn.rollback().await?; } else { - repo.save().await?; + txn.commit().await?; } Ok(()) } diff --git a/crates/cli/src/commands/database.rs b/crates/cli/src/commands/database.rs index 0277a09d..9c84535f 100644 --- a/crates/cli/src/commands/database.rs +++ b/crates/cli/src/commands/database.rs @@ -18,7 +18,7 @@ use mas_config::DatabaseConfig; use mas_storage_pg::MIGRATOR; use tracing::{info_span, Instrument}; -use crate::util::database_from_config; +use crate::util::database_connection_from_config; #[derive(Parser, Debug)] pub(super) struct Options { @@ -36,11 +36,11 @@ impl Options { pub async fn run(self, root: &super::Options) -> anyhow::Result<()> { let _span = info_span!("cli.database.migrate").entered(); let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; + let mut conn = database_connection_from_config(&config).await?; // Run pending migrations MIGRATOR - .run(&pool) + .run(&mut conn) .instrument(info_span!("db.migrate")) .await .context("could not run migrations")?; diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 4d41ee82..4da5577e 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -20,14 +20,14 @@ use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - Repository, RepositoryAccess, SystemClock, + RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; use rand::SeedableRng; -use sqlx::types::Uuid; +use sqlx::{types::Uuid, Acquire}; use tracing::{info, info_span, warn}; -use crate::util::{database_from_config, password_manager_from_config}; +use crate::util::{database_connection_from_config, password_manager_from_config}; #[derive(Parser, Debug)] pub(super) struct Options { @@ -103,10 +103,11 @@ impl Options { let database_config: DatabaseConfig = root.load_config()?; let passwords_config: PasswordsConfig = root.load_config()?; - let pool = database_from_config(&database_config).await?; + let mut conn = database_connection_from_config(&database_config).await?; let password_manager = password_manager_from_config(&passwords_config).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); let user = repo .user() .find_by_username(&username) @@ -122,7 +123,7 @@ impl Options { .await?; info!(%user.id, %user.username, "Password changed"); - repo.save().await?; + repo.into_inner().commit().await?; Ok(()) } @@ -135,9 +136,10 @@ impl Options { ) .entered(); - let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let database_config: DatabaseConfig = root.load_config()?; + let mut conn = database_connection_from_config(&database_config).await?; + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); let user = repo .user() @@ -152,7 +154,7 @@ impl Options { .context("Email not found")?; let email = repo.user_email().mark_as_verified(&clock, email).await?; - repo.save().await?; + repo.into_inner().commit().await?; info!(?email, "Email marked as verified"); Ok(()) @@ -163,9 +165,10 @@ impl Options { admin, device_id, } => { - let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let database_config: DatabaseConfig = root.load_config()?; + let mut conn = database_connection_from_config(&database_config).await?; + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); let user = repo .user() @@ -191,7 +194,7 @@ impl Options { .add(&mut rng, &clock, &compat_session, token, None) .await?; - repo.save().await?; + repo.into_inner().commit().await?; info!( %compat_access_token.id, @@ -207,16 +210,16 @@ impl Options { SC::ProvisionAllUsers => { let _span = info_span!("cli.manage.provision_all_users").entered(); - let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; - let mut conn = pool.acquire().await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let database_config: DatabaseConfig = root.load_config()?; + let mut conn = database_connection_from_config(&database_config).await?; + let mut txn = conn.begin().await?; // TODO: do some pagination here let ids: Vec = sqlx::query_scalar("SELECT user_id FROM users") - .fetch_all(&mut *conn) + .fetch_all(&mut *txn) .await?; - drop(conn); + + let mut repo = PgRepository::from_conn(txn); for id in ids { let id = id.into(); @@ -225,7 +228,7 @@ impl Options { repo.job().schedule_job(job).await?; } - repo.save().await?; + repo.into_inner().commit().await?; Ok(()) } @@ -233,10 +236,10 @@ impl Options { SC::KillSessions { username, dry_run } => { let _span = info_span!("cli.manage.kill_sessions", user.username = username).entered(); - let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; - let mut conn = pool.acquire().await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let database_config: DatabaseConfig = root.load_config()?; + let mut conn = database_connection_from_config(&database_config).await?; + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); let user = repo .user() @@ -251,7 +254,7 @@ impl Options { "#, ) .bind(Uuid::from(user.id)) - .fetch_all(&mut *conn) + .fetch_all(&mut **repo) .await?; for id in compat_sessions_ids { @@ -281,7 +284,7 @@ impl Options { "#, ) .bind(Uuid::from(user.id)) - .fetch_all(&mut *conn) + .fetch_all(&mut **repo) .await?; for id in oauth2_sessions_ids { @@ -316,7 +319,7 @@ impl Options { "#, ) .bind(Uuid::from(user.id)) - .fetch_all(&mut *conn) + .fetch_all(&mut **repo) .await?; for id in user_sessions_ids { @@ -337,11 +340,12 @@ impl Options { .await?; } + let txn = repo.into_inner(); if dry_run { info!("Dry run, not saving"); - repo.cancel().await?; + txn.rollback().await?; } else { - repo.save().await?; + txn.commit().await?; } Ok(()) @@ -353,8 +357,9 @@ impl Options { } => { let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let mut conn = database_connection_from_config(&config).await?; + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); let user = repo .user() @@ -375,7 +380,8 @@ impl Options { .schedule_job(DeactivateUserJob::new(&user, false)) .await?; } - repo.save().await?; + + repo.into_inner().commit().await?; Ok(()) } @@ -383,8 +389,9 @@ impl Options { SC::UnlockUser { username } => { let _span = info_span!("cli.manage.lock_user", user.username = username).entered(); let config: DatabaseConfig = root.load_config()?; - let pool = database_from_config(&config).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + let mut conn = database_connection_from_config(&config).await?; + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); let user = repo .user() @@ -395,7 +402,7 @@ impl Options { info!(%user.id, "Unlocking user"); repo.user().unlock(user).await?; - repo.save().await?; + repo.into_inner().commit().await?; Ok(()) } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 45416ba2..fa9b7179 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -31,7 +31,7 @@ use tokio::signal::unix::SignalKind; use tracing::{info, info_span, warn, Instrument}; use crate::util::{ - database_from_config, mailer_from_config, password_manager_from_config, + database_pool_from_config, mailer_from_config, password_manager_from_config, policy_factory_from_config, register_sighup, templates_from_config, }; @@ -59,7 +59,7 @@ impl Options { // Connect to the database info!("Connecting to the database"); - let pool = database_from_config(&config.database).await?; + let pool = database_pool_from_config(&config.database).await?; if self.migrate { info!("Running pending migrations"); diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index dbdd4af4..e55dc27f 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -23,7 +23,7 @@ use rand::{ }; use tracing::{info, info_span}; -use crate::util::{database_from_config, mailer_from_config, templates_from_config}; +use crate::util::{database_pool_from_config, mailer_from_config, templates_from_config}; #[derive(Parser, Debug, Default)] pub(super) struct Options {} @@ -35,7 +35,7 @@ impl Options { // Connect to the database info!("Connecting to the database"); - let pool = database_from_config(&config.database).await?; + let pool = database_pool_from_config(&config.database).await?; let url_builder = UrlBuilder::new( config.http.public_base.clone(), diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 20a1a84b..5eb934ba 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -26,7 +26,7 @@ use mas_router::UrlBuilder; use mas_templates::{TemplateLoadingError, Templates}; use sqlx::{ postgres::{PgConnectOptions, PgPoolOptions}, - ConnectOptions, PgPool, + ConnectOptions, PgConnection, PgPool, }; use tracing::{error, info, log::LevelFilter}; @@ -120,8 +120,9 @@ pub async fn templates_from_config( .await } -#[tracing::instrument(name = "db.connect", skip_all, err(Debug))] -pub async fn database_from_config(config: &DatabaseConfig) -> Result { +fn database_connect_options_from_config( + config: &DatabaseConfig, +) -> Result { let options = match &config.options { DatabaseConnectConfig::Uri { uri } => uri .parse() @@ -169,6 +170,13 @@ pub async fn database_from_config(config: &DatabaseConfig) -> Result Result { + let options = database_connect_options_from_config(config)?; PgPoolOptions::new() .max_connections(config.max_connections.into()) .min_connections(config.min_connections) @@ -180,6 +188,17 @@ pub async fn database_from_config(config: &DatabaseConfig) -> Result Result { + database_connect_options_from_config(config)? + .connect() + .await + .context("could not connect to the database") +} + /// Reload templates on SIGHUP pub fn register_sighup(templates: &Templates) -> anyhow::Result<()> { #[cfg(unix)] diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index 3e749d97..c53723e7 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::{Deref, DerefMut}; + use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use mas_storage::{ compat::{ @@ -30,7 +32,7 @@ use mas_storage::{ user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, Repository, RepositoryAccess, RepositoryTransaction, }; -use sqlx::{PgPool, Postgres, Transaction}; +use sqlx::{PgConnection, PgPool, Postgres, Transaction}; use tracing::Instrument; use crate::{ @@ -56,8 +58,8 @@ use crate::{ /// An implementation of the [`Repository`] trait backed by a PostgreSQL /// transaction. -pub struct PgRepository { - txn: Transaction<'static, Postgres>, +pub struct PgRepository> { + conn: C, } impl PgRepository { @@ -69,7 +71,46 @@ impl PgRepository { /// Returns a [`DatabaseError`] if the transaction could not be started. pub async fn from_pool(pool: &PgPool) -> Result { let txn = pool.begin().await?; - Ok(PgRepository { txn }) + Ok(Self::from_conn(txn)) + } +} + +impl PgRepository { + /// Create a new [`PgRepository`] from an existing PostgreSQL connection + /// with a transaction + pub fn from_conn(conn: C) -> Self { + PgRepository { conn } + } + + /// Consume this [`PgRepository`], returning the underlying connection. + pub fn into_inner(self) -> C { + self.conn + } +} + +impl AsRef for PgRepository { + fn as_ref(&self) -> &C { + &self.conn + } +} + +impl AsMut for PgRepository { + fn as_mut(&mut self) -> &mut C { + &mut self.conn + } +} + +impl Deref for PgRepository { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.conn + } +} + +impl DerefMut for PgRepository { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.conn } } @@ -80,7 +121,7 @@ impl RepositoryTransaction for PgRepository { fn save(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { let span = tracing::info_span!("db.save"); - self.txn + self.conn .commit() .map_err(DatabaseError::from) .instrument(span) @@ -89,7 +130,7 @@ impl RepositoryTransaction for PgRepository { fn cancel(self: Box) -> BoxFuture<'static, Result<(), Self::Error>> { let span = tracing::info_span!("db.cancel"); - self.txn + self.conn .rollback() .map_err(DatabaseError::from) .instrument(span) @@ -97,102 +138,107 @@ impl RepositoryTransaction for PgRepository { } } -impl RepositoryAccess for PgRepository { +impl RepositoryAccess for PgRepository +where + C: AsMut + Send, +{ type Error = DatabaseError; fn upstream_oauth_link<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgUpstreamOAuthLinkRepository::new(&mut self.txn)) + Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut())) } fn upstream_oauth_provider<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgUpstreamOAuthProviderRepository::new(&mut self.txn)) + Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut())) } fn upstream_oauth_session<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgUpstreamOAuthSessionRepository::new(&mut self.txn)) + Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut())) } fn user<'c>(&'c mut self) -> Box + 'c> { - Box::new(PgUserRepository::new(&mut self.txn)) + Box::new(PgUserRepository::new(self.conn.as_mut())) } fn user_email<'c>(&'c mut self) -> Box + 'c> { - Box::new(PgUserEmailRepository::new(&mut self.txn)) + Box::new(PgUserEmailRepository::new(self.conn.as_mut())) } fn user_password<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgUserPasswordRepository::new(&mut self.txn)) + Box::new(PgUserPasswordRepository::new(self.conn.as_mut())) } fn browser_session<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgBrowserSessionRepository::new(&mut self.txn)) + Box::new(PgBrowserSessionRepository::new(self.conn.as_mut())) } fn oauth2_client<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgOAuth2ClientRepository::new(&mut self.txn)) + Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut())) } fn oauth2_authorization_grant<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgOAuth2AuthorizationGrantRepository::new(&mut self.txn)) + Box::new(PgOAuth2AuthorizationGrantRepository::new( + self.conn.as_mut(), + )) } fn oauth2_session<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgOAuth2SessionRepository::new(&mut self.txn)) + Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut())) } fn oauth2_access_token<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgOAuth2AccessTokenRepository::new(&mut self.txn)) + Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut())) } fn oauth2_refresh_token<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgOAuth2RefreshTokenRepository::new(&mut self.txn)) + Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut())) } fn compat_session<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgCompatSessionRepository::new(&mut self.txn)) + Box::new(PgCompatSessionRepository::new(self.conn.as_mut())) } fn compat_sso_login<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgCompatSsoLoginRepository::new(&mut self.txn)) + Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut())) } fn compat_access_token<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgCompatAccessTokenRepository::new(&mut self.txn)) + Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut())) } fn compat_refresh_token<'c>( &'c mut self, ) -> Box + 'c> { - Box::new(PgCompatRefreshTokenRepository::new(&mut self.txn)) + Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut())) } fn job<'c>(&'c mut self) -> Box + 'c> { - Box::new(PgJobRepository::new(&mut self.txn)) + Box::new(PgJobRepository::new(self.conn.as_mut())) } }