diff --git a/Cargo.lock b/Cargo.lock index c41f4b64..74de32e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1029,6 +1029,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", + "unicode-width", "windows-sys 0.52.0", ] @@ -1492,6 +1493,19 @@ dependencies = [ "serde", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -3022,6 +3036,7 @@ dependencies = [ "axum", "camino", "clap", + "dialoguer", "dotenvy", "figment", "httpdate", @@ -5455,6 +5470,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "signal-hook-registry" version = "1.4.1" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index b603b989..5faaf078 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -16,6 +16,7 @@ anyhow.workspace = true axum = "0.6.20" camino.workspace = true clap.workspace = true +dialoguer = "0.11.0" dotenvy = "0.15.7" figment.workspace = true httpdate = "1.0.3" diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index ca639b0e..57e8c67f 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -12,21 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, io::Write}; +use std::collections::HashMap; use anyhow::Context; use clap::{ArgAction, CommandFactory, Parser}; +use dialoguer::{Confirm, Input}; use figment::Figment; -use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig}; +use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig}; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_email::Address; +use mas_handlers::HttpClientFactory; +use mas_matrix::HomeserverConnection; +use mas_matrix_synapse::SynapseConnection; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, }; -use mas_storage_pg::PgRepository; +use mas_storage_pg::{DatabaseError, PgRepository}; use rand::{RngCore, SeedableRng}; use sqlx::{types::Uuid, Acquire}; use tracing::{info, info_span, warn}; @@ -477,10 +481,18 @@ impl Options { admin, display_name, } => { + let http_client_factory = HttpClientFactory::new(); let password_config = PasswordsConfig::extract(figment)?; let database_config = DatabaseConfig::extract(figment)?; + let matrix_config = MatrixConfig::extract(figment)?; let password_manager = password_manager_from_config(&password_config).await?; + let homeserver = SynapseConnection::new( + matrix_config.homeserver, + matrix_config.endpoint, + matrix_config.secret, + http_client_factory, + ); let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -501,6 +513,7 @@ impl Options { upstream_providers.insert(provider.id, provider); } + // TODO: prompt for link? let upstream_provider_mappings = upstream_provider_mappings .into_iter() .map(|mapping| { @@ -516,30 +529,61 @@ impl Options { let password = password.into_bytes().into(); Some(password_manager.hash(&mut rng, password).await?) } else { + // TODO: prompt for password if not provided and needed None }; - // TODO: prompt - let username = username.context("Username is required")?; + let localpart = if let Some(username) = username { + check_and_normalize_username(&username, &mut repo, &homeserver) + .await? + .to_owned() + } else { + loop { + let username = tokio::task::spawn_blocking(|| { + Input::::new() + .with_prompt("Username") + .interact_text() + }) + .await??; - if repo.user().exists(&username).await? { - anyhow::bail!("User already exists"); - } + match check_and_normalize_username(&username, &mut repo, &homeserver).await + { + Ok(localpart) => break localpart.to_owned(), + Err(e) => { + warn!("Invalid username: {e}"); + } + } + } + }; let req = UserCreationRequest { - username, + homeserver: &homeserver, + username: localpart, hashed_password, + // TODO: prompt for emails emails, upstream_provider_mappings, + // TODO: prompt for display name display_name, + // TODO: prompt for admin admin, }; - req.command(&mut std::io::stdout())?; + info!("Do you want to register this user?\n{req}"); - //do_register(&mut repo, &mut rng, &clock, req).await?; + let confirmation = tokio::task::spawn_blocking(|| { + Confirm::new().with_prompt("Confirm?").interact() + }) + .await??; - repo.into_inner().commit().await?; + if confirmation { + let user = do_register(&mut repo, &mut rng, &clock, req).await?; + repo.into_inner().commit().await?; + info!(%user.id, "User registered"); + } else { + let cmd = UserCreationCommand(&req); + info!("Aborted. {cmd}"); + } Ok(()) } @@ -547,7 +591,37 @@ impl Options { } } +async fn check_and_normalize_username<'a>( + localpart_or_mxid: &'a str, + repo: &mut dyn RepositoryAccess, + homeserver: &SynapseConnection, +) -> anyhow::Result<&'a str> { + // XXX: this is a very basic MXID to localpart conversion + // Strip any leading '@' + let mut localpart = localpart_or_mxid.trim_start_matches('@'); + + // Strip any trailing ':homeserver' + if let Some(index) = localpart.find(':') { + localpart = &localpart[..index]; + } + + if localpart.is_empty() { + return Err(anyhow::anyhow!("Username cannot be empty")); + } + + if repo.user().exists(localpart).await? { + return Err(anyhow::anyhow!("User already exists")); + } + + if !homeserver.is_localpart_available(localpart).await? { + return Err(anyhow::anyhow!("Username not available on homeserver")); + } + + Ok(localpart) +} + struct UserCreationRequest<'a> { + homeserver: &'a SynapseConnection, username: String, hashed_password: Option<(u16, String)>, emails: Vec
, @@ -556,8 +630,10 @@ struct UserCreationRequest<'a> { admin: bool, } -impl<'a> UserCreationRequest<'a> { - fn command(&self, w: &mut W) -> std::io::Result<()> { +struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>); + +impl std::fmt::Display for UserCreationCommand<'_> { + fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let command = super::Options::command(); let manage = command.find_subcommand("manage").unwrap(); let register_user = manage.find_subcommand("register-user").unwrap(); @@ -581,15 +657,15 @@ impl<'a> UserCreationRequest<'a> { w, " --{} {:?}", username_arg.get_long().unwrap(), - self.username + self.0.username )?; - for email in &self.emails { + for email in &self.0.emails { let email: &str = email.as_ref(); write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?; } - if let Some(display_name) = &self.display_name { + if let Some(display_name) = &self.0.display_name { write!( w, " --{} {:?}", @@ -598,11 +674,11 @@ impl<'a> UserCreationRequest<'a> { )?; } - if self.hashed_password.is_some() { + if self.0.hashed_password.is_some() { write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?; } - for (provider, subject) in &self.upstream_provider_mappings { + for (provider, subject) in &self.0.upstream_provider_mappings { let mapping = format!("{}:{}", provider.id, subject); write!( w, @@ -611,7 +687,7 @@ impl<'a> UserCreationRequest<'a> { )?; } - if self.admin { + if self.0.admin { write!(w, " --{}", admin_arg.get_long().unwrap())?; } @@ -619,6 +695,37 @@ impl<'a> UserCreationRequest<'a> { } } +impl std::fmt::Display for UserCreationRequest<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let username = &self.username; + let mxid = self.homeserver.mxid(username); + writeln!(f, "Username: {username}")?; + writeln!(f, "Matrix ID: {mxid}")?; + if let Some(display_name) = &self.display_name { + writeln!(f, "Display name: {display_name}")?; + } + if self.hashed_password.is_some() { + writeln!(f, "Password: ")?; + } + + for (provider, subject) in &self.upstream_provider_mappings { + let provider = provider + .human_name + .clone() + .unwrap_or_else(|| provider.id.to_string()); + writeln!(f, "Upstream account: {provider} => {subject}")?; + } + + for email in &self.emails { + writeln!(f, "Email: {email}")?; + } + + writeln!(f, "Can request admin: {}", self.admin)?; + + Ok(()) + } +} + async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>( repo: &mut dyn RepositoryAccess, rng: &mut (dyn RngCore + Send), @@ -630,6 +737,7 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>( upstream_provider_mappings, display_name, admin, + .. }: UserCreationRequest<'a>, ) -> Result { let mut user = repo.user().add(rng, clock, username).await?;