You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-06 06:02:40 +03:00
Prompt for username and confirm user creation
This commit is contained in:
@@ -12,21 +12,25 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, io::Write};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::{ArgAction, CommandFactory, Parser};
|
||||
use dialoguer::{Confirm, Input};
|
||||
use figment::Figment;
|
||||
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig};
|
||||
use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig};
|
||||
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
|
||||
use mas_email::Address;
|
||||
use mas_handlers::HttpClientFactory;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_matrix_synapse::SynapseConnection;
|
||||
use mas_storage::{
|
||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
Clock, RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use mas_storage_pg::{DatabaseError, PgRepository};
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use sqlx::{types::Uuid, Acquire};
|
||||
use tracing::{info, info_span, warn};
|
||||
@@ -477,10 +481,18 @@ impl Options {
|
||||
admin,
|
||||
display_name,
|
||||
} => {
|
||||
let http_client_factory = HttpClientFactory::new();
|
||||
let password_config = PasswordsConfig::extract(figment)?;
|
||||
let database_config = DatabaseConfig::extract(figment)?;
|
||||
let matrix_config = MatrixConfig::extract(figment)?;
|
||||
|
||||
let password_manager = password_manager_from_config(&password_config).await?;
|
||||
let homeserver = SynapseConnection::new(
|
||||
matrix_config.homeserver,
|
||||
matrix_config.endpoint,
|
||||
matrix_config.secret,
|
||||
http_client_factory,
|
||||
);
|
||||
let mut conn = database_connection_from_config(&database_config).await?;
|
||||
let txn = conn.begin().await?;
|
||||
let mut repo = PgRepository::from_conn(txn);
|
||||
@@ -501,6 +513,7 @@ impl Options {
|
||||
upstream_providers.insert(provider.id, provider);
|
||||
}
|
||||
|
||||
// TODO: prompt for link?
|
||||
let upstream_provider_mappings = upstream_provider_mappings
|
||||
.into_iter()
|
||||
.map(|mapping| {
|
||||
@@ -516,30 +529,61 @@ impl Options {
|
||||
let password = password.into_bytes().into();
|
||||
Some(password_manager.hash(&mut rng, password).await?)
|
||||
} else {
|
||||
// TODO: prompt for password if not provided and needed
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: prompt
|
||||
let username = username.context("Username is required")?;
|
||||
let localpart = if let Some(username) = username {
|
||||
check_and_normalize_username(&username, &mut repo, &homeserver)
|
||||
.await?
|
||||
.to_owned()
|
||||
} else {
|
||||
loop {
|
||||
let username = tokio::task::spawn_blocking(|| {
|
||||
Input::<String>::new()
|
||||
.with_prompt("Username")
|
||||
.interact_text()
|
||||
})
|
||||
.await??;
|
||||
|
||||
if repo.user().exists(&username).await? {
|
||||
anyhow::bail!("User already exists");
|
||||
}
|
||||
match check_and_normalize_username(&username, &mut repo, &homeserver).await
|
||||
{
|
||||
Ok(localpart) => break localpart.to_owned(),
|
||||
Err(e) => {
|
||||
warn!("Invalid username: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let req = UserCreationRequest {
|
||||
username,
|
||||
homeserver: &homeserver,
|
||||
username: localpart,
|
||||
hashed_password,
|
||||
// TODO: prompt for emails
|
||||
emails,
|
||||
upstream_provider_mappings,
|
||||
// TODO: prompt for display name
|
||||
display_name,
|
||||
// TODO: prompt for admin
|
||||
admin,
|
||||
};
|
||||
|
||||
req.command(&mut std::io::stdout())?;
|
||||
info!("Do you want to register this user?\n{req}");
|
||||
|
||||
//do_register(&mut repo, &mut rng, &clock, req).await?;
|
||||
let confirmation = tokio::task::spawn_blocking(|| {
|
||||
Confirm::new().with_prompt("Confirm?").interact()
|
||||
})
|
||||
.await??;
|
||||
|
||||
repo.into_inner().commit().await?;
|
||||
if confirmation {
|
||||
let user = do_register(&mut repo, &mut rng, &clock, req).await?;
|
||||
repo.into_inner().commit().await?;
|
||||
info!(%user.id, "User registered");
|
||||
} else {
|
||||
let cmd = UserCreationCommand(&req);
|
||||
info!("Aborted. {cmd}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -547,7 +591,37 @@ impl Options {
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_and_normalize_username<'a>(
|
||||
localpart_or_mxid: &'a str,
|
||||
repo: &mut dyn RepositoryAccess<Error = DatabaseError>,
|
||||
homeserver: &SynapseConnection,
|
||||
) -> anyhow::Result<&'a str> {
|
||||
// XXX: this is a very basic MXID to localpart conversion
|
||||
// Strip any leading '@'
|
||||
let mut localpart = localpart_or_mxid.trim_start_matches('@');
|
||||
|
||||
// Strip any trailing ':homeserver'
|
||||
if let Some(index) = localpart.find(':') {
|
||||
localpart = &localpart[..index];
|
||||
}
|
||||
|
||||
if localpart.is_empty() {
|
||||
return Err(anyhow::anyhow!("Username cannot be empty"));
|
||||
}
|
||||
|
||||
if repo.user().exists(localpart).await? {
|
||||
return Err(anyhow::anyhow!("User already exists"));
|
||||
}
|
||||
|
||||
if !homeserver.is_localpart_available(localpart).await? {
|
||||
return Err(anyhow::anyhow!("Username not available on homeserver"));
|
||||
}
|
||||
|
||||
Ok(localpart)
|
||||
}
|
||||
|
||||
struct UserCreationRequest<'a> {
|
||||
homeserver: &'a SynapseConnection,
|
||||
username: String,
|
||||
hashed_password: Option<(u16, String)>,
|
||||
emails: Vec<Address>,
|
||||
@@ -556,8 +630,10 @@ struct UserCreationRequest<'a> {
|
||||
admin: bool,
|
||||
}
|
||||
|
||||
impl<'a> UserCreationRequest<'a> {
|
||||
fn command<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
|
||||
struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>);
|
||||
|
||||
impl std::fmt::Display for UserCreationCommand<'_> {
|
||||
fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let command = super::Options::command();
|
||||
let manage = command.find_subcommand("manage").unwrap();
|
||||
let register_user = manage.find_subcommand("register-user").unwrap();
|
||||
@@ -581,15 +657,15 @@ impl<'a> UserCreationRequest<'a> {
|
||||
w,
|
||||
" --{} {:?}",
|
||||
username_arg.get_long().unwrap(),
|
||||
self.username
|
||||
self.0.username
|
||||
)?;
|
||||
|
||||
for email in &self.emails {
|
||||
for email in &self.0.emails {
|
||||
let email: &str = email.as_ref();
|
||||
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
|
||||
}
|
||||
|
||||
if let Some(display_name) = &self.display_name {
|
||||
if let Some(display_name) = &self.0.display_name {
|
||||
write!(
|
||||
w,
|
||||
" --{} {:?}",
|
||||
@@ -598,11 +674,11 @@ impl<'a> UserCreationRequest<'a> {
|
||||
)?;
|
||||
}
|
||||
|
||||
if self.hashed_password.is_some() {
|
||||
if self.0.hashed_password.is_some() {
|
||||
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
|
||||
}
|
||||
|
||||
for (provider, subject) in &self.upstream_provider_mappings {
|
||||
for (provider, subject) in &self.0.upstream_provider_mappings {
|
||||
let mapping = format!("{}:{}", provider.id, subject);
|
||||
write!(
|
||||
w,
|
||||
@@ -611,7 +687,7 @@ impl<'a> UserCreationRequest<'a> {
|
||||
)?;
|
||||
}
|
||||
|
||||
if self.admin {
|
||||
if self.0.admin {
|
||||
write!(w, " --{}", admin_arg.get_long().unwrap())?;
|
||||
}
|
||||
|
||||
@@ -619,6 +695,37 @@ impl<'a> UserCreationRequest<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UserCreationRequest<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let username = &self.username;
|
||||
let mxid = self.homeserver.mxid(username);
|
||||
writeln!(f, "Username: {username}")?;
|
||||
writeln!(f, "Matrix ID: {mxid}")?;
|
||||
if let Some(display_name) = &self.display_name {
|
||||
writeln!(f, "Display name: {display_name}")?;
|
||||
}
|
||||
if self.hashed_password.is_some() {
|
||||
writeln!(f, "Password: <SET>")?;
|
||||
}
|
||||
|
||||
for (provider, subject) in &self.upstream_provider_mappings {
|
||||
let provider = provider
|
||||
.human_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| provider.id.to_string());
|
||||
writeln!(f, "Upstream account: {provider} => {subject}")?;
|
||||
}
|
||||
|
||||
for email in &self.emails {
|
||||
writeln!(f, "Email: {email}")?;
|
||||
}
|
||||
|
||||
writeln!(f, "Can request admin: {}", self.admin)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
|
||||
repo: &mut dyn RepositoryAccess<Error = E>,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
@@ -630,6 +737,7 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
|
||||
upstream_provider_mappings,
|
||||
display_name,
|
||||
admin,
|
||||
..
|
||||
}: UserCreationRequest<'a>,
|
||||
) -> Result<User, E> {
|
||||
let mut user = repo.user().add(rng, clock, username).await?;
|
||||
|
Reference in New Issue
Block a user