diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index e861b9f0..50aa7250 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -27,7 +27,7 @@ use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, + job::{DeactivateUserJob, JobRepositoryExt, ProvisionUserJob, SyncDevicesJob}, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, }; @@ -368,10 +368,6 @@ impl Options { if dry_run { continue; } - - let job = DeleteDeviceJob::new(&user, &compat_session.device); - repo.job().schedule_job(job).await?; - repo.compat_session().finish(&clock, compat_session).await?; } let oauth2_sessions_ids: Vec = sqlx::query_scalar( @@ -398,16 +394,6 @@ impl Options { if dry_run { continue; } - - for scope in &*oauth2_session.scope { - if let Some(device) = Device::from_scope_token(scope) { - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &device)) - .await?; - } - } - repo.oauth2_session().finish(&clock, oauth2_session).await?; } @@ -439,6 +425,10 @@ impl Options { .await?; } + // Schedule a job to sync the devices of the user with the homeserver + warn!("Scheduling job to sync devices for the user"); + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + let txn = repo.into_inner(); if dry_run { info!("Dry run, not saving"); diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 55fb2066..76000aef 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -20,7 +20,7 @@ use mas_axum_utils::sentry::SentryEventID; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::{DeleteDeviceJob, JobRepositoryExt}, + job::{JobRepositoryExt, SyncDevicesJob}, BoxClock, BoxRepository, Clock, RepositoryAccess, }; use thiserror::Error; @@ -111,9 +111,8 @@ pub(crate) async fn post( // XXX: this is probably not the right error .ok_or(RouteError::InvalidAuthorization)?; - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &session.device)) - .await?; + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; repo.compat_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/compat_session.rs b/crates/handlers/src/graphql/mutations/compat_session.rs index 49c57735..9c8034b9 100644 --- a/crates/handlers/src/graphql/mutations/compat_session.rs +++ b/crates/handlers/src/graphql/mutations/compat_session.rs @@ -16,7 +16,7 @@ use anyhow::Context as _; use async_graphql::{Context, Enum, InputObject, Object, ID}; use mas_storage::{ compat::CompatSessionRepository, - job::{DeleteDeviceJob, JobRepositoryExt}, + job::{JobRepositoryExt, SyncDevicesJob}, RepositoryAccess, }; @@ -101,10 +101,8 @@ impl CompatSessionMutations { .await? .context("Could not load user")?; - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &session.device)) - .await?; + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; let session = repo.compat_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/oauth2_session.rs b/crates/handlers/src/graphql/mutations/oauth2_session.rs index 0acc137b..69a2dcb6 100644 --- a/crates/handlers/src/graphql/mutations/oauth2_session.rs +++ b/crates/handlers/src/graphql/mutations/oauth2_session.rs @@ -17,7 +17,7 @@ use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use chrono::Duration; use mas_data_model::{Device, TokenType}; use mas_storage::{ - job::{DeleteDeviceJob, JobRepositoryExt, ProvisionDeviceJob}, + job::{JobRepositoryExt, ProvisionDeviceJob, SyncDevicesJob}, oauth2::{ OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, @@ -236,20 +236,8 @@ impl OAuth2SessionMutations { .await? .context("Could not load user")?; - // Scan the scopes of the session to find if there is any device that should be - // deleted from the Matrix server. - // TODO: this should be moved in a higher level "end oauth session" method. - // XXX: this might not be the right semantic, but it's the best we - // can do for now, since we're not explicitly storing devices for OAuth2 - // sessions. - for scope in &*session.scope { - if let Some(device) = Device::from_scope_token(scope) { - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &device)) - .await?; - } - } + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; } let session = repo.oauth2_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index ff725df3..21d33023 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -19,11 +19,11 @@ use mas_axum_utils::{ http_client_factory::HttpClientFactory, sentry::SentryEventID, }; -use mas_data_model::{Device, TokenType}; +use mas_data_model::TokenType; use mas_iana::oauth::OAuthTokenTypeHint; use mas_keystore::Encrypter; use mas_storage::{ - job::{DeleteDeviceJob, JobRepositoryExt}, + job::{JobRepositoryExt, SyncDevicesJob}, BoxClock, BoxRepository, RepositoryAccess, }; use oauth2_types::{ @@ -217,20 +217,8 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UnknownToken)?; - // Scan the scopes of the session to find if there is any device that should be - // deleted from the Matrix server. - // TODO: this should be moved in a higher level "end oauth session" method. - // XXX: this might not be the right semantic, but it's the best we - // can do for now, since we're not explicitly storing devices for OAuth2 - // sessions. - for scope in &*session.scope { - if let Some(device) = Device::from_scope_token(scope) { - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &device)) - .await?; - } - } + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; } // Now that we checked everything, we can end the session. diff --git a/crates/matrix-synapse/src/lib.rs b/crates/matrix-synapse/src/lib.rs index 46be1b2e..77a3c711 100644 --- a/crates/matrix-synapse/src/lib.rs +++ b/crates/matrix-synapse/src/lib.rs @@ -14,6 +14,8 @@ #![allow(clippy::blocks_in_conditions)] +use std::collections::HashSet; + use anyhow::{bail, Context}; use http::{header::AUTHORIZATION, request::Builder, Method, Request, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; @@ -131,9 +133,19 @@ struct SynapseUser { external_ids: Option>, } +#[derive(Deserialize)] +struct SynapseDeviceListResponse { + devices: Vec, +} + +#[derive(Serialize, Deserialize)] +struct SynapseDevice { + device_id: String, +} + #[derive(Serialize)] -struct SynapseDevice<'a> { - device_id: &'a str, +struct SynapseDeleteDevicesRequest { + devices: Vec, } #[derive(Serialize)] @@ -356,7 +368,9 @@ impl HomeserverConnection for SynapseConnection { let request = self .post(&format!("_synapse/admin/v2/users/{mxid}/devices")) - .body(SynapseDevice { device_id })?; + .body(SynapseDevice { + device_id: device_id.to_owned(), + })?; let response = client .ready() @@ -411,6 +425,82 @@ impl HomeserverConnection for SynapseConnection { Ok(()) } + #[tracing::instrument( + name = "homeserver.sync_devices", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = mxid, + ), + err(Debug), + )] + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + // Get the list of current devices + let mxid_url = urlencoding::encode(mxid); + let mut client = self + .http_client_factory + .client("homeserver.sync_devices.query") + .response_body_to_bytes() + .catch_http_errors(catch_homeserver_error) + .json_response(); + + let request = self + .get(&format!("_synapse/admin/v2/users/{mxid_url}/devices")) + .body(EmptyBody::new())?; + + let response = client + .ready() + .await? + .call(request) + .await + .context("Failed to query user from Synapse")?; + + if response.status() != StatusCode::OK { + return Err(anyhow::anyhow!("Failed to query user devices from Synapse")); + } + + let body: SynapseDeviceListResponse = response.into_body(); + + let existing_devices: HashSet = + body.devices.into_iter().map(|d| d.device_id).collect(); + + // First, delete all the devices that are not needed anymore + let to_delete = existing_devices.difference(&devices).cloned().collect(); + + let mut client = self + .http_client_factory + .client("homeserver.sync_devices.delete") + .response_body_to_bytes() + .catch_http_errors(catch_homeserver_error) + .request_bytes_to_body() + .json_request(); + + let request = self + .post(&format!( + "_synapse/admin/v2/users/{mxid_url}/delete_devices" + )) + .body(SynapseDeleteDevicesRequest { devices: to_delete })?; + + let response = client + .ready() + .await? + .call(request) + .await + .context("Failed to query user from Synapse")?; + + if response.status() != StatusCode::OK { + return Err(anyhow::anyhow!("Failed to delete devices from Synapse")); + } + + // Then, create the devices that are missing. There is no batching API to do + // this, so we do this sequentially, which is fine as the API is idempotent. + for device_id in devices.difference(&existing_devices) { + self.create_device(mxid, device_id).await?; + } + + Ok(()) + } + #[tracing::instrument( name = "homeserver.delete_user", skip_all, diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index 26b921fa..6b416092 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -14,7 +14,7 @@ mod mock; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; pub use self::mock::HomeserverConnection as MockHomeserverConnection; @@ -262,6 +262,19 @@ pub trait HomeserverConnection: Send + Sync { /// not be deleted. async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>; + /// Sync the list of devices of a user with the homeserver. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID of the user to sync the devices for. + /// * `devices` - The list of devices to sync. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the devices could + /// not be synced. + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error>; + /// Delete a user on the homeserver. /// /// # Parameters @@ -341,6 +354,10 @@ impl HomeserverConnection for &T (**self).delete_device(mxid, device_id).await } + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + (**self).sync_devices(mxid, devices).await + } + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { (**self).delete_user(mxid, erase).await } @@ -387,6 +404,10 @@ impl HomeserverConnection for Arc { (**self).delete_device(mxid, device_id).await } + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + (**self).sync_devices(mxid, devices).await + } + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { (**self).delete_user(mxid, erase).await } diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs index 7a67c550..d7f0421e 100644 --- a/crates/matrix/src/mock.rs +++ b/crates/matrix/src/mock.rs @@ -128,6 +128,13 @@ impl crate::HomeserverConnection for HomeserverConnection { Ok(()) } + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices = devices; + Ok(()) + } + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index 599cdca0..b8895004 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -394,6 +394,31 @@ mod jobs { const NAME: &'static str = "delete-device"; } + /// A job which syncs the list of devices of a user with the homeserver + #[derive(Serialize, Deserialize, Debug, Clone)] + pub struct SyncDevicesJob { + user_id: Ulid, + } + + impl SyncDevicesJob { + /// Create a new job to sync the list of devices of a user with the + /// homeserver + #[must_use] + pub fn new(user: &User) -> Self { + Self { user_id: user.id } + } + + /// The ID of the user to sync the devices for + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + } + + impl Job for SyncDevicesJob { + const NAME: &'static str = "sync-devices"; + } + /// A job to deactivate and lock a user #[derive(Serialize, Deserialize, Debug, Clone)] pub struct DeactivateUserJob { @@ -468,5 +493,5 @@ mod jobs { pub use self::jobs::{ DeactivateUserJob, DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, - SendAccountRecoveryEmailsJob, VerifyEmailJob, + SendAccountRecoveryEmailsJob, SyncDevicesJob, VerifyEmailJob, }; diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index a3c99309..d085236e 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -12,13 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; + use anyhow::Context; use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use mas_data_model::Device; use mas_matrix::ProvisionRequest; use mas_storage::{ - job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob}, + compat::CompatSessionFilter, + job::{ + DeleteDeviceJob, JobRepositoryExt as _, JobWithSpanContext, ProvisionDeviceJob, + ProvisionUserJob, SyncDevicesJob, + }, + oauth2::OAuth2SessionFilter, user::{UserEmailRepository, UserRepository}, - RepositoryAccess, + Pagination, RepositoryAccess, }; use tracing::info; @@ -56,9 +64,6 @@ async fn provision_user( .filter(|email| email.confirmed_at.is_some()) .map(|email| email.email) .collect(); - - repo.cancel().await?; - let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); if let Some(display_name) = job.display_name_to_set() { @@ -73,6 +78,12 @@ async fn provision_user( info!(%user.id, %mxid, "User updated"); } + // Schedule a device sync job + let sync_device_job = SyncDevicesJob::new(&user); + repo.job().schedule_job(sync_device_job).await?; + + repo.save().await?; + Ok(()) } @@ -144,6 +155,84 @@ async fn delete_device( Ok(()) } +/// Job to sync the list of devices of a user with the homeserver. +#[tracing::instrument( + name = "job.sync_devices", + fields(user.id = %job.user_id()), + skip_all, + err(Debug), +)] +async fn sync_devices( + job: JobWithSpanContext, + ctx: JobContext, +) -> Result<(), anyhow::Error> { + let state = ctx.state(); + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(job.user_id()) + .await? + .context("User not found")?; + + let mut devices = HashSet::new(); + + // Cycle through all the compat sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .compat_session() + .list( + CompatSessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for (compat_session, _) in page.edges { + devices.insert(compat_session.device.as_str().to_owned()); + cursor = cursor.after(compat_session.id); + } + + if !page.has_next_page { + break; + } + } + + // Cycle though all the oauth2 sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .oauth2_session() + .list( + OAuth2SessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for oauth2_session in page.edges { + for scope in &*oauth2_session.scope { + if let Some(device) = Device::from_scope_token(scope) { + devices.insert(device.as_str().to_owned()); + } + } + + cursor = cursor.after(oauth2_session.id); + } + + if !page.has_next_page { + break; + } + } + + // We now have a complete list of devices, so we can sync them with the + // homeserver + let mxid = matrix.mxid(&user.username); + matrix.sync_devices(&mxid, devices).await?; + + Ok(()) +} + pub(crate) fn register( suffix: &str, monitor: Monitor, @@ -156,9 +245,12 @@ pub(crate) fn register( crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); let delete_device_worker = crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); + let sync_devices_worker = + crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory); monitor .register(provision_user_worker) .register(provision_device_worker) .register(delete_device_worker) + .register(sync_devices_worker) }