diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 0f63f506..ca639b0e 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,24 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{collections::HashMap, io::Write}; + use anyhow::Context; -use clap::Parser; +use clap::{ArgAction, CommandFactory, Parser}; use figment::Figment; use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig}; -use mas_data_model::{Device, TokenType}; +use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; +use mas_email::Address; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, - RepositoryAccess, SystemClock, + Clock, RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; -use rand::SeedableRng; +use rand::{RngCore, SeedableRng}; use sqlx::{types::Uuid, Acquire}; use tracing::{info, info_span, warn}; use crate::util::{database_connection_from_config, password_manager_from_config}; +const USER_ATTRIBUTES_HEADING: &str = "User attributes"; + +#[derive(Debug, Clone)] +struct UpstreamProviderMapping { + upstream_provider_id: Ulid, + subject: String, +} + +fn parse_upstream_provider_mapping(s: &str) -> Result { + let (id, subject) = s.split_once(':').context("Invalid format")?; + let upstream_provider_id = id.parse().context("Invalid upstream provider ID")?; + let subject = subject.to_owned(); + + Ok(UpstreamProviderMapping { + upstream_provider_id, + subject, + }) +} + #[derive(Parser, Debug)] pub(super) struct Options { #[command(subcommand)] @@ -86,6 +108,40 @@ enum Subcommand { /// User to unlock username: String, }, + + /// Register a user + RegisterUser { + /// Username to register + #[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)] + username: Option, + + /// Password to set + #[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)] + password: Option, + + /// Email to add + #[arg(short, long = "email", action = ArgAction::Append, help_heading = USER_ATTRIBUTES_HEADING)] + emails: Vec
, + + /// Upstream OAuth 2.0 provider mapping to add + #[arg( + short = 'M', + long = "upstream-provider-mapping", + value_parser = parse_upstream_provider_mapping, + action = ArgAction::Append, + value_name = "UPSTREAM_PROVIDER_ID:SUBJECT", + help_heading = USER_ATTRIBUTES_HEADING + )] + upstream_provider_mappings: Vec, + + /// Make the user an admin + #[arg(long, action = ArgAction::SetTrue, help_heading = USER_ATTRIBUTES_HEADING)] + admin: bool, + + /// Set the user's display name + #[arg(short = 'D', long, help_heading = USER_ATTRIBUTES_HEADING)] + display_name: Option, + }, } impl Options { @@ -412,6 +468,216 @@ impl Options { Ok(()) } + + SC::RegisterUser { + username, + password, + emails, + upstream_provider_mappings, + admin, + display_name, + } => { + let password_config = PasswordsConfig::extract(figment)?; + let database_config = DatabaseConfig::extract(figment)?; + + let password_manager = password_manager_from_config(&password_config).await?; + let mut conn = database_connection_from_config(&database_config).await?; + let txn = conn.begin().await?; + let mut repo = PgRepository::from_conn(txn); + + // Load all the providers we need + let mut upstream_providers = HashMap::new(); + for mapping in &upstream_provider_mappings { + if upstream_providers.contains_key(&mapping.upstream_provider_id) { + continue; + } + + let provider = repo + .upstream_oauth_provider() + .lookup(mapping.upstream_provider_id) + .await? + .context("Upstream provider not found")?; + + upstream_providers.insert(provider.id, provider); + } + + let upstream_provider_mappings = upstream_provider_mappings + .into_iter() + .map(|mapping| { + ( + &upstream_providers[&mapping.upstream_provider_id], + mapping.subject, + ) + }) + .collect(); + + // Hash the password if it's provided + let hashed_password = if let Some(password) = password { + let password = password.into_bytes().into(); + Some(password_manager.hash(&mut rng, password).await?) + } else { + None + }; + + // TODO: prompt + let username = username.context("Username is required")?; + + if repo.user().exists(&username).await? { + anyhow::bail!("User already exists"); + } + + let req = UserCreationRequest { + username, + hashed_password, + emails, + upstream_provider_mappings, + display_name, + admin, + }; + + req.command(&mut std::io::stdout())?; + + //do_register(&mut repo, &mut rng, &clock, req).await?; + + repo.into_inner().commit().await?; + + Ok(()) + } } } } + +struct UserCreationRequest<'a> { + username: String, + hashed_password: Option<(u16, String)>, + emails: Vec
, + upstream_provider_mappings: Vec<(&'a UpstreamOAuthProvider, String)>, + display_name: Option, + admin: bool, +} + +impl<'a> UserCreationRequest<'a> { + fn command(&self, w: &mut W) -> std::io::Result<()> { + let command = super::Options::command(); + let manage = command.find_subcommand("manage").unwrap(); + let register_user = manage.find_subcommand("register-user").unwrap(); + let username_arg = ®ister_user[&clap::Id::from("username")]; + let password_arg = ®ister_user[&clap::Id::from("password")]; + let email_arg = ®ister_user[&clap::Id::from("emails")]; + let upstream_provider_mapping_arg = + ®ister_user[&clap::Id::from("upstream_provider_mappings")]; + let display_name_arg = ®ister_user[&clap::Id::from("display_name")]; + let admin_arg = ®ister_user[&clap::Id::from("admin")]; + + write!( + w, + "{} {} {}", + command.get_name(), + manage.get_name(), + register_user.get_name() + )?; + + write!( + w, + " --{} {:?}", + username_arg.get_long().unwrap(), + self.username + )?; + + for email in &self.emails { + let email: &str = email.as_ref(); + write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?; + } + + if let Some(display_name) = &self.display_name { + write!( + w, + " --{} {:?}", + display_name_arg.get_long().unwrap(), + display_name + )?; + } + + if self.hashed_password.is_some() { + write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?; + } + + for (provider, subject) in &self.upstream_provider_mappings { + let mapping = format!("{}:{}", provider.id, subject); + write!( + w, + " --{} {mapping:?}", + upstream_provider_mapping_arg.get_long().unwrap(), + )?; + } + + if self.admin { + write!(w, " --{}", admin_arg.get_long().unwrap())?; + } + + Ok(()) + } +} + +async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>( + repo: &mut dyn RepositoryAccess, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + UserCreationRequest { + username, + hashed_password, + emails, + upstream_provider_mappings, + display_name, + admin, + }: UserCreationRequest<'a>, +) -> Result { + let mut user = repo.user().add(rng, clock, username).await?; + + if let Some((version, hashed_password)) = hashed_password { + repo.user_password() + .add(rng, clock, &user, version, hashed_password, None) + .await?; + } + + for email in emails { + let user_email = repo + .user_email() + .add(rng, clock, &user, email.to_string()) + .await?; + + let user_email = repo + .user_email() + .mark_as_verified(clock, user_email) + .await?; + + if user.primary_user_email_id.is_none() { + repo.user_email().set_as_primary(&user_email).await?; + user.primary_user_email_id = Some(user_email.id); + } + } + + for (provider, subject) in upstream_provider_mappings { + let link = repo + .upstream_oauth_link() + .add(rng, clock, provider, subject) + .await?; + + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; + } + + if admin { + user = repo.user().set_can_request_admin(user, true).await?; + } + + let mut provision_job = ProvisionUserJob::new(&user); + if let Some(display_name) = display_name { + provision_job = provision_job.set_display_name(display_name); + } + + repo.job().schedule_job(provision_job).await?; + + Ok(user) +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 9b52f010..c79bbe2b 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -28,6 +28,8 @@ pub(crate) mod users; #[error("invalid state transition")] pub struct InvalidTransitionError; +pub use ulid::Ulid; + pub use self::{ compat::{ CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,