You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Add a manage register-user
utility to the CLI
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -12,24 +12,46 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::{collections::HashMap, io::Write};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use clap::Parser;
|
use clap::{ArgAction, CommandFactory, Parser};
|
||||||
use figment::Figment;
|
use figment::Figment;
|
||||||
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig};
|
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig};
|
||||||
use mas_data_model::{Device, TokenType};
|
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
|
||||||
|
use mas_email::Address;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||||
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
|
||||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||||
RepositoryAccess, SystemClock,
|
Clock, RepositoryAccess, SystemClock,
|
||||||
};
|
};
|
||||||
use mas_storage_pg::PgRepository;
|
use mas_storage_pg::PgRepository;
|
||||||
use rand::SeedableRng;
|
use rand::{RngCore, SeedableRng};
|
||||||
use sqlx::{types::Uuid, Acquire};
|
use sqlx::{types::Uuid, Acquire};
|
||||||
use tracing::{info, info_span, warn};
|
use tracing::{info, info_span, warn};
|
||||||
|
|
||||||
use crate::util::{database_connection_from_config, password_manager_from_config};
|
use crate::util::{database_connection_from_config, password_manager_from_config};
|
||||||
|
|
||||||
|
const USER_ATTRIBUTES_HEADING: &str = "User attributes";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct UpstreamProviderMapping {
|
||||||
|
upstream_provider_id: Ulid,
|
||||||
|
subject: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_upstream_provider_mapping(s: &str) -> Result<UpstreamProviderMapping, anyhow::Error> {
|
||||||
|
let (id, subject) = s.split_once(':').context("Invalid format")?;
|
||||||
|
let upstream_provider_id = id.parse().context("Invalid upstream provider ID")?;
|
||||||
|
let subject = subject.to_owned();
|
||||||
|
|
||||||
|
Ok(UpstreamProviderMapping {
|
||||||
|
upstream_provider_id,
|
||||||
|
subject,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
pub(super) struct Options {
|
pub(super) struct Options {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
@@ -86,6 +108,40 @@ enum Subcommand {
|
|||||||
/// User to unlock
|
/// User to unlock
|
||||||
username: String,
|
username: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Register a user
|
||||||
|
RegisterUser {
|
||||||
|
/// Username to register
|
||||||
|
#[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)]
|
||||||
|
username: Option<String>,
|
||||||
|
|
||||||
|
/// Password to set
|
||||||
|
#[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)]
|
||||||
|
password: Option<String>,
|
||||||
|
|
||||||
|
/// Email to add
|
||||||
|
#[arg(short, long = "email", action = ArgAction::Append, help_heading = USER_ATTRIBUTES_HEADING)]
|
||||||
|
emails: Vec<Address>,
|
||||||
|
|
||||||
|
/// Upstream OAuth 2.0 provider mapping to add
|
||||||
|
#[arg(
|
||||||
|
short = 'M',
|
||||||
|
long = "upstream-provider-mapping",
|
||||||
|
value_parser = parse_upstream_provider_mapping,
|
||||||
|
action = ArgAction::Append,
|
||||||
|
value_name = "UPSTREAM_PROVIDER_ID:SUBJECT",
|
||||||
|
help_heading = USER_ATTRIBUTES_HEADING
|
||||||
|
)]
|
||||||
|
upstream_provider_mappings: Vec<UpstreamProviderMapping>,
|
||||||
|
|
||||||
|
/// Make the user an admin
|
||||||
|
#[arg(long, action = ArgAction::SetTrue, help_heading = USER_ATTRIBUTES_HEADING)]
|
||||||
|
admin: bool,
|
||||||
|
|
||||||
|
/// Set the user's display name
|
||||||
|
#[arg(short = 'D', long, help_heading = USER_ATTRIBUTES_HEADING)]
|
||||||
|
display_name: Option<String>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Options {
|
impl Options {
|
||||||
@@ -412,6 +468,216 @@ impl Options {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SC::RegisterUser {
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
emails,
|
||||||
|
upstream_provider_mappings,
|
||||||
|
admin,
|
||||||
|
display_name,
|
||||||
|
} => {
|
||||||
|
let password_config = PasswordsConfig::extract(figment)?;
|
||||||
|
let database_config = DatabaseConfig::extract(figment)?;
|
||||||
|
|
||||||
|
let password_manager = password_manager_from_config(&password_config).await?;
|
||||||
|
let mut conn = database_connection_from_config(&database_config).await?;
|
||||||
|
let txn = conn.begin().await?;
|
||||||
|
let mut repo = PgRepository::from_conn(txn);
|
||||||
|
|
||||||
|
// Load all the providers we need
|
||||||
|
let mut upstream_providers = HashMap::new();
|
||||||
|
for mapping in &upstream_provider_mappings {
|
||||||
|
if upstream_providers.contains_key(&mapping.upstream_provider_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider = repo
|
||||||
|
.upstream_oauth_provider()
|
||||||
|
.lookup(mapping.upstream_provider_id)
|
||||||
|
.await?
|
||||||
|
.context("Upstream provider not found")?;
|
||||||
|
|
||||||
|
upstream_providers.insert(provider.id, provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
let upstream_provider_mappings = upstream_provider_mappings
|
||||||
|
.into_iter()
|
||||||
|
.map(|mapping| {
|
||||||
|
(
|
||||||
|
&upstream_providers[&mapping.upstream_provider_id],
|
||||||
|
mapping.subject,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Hash the password if it's provided
|
||||||
|
let hashed_password = if let Some(password) = password {
|
||||||
|
let password = password.into_bytes().into();
|
||||||
|
Some(password_manager.hash(&mut rng, password).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: prompt
|
||||||
|
let username = username.context("Username is required")?;
|
||||||
|
|
||||||
|
if repo.user().exists(&username).await? {
|
||||||
|
anyhow::bail!("User already exists");
|
||||||
|
}
|
||||||
|
|
||||||
|
let req = UserCreationRequest {
|
||||||
|
username,
|
||||||
|
hashed_password,
|
||||||
|
emails,
|
||||||
|
upstream_provider_mappings,
|
||||||
|
display_name,
|
||||||
|
admin,
|
||||||
|
};
|
||||||
|
|
||||||
|
req.command(&mut std::io::stdout())?;
|
||||||
|
|
||||||
|
//do_register(&mut repo, &mut rng, &clock, req).await?;
|
||||||
|
|
||||||
|
repo.into_inner().commit().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct UserCreationRequest<'a> {
|
||||||
|
username: String,
|
||||||
|
hashed_password: Option<(u16, String)>,
|
||||||
|
emails: Vec<Address>,
|
||||||
|
upstream_provider_mappings: Vec<(&'a UpstreamOAuthProvider, String)>,
|
||||||
|
display_name: Option<String>,
|
||||||
|
admin: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> UserCreationRequest<'a> {
|
||||||
|
fn command<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
|
||||||
|
let command = super::Options::command();
|
||||||
|
let manage = command.find_subcommand("manage").unwrap();
|
||||||
|
let register_user = manage.find_subcommand("register-user").unwrap();
|
||||||
|
let username_arg = ®ister_user[&clap::Id::from("username")];
|
||||||
|
let password_arg = ®ister_user[&clap::Id::from("password")];
|
||||||
|
let email_arg = ®ister_user[&clap::Id::from("emails")];
|
||||||
|
let upstream_provider_mapping_arg =
|
||||||
|
®ister_user[&clap::Id::from("upstream_provider_mappings")];
|
||||||
|
let display_name_arg = ®ister_user[&clap::Id::from("display_name")];
|
||||||
|
let admin_arg = ®ister_user[&clap::Id::from("admin")];
|
||||||
|
|
||||||
|
write!(
|
||||||
|
w,
|
||||||
|
"{} {} {}",
|
||||||
|
command.get_name(),
|
||||||
|
manage.get_name(),
|
||||||
|
register_user.get_name()
|
||||||
|
)?;
|
||||||
|
|
||||||
|
write!(
|
||||||
|
w,
|
||||||
|
" --{} {:?}",
|
||||||
|
username_arg.get_long().unwrap(),
|
||||||
|
self.username
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for email in &self.emails {
|
||||||
|
let email: &str = email.as_ref();
|
||||||
|
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(display_name) = &self.display_name {
|
||||||
|
write!(
|
||||||
|
w,
|
||||||
|
" --{} {:?}",
|
||||||
|
display_name_arg.get_long().unwrap(),
|
||||||
|
display_name
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.hashed_password.is_some() {
|
||||||
|
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (provider, subject) in &self.upstream_provider_mappings {
|
||||||
|
let mapping = format!("{}:{}", provider.id, subject);
|
||||||
|
write!(
|
||||||
|
w,
|
||||||
|
" --{} {mapping:?}",
|
||||||
|
upstream_provider_mapping_arg.get_long().unwrap(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.admin {
|
||||||
|
write!(w, " --{}", admin_arg.get_long().unwrap())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
|
||||||
|
repo: &mut dyn RepositoryAccess<Error = E>,
|
||||||
|
rng: &mut (dyn RngCore + Send),
|
||||||
|
clock: &dyn Clock,
|
||||||
|
UserCreationRequest {
|
||||||
|
username,
|
||||||
|
hashed_password,
|
||||||
|
emails,
|
||||||
|
upstream_provider_mappings,
|
||||||
|
display_name,
|
||||||
|
admin,
|
||||||
|
}: UserCreationRequest<'a>,
|
||||||
|
) -> Result<User, E> {
|
||||||
|
let mut user = repo.user().add(rng, clock, username).await?;
|
||||||
|
|
||||||
|
if let Some((version, hashed_password)) = hashed_password {
|
||||||
|
repo.user_password()
|
||||||
|
.add(rng, clock, &user, version, hashed_password, None)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for email in emails {
|
||||||
|
let user_email = repo
|
||||||
|
.user_email()
|
||||||
|
.add(rng, clock, &user, email.to_string())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let user_email = repo
|
||||||
|
.user_email()
|
||||||
|
.mark_as_verified(clock, user_email)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if user.primary_user_email_id.is_none() {
|
||||||
|
repo.user_email().set_as_primary(&user_email).await?;
|
||||||
|
user.primary_user_email_id = Some(user_email.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (provider, subject) in upstream_provider_mappings {
|
||||||
|
let link = repo
|
||||||
|
.upstream_oauth_link()
|
||||||
|
.add(rng, clock, provider, subject)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
repo.upstream_oauth_link()
|
||||||
|
.associate_to_user(&link, &user)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if admin {
|
||||||
|
user = repo.user().set_can_request_admin(user, true).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut provision_job = ProvisionUserJob::new(&user);
|
||||||
|
if let Some(display_name) = display_name {
|
||||||
|
provision_job = provision_job.set_display_name(display_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.job().schedule_job(provision_job).await?;
|
||||||
|
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
@@ -28,6 +28,8 @@ pub(crate) mod users;
|
|||||||
#[error("invalid state transition")]
|
#[error("invalid state transition")]
|
||||||
pub struct InvalidTransitionError;
|
pub struct InvalidTransitionError;
|
||||||
|
|
||||||
|
pub use ulid::Ulid;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
compat::{
|
compat::{
|
||||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||||
|
Reference in New Issue
Block a user