diff --git a/Cargo.lock b/Cargo.lock index 74de32e6..d8aca5e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1500,6 +1500,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" dependencies = [ "console", + "fuzzy-matcher", "shell-words", "tempfile", "thiserror", @@ -1967,6 +1968,15 @@ dependencies = [ "slab", ] +[[package]] +name = "fuzzy-matcher" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54614a3312934d066701a80f20f15fa3b56d67ac7722b39eea5b4c9dd1d66c94" +dependencies = [ + "thread_local", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -3036,6 +3046,7 @@ dependencies = [ "axum", "camino", "clap", + "console", "dialoguer", "dotenvy", "figment", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 5faaf078..03976429 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -16,7 +16,8 @@ anyhow.workspace = true axum = "0.6.20" camino.workspace = true clap.workspace = true -dialoguer = "0.11.0" +console = "0.15.8" +dialoguer = { version = "0.11.0", features = ["fuzzy-select"] } dotenvy = "0.15.7" figment.workspace = true httpdate = "1.0.3" diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 57e8c67f..3a129b40 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::BTreeMap; use anyhow::Context; use clap::{ArgAction, CommandFactory, Parser}; -use dialoguer::{Confirm, Input}; +use console::{pad_str, style, Alignment, Style, Term}; +use dialoguer::{theme::ColorfulTheme, Confirm, FuzzySelect, Input, Password}; use figment::Figment; use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig}; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; @@ -114,9 +115,13 @@ enum Subcommand { }, /// Register a user + /// + /// This will interactively prompt for the user's attributes unless the + /// `--yes` flag is set. It bypasses any policy check on the password, + /// email, etc. RegisterUser { /// Username to register - #[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)] + #[arg(help_heading = USER_ATTRIBUTES_HEADING, required_if_eq("yes", "true"))] username: Option, /// Password to set @@ -129,7 +134,7 @@ enum Subcommand { /// Upstream OAuth 2.0 provider mapping to add #[arg( - short = 'M', + short = 'm', long = "upstream-provider-mapping", value_parser = parse_upstream_provider_mapping, action = ArgAction::Append, @@ -139,11 +144,19 @@ enum Subcommand { upstream_provider_mappings: Vec, /// Make the user an admin - #[arg(long, action = ArgAction::SetTrue, help_heading = USER_ATTRIBUTES_HEADING)] + #[arg(short, long, action = ArgAction::SetTrue, group = "admin-flag", help_heading = USER_ATTRIBUTES_HEADING)] admin: bool, + /// Make the user not an admin + #[arg(short = 'A', long, action = ArgAction::SetTrue, group = "admin-flag", help_heading = USER_ATTRIBUTES_HEADING)] + no_admin: bool, + + // Don't ask questions, just do it + #[arg(short, long, action = ArgAction::SetTrue)] + yes: bool, + /// Set the user's display name - #[arg(short = 'D', long, help_heading = USER_ATTRIBUTES_HEADING)] + #[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)] display_name: Option, }, } @@ -343,7 +356,7 @@ impl Options { let oauth2_sessions_ids: Vec = sqlx::query_scalar( r" - SELECT oauth2_sessions.oauth2_session_id + SELECT oauth2_sessions.oauth2_session_id FROM oauth2_sessions INNER JOIN user_sessions USING (user_session_id) WHERE user_sessions.user_id = $1 AND oauth2_sessions.finished_at IS NULL @@ -479,7 +492,9 @@ impl Options { emails, upstream_provider_mappings, admin, + no_admin, display_name, + yes, } => { let http_client_factory = HttpClientFactory::new(); let password_config = PasswordsConfig::extract(figment)?; @@ -497,50 +512,16 @@ impl Options { let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); - // Load all the providers we need - let mut upstream_providers = HashMap::new(); - for mapping in &upstream_provider_mappings { - if upstream_providers.contains_key(&mapping.upstream_provider_id) { - continue; - } - - let provider = repo - .upstream_oauth_provider() - .lookup(mapping.upstream_provider_id) - .await? - .context("Upstream provider not found")?; - - upstream_providers.insert(provider.id, provider); - } - - // TODO: prompt for link? - let upstream_provider_mappings = upstream_provider_mappings - .into_iter() - .map(|mapping| { - ( - &upstream_providers[&mapping.upstream_provider_id], - mapping.subject, - ) - }) - .collect(); - - // Hash the password if it's provided - let hashed_password = if let Some(password) = password { - let password = password.into_bytes().into(); - Some(password_manager.hash(&mut rng, password).await?) - } else { - // TODO: prompt for password if not provided and needed - None - }; - + // If the username is provided, check if it's available and normalize it. let localpart = if let Some(username) = username { check_and_normalize_username(&username, &mut repo, &homeserver) .await? .to_owned() } else { + // Else we prompt for one until we get a valid one. loop { let username = tokio::task::spawn_blocking(|| { - Input::::new() + Input::::with_theme(&ColorfulTheme::default()) .with_prompt("Username") .interact_text() }) @@ -556,33 +537,189 @@ impl Options { } }; - let req = UserCreationRequest { - homeserver: &homeserver, + // Load all the upstream providers + let upstream_providers: BTreeMap<_, _> = repo + .upstream_oauth_provider() + .all_enabled() + .await? + .into_iter() + .map(|provider| (provider.id, provider)) + .collect(); + + let upstream_provider_mappings = upstream_provider_mappings + .into_iter() + .map(|mapping| { + ( + &upstream_providers[&mapping.upstream_provider_id], + mapping.subject, + ) + }) + .collect(); + + let admin = match (admin, no_admin) { + (false, false) => None, + (true, false) => Some(true), + (false, true) => Some(false), + _ => unreachable!("This should be handled by the clap group"), + }; + + // Hash the password if it's provided + let hashed_password = if let Some(password) = password { + let password = password.into_bytes().into(); + Some(password_manager.hash(&mut rng, password).await?) + } else { + None + }; + + let mut req = UserCreationRequest { username: localpart, hashed_password, - // TODO: prompt for emails emails, upstream_provider_mappings, - // TODO: prompt for display name display_name, - // TODO: prompt for admin admin, }; - info!("Do you want to register this user?\n{req}"); + let term = Term::buffered_stdout(); + loop { + req.show(&term, &homeserver)?; - let confirmation = tokio::task::spawn_blocking(|| { - Confirm::new().with_prompt("Confirm?").interact() - }) - .await??; + // If we're in `yes` mode, we don't prompt for actions + if yes { + break; + } + + term.write_line(&format!( + "\n{msg}:\n\n {cmd}\n", + msg = style("Non-interactive equivalent to create this user").bold(), + cmd = style(UserCreationCommand(&req)).underlined(), + ))?; + + term.flush()?; + + let action = req + .prompt_action( + password_manager.is_enabled(), + !upstream_providers.is_empty(), + ) + .await? + .context("Aborted")?; + + match action { + Action::CreateUser => break, + Action::ChangeUsername => { + req.username = loop { + let current_username = req.username.clone(); + let username = tokio::task::spawn_blocking(|| { + Input::::with_theme(&ColorfulTheme::default()) + .with_prompt("Username") + .with_initial_text(current_username) + .interact_text() + }) + .await??; + + match check_and_normalize_username( + &username, + &mut repo, + &homeserver, + ) + .await + { + Ok(localpart) => break localpart.to_owned(), + Err(e) => { + warn!("Invalid username: {e}"); + } + } + }; + } + Action::SetPassword => { + let password = tokio::task::spawn_blocking(|| { + Password::with_theme(&ColorfulTheme::default()) + .with_prompt("Password") + .with_confirmation("Confirm password", "Passwords mismatching") + .interact() + }) + .await??; + let password = password.into_bytes().into(); + req.hashed_password = + Some(password_manager.hash(&mut rng, password).await?); + } + Action::SetDisplayName => { + let display_name = tokio::task::spawn_blocking(|| { + Input::::with_theme(&ColorfulTheme::default()) + .with_prompt("Display name") + .interact() + }) + .await??; + req.display_name = Some(display_name); + } + Action::AddEmail => { + let email = tokio::task::spawn_blocking(|| { + Input::
::with_theme(&ColorfulTheme::default()) + .with_prompt("Email") + .interact_text() + }) + .await??; + req.emails.push(email); + } + Action::SetAdmin => { + let admin = tokio::task::spawn_blocking(|| { + Confirm::with_theme(&ColorfulTheme::default()) + .with_prompt("Make user admin?") + .interact() + }) + .await??; + req.admin = Some(admin); + } + Action::AddUpstreamProviderMapping => { + let providers = upstream_providers.clone(); + let provider_id = tokio::task::spawn_blocking(move || { + let providers: Vec<_> = providers.into_values().collect(); + let human_readable_providers: Vec<_> = + providers.iter().map(HumanReadable).collect(); + FuzzySelect::with_theme(&ColorfulTheme::default()) + .with_prompt("Upstream provider") + .items(&human_readable_providers) + .default(0) + .interact() + .map(move |selected| providers[selected].id) + }) + .await??; + let provider = &upstream_providers[&provider_id]; + + let subject = tokio::task::spawn_blocking(|| { + Input::::with_theme(&ColorfulTheme::default()) + .with_prompt("Subject") + .interact() + }) + .await??; + + req.upstream_provider_mappings.push((&provider, subject)); + } + } + } + + if req.emails.is_empty() { + warn!("No email address provided, user will need to add one"); + } + + let confirmation = if yes { + true + } else { + tokio::task::spawn_blocking(|| { + Confirm::with_theme(&ColorfulTheme::default()) + .with_prompt("Confirm?") + .interact() + }) + .await?? + }; if confirmation { - let user = do_register(&mut repo, &mut rng, &clock, req).await?; + let user = req.do_register(&mut repo, &mut rng, &clock).await?; repo.into_inner().commit().await?; info!(%user.id, "User registered"); } else { - let cmd = UserCreationCommand(&req); - info!("Aborted. {cmd}"); + warn!("Aborted"); } Ok(()) @@ -591,6 +728,21 @@ impl Options { } } +/// A wrapper to display some objects differently +#[derive(Debug, Clone, Copy)] +struct HumanReadable(T); + +impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let provider = self.0; + if let Some(human_name) = &provider.human_name { + write!(f, "{} ({})", human_name, provider.id) + } else { + write!(f, "{} ({})", provider.issuer, provider.id) + } + } +} + async fn check_and_normalize_username<'a>( localpart_or_mxid: &'a str, repo: &mut dyn RepositoryAccess, @@ -621,15 +773,215 @@ async fn check_and_normalize_username<'a>( } struct UserCreationRequest<'a> { - homeserver: &'a SynapseConnection, username: String, hashed_password: Option<(u16, String)>, emails: Vec
, upstream_provider_mappings: Vec<(&'a UpstreamOAuthProvider, String)>, display_name: Option, - admin: bool, + admin: Option, } +impl UserCreationRequest<'_> { + // Get a list of the possible actions + fn possible_actions( + &self, + has_password_auth: bool, + has_upstream_providers: bool, + ) -> Vec { + let mut actions = vec![Action::CreateUser, Action::ChangeUsername, Action::AddEmail]; + + if has_password_auth && self.hashed_password.is_none() { + actions.push(Action::SetPassword); + } + + if has_upstream_providers { + actions.push(Action::AddUpstreamProviderMapping); + } + + if self.admin.is_none() { + actions.push(Action::SetAdmin); + } + + if self.display_name.is_none() { + actions.push(Action::SetDisplayName); + } + + actions + } + + /// Prompt for the next action + async fn prompt_action( + &self, + has_password_auth: bool, + has_upstream_providers: bool, + ) -> anyhow::Result> { + let actions = self.possible_actions(has_password_auth, has_upstream_providers); + tokio::task::spawn_blocking(move || { + let index = FuzzySelect::with_theme(&ColorfulTheme::default()) + .with_prompt("What do you want to do next? ( to abort)") + .items(&actions) + .default(0) + .interact_opt()?; + Ok(index.map(|index| actions[index])) + }) + .await? + } + + /// Show the user creation request in a human-readable format + fn show(&self, term: &Term, homeserver: &SynapseConnection) -> std::io::Result<()> { + let value_style = Style::new().green(); + let key_style = Style::new().bold(); + let warning_style = Style::new().italic().red().bright(); + let username = &self.username; + let mxid = homeserver.mxid(username); + + term.write_line(&style("User attributes").bold().underlined().to_string())?; + + macro_rules! display { + ($key:expr, $value:expr) => { + term.write_line(&format!( + "{key}: {value}", + key = key_style.apply_to(pad_str($key, 17, Alignment::Right, None)), + value = value_style.apply_to($value) + ))?; + }; + } + + display!("Username", username); + display!("Matrix ID", mxid); + if let Some(display_name) = &self.display_name { + display!("Display name", display_name); + } + + if self.hashed_password.is_some() { + display!("Password", "********"); + } + + for (provider, subject) in &self.upstream_provider_mappings { + let provider = HumanReadable(*provider); + display!("Upstream account", format!("{provider} : {subject:?}")); + } + + for email in &self.emails { + display!("Email", email); + } + + if self.emails.is_empty() { + term.write_line( + &warning_style + .apply_to("No email address provided, user will be prompted to add one") + .to_string(), + )?; + } + + if self.hashed_password.is_none() && self.upstream_provider_mappings.is_empty() { + term.write_line( + &warning_style.apply_to("No password or upstream provider mapping provided, user will not be able to log in") + .to_string(), + )?; + } + + if let Some(admin) = self.admin { + display!("Can request admin", admin); + } + + term.flush()?; + + Ok(()) + } + + /// Submit the user creation request + async fn do_register( + self, + repo: &mut dyn RepositoryAccess, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + ) -> Result { + let Self { + username, + hashed_password, + emails, + upstream_provider_mappings, + display_name, + admin, + } = self; + let mut user = repo.user().add(rng, clock, username).await?; + + if let Some((version, hashed_password)) = hashed_password { + repo.user_password() + .add(rng, clock, &user, version, hashed_password, None) + .await?; + } + + for email in emails { + let user_email = repo + .user_email() + .add(rng, clock, &user, email.to_string()) + .await?; + + let user_email = repo + .user_email() + .mark_as_verified(clock, user_email) + .await?; + + if user.primary_user_email_id.is_none() { + repo.user_email().set_as_primary(&user_email).await?; + user.primary_user_email_id = Some(user_email.id); + } + } + + for (provider, subject) in upstream_provider_mappings { + let link = repo + .upstream_oauth_link() + .add(rng, clock, provider, subject) + .await?; + + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; + } + + if let Some(admin) = admin { + user = repo.user().set_can_request_admin(user, admin).await?; + } + + let mut provision_job = ProvisionUserJob::new(&user); + if let Some(display_name) = display_name { + provision_job = provision_job.set_display_name(display_name); + } + + repo.job().schedule_job(provision_job).await?; + + Ok(user) + } +} + +#[derive(Debug, Clone, Copy)] +enum Action { + CreateUser, + ChangeUsername, + SetPassword, + SetDisplayName, + AddEmail, + SetAdmin, + AddUpstreamProviderMapping, +} + +impl std::fmt::Display for Action { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Action::CreateUser => write!(f, "Create the user"), + Action::ChangeUsername => write!(f, "Change the username"), + Action::SetPassword => write!(f, "Set a password"), + Action::AddEmail => write!(f, "Add email"), + Action::SetDisplayName => write!(f, "Set a display name"), + Action::SetAdmin => write!(f, "Set the admin status"), + Action::AddUpstreamProviderMapping => write!(f, "Add upstream provider mapping"), + } + } +} + +/// A wrapper to display the user creation request as a command struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>); impl std::fmt::Display for UserCreationCommand<'_> { @@ -637,27 +989,23 @@ impl std::fmt::Display for UserCreationCommand<'_> { let command = super::Options::command(); let manage = command.find_subcommand("manage").unwrap(); let register_user = manage.find_subcommand("register-user").unwrap(); - let username_arg = ®ister_user[&clap::Id::from("username")]; + let yes_arg = ®ister_user[&clap::Id::from("yes")]; let password_arg = ®ister_user[&clap::Id::from("password")]; let email_arg = ®ister_user[&clap::Id::from("emails")]; let upstream_provider_mapping_arg = ®ister_user[&clap::Id::from("upstream_provider_mappings")]; let display_name_arg = ®ister_user[&clap::Id::from("display_name")]; let admin_arg = ®ister_user[&clap::Id::from("admin")]; + let no_admin_arg = ®ister_user[&clap::Id::from("no_admin")]; write!( w, - "{} {} {}", + "{} {} {} --{} {}", command.get_name(), manage.get_name(), - register_user.get_name() - )?; - - write!( - w, - " --{} {:?}", - username_arg.get_long().unwrap(), - self.0.username + register_user.get_name(), + yes_arg.get_long().unwrap(), + self.0.username, )?; for email in &self.0.emails { @@ -687,105 +1035,12 @@ impl std::fmt::Display for UserCreationCommand<'_> { )?; } - if self.0.admin { - write!(w, " --{}", admin_arg.get_long().unwrap())?; + match self.0.admin { + Some(true) => write!(w, " --{}", admin_arg.get_long().unwrap())?, + Some(false) => write!(w, " --{}", no_admin_arg.get_long().unwrap())?, + None => {} } Ok(()) } } - -impl std::fmt::Display for UserCreationRequest<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let username = &self.username; - let mxid = self.homeserver.mxid(username); - writeln!(f, "Username: {username}")?; - writeln!(f, "Matrix ID: {mxid}")?; - if let Some(display_name) = &self.display_name { - writeln!(f, "Display name: {display_name}")?; - } - if self.hashed_password.is_some() { - writeln!(f, "Password: ")?; - } - - for (provider, subject) in &self.upstream_provider_mappings { - let provider = provider - .human_name - .clone() - .unwrap_or_else(|| provider.id.to_string()); - writeln!(f, "Upstream account: {provider} => {subject}")?; - } - - for email in &self.emails { - writeln!(f, "Email: {email}")?; - } - - writeln!(f, "Can request admin: {}", self.admin)?; - - Ok(()) - } -} - -async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>( - repo: &mut dyn RepositoryAccess, - rng: &mut (dyn RngCore + Send), - clock: &dyn Clock, - UserCreationRequest { - username, - hashed_password, - emails, - upstream_provider_mappings, - display_name, - admin, - .. - }: UserCreationRequest<'a>, -) -> Result { - let mut user = repo.user().add(rng, clock, username).await?; - - if let Some((version, hashed_password)) = hashed_password { - repo.user_password() - .add(rng, clock, &user, version, hashed_password, None) - .await?; - } - - for email in emails { - let user_email = repo - .user_email() - .add(rng, clock, &user, email.to_string()) - .await?; - - let user_email = repo - .user_email() - .mark_as_verified(clock, user_email) - .await?; - - if user.primary_user_email_id.is_none() { - repo.user_email().set_as_primary(&user_email).await?; - user.primary_user_email_id = Some(user_email.id); - } - } - - for (provider, subject) in upstream_provider_mappings { - let link = repo - .upstream_oauth_link() - .add(rng, clock, provider, subject) - .await?; - - repo.upstream_oauth_link() - .associate_to_user(&link, &user) - .await?; - } - - if admin { - user = repo.user().set_can_request_admin(user, true).await?; - } - - let mut provision_job = ProvisionUserJob::new(&user); - if let Some(display_name) = display_name { - provision_job = provision_job.set_display_name(display_name); - } - - repo.job().schedule_job(provision_job).await?; - - Ok(user) -}