1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Add a manage register-user utility to the CLI

This commit is contained in:
Quentin Gliech
2024-04-16 16:20:07 +02:00
parent e179dc6b2b
commit 1cb48b8026
2 changed files with 273 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // Copyright 2021-2024 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -12,24 +12,46 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, io::Write};
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::{ArgAction, CommandFactory, Parser};
use figment::Figment; use figment::Figment;
use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig}; use mas_config::{ConfigurationSection, DatabaseConfig, PasswordsConfig};
use mas_data_model::{Device, TokenType}; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
use mas_email::Address;
use mas_storage::{ use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatSessionRepository},
job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob}, job::{DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob},
user::{UserEmailRepository, UserPasswordRepository, UserRepository}, user::{UserEmailRepository, UserPasswordRepository, UserRepository},
RepositoryAccess, SystemClock, Clock, RepositoryAccess, SystemClock,
}; };
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use rand::SeedableRng; use rand::{RngCore, SeedableRng};
use sqlx::{types::Uuid, Acquire}; use sqlx::{types::Uuid, Acquire};
use tracing::{info, info_span, warn}; use tracing::{info, info_span, warn};
use crate::util::{database_connection_from_config, password_manager_from_config}; use crate::util::{database_connection_from_config, password_manager_from_config};
const USER_ATTRIBUTES_HEADING: &str = "User attributes";
#[derive(Debug, Clone)]
struct UpstreamProviderMapping {
upstream_provider_id: Ulid,
subject: String,
}
fn parse_upstream_provider_mapping(s: &str) -> Result<UpstreamProviderMapping, anyhow::Error> {
let (id, subject) = s.split_once(':').context("Invalid format")?;
let upstream_provider_id = id.parse().context("Invalid upstream provider ID")?;
let subject = subject.to_owned();
Ok(UpstreamProviderMapping {
upstream_provider_id,
subject,
})
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
pub(super) struct Options { pub(super) struct Options {
#[command(subcommand)] #[command(subcommand)]
@@ -86,6 +108,40 @@ enum Subcommand {
/// User to unlock /// User to unlock
username: String, username: String,
}, },
/// Register a user
RegisterUser {
/// Username to register
#[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)]
username: Option<String>,
/// Password to set
#[arg(short, long, help_heading = USER_ATTRIBUTES_HEADING)]
password: Option<String>,
/// Email to add
#[arg(short, long = "email", action = ArgAction::Append, help_heading = USER_ATTRIBUTES_HEADING)]
emails: Vec<Address>,
/// Upstream OAuth 2.0 provider mapping to add
#[arg(
short = 'M',
long = "upstream-provider-mapping",
value_parser = parse_upstream_provider_mapping,
action = ArgAction::Append,
value_name = "UPSTREAM_PROVIDER_ID:SUBJECT",
help_heading = USER_ATTRIBUTES_HEADING
)]
upstream_provider_mappings: Vec<UpstreamProviderMapping>,
/// Make the user an admin
#[arg(long, action = ArgAction::SetTrue, help_heading = USER_ATTRIBUTES_HEADING)]
admin: bool,
/// Set the user's display name
#[arg(short = 'D', long, help_heading = USER_ATTRIBUTES_HEADING)]
display_name: Option<String>,
},
} }
impl Options { impl Options {
@@ -412,6 +468,216 @@ impl Options {
Ok(()) Ok(())
} }
SC::RegisterUser {
username,
password,
emails,
upstream_provider_mappings,
admin,
display_name,
} => {
let password_config = PasswordsConfig::extract(figment)?;
let database_config = DatabaseConfig::extract(figment)?;
let password_manager = password_manager_from_config(&password_config).await?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
// Load all the providers we need
let mut upstream_providers = HashMap::new();
for mapping in &upstream_provider_mappings {
if upstream_providers.contains_key(&mapping.upstream_provider_id) {
continue;
}
let provider = repo
.upstream_oauth_provider()
.lookup(mapping.upstream_provider_id)
.await?
.context("Upstream provider not found")?;
upstream_providers.insert(provider.id, provider);
}
let upstream_provider_mappings = upstream_provider_mappings
.into_iter()
.map(|mapping| {
(
&upstream_providers[&mapping.upstream_provider_id],
mapping.subject,
)
})
.collect();
// Hash the password if it's provided
let hashed_password = if let Some(password) = password {
let password = password.into_bytes().into();
Some(password_manager.hash(&mut rng, password).await?)
} else {
None
};
// TODO: prompt
let username = username.context("Username is required")?;
if repo.user().exists(&username).await? {
anyhow::bail!("User already exists");
}
let req = UserCreationRequest {
username,
hashed_password,
emails,
upstream_provider_mappings,
display_name,
admin,
};
req.command(&mut std::io::stdout())?;
//do_register(&mut repo, &mut rng, &clock, req).await?;
repo.into_inner().commit().await?;
Ok(())
}
} }
} }
} }
struct UserCreationRequest<'a> {
username: String,
hashed_password: Option<(u16, String)>,
emails: Vec<Address>,
upstream_provider_mappings: Vec<(&'a UpstreamOAuthProvider, String)>,
display_name: Option<String>,
admin: bool,
}
impl<'a> UserCreationRequest<'a> {
fn command<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
let command = super::Options::command();
let manage = command.find_subcommand("manage").unwrap();
let register_user = manage.find_subcommand("register-user").unwrap();
let username_arg = &register_user[&clap::Id::from("username")];
let password_arg = &register_user[&clap::Id::from("password")];
let email_arg = &register_user[&clap::Id::from("emails")];
let upstream_provider_mapping_arg =
&register_user[&clap::Id::from("upstream_provider_mappings")];
let display_name_arg = &register_user[&clap::Id::from("display_name")];
let admin_arg = &register_user[&clap::Id::from("admin")];
write!(
w,
"{} {} {}",
command.get_name(),
manage.get_name(),
register_user.get_name()
)?;
write!(
w,
" --{} {:?}",
username_arg.get_long().unwrap(),
self.username
)?;
for email in &self.emails {
let email: &str = email.as_ref();
write!(w, " --{} {email:?}", email_arg.get_long().unwrap())?;
}
if let Some(display_name) = &self.display_name {
write!(
w,
" --{} {:?}",
display_name_arg.get_long().unwrap(),
display_name
)?;
}
if self.hashed_password.is_some() {
write!(w, " --{} $PASSWORD", password_arg.get_long().unwrap())?;
}
for (provider, subject) in &self.upstream_provider_mappings {
let mapping = format!("{}:{}", provider.id, subject);
write!(
w,
" --{} {mapping:?}",
upstream_provider_mapping_arg.get_long().unwrap(),
)?;
}
if self.admin {
write!(w, " --{}", admin_arg.get_long().unwrap())?;
}
Ok(())
}
}
async fn do_register<'a, E: std::error::Error + Send + Sync + 'static>(
repo: &mut dyn RepositoryAccess<Error = E>,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
UserCreationRequest {
username,
hashed_password,
emails,
upstream_provider_mappings,
display_name,
admin,
}: UserCreationRequest<'a>,
) -> Result<User, E> {
let mut user = repo.user().add(rng, clock, username).await?;
if let Some((version, hashed_password)) = hashed_password {
repo.user_password()
.add(rng, clock, &user, version, hashed_password, None)
.await?;
}
for email in emails {
let user_email = repo
.user_email()
.add(rng, clock, &user, email.to_string())
.await?;
let user_email = repo
.user_email()
.mark_as_verified(clock, user_email)
.await?;
if user.primary_user_email_id.is_none() {
repo.user_email().set_as_primary(&user_email).await?;
user.primary_user_email_id = Some(user_email.id);
}
}
for (provider, subject) in upstream_provider_mappings {
let link = repo
.upstream_oauth_link()
.add(rng, clock, provider, subject)
.await?;
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await?;
}
if admin {
user = repo.user().set_can_request_admin(user, true).await?;
}
let mut provision_job = ProvisionUserJob::new(&user);
if let Some(display_name) = display_name {
provision_job = provision_job.set_display_name(display_name);
}
repo.job().schedule_job(provision_job).await?;
Ok(user)
}

View File

@@ -28,6 +28,8 @@ pub(crate) mod users;
#[error("invalid state transition")] #[error("invalid state transition")]
pub struct InvalidTransitionError; pub struct InvalidTransitionError;
pub use ulid::Ulid;
pub use self::{ pub use self::{
compat::{ compat::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,