diff --git a/crates/cli/src/commands/config.rs b/crates/cli/src/commands/config.rs index 1bfdbb11..fc1d47fa 100644 --- a/crates/cli/src/commands/config.rs +++ b/crates/cli/src/commands/config.rs @@ -12,75 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; - +use anyhow::Context; use camino::Utf8PathBuf; use clap::Parser; use mas_config::{ConfigurationSection, RootConfig, SyncConfig}; -use mas_storage::{ - upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository}, - RepositoryAccess, SystemClock, -}; -use mas_storage_pg::PgRepository; +use mas_storage::SystemClock; +use mas_storage_pg::MIGRATOR; use rand::SeedableRng; -use sqlx::{postgres::PgAdvisoryLock, Acquire}; use tokio::io::AsyncWriteExt; -use tracing::{error, info, info_span, warn}; +use tracing::{info, info_span, Instrument}; use crate::util::database_connection_from_config; -fn map_import_action( - config: &mas_config::UpstreamOAuth2ImportAction, -) -> mas_data_model::UpstreamOAuthProviderImportAction { - match config { - mas_config::UpstreamOAuth2ImportAction::Ignore => { - mas_data_model::UpstreamOAuthProviderImportAction::Ignore - } - mas_config::UpstreamOAuth2ImportAction::Suggest => { - mas_data_model::UpstreamOAuthProviderImportAction::Suggest - } - mas_config::UpstreamOAuth2ImportAction::Force => { - mas_data_model::UpstreamOAuthProviderImportAction::Force - } - mas_config::UpstreamOAuth2ImportAction::Require => { - mas_data_model::UpstreamOAuthProviderImportAction::Require - } - } -} - -fn map_claims_imports( - config: &mas_config::UpstreamOAuth2ClaimsImports, -) -> mas_data_model::UpstreamOAuthProviderClaimsImports { - mas_data_model::UpstreamOAuthProviderClaimsImports { - subject: mas_data_model::UpstreamOAuthProviderSubjectPreference { - template: config.subject.template.clone(), - }, - localpart: mas_data_model::UpstreamOAuthProviderImportPreference { - action: map_import_action(&config.localpart.action), - template: config.localpart.template.clone(), - }, - displayname: mas_data_model::UpstreamOAuthProviderImportPreference { - action: map_import_action(&config.displayname.action), - template: config.displayname.template.clone(), - }, - email: mas_data_model::UpstreamOAuthProviderImportPreference { - action: map_import_action(&config.email.action), - template: config.email.template.clone(), - }, - verify_email: match config.email.set_email_verification { - mas_config::UpstreamOAuth2SetEmailVerification::Always => { - mas_data_model::UpsreamOAuthProviderSetEmailVerification::Always - } - mas_config::UpstreamOAuth2SetEmailVerification::Never => { - mas_data_model::UpsreamOAuthProviderSetEmailVerification::Never - } - mas_config::UpstreamOAuth2SetEmailVerification::Import => { - mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import - } - }, - } -} - #[derive(Parser, Debug)] pub(super) struct Options { #[command(subcommand)] @@ -169,230 +112,32 @@ impl Options { } SC::Sync { prune, dry_run } => { - sync(root, prune, dry_run).await?; + let config: SyncConfig = root.load_config()?; + let clock = SystemClock::default(); + let encrypter = config.secrets.encrypter(); + + // Grab a connection to the database + let mut conn = database_connection_from_config(&config.database).await?; + + MIGRATOR + .run(&mut conn) + .instrument(info_span!("db.migrate")) + .await + .context("could not run migrations")?; + + crate::sync::config_sync( + config.upstream_oauth2, + config.clients, + &mut conn, + &encrypter, + &clock, + prune, + dry_run, + ) + .await?; } } Ok(()) } } - -#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))] -async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> { - // XXX: we should disallow SeedableRng::from_entropy - let clock = SystemClock::default(); - - let config: SyncConfig = root.load_config()?; - let encrypter = config.secrets.encrypter(); - // Grab a connection to the database - let mut conn = database_connection_from_config(&config.database).await?; - // Start a transaction - let txn = conn.begin().await?; - - // Grab a lock within the transaction - tracing::info!("Acquiring config lock"); - let lock = PgAdvisoryLock::new("MAS config sync"); - let lock = lock.acquire(txn).await?; - - // Create a repository from the connection with the lock - let mut repo = PgRepository::from_conn(lock); - - tracing::info!( - prune, - dry_run, - "Syncing providers and clients defined in config to database" - ); - - { - let _span = info_span!("cli.config.sync.providers").entered(); - let config_ids = config - .upstream_oauth2 - .providers - .iter() - .map(|p| p.id) - .collect::>(); - - let existing = repo.upstream_oauth_provider().all().await?; - let existing_ids = existing.iter().map(|p| p.id).collect::>(); - let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); - if prune { - for provider in to_delete { - info!(%provider.id, "Deleting provider"); - - if dry_run { - continue; - } - - repo.upstream_oauth_provider().delete(provider).await?; - } - } else { - let len = to_delete.count(); - match len { - 0 => {}, - 1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."), - n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."), - } - } - - for provider in config.upstream_oauth2.providers { - let _span = info_span!("provider", %provider.id).entered(); - if existing_ids.contains(&provider.id) { - info!("Updating provider"); - } else { - info!("Adding provider"); - } - - if dry_run { - continue; - } - - let encrypted_client_secret = provider - .client_secret() - .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) - .transpose()?; - let token_endpoint_auth_method = provider.client_auth_method(); - let token_endpoint_signing_alg = provider.client_auth_signing_alg(); - - let discovery_mode = match provider.discovery_mode { - mas_config::UpstreamOAuth2DiscoveryMode::Oidc => { - mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc - } - mas_config::UpstreamOAuth2DiscoveryMode::Insecure => { - mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure - } - mas_config::UpstreamOAuth2DiscoveryMode::Disabled => { - mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled - } - }; - - if discovery_mode.is_disabled() { - if provider.authorization_endpoint.is_none() { - error!("Provider has discovery disabled but no authorization endpoint set"); - } - - if provider.token_endpoint.is_none() { - error!("Provider has discovery disabled but no token endpoint set"); - } - - if provider.jwks_uri.is_none() { - error!("Provider has discovery disabled but no JWKS URI set"); - } - } - - let pkce_mode = match provider.pkce_method { - mas_config::UpstreamOAuth2PkceMethod::Auto => { - mas_data_model::UpstreamOAuthProviderPkceMode::Auto - } - mas_config::UpstreamOAuth2PkceMethod::Always => { - mas_data_model::UpstreamOAuthProviderPkceMode::S256 - } - mas_config::UpstreamOAuth2PkceMethod::Never => { - mas_data_model::UpstreamOAuthProviderPkceMode::Disabled - } - }; - - repo.upstream_oauth_provider() - .upsert( - &clock, - provider.id, - UpstreamOAuthProviderParams { - issuer: provider.issuer, - human_name: provider.human_name, - brand_name: provider.brand_name, - scope: provider.scope.parse()?, - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id: provider.client_id, - encrypted_client_secret, - claims_imports: map_claims_imports(&provider.claims_imports), - token_endpoint_override: provider.token_endpoint, - authorization_endpoint_override: provider.authorization_endpoint, - jwks_uri_override: provider.jwks_uri, - discovery_mode, - pkce_mode, - additional_authorization_parameters: provider - .additional_authorization_parameters - .into_iter() - .collect(), - }, - ) - .await?; - } - } - - { - let _span = info_span!("cli.config.sync.clients").entered(); - let config_ids = config - .clients - .iter() - .map(|c| c.client_id) - .collect::>(); - - let existing = repo.oauth2_client().all_static().await?; - let existing_ids = existing.iter().map(|p| p.id).collect::>(); - let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); - if prune { - for client in to_delete { - info!(client.id = %client.client_id, "Deleting client"); - - if dry_run { - continue; - } - - repo.oauth2_client().delete(client).await?; - } - } else { - let len = to_delete.count(); - match len { - 0 => {}, - 1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."), - n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."), - } - } - - for client in config.clients.iter() { - let _span = info_span!("client", client.id = %client.client_id).entered(); - if existing_ids.contains(&client.client_id) { - info!("Updating client"); - } else { - info!("Adding client"); - } - - if dry_run { - continue; - } - - let client_secret = client.client_secret(); - let client_auth_method = client.client_auth_method(); - let jwks = client.jwks(); - let jwks_uri = client.jwks_uri(); - - // TODO: should be moved somewhere else - let encrypted_client_secret = client_secret - .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) - .transpose()?; - - repo.oauth2_client() - .upsert_static( - client.client_id, - client_auth_method, - encrypted_client_secret, - jwks.cloned(), - jwks_uri.cloned(), - client.redirect_uris.clone(), - ) - .await?; - } - } - - // Get the lock and release it to commit the transaction - let lock = repo.into_inner(); - let txn = lock.release_now().await?; - if dry_run { - info!("Dry run, rolling back changes"); - txn.rollback().await?; - } else { - txn.commit().await?; - } - Ok(()) -} diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index c8b5b091..f4673637 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -22,6 +22,7 @@ use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCa use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; +use mas_storage::SystemClock; use mas_storage_pg::MIGRATOR; use rand::{ distributions::{Alphanumeric, DistString}, @@ -39,6 +40,7 @@ use crate::{ }, }; +#[allow(clippy::struct_excessive_bools)] #[derive(Parser, Debug, Default)] pub(super) struct Options { /// Do not apply pending migrations on start @@ -53,6 +55,10 @@ pub(super) struct Options { /// Do not start the task worker #[arg(long)] no_worker: bool, + + /// Do not sync the configuration with the database + #[arg(long)] + no_sync: bool, } impl Options { @@ -88,6 +94,28 @@ impl Options { .context("could not run migrations")?; } + let encrypter = config.secrets.encrypter(); + + if self.no_sync { + info!("Skipping configuration sync"); + } else { + // Sync the configuration with the database + let mut conn = pool.acquire().await?; + let clients_config = root.load_config()?; + let upstream_oauth2_config = root.load_config()?; + + crate::sync::config_sync( + upstream_oauth2_config, + clients_config, + &mut conn, + &encrypter, + &SystemClock::default(), + false, + false, + ) + .await?; + } + // Initialize the key store let key_store = config .secrets @@ -95,7 +123,6 @@ impl Options { .await .context("could not import keys from config")?; - let encrypter = config.secrets.encrypter(); let cookie_manager = CookieManager::derive_from(config.http.public_base.clone(), &config.secrets.encryption); diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 4c183c48..67494e6e 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -30,6 +30,7 @@ mod app_state; mod commands; mod sentry_transport; mod server; +mod sync; mod telemetry; mod util; diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs new file mode 100644 index 00000000..e9042877 --- /dev/null +++ b/crates/cli/src/sync.rs @@ -0,0 +1,295 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utilities to synchronize the configuration file with the database. + +use std::collections::HashSet; + +use mas_config::{ClientsConfig, UpstreamOAuth2Config}; +use mas_keystore::Encrypter; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderParams, Clock, RepositoryAccess}; +use mas_storage_pg::PgRepository; +use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection}; +use tracing::{error, info, info_span, warn}; + +fn map_import_action( + config: &mas_config::UpstreamOAuth2ImportAction, +) -> mas_data_model::UpstreamOAuthProviderImportAction { + match config { + mas_config::UpstreamOAuth2ImportAction::Ignore => { + mas_data_model::UpstreamOAuthProviderImportAction::Ignore + } + mas_config::UpstreamOAuth2ImportAction::Suggest => { + mas_data_model::UpstreamOAuthProviderImportAction::Suggest + } + mas_config::UpstreamOAuth2ImportAction::Force => { + mas_data_model::UpstreamOAuthProviderImportAction::Force + } + mas_config::UpstreamOAuth2ImportAction::Require => { + mas_data_model::UpstreamOAuthProviderImportAction::Require + } + } +} + +fn map_claims_imports( + config: &mas_config::UpstreamOAuth2ClaimsImports, +) -> mas_data_model::UpstreamOAuthProviderClaimsImports { + mas_data_model::UpstreamOAuthProviderClaimsImports { + subject: mas_data_model::UpstreamOAuthProviderSubjectPreference { + template: config.subject.template.clone(), + }, + localpart: mas_data_model::UpstreamOAuthProviderImportPreference { + action: map_import_action(&config.localpart.action), + template: config.localpart.template.clone(), + }, + displayname: mas_data_model::UpstreamOAuthProviderImportPreference { + action: map_import_action(&config.displayname.action), + template: config.displayname.template.clone(), + }, + email: mas_data_model::UpstreamOAuthProviderImportPreference { + action: map_import_action(&config.email.action), + template: config.email.template.clone(), + }, + verify_email: match config.email.set_email_verification { + mas_config::UpstreamOAuth2SetEmailVerification::Always => { + mas_data_model::UpsreamOAuthProviderSetEmailVerification::Always + } + mas_config::UpstreamOAuth2SetEmailVerification::Never => { + mas_data_model::UpsreamOAuthProviderSetEmailVerification::Never + } + mas_config::UpstreamOAuth2SetEmailVerification::Import => { + mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import + } + }, + } +} + +#[tracing::instrument(name = "config.sync", skip_all, err(Debug))] +pub async fn config_sync( + upstream_oauth2_config: UpstreamOAuth2Config, + clients_config: ClientsConfig, + connection: &mut PgConnection, + encrypter: &Encrypter, + clock: &dyn Clock, + prune: bool, + dry_run: bool, +) -> anyhow::Result<()> { + // Start a transaction + let txn = connection.begin().await?; + + // Grab a lock within the transaction + tracing::info!("Acquiring configuration lock"); + let lock = PgAdvisoryLock::new("MAS config sync"); + let lock = lock.acquire(txn).await?; + + // Create a repository from the connection with the lock + let mut repo = PgRepository::from_conn(lock); + + tracing::info!( + prune, + dry_run, + "Syncing providers and clients defined in config to database" + ); + + { + let _span = info_span!("cli.config.sync.providers").entered(); + let config_ids = upstream_oauth2_config + .providers + .iter() + .map(|p| p.id) + .collect::>(); + + let existing = repo.upstream_oauth_provider().all().await?; + let existing_ids = existing.iter().map(|p| p.id).collect::>(); + let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); + if prune { + for provider in to_delete { + info!(%provider.id, "Deleting provider"); + + if dry_run { + continue; + } + + repo.upstream_oauth_provider().delete(provider).await?; + } + } else { + let len = to_delete.count(); + match len { + 0 => {}, + 1 => warn!("A provider in the database is not in the config. Run `mas-cli config sync --prune` to delete it."), + n => warn!("{n} providers in the database are not in the config. Run `mas-cli config sync --prune` to delete them."), + } + } + + for provider in upstream_oauth2_config.providers { + let _span = info_span!("provider", %provider.id).entered(); + if existing_ids.contains(&provider.id) { + info!("Updating provider"); + } else { + info!("Adding provider"); + } + + if dry_run { + continue; + } + + let encrypted_client_secret = provider + .client_secret() + .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) + .transpose()?; + let token_endpoint_auth_method = provider.client_auth_method(); + let token_endpoint_signing_alg = provider.client_auth_signing_alg(); + + let discovery_mode = match provider.discovery_mode { + mas_config::UpstreamOAuth2DiscoveryMode::Oidc => { + mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc + } + mas_config::UpstreamOAuth2DiscoveryMode::Insecure => { + mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure + } + mas_config::UpstreamOAuth2DiscoveryMode::Disabled => { + mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled + } + }; + + if discovery_mode.is_disabled() { + if provider.authorization_endpoint.is_none() { + error!("Provider has discovery disabled but no authorization endpoint set"); + } + + if provider.token_endpoint.is_none() { + error!("Provider has discovery disabled but no token endpoint set"); + } + + if provider.jwks_uri.is_none() { + error!("Provider has discovery disabled but no JWKS URI set"); + } + } + + let pkce_mode = match provider.pkce_method { + mas_config::UpstreamOAuth2PkceMethod::Auto => { + mas_data_model::UpstreamOAuthProviderPkceMode::Auto + } + mas_config::UpstreamOAuth2PkceMethod::Always => { + mas_data_model::UpstreamOAuthProviderPkceMode::S256 + } + mas_config::UpstreamOAuth2PkceMethod::Never => { + mas_data_model::UpstreamOAuthProviderPkceMode::Disabled + } + }; + + repo.upstream_oauth_provider() + .upsert( + clock, + provider.id, + UpstreamOAuthProviderParams { + issuer: provider.issuer, + human_name: provider.human_name, + brand_name: provider.brand_name, + scope: provider.scope.parse()?, + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id: provider.client_id, + encrypted_client_secret, + claims_imports: map_claims_imports(&provider.claims_imports), + token_endpoint_override: provider.token_endpoint, + authorization_endpoint_override: provider.authorization_endpoint, + jwks_uri_override: provider.jwks_uri, + discovery_mode, + pkce_mode, + additional_authorization_parameters: provider + .additional_authorization_parameters + .into_iter() + .collect(), + }, + ) + .await?; + } + } + + { + let _span = info_span!("cli.config.sync.clients").entered(); + let config_ids = clients_config + .iter() + .map(|c| c.client_id) + .collect::>(); + + let existing = repo.oauth2_client().all_static().await?; + let existing_ids = existing.iter().map(|p| p.id).collect::>(); + let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); + if prune { + for client in to_delete { + info!(client.id = %client.client_id, "Deleting client"); + + if dry_run { + continue; + } + + repo.oauth2_client().delete(client).await?; + } + } else { + let len = to_delete.count(); + match len { + 0 => {}, + 1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."), + n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."), + } + } + + for client in clients_config { + let _span = info_span!("client", client.id = %client.client_id).entered(); + if existing_ids.contains(&client.client_id) { + info!("Updating client"); + } else { + info!("Adding client"); + } + + if dry_run { + continue; + } + + let client_secret = client.client_secret(); + let client_auth_method = client.client_auth_method(); + let jwks = client.jwks(); + let jwks_uri = client.jwks_uri(); + + // TODO: should be moved somewhere else + let encrypted_client_secret = client_secret + .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) + .transpose()?; + + repo.oauth2_client() + .upsert_static( + client.client_id, + client_auth_method, + encrypted_client_secret, + jwks.cloned(), + jwks_uri.cloned(), + client.redirect_uris, + ) + .await?; + } + } + + // Get the lock and release it to commit the transaction + let lock = repo.into_inner(); + let txn = lock.release_now().await?; + if dry_run { + info!("Dry run, rolling back changes"); + txn.rollback().await?; + } else { + txn.commit().await?; + } + Ok(()) +} diff --git a/crates/config/src/sections/clients.rs b/crates/config/src/sections/clients.rs index 837f8c91..1cd36fde 100644 --- a/crates/config/src/sections/clients.rs +++ b/crates/config/src/sections/clients.rs @@ -170,6 +170,15 @@ impl DerefMut for ClientsConfig { } } +impl IntoIterator for ClientsConfig { + type Item = ClientConfig; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + #[async_trait] impl ConfigurationSection for ClientsConfig { fn path() -> &'static str {