diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index b8d6c6a5..c76ec7bc 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -26,11 +26,12 @@ use mas_handlers::HttpClientFactory; use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ - compat::{CompatAccessTokenRepository, CompatSessionRepository}, + compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository}, job::{ DeactivateUserJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob, }, - user::{UserEmailRepository, UserPasswordRepository, UserRepository}, + oauth2::OAuth2SessionFilter, + user::{BrowserSessionFilter, UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -348,83 +349,43 @@ impl Options { .await? .context("User not found")?; - let compat_sessions_ids: Vec = sqlx::query_scalar( - r" - SELECT compat_session_id FROM compat_sessions - WHERE user_id = $1 AND finished_at IS NULL - ", - ) - .bind(Uuid::from(user.id)) - .fetch_all(&mut **repo) - .await?; + let filter = CompatSessionFilter::new().for_user(&user).active_only(); + let affected = if dry_run { + repo.compat_session().count(filter).await? + } else { + repo.compat_session().finish_bulk(&clock, filter).await? + }; - for id in compat_sessions_ids { - let id = id.into(); - let compat_session = repo - .compat_session() - .lookup(id) - .await? - .context("Session not found")?; - info!(%compat_session.id, %compat_session.device, "Killing compat session"); - - if dry_run { - continue; - } + match affected { + 0 => info!("No active compatibility sessions to end"), + 1 => info!("Ended 1 active compatibility session"), + _ => info!("Ended {affected} active compatibility sessions"), } - let oauth2_sessions_ids: Vec = sqlx::query_scalar( - r" - SELECT oauth2_sessions.oauth2_session_id - FROM oauth2_sessions - INNER JOIN user_sessions USING (user_session_id) - WHERE user_sessions.user_id = $1 AND oauth2_sessions.finished_at IS NULL - ", - ) - .bind(Uuid::from(user.id)) - .fetch_all(&mut **repo) - .await?; + let filter = OAuth2SessionFilter::new().for_user(&user).active_only(); + let affected = if dry_run { + repo.oauth2_session().count(filter).await? + } else { + repo.oauth2_session().finish_bulk(&clock, filter).await? + }; - for id in oauth2_sessions_ids { - let id = id.into(); - let oauth2_session = repo - .oauth2_session() - .lookup(id) - .await? - .context("Session not found")?; - info!(%oauth2_session.id, %oauth2_session.scope, "Killing oauth2 session"); + match affected { + 0 => info!("No active compatibility sessions to end"), + 1 => info!("Ended 1 active OAuth 2.0 session"), + _ => info!("Ended {affected} active OAuth 2.0 sessions"), + }; - if dry_run { - continue; - } - repo.oauth2_session().finish(&clock, oauth2_session).await?; - } + let filter = BrowserSessionFilter::new().for_user(&user).active_only(); + let affected = if dry_run { + repo.browser_session().count(filter).await? + } else { + repo.browser_session().finish_bulk(&clock, filter).await? + }; - let user_sessions_ids: Vec = sqlx::query_scalar( - r" - SELECT user_session_id FROM user_sessions - WHERE user_id = $1 AND finished_at IS NULL - ", - ) - .bind(Uuid::from(user.id)) - .fetch_all(&mut **repo) - .await?; - - for id in user_sessions_ids { - let id = id.into(); - let browser_session = repo - .browser_session() - .lookup(id) - .await? - .context("Session not found")?; - info!(%browser_session.id, "Killing browser session"); - - if dry_run { - continue; - } - - repo.browser_session() - .finish(&clock, browser_session) - .await?; + match affected { + 0 => info!("No active browser sessions to end"), + 1 => info!("Ended 1 active browser session"), + _ => info!("Ended {affected} active browser sessions"), } // Schedule a job to sync the devices of the user with the homeserver