1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Prompt for all parameters interactively

This commit is contained in:
Quentin Gliech
2024-04-17 16:45:35 +02:00
parent 8c402a1f50
commit 4d1b6aeded
3 changed files with 434 additions and 167 deletions

11
Cargo.lock generated
View File

@@ -1500,6 +1500,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de"
dependencies = [ dependencies = [
"console", "console",
"fuzzy-matcher",
"shell-words", "shell-words",
"tempfile", "tempfile",
"thiserror", "thiserror",
@@ -1967,6 +1968,15 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "fuzzy-matcher"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54614a3312934d066701a80f20f15fa3b56d67ac7722b39eea5b4c9dd1d66c94"
dependencies = [
"thread_local",
]
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@@ -3036,6 +3046,7 @@ dependencies = [
"axum", "axum",
"camino", "camino",
"clap", "clap",
"console",
"dialoguer", "dialoguer",
"dotenvy", "dotenvy",
"figment", "figment",

View File

@@ -16,7 +16,8 @@ anyhow.workspace = true
axum = "0.6.20" axum = "0.6.20"
camino.workspace = true camino.workspace = true
clap.workspace = true clap.workspace = true
dialoguer = "0.11.0" console = "0.15.8"
dialoguer = { version = "0.11.0", features = ["fuzzy-select"] }
dotenvy = "0.15.7" dotenvy = "0.15.7"
figment.workspace = true figment.workspace = true
httpdate = "1.0.3" httpdate = "1.0.3"

View File

@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashMap; use std::collections::BTreeMap;
use anyhow::Context; use anyhow::Context;
use clap::{ArgAction, CommandFactory, Parser}; use clap::{ArgAction, CommandFactory, Parser};
use dialoguer::{Confirm, Input}; use console::{pad_str, style, Alignment, Style, Term};
use dialoguer::{theme::ColorfulTheme, Confirm, FuzzySelect, Input, Password};
use figment::Figment; use figment::Figment;
use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig}; use mas_config::{ConfigurationSection, DatabaseConfig, MatrixConfig, PasswordsConfig};
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
@@ -114,9 +115,13 @@ enum Subcommand {
}, },
/// Register a user /// Register a user
///
/// This will interactively prompt for the user's attributes unless the
/// `--yes` flag is set. It bypasses any policy check on the password,
/// email, etc.
RegisterUser { RegisterUser {
/// Username to register /// Username to register
#[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)] #[arg(help_heading = USER_ATTRIBUTES_HEADING, required_if_eq("yes", "true"))]
username: Option<String>, username: Option<String>,
/// Password to set /// Password to set
@@ -129,7 +134,7 @@ enum Subcommand {
/// Upstream OAuth 2.0 provider mapping to add /// Upstream OAuth 2.0 provider mapping to add
#[arg( #[arg(
short = 'M', short = 'm',
long = "upstream-provider-mapping", long = "upstream-provider-mapping",
value_parser = parse_upstream_provider_mapping, value_parser = parse_upstream_provider_mapping,
action = ArgAction::Append, action = ArgAction::Append,
@@ -139,11 +144,19 @@ enum Subcommand {
upstream_provider_mappings: Vec<UpstreamProviderMapping>, upstream_provider_mappings: Vec<UpstreamProviderMapping>,
/// Make the user an admin /// Make the user an admin
#[arg(long, action = ArgAction::SetTrue, help_heading = USER_ATTRIBUTES_HEADING)] #[arg(short, long, action = ArgAction::SetTrue, group = "admin-flag", help_heading = USER_ATTRIBUTES_HEADING)]
admin: bool, admin: bool,
/// Make the user not an admin
#[arg(short = 'A', long, action = ArgAction::SetTrue, group = "admin-flag", help_heading = USER_ATTRIBUTES_HEADING)]
no_admin: bool,
// Don't ask questions, just do it
#[arg(short, long, action = ArgAction::SetTrue)]
yes: bool,
/// Set the user's display name /// Set the user's display name
#[arg(short = 'D', long, help_heading = USER_ATTRIBUTES_HEADING)] #[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)]
display_name: Option<String>, display_name: Option<String>,
}, },
} }
@@ -479,7 +492,9 @@ impl Options {
emails, emails,
upstream_provider_mappings, upstream_provider_mappings,
admin, admin,
no_admin,
display_name, display_name,
yes,
} => { } => {
let http_client_factory = HttpClientFactory::new(); let http_client_factory = HttpClientFactory::new();
let password_config = PasswordsConfig::extract(figment)?; let password_config = PasswordsConfig::extract(figment)?;
@@ -497,50 +512,16 @@ impl Options {
let txn = conn.begin().await?; let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn); let mut repo = PgRepository::from_conn(txn);
// Load all the providers we need // If the username is provided, check if it's available and normalize it.
let mut upstream_providers = HashMap::new();
for mapping in &upstream_provider_mappings {
if upstream_providers.contains_key(&mapping.upstream_provider_id) {
continue;
}
let provider = repo
.upstream_oauth_provider()
.lookup(mapping.upstream_provider_id)
.await?
.context("Upstream provider not found")?;
upstream_providers.insert(provider.id, provider);
}
// TODO: prompt for link?
let upstream_provider_mappings = upstream_provider_mappings
.into_iter()
.map(|mapping| {
(
&upstream_providers[&mapping.upstream_provider_id],
mapping.subject,
)
})
.collect();
// Hash the password if it's provided
let hashed_password = if let Some(password) = password {
let password = password.into_bytes().into();
Some(password_manager.hash(&mut rng, password).await?)
} else {
// TODO: prompt for password if not provided and needed
None
};
let localpart = if let Some(username) = username { let localpart = if let Some(username) = username {
check_and_normalize_username(&username, &mut repo, &homeserver) check_and_normalize_username(&username, &mut repo, &homeserver)
.await? .await?
.to_owned() .to_owned()
} else { } else {
// Else we prompt for one until we get a valid one.
loop { loop {
let username = tokio::task::spawn_blocking(|| { let username = tokio::task::spawn_blocking(|| {
Input::<String>::new() Input::<String>::with_theme(&ColorfulTheme::default())
.with_prompt("Username") .with_prompt("Username")
.interact_text() .interact_text()
}) })
@@ -556,33 +537,189 @@ impl Options {
} }
}; };
let req = UserCreationRequest { // Load all the upstream providers
homeserver: &homeserver, let upstream_providers: BTreeMap<_, _> = repo
.upstream_oauth_provider()
.all_enabled()
.await?
.into_iter()
.map(|provider| (provider.id, provider))
.collect();
let upstream_provider_mappings = upstream_provider_mappings
.into_iter()
.map(|mapping| {
(
&upstream_providers[&mapping.upstream_provider_id],
mapping.subject,
)
})
.collect();
let admin = match (admin, no_admin) {
(false, false) => None,
(true, false) => Some(true),
(false, true) => Some(false),
_ => unreachable!("This should be handled by the clap group"),
};
// Hash the password if it's provided
let hashed_password = if let Some(password) = password {
let password = password.into_bytes().into();
Some(password_manager.hash(&mut rng, password).await?)
} else {
None
};
let mut req = UserCreationRequest {
username: localpart, username: localpart,
hashed_password, hashed_password,
// TODO: prompt for emails
emails, emails,
upstream_provider_mappings, upstream_provider_mappings,
// TODO: prompt for display name
display_name, display_name,
// TODO: prompt for admin
admin, admin,
}; };
info!("Do you want to register this user?\n{req}"); let term = Term::buffered_stdout();
loop {
req.show(&term, &homeserver)?;
let confirmation = tokio::task::spawn_blocking(|| { // If we're in `yes` mode, we don't prompt for actions
Confirm::new().with_prompt("Confirm?").interact() if yes {
break;
}
term.write_line(&format!(
"\n{msg}:\n\n {cmd}\n",
msg = style("Non-interactive equivalent to create this user").bold(),
cmd = style(UserCreationCommand(&req)).underlined(),
))?;
term.flush()?;
let action = req
.prompt_action(
password_manager.is_enabled(),
!upstream_providers.is_empty(),
)
.await?
.context("Aborted")?;
match action {
Action::CreateUser => break,
Action::ChangeUsername => {
req.username = loop {
let current_username = req.username.clone();
let username = tokio::task::spawn_blocking(|| {
Input::<String>::with_theme(&ColorfulTheme::default())
.with_prompt("Username")
.with_initial_text(current_username)
.interact_text()
}) })
.await??; .await??;
match check_and_normalize_username(
&username,
&mut repo,
&homeserver,
)
.await
{
Ok(localpart) => break localpart.to_owned(),
Err(e) => {
warn!("Invalid username: {e}");
}
}
};
}
Action::SetPassword => {
let password = tokio::task::spawn_blocking(|| {
Password::with_theme(&ColorfulTheme::default())
.with_prompt("Password")
.with_confirmation("Confirm password", "Passwords mismatching")
.interact()
})
.await??;
let password = password.into_bytes().into();
req.hashed_password =
Some(password_manager.hash(&mut rng, password).await?);
}
Action::SetDisplayName => {
let display_name = tokio::task::spawn_blocking(|| {
Input::<String>::with_theme(&ColorfulTheme::default())
.with_prompt("Display name")
.interact()
})
.await??;
req.display_name = Some(display_name);
}
Action::AddEmail => {
let email = tokio::task::spawn_blocking(|| {
Input::<Address>::with_theme(&ColorfulTheme::default())
.with_prompt("Email")
.interact_text()
})
.await??;
req.emails.push(email);
}
Action::SetAdmin => {
let admin = tokio::task::spawn_blocking(|| {
Confirm::with_theme(&ColorfulTheme::default())
.with_prompt("Make user admin?")
.interact()
})
.await??;
req.admin = Some(admin);
}
Action::AddUpstreamProviderMapping => {
let providers = upstream_providers.clone();
let provider_id = tokio::task::spawn_blocking(move || {
let providers: Vec<_> = providers.into_values().collect();
let human_readable_providers: Vec<_> =
providers.iter().map(HumanReadable).collect();
FuzzySelect::with_theme(&ColorfulTheme::default())
.with_prompt("Upstream provider")
.items(&human_readable_providers)
.default(0)
.interact()
.map(move |selected| providers[selected].id)
})
.await??;
let provider = &upstream_providers[&provider_id];
let subject = tokio::task::spawn_blocking(|| {
Input::<String>::with_theme(&ColorfulTheme::default())
.with_prompt("Subject")
.interact()
})
.await??;
req.upstream_provider_mappings.push((&provider, subject));
}
}
}
if req.emails.is_empty() {
warn!("No email address provided, user will need to add one");
}
let confirmation = if yes {
true
} else {
tokio::task::spawn_blocking(|| {
Confirm::with_theme(&ColorfulTheme::default())
.with_prompt("Confirm?")
.interact()
})
.await??
};
if confirmation { if confirmation {
let user = do_register(&mut repo, &mut rng, &clock, req).await?; let user = req.do_register(&mut repo, &mut rng, &clock).await?;
repo.into_inner().commit().await?; repo.into_inner().commit().await?;
info!(%user.id, "User registered"); info!(%user.id, "User registered");
} else { } else {
let cmd = UserCreationCommand(&req); warn!("Aborted");
info!("Aborted. {cmd}");
} }
Ok(()) Ok(())
@@ -591,6 +728,21 @@ impl Options {
} }
} }
/// A wrapper to display some objects differently
#[derive(Debug, Clone, Copy)]
struct HumanReadable<T>(T);
impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let provider = self.0;
if let Some(human_name) = &provider.human_name {
write!(f, "{} ({})", human_name, provider.id)
} else {
write!(f, "{} ({})", provider.issuer, provider.id)
}
}
}
async fn check_and_normalize_username<'a>( async fn check_and_normalize_username<'a>(
localpart_or_mxid: &'a str, localpart_or_mxid: &'a str,
repo: &mut dyn RepositoryAccess<Error = DatabaseError>, repo: &mut dyn RepositoryAccess<Error = DatabaseError>,
@@ -621,125 +773,138 @@ async fn check_and_normalize_username<'a>(
} }
struct UserCreationRequest<'a> { struct UserCreationRequest<'a> {
homeserver: &'a SynapseConnection,
username: String, username: String,
hashed_password: Option<(u16, String)>, hashed_password: Option<(u16, String)>,
emails: Vec<Address>, emails: Vec<Address>,
upstream_provider_mappings: Vec<(&'a UpstreamOAuthProvider, String)>, upstream_provider_mappings: Vec<(&'a UpstreamOAuthProvider, String)>,
display_name: Option<String>, display_name: Option<String>,
admin: bool, admin: Option<bool>,
} }
struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>); impl UserCreationRequest<'_> {
// Get a list of the possible actions
fn possible_actions(
&self,
has_password_auth: bool,
has_upstream_providers: bool,
) -> Vec<Action> {
let mut actions = vec![Action::CreateUser, Action::ChangeUsername, Action::AddEmail];
impl std::fmt::Display for UserCreationCommand<'_> { if has_password_auth && self.hashed_password.is_none() {
fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { actions.push(Action::SetPassword);
let command = super::Options::command();
let manage = command.find_subcommand("manage").unwrap();
let register_user = manage.find_subcommand("register-user").unwrap();
let username_arg = &register_user[&clap::Id::from("username")];
let password_arg = &register_user[&clap::Id::from("password")];
let email_arg = &register_user[&clap::Id::from("emails")];
let upstream_provider_mapping_arg =
&register_user[&clap::Id::from("upstream_provider_mappings")];
let display_name_arg = &register_user[&clap::Id::from("display_name")];
let admin_arg = &register_user[&clap::Id::from("admin")];
write!(
w,
"{} {} {}",
command.get_name(),
manage.get_name(),
register_user.get_name()
)?;
write!(
w,
" --{} {:?}",
username_arg.get_long().unwrap(),
self.0.username
)?;
for email in &self.0.emails {
let email: &str = email.as_ref();
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
} }
if let Some(display_name) = &self.0.display_name { if has_upstream_providers {
write!( actions.push(Action::AddUpstreamProviderMapping);
w,
" --{} {:?}",
display_name_arg.get_long().unwrap(),
display_name
)?;
} }
if self.0.hashed_password.is_some() { if self.admin.is_none() {
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?; actions.push(Action::SetAdmin);
} }
for (provider, subject) in &self.0.upstream_provider_mappings { if self.display_name.is_none() {
let mapping = format!("{}:{}", provider.id, subject); actions.push(Action::SetDisplayName);
write!(
w,
" --{} {mapping:?}",
upstream_provider_mapping_arg.get_long().unwrap(),
)?;
} }
if self.0.admin { actions
write!(w, " --{}", admin_arg.get_long().unwrap())?;
} }
Ok(()) /// Prompt for the next action
} async fn prompt_action(
&self,
has_password_auth: bool,
has_upstream_providers: bool,
) -> anyhow::Result<Option<Action>> {
let actions = self.possible_actions(has_password_auth, has_upstream_providers);
tokio::task::spawn_blocking(move || {
let index = FuzzySelect::with_theme(&ColorfulTheme::default())
.with_prompt("What do you want to do next? (<Esc> to abort)")
.items(&actions)
.default(0)
.interact_opt()?;
Ok(index.map(|index| actions[index]))
})
.await?
} }
impl std::fmt::Display for UserCreationRequest<'_> { /// Show the user creation request in a human-readable format
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn show(&self, term: &Term, homeserver: &SynapseConnection) -> std::io::Result<()> {
let value_style = Style::new().green();
let key_style = Style::new().bold();
let warning_style = Style::new().italic().red().bright();
let username = &self.username; let username = &self.username;
let mxid = self.homeserver.mxid(username); let mxid = homeserver.mxid(username);
writeln!(f, "Username: {username}")?;
writeln!(f, "Matrix ID: {mxid}")?; term.write_line(&style("User attributes").bold().underlined().to_string())?;
if let Some(display_name) = &self.display_name {
writeln!(f, "Display name: {display_name}")?; macro_rules! display {
($key:expr, $value:expr) => {
term.write_line(&format!(
"{key}: {value}",
key = key_style.apply_to(pad_str($key, 17, Alignment::Right, None)),
value = value_style.apply_to($value)
))?;
};
} }
display!("Username", username);
display!("Matrix ID", mxid);
if let Some(display_name) = &self.display_name {
display!("Display name", display_name);
}
if self.hashed_password.is_some() { if self.hashed_password.is_some() {
writeln!(f, "Password: <SET>")?; display!("Password", "********");
} }
for (provider, subject) in &self.upstream_provider_mappings { for (provider, subject) in &self.upstream_provider_mappings {
let provider = provider let provider = HumanReadable(*provider);
.human_name display!("Upstream account", format!("{provider} : {subject:?}"));
.clone()
.unwrap_or_else(|| provider.id.to_string());
writeln!(f, "Upstream account: {provider} => {subject}")?;
} }
for email in &self.emails { for email in &self.emails {
writeln!(f, "Email: {email}")?; display!("Email", email);
} }
writeln!(f, "Can request admin: {}", self.admin)?; if self.emails.is_empty() {
term.write_line(
&warning_style
.apply_to("No email address provided, user will be prompted to add one")
.to_string(),
)?;
}
if self.hashed_password.is_none() && self.upstream_provider_mappings.is_empty() {
term.write_line(
&warning_style.apply_to("No password or upstream provider mapping provided, user will not be able to log in")
.to_string(),
)?;
}
if let Some(admin) = self.admin {
display!("Can request admin", admin);
}
term.flush()?;
Ok(()) Ok(())
} }
}
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>( /// Submit the user creation request
async fn do_register<E: std::error::Error + Send + Sync + 'static>(
self,
repo: &mut dyn RepositoryAccess<Error = E>, repo: &mut dyn RepositoryAccess<Error = E>,
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
UserCreationRequest { ) -> Result<User, E> {
let Self {
username, username,
hashed_password, hashed_password,
emails, emails,
upstream_provider_mappings, upstream_provider_mappings,
display_name, display_name,
admin, admin,
.. } = self;
}: UserCreationRequest<'a>,
) -> Result<User, E> {
let mut user = repo.user().add(rng, clock, username).await?; let mut user = repo.user().add(rng, clock, username).await?;
if let Some((version, hashed_password)) = hashed_password { if let Some((version, hashed_password)) = hashed_password {
@@ -776,8 +941,8 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
.await?; .await?;
} }
if admin { if let Some(admin) = admin {
user = repo.user().set_can_request_admin(user, true).await?; user = repo.user().set_can_request_admin(user, admin).await?;
} }
let mut provision_job = ProvisionUserJob::new(&user); let mut provision_job = ProvisionUserJob::new(&user);
@@ -789,3 +954,93 @@ async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
Ok(user) Ok(user)
} }
}
#[derive(Debug, Clone, Copy)]
enum Action {
CreateUser,
ChangeUsername,
SetPassword,
SetDisplayName,
AddEmail,
SetAdmin,
AddUpstreamProviderMapping,
}
impl std::fmt::Display for Action {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Action::CreateUser => write!(f, "Create the user"),
Action::ChangeUsername => write!(f, "Change the username"),
Action::SetPassword => write!(f, "Set a password"),
Action::AddEmail => write!(f, "Add email"),
Action::SetDisplayName => write!(f, "Set a display name"),
Action::SetAdmin => write!(f, "Set the admin status"),
Action::AddUpstreamProviderMapping => write!(f, "Add upstream provider mapping"),
}
}
}
/// A wrapper to display the user creation request as a command
struct UserCreationCommand<'a>(&'a UserCreationRequest<'a>);
impl std::fmt::Display for UserCreationCommand<'_> {
fn fmt(&self, w: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let command = super::Options::command();
let manage = command.find_subcommand("manage").unwrap();
let register_user = manage.find_subcommand("register-user").unwrap();
let yes_arg = &register_user[&clap::Id::from("yes")];
let password_arg = &register_user[&clap::Id::from("password")];
let email_arg = &register_user[&clap::Id::from("emails")];
let upstream_provider_mapping_arg =
&register_user[&clap::Id::from("upstream_provider_mappings")];
let display_name_arg = &register_user[&clap::Id::from("display_name")];
let admin_arg = &register_user[&clap::Id::from("admin")];
let no_admin_arg = &register_user[&clap::Id::from("no_admin")];
write!(
w,
"{} {} {} --{} {}",
command.get_name(),
manage.get_name(),
register_user.get_name(),
yes_arg.get_long().unwrap(),
self.0.username,
)?;
for email in &self.0.emails {
let email: &str = email.as_ref();
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
}
if let Some(display_name) = &self.0.display_name {
write!(
w,
" --{} {:?}",
display_name_arg.get_long().unwrap(),
display_name
)?;
}
if self.0.hashed_password.is_some() {
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
}
for (provider, subject) in &self.0.upstream_provider_mappings {
let mapping = format!("{}:{}", provider.id, subject);
write!(
w,
" --{} {mapping:?}",
upstream_provider_mapping_arg.get_long().unwrap(),
)?;
}
match self.0.admin {
Some(true) => write!(w, " --{}", admin_arg.get_long().unwrap())?,
Some(false) => write!(w, " --{}", no_admin_arg.get_long().unwrap())?,
None => {}
}
Ok(())
}
}