You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-06 06:02:40 +03:00
Prompt for username and confirm user creation
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -1029,6 +1029,7 @@ dependencies = [
|
|||||||
"encode_unicode",
|
"encode_unicode",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"libc",
|
"libc",
|
||||||
|
"unicode-width",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1492,6 +1493,19 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dialoguer"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de"
|
||||||
|
dependencies = [
|
||||||
|
"console",
|
||||||
|
"shell-words",
|
||||||
|
"tempfile",
|
||||||
|
"thiserror",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "digest"
|
name = "digest"
|
||||||
version = "0.10.7"
|
version = "0.10.7"
|
||||||
@@ -3022,6 +3036,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"camino",
|
"camino",
|
||||||
"clap",
|
"clap",
|
||||||
|
"dialoguer",
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
"figment",
|
"figment",
|
||||||
"httpdate",
|
"httpdate",
|
||||||
@@ -5455,6 +5470,12 @@ dependencies = [
|
|||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shell-words"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.1"
|
version = "1.4.1"
|
||||||
|
@@ -16,6 +16,7 @@ anyhow.workspace = true
|
|||||||
axum = "0.6.20"
|
axum = "0.6.20"
|
||||||
camino.workspace = true
|
camino.workspace = true
|
||||||
clap.workspace = true
|
clap.workspace = true
|
||||||
|
dialoguer = "0.11.0"
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
figment.workspace = true
|
figment.workspace = true
|
||||||
httpdate = "1.0.3"
|
httpdate = "1.0.3"
|
||||||
|
@@ -12,21 +12,25 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{collections::HashMap, io::Write};
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use clap::{ArgAction, CommandFactory, Parser};
|
use clap::{ArgAction, CommandFactory, Parser};
|
||||||
|
use dialoguer::{Confirm, Input};
|
||||||
use figment::Figment;
|
use figment::Figment;
|
||||||
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig};
|
use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig};
|
||||||
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
|
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
|
||||||
use mas_email::Address;
|
use mas_email::Address;
|
||||||
|
use mas_handlers::HttpClientFactory;
|
||||||
|
use mas_matrix::HomeserverConnection;
|
||||||
|
use mas_matrix_synapse::SynapseConnection;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||||
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
||||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||||
Clock, RepositoryAccess, SystemClock,
|
Clock, RepositoryAccess, SystemClock,
|
||||||
};
|
};
|
||||||
use mas_storage_pg::PgRepository;
|
use mas_storage_pg::{DatabaseError, PgRepository};
|
||||||
use rand::{RngCore, SeedableRng};
|
use rand::{RngCore, SeedableRng};
|
||||||
use sqlx::{types::Uuid, Acquire};
|
use sqlx::{types::Uuid, Acquire};
|
||||||
use tracing::{info, info_span, warn};
|
use tracing::{info, info_span, warn};
|
||||||
@@ -477,10 +481,18 @@ impl Options {
|
|||||||
admin,
|
admin,
|
||||||
display_name,
|
display_name,
|
||||||
} => {
|
} => {
|
||||||
|
let http_client_factory = HttpClientFactory::new();
|
||||||
let password_config = PasswordsConfig::extract(figment)?;
|
let password_config = PasswordsConfig::extract(figment)?;
|
||||||
let database_config = DatabaseConfig::extract(figment)?;
|
let database_config = DatabaseConfig::extract(figment)?;
|
||||||
|
let matrix_config = MatrixConfig::extract(figment)?;
|
||||||
|
|
||||||
let password_manager = password_manager_from_config(&password_config).await?;
|
let password_manager = password_manager_from_config(&password_config).await?;
|
||||||
|
let homeserver = SynapseConnection::new(
|
||||||
|
matrix_config.homeserver,
|
||||||
|
matrix_config.endpoint,
|
||||||
|
matrix_config.secret,
|
||||||
|
http_client_factory,
|
||||||
|
);
|
||||||
let mut conn = database_connection_from_config(&database_config).await?;
|
let mut conn = database_connection_from_config(&database_config).await?;
|
||||||
let txn = conn.begin().await?;
|
let txn = conn.begin().await?;
|
||||||
let mut repo = PgRepository::from_conn(txn);
|
let mut repo = PgRepository::from_conn(txn);
|
||||||
@@ -501,6 +513,7 @@ impl Options {
|
|||||||
upstream_providers.insert(provider.id, provider);
|
upstream_providers.insert(provider.id, provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: prompt for link?
|
||||||
let upstream_provider_mappings = upstream_provider_mappings
|
let upstream_provider_mappings = upstream_provider_mappings
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|mapping| {
|
.map(|mapping| {
|
||||||
@@ -516,30 +529,61 @@ impl Options {
|
|||||||
let password = password.into_bytes().into();
|
let password = password.into_bytes().into();
|
||||||
Some(password_manager.hash(&mut rng, password).await?)
|
Some(password_manager.hash(&mut rng, password).await?)
|
||||||
} else {
|
} else {
|
||||||
|
// TODO: prompt for password if not provided and needed
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: prompt
|
let localpart = if let Some(username) = username {
|
||||||
let username = username.context("Username is required")?;
|
check_and_normalize_username(&username, &mut repo, &homeserver)
|
||||||
|
.await?
|
||||||
|
.to_owned()
|
||||||
|
} else {
|
||||||
|
loop {
|
||||||
|
let username = tokio::task::spawn_blocking(|| {
|
||||||
|
Input::<String>::new()
|
||||||
|
.with_prompt("Username")
|
||||||
|
.interact_text()
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
if repo.user().exists(&username).await? {
|
match check_and_normalize_username(&username, &mut repo, &homeserver).await
|
||||||
anyhow::bail!("User already exists");
|
{
|
||||||
}
|
Ok(localpart) => break localpart.to_owned(),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Invalid username: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let req = UserCreationRequest {
|
let req = UserCreationRequest {
|
||||||
username,
|
homeserver: &homeserver,
|
||||||
|
username: localpart,
|
||||||
hashed_password,
|
hashed_password,
|
||||||
|
// TODO: prompt for emails
|
||||||
emails,
|
emails,
|
||||||
upstream_provider_mappings,
|
upstream_provider_mappings,
|
||||||
|
// TODO: prompt for display name
|
||||||
display_name,
|
display_name,
|
||||||
|
// TODO: prompt for admin
|
||||||
admin,
|
admin,
|
||||||
};
|
};
|
||||||
|
|
||||||
req.command(&mut std::io::stdout())?;
|
info!("Do you want to register this user?\n{req}");
|
||||||
|
|
||||||
//do_register(&mut repo, &mut rng, &clock, req).await?;
|
let confirmation = tokio::task::spawn_blocking(|| {
|
||||||
|
Confirm::new().with_prompt("Confirm?").interact()
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
repo.into_inner().commit().await?;
|
if confirmation {
|
||||||
|
let user = do_register(&mut repo, &mut rng, &clock, req).await?;
|
||||||
|
repo.into_inner().commit().await?;
|
||||||
|
info!(%user.id, "User registered");
|
||||||
|
} else {
|
||||||
|
let cmd = UserCreationCommand(&req);
|
||||||
|
info!("Aborted. {cmd}");
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -547,7 +591,37 @@ impl Options {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn check_and_normalize_username<'a>(
|
||||||
|
localpart_or_mxid: &'a str,
|
||||||
|
repo: &mut dyn RepositoryAccess<Error = DatabaseError>,
|
||||||
|
homeserver: &SynapseConnection,
|
||||||
|
) -> anyhow::Result<&'a str> {
|
||||||
|
// XXX: this is a very basic MXID to localpart conversion
|
||||||
|
// Strip any leading '@'
|
||||||
|
let mut localpart = localpart_or_mxid.trim_start_matches('@');
|
||||||
|
|
||||||
|
// Strip any trailing ':homeserver'
|
||||||
|
if let Some(index) = localpart.find(':') {
|
||||||
|
localpart = &localpart[..index];
|
||||||
|
}
|
||||||
|
|
||||||
|
if localpart.is_empty() {
|
||||||
|
return Err(anyhow::anyhow!("Username cannot be empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if repo.user().exists(localpart).await? {
|
||||||
|
return Err(anyhow::anyhow!("User already exists"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !homeserver.is_localpart_available(localpart).await? {
|
||||||
|
return Err(anyhow::anyhow!("Username not available on homeserver"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(localpart)
|
||||||
|
}
|
||||||
|
|
||||||
struct UserCreationRequest<'a> {
|
struct UserCreationRequest<'a> {
|
||||||
|
homeserver: &'a SynapseConnection,
|
||||||
username: String,
|
username: String,
|
||||||
hashed_password: Option<(u16, String)>,
|
hashed_password: Option<(u16, String)>,
|
||||||
emails: Vec<Address>,
|
emails: Vec<Address>,
|
||||||
@@ -556,8 +630,10 @@ struct UserCreationRequest<'a> {
|
|||||||
admin: bool,
|
admin: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> UserCreationRequest<'a> {
|
struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>);
|
||||||
fn command<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
|
|
||||||
|
impl std::fmt::Display for UserCreationCommand<'_> {
|
||||||
|
fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let command = super::Options::command();
|
let command = super::Options::command();
|
||||||
let manage = command.find_subcommand("manage").unwrap();
|
let manage = command.find_subcommand("manage").unwrap();
|
||||||
let register_user = manage.find_subcommand("register-user").unwrap();
|
let register_user = manage.find_subcommand("register-user").unwrap();
|
||||||
@@ -581,15 +657,15 @@ impl<'a> UserCreationRequest<'a> {
|
|||||||
w,
|
w,
|
||||||
" --{} {:?}",
|
" --{} {:?}",
|
||||||
username_arg.get_long().unwrap(),
|
username_arg.get_long().unwrap(),
|
||||||
self.username
|
self.0.username
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for email in &self.emails {
|
for email in &self.0.emails {
|
||||||
let email: &str = email.as_ref();
|
let email: &str = email.as_ref();
|
||||||
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
|
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(display_name) = &self.display_name {
|
if let Some(display_name) = &self.0.display_name {
|
||||||
write!(
|
write!(
|
||||||
w,
|
w,
|
||||||
" --{} {:?}",
|
" --{} {:?}",
|
||||||
@@ -598,11 +674,11 @@ impl<'a> UserCreationRequest<'a> {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.hashed_password.is_some() {
|
if self.0.hashed_password.is_some() {
|
||||||
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
|
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (provider, subject) in &self.upstream_provider_mappings {
|
for (provider, subject) in &self.0.upstream_provider_mappings {
|
||||||
let mapping = format!("{}:{}", provider.id, subject);
|
let mapping = format!("{}:{}", provider.id, subject);
|
||||||
write!(
|
write!(
|
||||||
w,
|
w,
|
||||||
@@ -611,7 +687,7 @@ impl<'a> UserCreationRequest<'a> {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.admin {
|
if self.0.admin {
|
||||||
write!(w, " --{}", admin_arg.get_long().unwrap())?;
|
write!(w, " --{}", admin_arg.get_long().unwrap())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -619,6 +695,37 @@ impl<'a> UserCreationRequest<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for UserCreationRequest<'_> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let username = &self.username;
|
||||||
|
let mxid = self.homeserver.mxid(username);
|
||||||
|
writeln!(f, "Username: {username}")?;
|
||||||
|
writeln!(f, "Matrix ID: {mxid}")?;
|
||||||
|
if let Some(display_name) = &self.display_name {
|
||||||
|
writeln!(f, "Display name: {display_name}")?;
|
||||||
|
}
|
||||||
|
if self.hashed_password.is_some() {
|
||||||
|
writeln!(f, "Password: <SET>")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (provider, subject) in &self.upstream_provider_mappings {
|
||||||
|
let provider = provider
|
||||||
|
.human_name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| provider.id.to_string());
|
||||||
|
writeln!(f, "Upstream account: {provider} => {subject}")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for email in &self.emails {
|
||||||
|
writeln!(f, "Email: {email}")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
writeln!(f, "Can request admin: {}", self.admin)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
|
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
|
||||||
repo: &mut dyn RepositoryAccess<Error = E>,
|
repo: &mut dyn RepositoryAccess<Error = E>,
|
||||||
rng: &mut (dyn RngCore + Send),
|
rng: &mut (dyn RngCore + Send),
|
||||||
@@ -630,6 +737,7 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
|
|||||||
upstream_provider_mappings,
|
upstream_provider_mappings,
|
||||||
display_name,
|
display_name,
|
||||||
admin,
|
admin,
|
||||||
|
..
|
||||||
}: UserCreationRequest<'a>,
|
}: UserCreationRequest<'a>,
|
||||||
) -> Result<User, E> {
|
) -> Result<User, E> {
|
||||||
let mut user = repo.user().add(rng, clock, username).await?;
|
let mut user = repo.user().add(rng, clock, username).await?;
|
||||||
|
Reference in New Issue
Block a user