diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index f95d7edf..5f37ac93 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,7 +18,7 @@ use mas_config::{DatabaseConfig, PasswordsConfig}; use mas_data_model::{Device, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::{JobRepositoryExt, ProvisionUserJob}, + job::{DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Repository, RepositoryAccess, SystemClock, }; @@ -59,6 +59,16 @@ enum Subcommand { /// Trigger a provisioning job for all users ProvisionAllUsers, + + /// Kill all sessions for a user + KillSessions { + /// User for which to kill sessions + username: String, + + /// Do a dry run + #[arg(long)] + dry_run: bool, + }, } impl Options { @@ -203,6 +213,123 @@ impl Options { Ok(()) } + + SC::KillSessions { username, dry_run } => { + let _span = + info_span!("cli.manage.kill_sessions", user.username = username).entered(); + let config: DatabaseConfig = root.load_config()?; + let pool = database_from_config(&config).await?; + let mut conn = pool.acquire().await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + + let user = repo + .user() + .find_by_username(&username) + .await? + .context("User not found")?; + + let compat_sessions_ids: Vec = sqlx::query_scalar( + r#" + SELECT compat_session_id FROM compat_sessions + WHERE user_id = $1 AND finished_at IS NULL + "#, + ) + .bind(Uuid::from(user.id)) + .fetch_all(&mut conn) + .await?; + + for id in compat_sessions_ids { + let id = id.into(); + let compat_session = repo + .compat_session() + .lookup(id) + .await? + .context("Session not found")?; + info!(%compat_session.id, %compat_session.device, "Killing compat session"); + + if dry_run { + continue; + } + + let job = DeleteDeviceJob::new(&user, &compat_session.device); + repo.job().schedule_job(job).await?; + repo.compat_session().finish(&clock, compat_session).await?; + } + + let oauth2_sessions_ids: Vec = sqlx::query_scalar( + r#" + SELECT oauth2_sessions.oauth2_session_id + FROM oauth2_sessions + INNER JOIN user_sessions USING (user_session_id) + WHERE user_sessions.user_id = $1 AND oauth2_sessions.finished_at IS NULL + "#, + ) + .bind(Uuid::from(user.id)) + .fetch_all(&mut conn) + .await?; + + for id in oauth2_sessions_ids { + let id = id.into(); + let oauth2_session = repo + .oauth2_session() + .lookup(id) + .await? + .context("Session not found")?; + info!(%oauth2_session.id, %oauth2_session.scope, "Killing oauth2 session"); + + if dry_run { + continue; + } + + for scope in oauth2_session.scope.iter() { + if let Some(device) = Device::from_scope_token(scope) { + // Schedule a job to delete the device. + repo.job() + .schedule_job(DeleteDeviceJob::new(&user, &device)) + .await?; + } + } + + repo.oauth2_session().finish(&clock, oauth2_session).await?; + } + + let user_sessions_ids: Vec = sqlx::query_scalar( + r#" + SELECT user_session_id FROM user_sessions + WHERE user_id = $1 AND finished_at IS NULL + "#, + ) + .bind(Uuid::from(user.id)) + .fetch_all(&mut conn) + .await?; + + for id in user_sessions_ids { + let id = id.into(); + let browser_session = repo + .browser_session() + .lookup(id) + .await? + .context("Session not found")?; + info!(%browser_session.id, "Killing browser session"); + + if dry_run { + continue; + } + + repo.browser_session() + .finish(&clock, browser_session) + .await?; + } + + if dry_run { + info!("Dry run, not saving"); + repo.cancel().await?; + } else { + repo.save().await?; + } + + Ok(()) + } } } }