From 3eab10672f0bad7b384524cb3ad03cab107c4141 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 15 Jul 2024 13:39:58 +0200 Subject: [PATCH] Add a lock during syncs of user devices --- crates/handlers/src/compat/login.rs | 3 ++ .../handlers/src/compat/login_sso_complete.rs | 3 ++ .../src/graphql/mutations/oauth2_session.rs | 3 ++ crates/handlers/src/oauth2/token.rs | 10 ++++++ ...c1bd60bb771c6f075df15ab0137a7ffc896da.json | 22 ++++++++++++ crates/storage-pg/src/user/mod.rs | 34 +++++++++++++++++++ crates/storage/src/user/mod.rs | 14 ++++++++ crates/tasks/src/matrix.rs | 10 ++++-- 8 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index da8756ae..3e83a260 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -419,6 +419,9 @@ async fn user_password_login( .await?; } + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); let mxid = homeserver.mxid(&user.username); diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 361a568b..79eaa299 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -202,6 +202,9 @@ pub async fn post( redirect_uri }; + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&session.user).await?; + let device = Device::generate(&mut rng); let mxid = homeserver.mxid(&session.user.username); homeserver diff --git a/crates/handlers/src/graphql/mutations/oauth2_session.rs b/crates/handlers/src/graphql/mutations/oauth2_session.rs index c8c17805..f7fc6c3a 100644 --- a/crates/handlers/src/graphql/mutations/oauth2_session.rs +++ b/crates/handlers/src/graphql/mutations/oauth2_session.rs @@ -168,6 +168,9 @@ impl OAuth2SessionMutations { .add(&mut rng, &clock, &client, Some(&user), None, scope) .await?; + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + // Look for devices to provision let mxid = homeserver.mxid(&user.username); for scope in &*session.scope { diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 0f27edba..8ef50837 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -461,6 +461,11 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } + // Lock the user sync to make sure we don't get into a race condition + repo.user() + .acquire_lock_for_sync(&browser_session.user) + .await?; + // Look for device to provision let mxid = homeserver.mxid(&browser_session.user.username); for scope in &*session.scope { @@ -748,6 +753,11 @@ async fn device_code_grant( params = params.with_id_token(id_token); } + // Lock the user sync to make sure we don't get into a race condition + repo.user() + .acquire_lock_for_sync(&browser_session.user) + .await?; + // Look for device to provision let mxid = homeserver.mxid(&browser_session.user.username); for scope in &*session.scope { diff --git a/crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json b/crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json new file mode 100644 index 00000000..aa173785 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT pg_advisory_xact_lock($1)\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "pg_advisory_xact_lock", + "type_info": "Void" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da" +} diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index a4675f2f..37718282 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -437,4 +437,38 @@ impl<'c> UserRepository for PgUserRepository<'c> { .try_into() .map_err(DatabaseError::to_invalid_operation) } + + #[tracing::instrument( + name = "db.user.acquire_lock_for_sync", + skip_all, + fields( + db.statement, + user.id = %user.id, + ), + err, + )] + async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> { + // XXX: this lock isn't stictly scoped to users, but as we don't use many + // postgres advisory locks, it's fine for now. Later on, we could use row-level + // locks to make sure we don't get into trouble + + // Convert the user ID to a u128 and grab the lower 64 bits + // As this includes 64bit of the random part of the ULID, it should be random + // enough to not collide + let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64; + + // Use a PG advisory lock, which will be released when the transaction is + // committed or rolled back + sqlx::query!( + r#" + SELECT pg_advisory_xact_lock($1) + "#, + lock_id, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } } diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 9763afed..6e696c06 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -259,6 +259,19 @@ pub trait UserRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, filter: UserFilter<'_>) -> Result; + + /// Acquire a lock on the user to make sure device operations are done in a + /// sequential way. The lock is released when the repository is saved or + /// rolled back. + /// + /// # Parameters + /// + /// * `user`: The user to lock + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error>; } repository_impl!(UserRepository: @@ -284,4 +297,5 @@ repository_impl!(UserRepository: pagination: Pagination, ) -> Result, Self::Error>; async fn count(&mut self, filter: UserFilter<'_>) -> Result; + async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error>; ); diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index c1c07481..7f10ed09 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -170,6 +170,9 @@ async fn sync_devices( .await? .context("User not found")?; + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + let mut devices = HashSet::new(); // Cycle through all the compat sessions of the user, and grab the devices @@ -219,12 +222,13 @@ async fn sync_devices( } } - // We now have a complete list of devices, we can now release the connection and - // sync with the homeserver - repo.save().await?; let mxid = matrix.mxid(&user.username); matrix.sync_devices(&mxid, devices).await?; + // We kept the connection until now, so that we still hold the lock on the user + // throughout the sync + repo.save().await?; + Ok(()) }