You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-26 10:44:51 +03:00
295 lines
9.5 KiB
Rust
295 lines
9.5 KiB
Rust
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
use std::collections::HashSet;
|
|
|
|
use clap::Parser;
|
|
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
|
|
use mas_storage::{
|
|
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess, SystemClock,
|
|
};
|
|
use mas_storage_pg::PgRepository;
|
|
use rand::SeedableRng;
|
|
use tracing::{info, info_span, warn};
|
|
|
|
use crate::util::database_from_config;
|
|
|
|
fn map_import_preference(
|
|
config: &mas_config::UpstreamOAuth2ImportPreference,
|
|
) -> mas_data_model::UpstreamOAuthProviderImportPreference {
|
|
let action = match &config.action {
|
|
mas_config::UpstreamOAuth2ImportAction::Ignore => {
|
|
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
|
|
}
|
|
mas_config::UpstreamOAuth2ImportAction::Suggest => {
|
|
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
|
|
}
|
|
mas_config::UpstreamOAuth2ImportAction::Force => {
|
|
mas_data_model::UpstreamOAuthProviderImportAction::Force
|
|
}
|
|
mas_config::UpstreamOAuth2ImportAction::Require => {
|
|
mas_data_model::UpstreamOAuthProviderImportAction::Require
|
|
}
|
|
};
|
|
|
|
mas_data_model::UpstreamOAuthProviderImportPreference { action }
|
|
}
|
|
|
|
fn map_claims_imports(
|
|
config: &mas_config::UpstreamOAuth2ClaimsImports,
|
|
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
|
|
mas_data_model::UpstreamOAuthProviderClaimsImports {
|
|
localpart: config
|
|
.localpart
|
|
.as_ref()
|
|
.map(map_import_preference)
|
|
.unwrap_or_default(),
|
|
displayname: config
|
|
.displayname
|
|
.as_ref()
|
|
.map(map_import_preference)
|
|
.unwrap_or_default(),
|
|
email: config
|
|
.email
|
|
.as_ref()
|
|
.map(map_import_preference)
|
|
.unwrap_or_default(),
|
|
}
|
|
}
|
|
|
|
#[derive(Parser, Debug)]
|
|
pub(super) struct Options {
|
|
#[command(subcommand)]
|
|
subcommand: Subcommand,
|
|
}
|
|
|
|
#[derive(Parser, Debug)]
|
|
enum Subcommand {
|
|
/// Dump the current config as YAML
|
|
Dump,
|
|
|
|
/// Check a config file
|
|
Check,
|
|
|
|
/// Generate a new config file
|
|
Generate,
|
|
|
|
/// Sync the clients and providers from the config file to the database
|
|
Sync {
|
|
/// Prune elements that are in the database but not in the config file
|
|
/// anymore
|
|
#[clap(long)]
|
|
prune: bool,
|
|
|
|
/// Do not actually write to the database
|
|
#[clap(long)]
|
|
dry_run: bool,
|
|
},
|
|
}
|
|
|
|
impl Options {
|
|
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
|
use Subcommand as SC;
|
|
match self.subcommand {
|
|
SC::Dump => {
|
|
let _span = info_span!("cli.config.dump").entered();
|
|
|
|
let config: RootConfig = root.load_config()?;
|
|
|
|
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
|
}
|
|
|
|
SC::Check => {
|
|
let _span = info_span!("cli.config.check").entered();
|
|
|
|
let _config: RootConfig = root.load_config()?;
|
|
info!(path = ?root.config, "Configuration file looks good");
|
|
}
|
|
|
|
SC::Generate => {
|
|
let _span = info_span!("cli.config.generate").entered();
|
|
|
|
// XXX: we should disallow SeedableRng::from_entropy
|
|
let rng = rand_chacha::ChaChaRng::from_entropy();
|
|
let config = RootConfig::load_and_generate(rng).await?;
|
|
|
|
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
|
}
|
|
|
|
SC::Sync { prune, dry_run } => {
|
|
sync(root, prune, dry_run).await?;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))]
|
|
async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> {
|
|
// XXX: we should disallow SeedableRng::from_entropy
|
|
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
|
let clock = SystemClock::default();
|
|
|
|
let config: SyncConfig = root.load_config()?;
|
|
let encrypter = config.secrets.encrypter();
|
|
let pool = database_from_config(&config.database).await?;
|
|
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
|
|
|
tracing::info!(
|
|
prune,
|
|
dry_run,
|
|
"Syncing providers and clients defined in config to database"
|
|
);
|
|
|
|
{
|
|
let _span = info_span!("cli.config.sync.providers").entered();
|
|
let config_ids = config
|
|
.upstream_oauth2
|
|
.providers
|
|
.iter()
|
|
.map(|p| p.id)
|
|
.collect::<HashSet<_>>();
|
|
|
|
let existing = repo.upstream_oauth_provider().all().await?;
|
|
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
|
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
|
if prune {
|
|
for provider in to_delete {
|
|
info!(%provider.id, "Deleting provider");
|
|
|
|
if dry_run {
|
|
continue;
|
|
}
|
|
|
|
repo.upstream_oauth_provider().delete(provider).await?;
|
|
}
|
|
} else {
|
|
let len = to_delete.count();
|
|
match len {
|
|
0 => {},
|
|
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
|
|
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
|
|
}
|
|
}
|
|
|
|
for provider in config.upstream_oauth2.providers {
|
|
if existing_ids.contains(&provider.id) {
|
|
info!(%provider.id, "Updating provider");
|
|
} else {
|
|
info!(%provider.id, "Adding provider");
|
|
}
|
|
|
|
if dry_run {
|
|
continue;
|
|
}
|
|
|
|
let encrypted_client_secret = provider
|
|
.client_secret()
|
|
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
|
.transpose()?;
|
|
let client_auth_method = provider.client_auth_method();
|
|
let client_auth_signing_alg = provider.client_auth_signing_alg();
|
|
|
|
repo.upstream_oauth_provider()
|
|
.upsert(
|
|
&clock,
|
|
provider.id,
|
|
provider.issuer,
|
|
provider.scope.parse()?,
|
|
client_auth_method,
|
|
client_auth_signing_alg,
|
|
provider.client_id,
|
|
encrypted_client_secret,
|
|
map_claims_imports(&provider.claims_imports),
|
|
)
|
|
.await?;
|
|
}
|
|
}
|
|
|
|
{
|
|
let _span = info_span!("cli.config.sync.clients").entered();
|
|
let config_ids = config
|
|
.clients
|
|
.iter()
|
|
.map(|c| c.client_id)
|
|
.collect::<HashSet<_>>();
|
|
|
|
let existing = repo.oauth2_client().all_static().await?;
|
|
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
|
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
|
if prune {
|
|
for client in to_delete {
|
|
info!(client.id = %client.client_id, "Deleting client");
|
|
|
|
if dry_run {
|
|
continue;
|
|
}
|
|
|
|
repo.oauth2_client().delete(client).await?;
|
|
}
|
|
} else {
|
|
let len = to_delete.count();
|
|
match len {
|
|
0 => {},
|
|
1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."),
|
|
n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."),
|
|
}
|
|
}
|
|
|
|
for client in config.clients.iter() {
|
|
if existing_ids.contains(&client.client_id) {
|
|
info!(client.id = %client.client_id, "Updating client");
|
|
} else {
|
|
info!(client.id = %client.client_id, "Adding client");
|
|
}
|
|
|
|
if dry_run {
|
|
continue;
|
|
}
|
|
|
|
let client_secret = client.client_secret();
|
|
let client_auth_method = client.client_auth_method();
|
|
let jwks = client.jwks();
|
|
let jwks_uri = client.jwks_uri();
|
|
|
|
// TODO: should be moved somewhere else
|
|
let encrypted_client_secret = client_secret
|
|
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
|
.transpose()?;
|
|
|
|
repo.oauth2_client()
|
|
.upsert_static(
|
|
&mut rng,
|
|
&clock,
|
|
client.client_id,
|
|
client_auth_method,
|
|
encrypted_client_secret,
|
|
jwks.cloned(),
|
|
jwks_uri.cloned(),
|
|
client.redirect_uris.clone(),
|
|
)
|
|
.await?;
|
|
}
|
|
}
|
|
|
|
if dry_run {
|
|
info!("Dry run, rolling back changes");
|
|
repo.cancel().await?;
|
|
} else {
|
|
repo.save().await?;
|
|
}
|
|
Ok(())
|
|
}
|