You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
CLI tool to sync the upstream IDPs with the config
This commit is contained in:
@ -16,13 +16,46 @@ use std::collections::HashSet;
|
||||
|
||||
use clap::Parser;
|
||||
use mas_config::{ConfigurationSection, RootConfig};
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::SeedableRng;
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::util::database_from_config;
|
||||
|
||||
fn map_import_preference(
|
||||
config: &mas_config::UpstreamOAuth2ImportPreference,
|
||||
) -> mas_data_model::UpstreamOAuthProviderImportPreference {
|
||||
let action = match &config.action {
|
||||
mas_config::UpstreamOAuth2ImportAction::Ignore => {
|
||||
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
|
||||
}
|
||||
mas_config::UpstreamOAuth2ImportAction::Suggest => {
|
||||
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
|
||||
}
|
||||
mas_config::UpstreamOAuth2ImportAction::Force => {
|
||||
mas_data_model::UpstreamOAuthProviderImportAction::Force
|
||||
}
|
||||
mas_config::UpstreamOAuth2ImportAction::Require => {
|
||||
mas_data_model::UpstreamOAuthProviderImportAction::Require
|
||||
}
|
||||
};
|
||||
|
||||
mas_data_model::UpstreamOAuthProviderImportPreference { action }
|
||||
}
|
||||
|
||||
fn map_claims_imports(
|
||||
config: &mas_config::UpstreamOAuth2ClaimsImports,
|
||||
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
|
||||
mas_data_model::UpstreamOAuthProviderClaimsImports {
|
||||
localpart: config.localpart.as_ref().map(map_import_preference).unwrap_or_default(),
|
||||
displayname: config.displayname.as_ref().map(map_import_preference).unwrap_or_default(),
|
||||
email: config.email.as_ref().map(map_import_preference).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
#[command(subcommand)]
|
||||
@ -64,12 +97,14 @@ impl Options {
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
||||
}
|
||||
|
||||
SC::Check => {
|
||||
let _span = info_span!("cli.config.check").entered();
|
||||
|
||||
let _config: RootConfig = root.load_config()?;
|
||||
info!(path = ?root.config, "Configuration file looks good");
|
||||
}
|
||||
|
||||
SC::Generate => {
|
||||
let _span = info_span!("cli.config.generate").entered();
|
||||
|
||||
@ -81,9 +116,13 @@ impl Options {
|
||||
}
|
||||
|
||||
SC::Sync { prune, dry_run } => {
|
||||
let _span = info_span!("cli.config.sync").entered();
|
||||
let _span =
|
||||
info_span!("cli.config.sync", prune = prune, dry_run = dry_run).entered();
|
||||
|
||||
let clock = SystemClock::default();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
@ -93,9 +132,6 @@ impl Options {
|
||||
"Syncing providers and clients defined in config to database"
|
||||
);
|
||||
|
||||
let existing = repo.upstream_oauth_provider().all().await?;
|
||||
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let config_ids = config
|
||||
.upstream_oauth2
|
||||
.providers
|
||||
@ -103,24 +139,54 @@ impl Options {
|
||||
.map(|p| p.id)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let needs_pruning = existing_ids.difference(&config_ids).collect::<Vec<_>>();
|
||||
let existing = repo.upstream_oauth_provider().all().await?;
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
for id in needs_pruning {
|
||||
info!(provider.id = %id, "Deleting provider");
|
||||
for provider in to_delete {
|
||||
info!(%provider.id, "Deleting provider");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
repo.upstream_oauth_provider().delete(provider).await?;
|
||||
}
|
||||
} else {
|
||||
let len = to_delete.count();
|
||||
match len {
|
||||
0 => {},
|
||||
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
|
||||
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
|
||||
}
|
||||
} else if !needs_pruning.is_empty() {
|
||||
warn!(
|
||||
"{} provider(s) in the database are not in the config. Run with `--prune` to delete them.",
|
||||
needs_pruning.len()
|
||||
);
|
||||
}
|
||||
|
||||
for provider in config.upstream_oauth2.providers {
|
||||
if existing_ids.contains(&provider.id) {
|
||||
info!(%provider.id, "Updating provider");
|
||||
} else {
|
||||
info!(%provider.id, "Adding provider");
|
||||
info!(%provider.id, "Syncing provider");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
let encrypted_client_secret = provider
|
||||
.client_secret()
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
let client_auth_method = provider.client_auth_method();
|
||||
let client_auth_signing_alg = provider.client_auth_signing_alg();
|
||||
|
||||
repo.upstream_oauth_provider()
|
||||
.upsert(
|
||||
&clock,
|
||||
provider.id,
|
||||
provider.issuer,
|
||||
provider.scope.parse()?,
|
||||
client_auth_method,
|
||||
client_auth_signing_alg,
|
||||
provider.client_id,
|
||||
encrypted_client_secret,
|
||||
map_claims_imports(&provider.claims_imports),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
|
Reference in New Issue
Block a user