1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Prompt for username and confirm user creation

This commit is contained in:
Quentin Gliech
2024-04-16 17:21:47 +02:00
parent 1cb48b8026
commit 8c402a1f50
3 changed files with 150 additions and 20 deletions

21
Cargo.lock generated
View File

@@ -1029,6 +1029,7 @@ dependencies = [
"encode_unicode", "encode_unicode",
"lazy_static", "lazy_static",
"libc", "libc",
"unicode-width",
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
@@ -1492,6 +1493,19 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "dialoguer"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de"
dependencies = [
"console",
"shell-words",
"tempfile",
"thiserror",
"zeroize",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@@ -3022,6 +3036,7 @@ dependencies = [
"axum", "axum",
"camino", "camino",
"clap", "clap",
"dialoguer",
"dotenvy", "dotenvy",
"figment", "figment",
"httpdate", "httpdate",
@@ -5455,6 +5470,12 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shell-words"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.1" version = "1.4.1"

View File

@@ -16,6 +16,7 @@ anyhow.workspace = true
axum = "0.6.20" axum = "0.6.20"
camino.workspace = true camino.workspace = true
clap.workspace = true clap.workspace = true
dialoguer = "0.11.0"
dotenvy = "0.15.7" dotenvy = "0.15.7"
figment.workspace = true figment.workspace = true
httpdate = "1.0.3" httpdate = "1.0.3"

View File

@@ -12,21 +12,25 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, io::Write}; use std::collections::HashMap;
use anyhow::Context; use anyhow::Context;
use clap::{ArgAction, CommandFactory, Parser}; use clap::{ArgAction, CommandFactory, Parser};
use dialoguer::{Confirm, Input};
use figment::Figment; use figment::Figment;
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig}; use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig};
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
use mas_email::Address; use mas_email::Address;
use mas_handlers::HttpClientFactory;
use mas_matrix::HomeserverConnection;
use mas_matrix_synapse::SynapseConnection;
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatSessionRepository},
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
user::{UserEmailRepository, UserPasswordRepository, UserRepository}, user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, RepositoryAccess, SystemClock, Clock, RepositoryAccess, SystemClock,
}; };
use mas_storage_pg::PgRepository; use mas_storage_pg::{DatabaseError, PgRepository};
use rand::{RngCore, SeedableRng}; use rand::{RngCore, SeedableRng};
use sqlx::{types::Uuid, Acquire}; use sqlx::{types::Uuid, Acquire};
use tracing::{info, info_span, warn}; use tracing::{info, info_span, warn};
@@ -477,10 +481,18 @@ impl Options {
admin, admin,
display_name, display_name,
} => { } => {
let http_client_factory = HttpClientFactory::new();
let password_config = PasswordsConfig::extract(figment)?; let password_config = PasswordsConfig::extract(figment)?;
let database_config = DatabaseConfig::extract(figment)?; let database_config = DatabaseConfig::extract(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
let password_manager = password_manager_from_config(&password_config).await?; let password_manager = password_manager_from_config(&password_config).await?;
let homeserver = SynapseConnection::new(
matrix_config.homeserver,
matrix_config.endpoint,
matrix_config.secret,
http_client_factory,
);
let mut conn = database_connection_from_config(&database_config).await?; let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?; let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn); let mut repo = PgRepository::from_conn(txn);
@@ -501,6 +513,7 @@ impl Options {
upstream_providers.insert(provider.id, provider); upstream_providers.insert(provider.id, provider);
} }
// TODO: prompt for link?
let upstream_provider_mappings = upstream_provider_mappings let upstream_provider_mappings = upstream_provider_mappings
.into_iter() .into_iter()
.map(|mapping| { .map(|mapping| {
@@ -516,30 +529,61 @@ impl Options {
let password = password.into_bytes().into(); let password = password.into_bytes().into();
Some(password_manager.hash(&mut rng, password).await?) Some(password_manager.hash(&mut rng, password).await?)
} else { } else {
// TODO: prompt for password if not provided and needed
None None
}; };
// TODO: prompt let localpart = if let Some(username) = username {
let username = username.context("Username is required")?; check_and_normalize_username(&username, &mut repo, &homeserver)
.await?
.to_owned()
} else {
loop {
let username = tokio::task::spawn_blocking(|| {
Input::<String>::new()
.with_prompt("Username")
.interact_text()
})
.await??;
if repo.user().exists(&username).await? { match check_and_normalize_username(&username, &mut repo, &homeserver).await
anyhow::bail!("User already exists"); {
} Ok(localpart) => break localpart.to_owned(),
Err(e) => {
warn!("Invalid username: {e}");
}
}
}
};
let req = UserCreationRequest { let req = UserCreationRequest {
username, homeserver: &homeserver,
username: localpart,
hashed_password, hashed_password,
// TODO: prompt for emails
emails, emails,
upstream_provider_mappings, upstream_provider_mappings,
// TODO: prompt for display name
display_name, display_name,
// TODO: prompt for admin
admin, admin,
}; };
req.command(&mut std::io::stdout())?; info!("Do you want to register this user?\n{req}");
//do_register(&mut repo, &mut rng, &clock, req).await?; let confirmation = tokio::task::spawn_blocking(|| {
Confirm::new().with_prompt("Confirm?").interact()
})
.await??;
repo.into_inner().commit().await?; if confirmation {
let user = do_register(&mut repo, &mut rng, &clock, req).await?;
repo.into_inner().commit().await?;
info!(%user.id, "User registered");
} else {
let cmd = UserCreationCommand(&req);
info!("Aborted. {cmd}");
}
Ok(()) Ok(())
} }
@@ -547,7 +591,37 @@ impl Options {
} }
} }
async fn check_and_normalize_username<'a>(
localpart_or_mxid: &'a str,
repo: &mut dyn RepositoryAccess<Error = DatabaseError>,
homeserver: &SynapseConnection,
) -> anyhow::Result<&'a str> {
// XXX: this is a very basic MXID to localpart conversion
// Strip any leading '@'
let mut localpart = localpart_or_mxid.trim_start_matches('@');
// Strip any trailing ':homeserver'
if let Some(index) = localpart.find(':') {
localpart = &localpart[..index];
}
if localpart.is_empty() {
return Err(anyhow::anyhow!("Username cannot be empty"));
}
if repo.user().exists(localpart).await? {
return Err(anyhow::anyhow!("User already exists"));
}
if !homeserver.is_localpart_available(localpart).await? {
return Err(anyhow::anyhow!("Username not available on homeserver"));
}
Ok(localpart)
}
struct UserCreationRequest<'a> { struct UserCreationRequest<'a> {
homeserver: &'a SynapseConnection,
username: String, username: String,
hashed_password: Option<(u16, String)>, hashed_password: Option<(u16, String)>,
emails: Vec<Address>, emails: Vec<Address>,
@@ -556,8 +630,10 @@ struct UserCreationRequest<'a> {
admin: bool, admin: bool,
} }
impl<'a> UserCreationRequest<'a> { struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>);
fn command<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
impl std::fmt::Display for UserCreationCommand<'_> {
fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let command = super::Options::command(); let command = super::Options::command();
let manage = command.find_subcommand("manage").unwrap(); let manage = command.find_subcommand("manage").unwrap();
let register_user = manage.find_subcommand("register-user").unwrap(); let register_user = manage.find_subcommand("register-user").unwrap();
@@ -581,15 +657,15 @@ impl<'a> UserCreationRequest<'a> {
w, w,
" --{} {:?}", " --{} {:?}",
username_arg.get_long().unwrap(), username_arg.get_long().unwrap(),
self.username self.0.username
)?; )?;
for email in &self.emails { for email in &self.0.emails {
let email: &str = email.as_ref(); let email: &str = email.as_ref();
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?; write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
} }
if let Some(display_name) = &self.display_name { if let Some(display_name) = &self.0.display_name {
write!( write!(
w, w,
" --{} {:?}", " --{} {:?}",
@@ -598,11 +674,11 @@ impl<'a> UserCreationRequest<'a> {
)?; )?;
} }
if self.hashed_password.is_some() { if self.0.hashed_password.is_some() {
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?; write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
} }
for (provider, subject) in &self.upstream_provider_mappings { for (provider, subject) in &self.0.upstream_provider_mappings {
let mapping = format!("{}:{}", provider.id, subject); let mapping = format!("{}:{}", provider.id, subject);
write!( write!(
w, w,
@@ -611,7 +687,7 @@ impl<'a> UserCreationRequest<'a> {
)?; )?;
} }
if self.admin { if self.0.admin {
write!(w, " --{}", admin_arg.get_long().unwrap())?; write!(w, " --{}", admin_arg.get_long().unwrap())?;
} }
@@ -619,6 +695,37 @@ impl<'a> UserCreationRequest<'a> {
} }
} }
impl std::fmt::Display for UserCreationRequest<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let username = &self.username;
let mxid = self.homeserver.mxid(username);
writeln!(f, "Username: {username}")?;
writeln!(f, "Matrix ID: {mxid}")?;
if let Some(display_name) = &self.display_name {
writeln!(f, "Display name: {display_name}")?;
}
if self.hashed_password.is_some() {
writeln!(f, "Password: <SET>")?;
}
for (provider, subject) in &self.upstream_provider_mappings {
let provider = provider
.human_name
.clone()
.unwrap_or_else(|| provider.id.to_string());
writeln!(f, "Upstream account: {provider} => {subject}")?;
}
for email in &self.emails {
writeln!(f, "Email: {email}")?;
}
writeln!(f, "Can request admin: {}", self.admin)?;
Ok(())
}
}
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>( async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
repo: &mut dyn RepositoryAccess<Error = E>, repo: &mut dyn RepositoryAccess<Error = E>,
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
@@ -630,6 +737,7 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
upstream_provider_mappings, upstream_provider_mappings,
display_name, display_name,
admin, admin,
..
}: UserCreationRequest<'a>, }: UserCreationRequest<'a>,
) -> Result<User, E> { ) -> Result<User, E> {
let mut user = repo.user().add(rng, clock, username).await?; let mut user = repo.user().add(rng, clock, username).await?;