You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Automatically sync the configuration on server startup
This commit is contained in:
@@ -12,75 +12,18 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use anyhow::Context;
|
||||||
|
|
||||||
use camino::Utf8PathBuf;
|
use camino::Utf8PathBuf;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
|
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
|
||||||
use mas_storage::{
|
use mas_storage::SystemClock;
|
||||||
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
|
use mas_storage_pg::MIGRATOR;
|
||||||
RepositoryAccess, SystemClock,
|
|
||||||
};
|
|
||||||
use mas_storage_pg::PgRepository;
|
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use sqlx::{postgres::PgAdvisoryLock, Acquire};
|
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
use tracing::{error, info, info_span, warn};
|
use tracing::{info, info_span, Instrument};
|
||||||
|
|
||||||
use crate::util::database_connection_from_config;
|
use crate::util::database_connection_from_config;
|
||||||
|
|
||||||
fn map_import_action(
|
|
||||||
config: &mas_config::UpstreamOAuth2ImportAction,
|
|
||||||
) -> mas_data_model::UpstreamOAuthProviderImportAction {
|
|
||||||
match config {
|
|
||||||
mas_config::UpstreamOAuth2ImportAction::Ignore => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2ImportAction::Suggest => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2ImportAction::Force => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderImportAction::Force
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2ImportAction::Require => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderImportAction::Require
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn map_claims_imports(
|
|
||||||
config: &mas_config::UpstreamOAuth2ClaimsImports,
|
|
||||||
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
|
|
||||||
mas_data_model::UpstreamOAuthProviderClaimsImports {
|
|
||||||
subject: mas_data_model::UpstreamOAuthProviderSubjectPreference {
|
|
||||||
template: config.subject.template.clone(),
|
|
||||||
},
|
|
||||||
localpart: mas_data_model::UpstreamOAuthProviderImportPreference {
|
|
||||||
action: map_import_action(&config.localpart.action),
|
|
||||||
template: config.localpart.template.clone(),
|
|
||||||
},
|
|
||||||
displayname: mas_data_model::UpstreamOAuthProviderImportPreference {
|
|
||||||
action: map_import_action(&config.displayname.action),
|
|
||||||
template: config.displayname.template.clone(),
|
|
||||||
},
|
|
||||||
email: mas_data_model::UpstreamOAuthProviderImportPreference {
|
|
||||||
action: map_import_action(&config.email.action),
|
|
||||||
template: config.email.template.clone(),
|
|
||||||
},
|
|
||||||
verify_email: match config.email.set_email_verification {
|
|
||||||
mas_config::UpstreamOAuth2SetEmailVerification::Always => {
|
|
||||||
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Always
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2SetEmailVerification::Never => {
|
|
||||||
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Never
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2SetEmailVerification::Import => {
|
|
||||||
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
pub(super) struct Options {
|
pub(super) struct Options {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
@@ -169,230 +112,32 @@ impl Options {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SC::Sync { prune, dry_run } => {
|
SC::Sync { prune, dry_run } => {
|
||||||
sync(root, prune, dry_run).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))]
|
|
||||||
async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> {
|
|
||||||
// XXX: we should disallow SeedableRng::from_entropy
|
|
||||||
let clock = SystemClock::default();
|
|
||||||
|
|
||||||
let config: SyncConfig = root.load_config()?;
|
let config: SyncConfig = root.load_config()?;
|
||||||
|
let clock = SystemClock::default();
|
||||||
let encrypter = config.secrets.encrypter();
|
let encrypter = config.secrets.encrypter();
|
||||||
|
|
||||||
// Grab a connection to the database
|
// Grab a connection to the database
|
||||||
let mut conn = database_connection_from_config(&config.database).await?;
|
let mut conn = database_connection_from_config(&config.database).await?;
|
||||||
// Start a transaction
|
|
||||||
let txn = conn.begin().await?;
|
|
||||||
|
|
||||||
// Grab a lock within the transaction
|
MIGRATOR
|
||||||
tracing::info!("Acquiring config lock");
|
.run(&mut conn)
|
||||||
let lock = PgAdvisoryLock::new("MAS config sync");
|
.instrument(info_span!("db.migrate"))
|
||||||
let lock = lock.acquire(txn).await?;
|
.await
|
||||||
|
.context("could not run migrations")?;
|
||||||
|
|
||||||
// Create a repository from the connection with the lock
|
crate::sync::config_sync(
|
||||||
let mut repo = PgRepository::from_conn(lock);
|
config.upstream_oauth2,
|
||||||
|
config.clients,
|
||||||
tracing::info!(
|
&mut conn,
|
||||||
|
&encrypter,
|
||||||
|
&clock,
|
||||||
prune,
|
prune,
|
||||||
dry_run,
|
dry_run,
|
||||||
"Syncing providers and clients defined in config to database"
|
|
||||||
);
|
|
||||||
|
|
||||||
{
|
|
||||||
let _span = info_span!("cli.config.sync.providers").entered();
|
|
||||||
let config_ids = config
|
|
||||||
.upstream_oauth2
|
|
||||||
.providers
|
|
||||||
.iter()
|
|
||||||
.map(|p| p.id)
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
let existing = repo.upstream_oauth_provider().all().await?;
|
|
||||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
|
||||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
|
||||||
if prune {
|
|
||||||
for provider in to_delete {
|
|
||||||
info!(%provider.id, "Deleting provider");
|
|
||||||
|
|
||||||
if dry_run {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
repo.upstream_oauth_provider().delete(provider).await?;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let len = to_delete.count();
|
|
||||||
match len {
|
|
||||||
0 => {},
|
|
||||||
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
|
|
||||||
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for provider in config.upstream_oauth2.providers {
|
|
||||||
let _span = info_span!("provider", %provider.id).entered();
|
|
||||||
if existing_ids.contains(&provider.id) {
|
|
||||||
info!("Updating provider");
|
|
||||||
} else {
|
|
||||||
info!("Adding provider");
|
|
||||||
}
|
|
||||||
|
|
||||||
if dry_run {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let encrypted_client_secret = provider
|
|
||||||
.client_secret()
|
|
||||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
|
||||||
.transpose()?;
|
|
||||||
let token_endpoint_auth_method = provider.client_auth_method();
|
|
||||||
let token_endpoint_signing_alg = provider.client_auth_signing_alg();
|
|
||||||
|
|
||||||
let discovery_mode = match provider.discovery_mode {
|
|
||||||
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if discovery_mode.is_disabled() {
|
|
||||||
if provider.authorization_endpoint.is_none() {
|
|
||||||
error!("Provider has discovery disabled but no authorization endpoint set");
|
|
||||||
}
|
|
||||||
|
|
||||||
if provider.token_endpoint.is_none() {
|
|
||||||
error!("Provider has discovery disabled but no token endpoint set");
|
|
||||||
}
|
|
||||||
|
|
||||||
if provider.jwks_uri.is_none() {
|
|
||||||
error!("Provider has discovery disabled but no JWKS URI set");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let pkce_mode = match provider.pkce_method {
|
|
||||||
mas_config::UpstreamOAuth2PkceMethod::Auto => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2PkceMethod::Always => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderPkceMode::S256
|
|
||||||
}
|
|
||||||
mas_config::UpstreamOAuth2PkceMethod::Never => {
|
|
||||||
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
repo.upstream_oauth_provider()
|
|
||||||
.upsert(
|
|
||||||
&clock,
|
|
||||||
provider.id,
|
|
||||||
UpstreamOAuthProviderParams {
|
|
||||||
issuer: provider.issuer,
|
|
||||||
human_name: provider.human_name,
|
|
||||||
brand_name: provider.brand_name,
|
|
||||||
scope: provider.scope.parse()?,
|
|
||||||
token_endpoint_auth_method,
|
|
||||||
token_endpoint_signing_alg,
|
|
||||||
client_id: provider.client_id,
|
|
||||||
encrypted_client_secret,
|
|
||||||
claims_imports: map_claims_imports(&provider.claims_imports),
|
|
||||||
token_endpoint_override: provider.token_endpoint,
|
|
||||||
authorization_endpoint_override: provider.authorization_endpoint,
|
|
||||||
jwks_uri_override: provider.jwks_uri,
|
|
||||||
discovery_mode,
|
|
||||||
pkce_mode,
|
|
||||||
additional_authorization_parameters: provider
|
|
||||||
.additional_authorization_parameters
|
|
||||||
.into_iter()
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
let _span = info_span!("cli.config.sync.clients").entered();
|
|
||||||
let config_ids = config
|
|
||||||
.clients
|
|
||||||
.iter()
|
|
||||||
.map(|c| c.client_id)
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
let existing = repo.oauth2_client().all_static().await?;
|
|
||||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
|
||||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
|
||||||
if prune {
|
|
||||||
for client in to_delete {
|
|
||||||
info!(client.id = %client.client_id, "Deleting client");
|
|
||||||
|
|
||||||
if dry_run {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
repo.oauth2_client().delete(client).await?;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let len = to_delete.count();
|
|
||||||
match len {
|
|
||||||
0 => {},
|
|
||||||
1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."),
|
|
||||||
n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for client in config.clients.iter() {
|
|
||||||
let _span = info_span!("client", client.id = %client.client_id).entered();
|
|
||||||
if existing_ids.contains(&client.client_id) {
|
|
||||||
info!("Updating client");
|
|
||||||
} else {
|
|
||||||
info!("Adding client");
|
|
||||||
}
|
|
||||||
|
|
||||||
if dry_run {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let client_secret = client.client_secret();
|
|
||||||
let client_auth_method = client.client_auth_method();
|
|
||||||
let jwks = client.jwks();
|
|
||||||
let jwks_uri = client.jwks_uri();
|
|
||||||
|
|
||||||
// TODO: should be moved somewhere else
|
|
||||||
let encrypted_client_secret = client_secret
|
|
||||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
|
||||||
.transpose()?;
|
|
||||||
|
|
||||||
repo.oauth2_client()
|
|
||||||
.upsert_static(
|
|
||||||
client.client_id,
|
|
||||||
client_auth_method,
|
|
||||||
encrypted_client_secret,
|
|
||||||
jwks.cloned(),
|
|
||||||
jwks_uri.cloned(),
|
|
||||||
client.redirect_uris.clone(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the lock and release it to commit the transaction
|
|
||||||
let lock = repo.into_inner();
|
|
||||||
let txn = lock.release_now().await?;
|
|
||||||
if dry_run {
|
|
||||||
info!("Dry run, rolling back changes");
|
|
||||||
txn.rollback().await?;
|
|
||||||
} else {
|
|
||||||
txn.commit().await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@@ -22,6 +22,7 @@ use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCa
|
|||||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||||
use mas_matrix_synapse::SynapseConnection;
|
use mas_matrix_synapse::SynapseConnection;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
|
use mas_storage::SystemClock;
|
||||||
use mas_storage_pg::MIGRATOR;
|
use mas_storage_pg::MIGRATOR;
|
||||||
use rand::{
|
use rand::{
|
||||||
distributions::{Alphanumeric, DistString},
|
distributions::{Alphanumeric, DistString},
|
||||||
@@ -39,6 +40,7 @@ use crate::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
#[derive(Parser, Debug, Default)]
|
#[derive(Parser, Debug, Default)]
|
||||||
pub(super) struct Options {
|
pub(super) struct Options {
|
||||||
/// Do not apply pending migrations on start
|
/// Do not apply pending migrations on start
|
||||||
@@ -53,6 +55,10 @@ pub(super) struct Options {
|
|||||||
/// Do not start the task worker
|
/// Do not start the task worker
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
no_worker: bool,
|
no_worker: bool,
|
||||||
|
|
||||||
|
/// Do not sync the configuration with the database
|
||||||
|
#[arg(long)]
|
||||||
|
no_sync: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Options {
|
impl Options {
|
||||||
@@ -88,6 +94,28 @@ impl Options {
|
|||||||
.context("could not run migrations")?;
|
.context("could not run migrations")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let encrypter = config.secrets.encrypter();
|
||||||
|
|
||||||
|
if self.no_sync {
|
||||||
|
info!("Skipping configuration sync");
|
||||||
|
} else {
|
||||||
|
// Sync the configuration with the database
|
||||||
|
let mut conn = pool.acquire().await?;
|
||||||
|
let clients_config = root.load_config()?;
|
||||||
|
let upstream_oauth2_config = root.load_config()?;
|
||||||
|
|
||||||
|
crate::sync::config_sync(
|
||||||
|
upstream_oauth2_config,
|
||||||
|
clients_config,
|
||||||
|
&mut conn,
|
||||||
|
&encrypter,
|
||||||
|
&SystemClock::default(),
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the key store
|
// Initialize the key store
|
||||||
let key_store = config
|
let key_store = config
|
||||||
.secrets
|
.secrets
|
||||||
@@ -95,7 +123,6 @@ impl Options {
|
|||||||
.await
|
.await
|
||||||
.context("could not import keys from config")?;
|
.context("could not import keys from config")?;
|
||||||
|
|
||||||
let encrypter = config.secrets.encrypter();
|
|
||||||
let cookie_manager =
|
let cookie_manager =
|
||||||
CookieManager::derive_from(config.http.public_base.clone(), &config.secrets.encryption);
|
CookieManager::derive_from(config.http.public_base.clone(), &config.secrets.encryption);
|
||||||
|
|
||||||
|
@@ -30,6 +30,7 @@ mod app_state;
|
|||||||
mod commands;
|
mod commands;
|
||||||
mod sentry_transport;
|
mod sentry_transport;
|
||||||
mod server;
|
mod server;
|
||||||
|
mod sync;
|
||||||
mod telemetry;
|
mod telemetry;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
|
295
crates/cli/src/sync.rs
Normal file
295
crates/cli/src/sync.rs
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
// Copyright 2024 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! Utilities to synchronize the configuration file with the database.
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use mas_config::{ClientsConfig, UpstreamOAuth2Config};
|
||||||
|
use mas_keystore::Encrypter;
|
||||||
|
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderParams, Clock, RepositoryAccess};
|
||||||
|
use mas_storage_pg::PgRepository;
|
||||||
|
use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection};
|
||||||
|
use tracing::{error, info, info_span, warn};
|
||||||
|
|
||||||
|
fn map_import_action(
|
||||||
|
config: &mas_config::UpstreamOAuth2ImportAction,
|
||||||
|
) -> mas_data_model::UpstreamOAuthProviderImportAction {
|
||||||
|
match config {
|
||||||
|
mas_config::UpstreamOAuth2ImportAction::Ignore => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2ImportAction::Suggest => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2ImportAction::Force => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderImportAction::Force
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2ImportAction::Require => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderImportAction::Require
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_claims_imports(
|
||||||
|
config: &mas_config::UpstreamOAuth2ClaimsImports,
|
||||||
|
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
|
||||||
|
mas_data_model::UpstreamOAuthProviderClaimsImports {
|
||||||
|
subject: mas_data_model::UpstreamOAuthProviderSubjectPreference {
|
||||||
|
template: config.subject.template.clone(),
|
||||||
|
},
|
||||||
|
localpart: mas_data_model::UpstreamOAuthProviderImportPreference {
|
||||||
|
action: map_import_action(&config.localpart.action),
|
||||||
|
template: config.localpart.template.clone(),
|
||||||
|
},
|
||||||
|
displayname: mas_data_model::UpstreamOAuthProviderImportPreference {
|
||||||
|
action: map_import_action(&config.displayname.action),
|
||||||
|
template: config.displayname.template.clone(),
|
||||||
|
},
|
||||||
|
email: mas_data_model::UpstreamOAuthProviderImportPreference {
|
||||||
|
action: map_import_action(&config.email.action),
|
||||||
|
template: config.email.template.clone(),
|
||||||
|
},
|
||||||
|
verify_email: match config.email.set_email_verification {
|
||||||
|
mas_config::UpstreamOAuth2SetEmailVerification::Always => {
|
||||||
|
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Always
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2SetEmailVerification::Never => {
|
||||||
|
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Never
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2SetEmailVerification::Import => {
|
||||||
|
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(name = "config.sync", skip_all, err(Debug))]
|
||||||
|
pub async fn config_sync(
|
||||||
|
upstream_oauth2_config: UpstreamOAuth2Config,
|
||||||
|
clients_config: ClientsConfig,
|
||||||
|
connection: &mut PgConnection,
|
||||||
|
encrypter: &Encrypter,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
prune: bool,
|
||||||
|
dry_run: bool,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
// Start a transaction
|
||||||
|
let txn = connection.begin().await?;
|
||||||
|
|
||||||
|
// Grab a lock within the transaction
|
||||||
|
tracing::info!("Acquiring configuration lock");
|
||||||
|
let lock = PgAdvisoryLock::new("MAS config sync");
|
||||||
|
let lock = lock.acquire(txn).await?;
|
||||||
|
|
||||||
|
// Create a repository from the connection with the lock
|
||||||
|
let mut repo = PgRepository::from_conn(lock);
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
prune,
|
||||||
|
dry_run,
|
||||||
|
"Syncing providers and clients defined in config to database"
|
||||||
|
);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _span = info_span!("cli.config.sync.providers").entered();
|
||||||
|
let config_ids = upstream_oauth2_config
|
||||||
|
.providers
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.id)
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
let existing = repo.upstream_oauth_provider().all().await?;
|
||||||
|
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||||
|
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||||
|
if prune {
|
||||||
|
for provider in to_delete {
|
||||||
|
info!(%provider.id, "Deleting provider");
|
||||||
|
|
||||||
|
if dry_run {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.upstream_oauth_provider().delete(provider).await?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let len = to_delete.count();
|
||||||
|
match len {
|
||||||
|
0 => {},
|
||||||
|
1 => warn!("A provider in the database is not in the config. Run `mas-cli config sync --prune` to delete it."),
|
||||||
|
n => warn!("{n} providers in the database are not in the config. Run `mas-cli config sync --prune` to delete them."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for provider in upstream_oauth2_config.providers {
|
||||||
|
let _span = info_span!("provider", %provider.id).entered();
|
||||||
|
if existing_ids.contains(&provider.id) {
|
||||||
|
info!("Updating provider");
|
||||||
|
} else {
|
||||||
|
info!("Adding provider");
|
||||||
|
}
|
||||||
|
|
||||||
|
if dry_run {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let encrypted_client_secret = provider
|
||||||
|
.client_secret()
|
||||||
|
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||||
|
.transpose()?;
|
||||||
|
let token_endpoint_auth_method = provider.client_auth_method();
|
||||||
|
let token_endpoint_signing_alg = provider.client_auth_signing_alg();
|
||||||
|
|
||||||
|
let discovery_mode = match provider.discovery_mode {
|
||||||
|
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if discovery_mode.is_disabled() {
|
||||||
|
if provider.authorization_endpoint.is_none() {
|
||||||
|
error!("Provider has discovery disabled but no authorization endpoint set");
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.token_endpoint.is_none() {
|
||||||
|
error!("Provider has discovery disabled but no token endpoint set");
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.jwks_uri.is_none() {
|
||||||
|
error!("Provider has discovery disabled but no JWKS URI set");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let pkce_mode = match provider.pkce_method {
|
||||||
|
mas_config::UpstreamOAuth2PkceMethod::Auto => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2PkceMethod::Always => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderPkceMode::S256
|
||||||
|
}
|
||||||
|
mas_config::UpstreamOAuth2PkceMethod::Never => {
|
||||||
|
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.upsert(
|
||||||
|
clock,
|
||||||
|
provider.id,
|
||||||
|
UpstreamOAuthProviderParams {
|
||||||
|
issuer: provider.issuer,
|
||||||
|
human_name: provider.human_name,
|
||||||
|
brand_name: provider.brand_name,
|
||||||
|
scope: provider.scope.parse()?,
|
||||||
|
token_endpoint_auth_method,
|
||||||
|
token_endpoint_signing_alg,
|
||||||
|
client_id: provider.client_id,
|
||||||
|
encrypted_client_secret,
|
||||||
|
claims_imports: map_claims_imports(&provider.claims_imports),
|
||||||
|
token_endpoint_override: provider.token_endpoint,
|
||||||
|
authorization_endpoint_override: provider.authorization_endpoint,
|
||||||
|
jwks_uri_override: provider.jwks_uri,
|
||||||
|
discovery_mode,
|
||||||
|
pkce_mode,
|
||||||
|
additional_authorization_parameters: provider
|
||||||
|
.additional_authorization_parameters
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let _span = info_span!("cli.config.sync.clients").entered();
|
||||||
|
let config_ids = clients_config
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.client_id)
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
let existing = repo.oauth2_client().all_static().await?;
|
||||||
|
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||||
|
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||||
|
if prune {
|
||||||
|
for client in to_delete {
|
||||||
|
info!(client.id = %client.client_id, "Deleting client");
|
||||||
|
|
||||||
|
if dry_run {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.oauth2_client().delete(client).await?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let len = to_delete.count();
|
||||||
|
match len {
|
||||||
|
0 => {},
|
||||||
|
1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."),
|
||||||
|
n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for client in clients_config {
|
||||||
|
let _span = info_span!("client", client.id = %client.client_id).entered();
|
||||||
|
if existing_ids.contains(&client.client_id) {
|
||||||
|
info!("Updating client");
|
||||||
|
} else {
|
||||||
|
info!("Adding client");
|
||||||
|
}
|
||||||
|
|
||||||
|
if dry_run {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_secret = client.client_secret();
|
||||||
|
let client_auth_method = client.client_auth_method();
|
||||||
|
let jwks = client.jwks();
|
||||||
|
let jwks_uri = client.jwks_uri();
|
||||||
|
|
||||||
|
// TODO: should be moved somewhere else
|
||||||
|
let encrypted_client_secret = client_secret
|
||||||
|
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
repo.oauth2_client()
|
||||||
|
.upsert_static(
|
||||||
|
client.client_id,
|
||||||
|
client_auth_method,
|
||||||
|
encrypted_client_secret,
|
||||||
|
jwks.cloned(),
|
||||||
|
jwks_uri.cloned(),
|
||||||
|
client.redirect_uris,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the lock and release it to commit the transaction
|
||||||
|
let lock = repo.into_inner();
|
||||||
|
let txn = lock.release_now().await?;
|
||||||
|
if dry_run {
|
||||||
|
info!("Dry run, rolling back changes");
|
||||||
|
txn.rollback().await?;
|
||||||
|
} else {
|
||||||
|
txn.commit().await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@@ -170,6 +170,15 @@ impl DerefMut for ClientsConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for ClientsConfig {
|
||||||
|
type Item = ClientConfig;
|
||||||
|
type IntoIter = std::vec::IntoIter<ClientConfig>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.0.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ConfigurationSection for ClientsConfig {
|
impl ConfigurationSection for ClientsConfig {
|
||||||
fn path() -> &'static str {
|
fn path() -> &'static str {
|
||||||
|
Reference in New Issue
Block a user