1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Prompt for username and confirm user creation

This commit is contained in:
Quentin Gliech
2024-04-16 17:21:47 +02:00
parent 1cb48b8026
commit 8c402a1f50
3 changed files with 150 additions and 20 deletions

View File

@@ -12,21 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, io::Write};
use std::collections::HashMap;
use anyhow::Context;
use clap::{ArgAction, CommandFactory, Parser};
use dialoguer::{Confirm, Input};
use figment::Figment;
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig};
use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig};
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
use mas_email::Address;
use mas_handlers::HttpClientFactory;
use mas_matrix::HomeserverConnection;
use mas_matrix_synapse::SynapseConnection;
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use mas_storage_pg::{DatabaseError, PgRepository};
use rand::{RngCore, SeedableRng};
use sqlx::{types::Uuid, Acquire};
use tracing::{info, info_span, warn};
@@ -477,10 +481,18 @@ impl Options {
admin,
display_name,
} => {
let http_client_factory = HttpClientFactory::new();
let password_config = PasswordsConfig::extract(figment)?;
let database_config = DatabaseConfig::extract(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
let password_manager = password_manager_from_config(&password_config).await?;
let homeserver = SynapseConnection::new(
matrix_config.homeserver,
matrix_config.endpoint,
matrix_config.secret,
http_client_factory,
);
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
@@ -501,6 +513,7 @@ impl Options {
upstream_providers.insert(provider.id, provider);
}
// TODO: prompt for link?
let upstream_provider_mappings = upstream_provider_mappings
.into_iter()
.map(|mapping| {
@@ -516,30 +529,61 @@ impl Options {
let password = password.into_bytes().into();
Some(password_manager.hash(&mut rng, password).await?)
} else {
// TODO: prompt for password if not provided and needed
None
};
// TODO: prompt
let username = username.context("Username is required")?;
let localpart = if let Some(username) = username {
check_and_normalize_username(&username, &mut repo, &homeserver)
.await?
.to_owned()
} else {
loop {
let username = tokio::task::spawn_blocking(|| {
Input::<String>::new()
.with_prompt("Username")
.interact_text()
})
.await??;
if repo.user().exists(&username).await? {
anyhow::bail!("User already exists");
}
match check_and_normalize_username(&username, &mut repo, &homeserver).await
{
Ok(localpart) => break localpart.to_owned(),
Err(e) => {
warn!("Invalid username: {e}");
}
}
}
};
let req = UserCreationRequest {
username,
homeserver: &homeserver,
username: localpart,
hashed_password,
// TODO: prompt for emails
emails,
upstream_provider_mappings,
// TODO: prompt for display name
display_name,
// TODO: prompt for admin
admin,
};
req.command(&mut std::io::stdout())?;
info!("Do you want to register this user?\n{req}");
//do_register(&mut repo, &mut rng, &clock, req).await?;
let confirmation = tokio::task::spawn_blocking(|| {
Confirm::new().with_prompt("Confirm?").interact()
})
.await??;
repo.into_inner().commit().await?;
if confirmation {
let user = do_register(&mut repo, &mut rng, &clock, req).await?;
repo.into_inner().commit().await?;
info!(%user.id, "User registered");
} else {
let cmd = UserCreationCommand(&req);
info!("Aborted. {cmd}");
}
Ok(())
}
@@ -547,7 +591,37 @@ impl Options {
}
}
async fn check_and_normalize_username<'a>(
localpart_or_mxid: &'a str,
repo: &mut dyn RepositoryAccess<Error = DatabaseError>,
homeserver: &SynapseConnection,
) -> anyhow::Result<&'a str> {
// XXX: this is a very basic MXID to localpart conversion
// Strip any leading '@'
let mut localpart = localpart_or_mxid.trim_start_matches('@');
// Strip any trailing ':homeserver'
if let Some(index) = localpart.find(':') {
localpart = &localpart[..index];
}
if localpart.is_empty() {
return Err(anyhow::anyhow!("Username cannot be empty"));
}
if repo.user().exists(localpart).await? {
return Err(anyhow::anyhow!("User already exists"));
}
if !homeserver.is_localpart_available(localpart).await? {
return Err(anyhow::anyhow!("Username not available on homeserver"));
}
Ok(localpart)
}
struct UserCreationRequest<'a> {
homeserver: &'a SynapseConnection,
username: String,
hashed_password: Option<(u16, String)>,
emails: Vec<Address>,
@@ -556,8 +630,10 @@ struct UserCreationRequest<'a> {
admin: bool,
}
impl<'a> UserCreationRequest<'a> {
fn command<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>);
impl std::fmt::Display for UserCreationCommand<'_> {
fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let command = super::Options::command();
let manage = command.find_subcommand("manage").unwrap();
let register_user = manage.find_subcommand("register-user").unwrap();
@@ -581,15 +657,15 @@ impl<'a> UserCreationRequest<'a> {
w,
" --{} {:?}",
username_arg.get_long().unwrap(),
self.username
self.0.username
)?;
for email in &self.emails {
for email in &self.0.emails {
let email: &str = email.as_ref();
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
}
if let Some(display_name) = &self.display_name {
if let Some(display_name) = &self.0.display_name {
write!(
w,
" --{} {:?}",
@@ -598,11 +674,11 @@ impl<'a> UserCreationRequest<'a> {
)?;
}
if self.hashed_password.is_some() {
if self.0.hashed_password.is_some() {
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
}
for (provider, subject) in &self.upstream_provider_mappings {
for (provider, subject) in &self.0.upstream_provider_mappings {
let mapping = format!("{}:{}", provider.id, subject);
write!(
w,
@@ -611,7 +687,7 @@ impl<'a> UserCreationRequest<'a> {
)?;
}
if self.admin {
if self.0.admin {
write!(w, " --{}", admin_arg.get_long().unwrap())?;
}
@@ -619,6 +695,37 @@ impl<'a> UserCreationRequest<'a> {
}
}
impl std::fmt::Display for UserCreationRequest<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let username = &self.username;
let mxid = self.homeserver.mxid(username);
writeln!(f, "Username: {username}")?;
writeln!(f, "Matrix ID: {mxid}")?;
if let Some(display_name) = &self.display_name {
writeln!(f, "Display name: {display_name}")?;
}
if self.hashed_password.is_some() {
writeln!(f, "Password: <SET>")?;
}
for (provider, subject) in &self.upstream_provider_mappings {
let provider = provider
.human_name
.clone()
.unwrap_or_else(|| provider.id.to_string());
writeln!(f, "Upstream account: {provider} => {subject}")?;
}
for email in &self.emails {
writeln!(f, "Email: {email}")?;
}
writeln!(f, "Can request admin: {}", self.admin)?;
Ok(())
}
}
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
repo: &mut dyn RepositoryAccess<Error = E>,
rng: &mut (dyn RngCore + Send),
@@ -630,6 +737,7 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
upstream_provider_mappings,
display_name,
admin,
..
}: UserCreationRequest<'a>,
) -> Result<User, E> {
let mut user = repo.user().add(rng, clock, username).await?;