diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index b0d22cdc..72cbe64c 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -14,11 +14,14 @@ //! Utilities to synchronize the configuration file with the database. -use std::collections::HashSet; +use std::collections::{BTreeMap, BTreeSet}; use mas_config::{ClientsConfig, UpstreamOAuth2Config}; use mas_keystore::Encrypter; -use mas_storage::{upstream_oauth2::UpstreamOAuthProviderParams, Clock, RepositoryAccess}; +use mas_storage::{ + upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams}, + Clock, Pagination, RepositoryAccess, +}; use mas_storage_pg::PgRepository; use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection}; use tracing::{error, info, info_span, warn}; @@ -107,35 +110,83 @@ pub async fn config_sync( let config_ids = upstream_oauth2_config .providers .iter() + .filter(|p| p.enabled) .map(|p| p.id) - .collect::>(); + .collect::>(); + + // Let's assume we have less than 1000 providers + let page = repo + .upstream_oauth_provider() + .list( + UpstreamOAuthProviderFilter::default(), + Pagination::first(1000), + ) + .await?; + + // A warning is probably enough + if page.has_next_page { + warn!( + "More than 1000 providers in the database, only the first 1000 will be considered" + ); + } + + let mut existing_enabled_ids = BTreeSet::new(); + let mut existing_disabled = BTreeMap::new(); + // Process the existing providers + for provider in page.edges { + if provider.enabled() { + if config_ids.contains(&provider.id) { + existing_enabled_ids.insert(provider.id); + } else { + // Provider is enabled in the database but not in the config + info!(%provider.id, "Disabling provider"); + + let provider = if dry_run { + provider + } else { + repo.upstream_oauth_provider() + .disable(clock, provider) + .await? + }; + + existing_disabled.insert(provider.id, provider); + } + } else { + existing_disabled.insert(provider.id, provider); + } + } - let existing = repo.upstream_oauth_provider().all_enabled().await?; - let existing_ids = existing.iter().map(|p| p.id).collect::>(); - let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); if prune { - for provider in to_delete { - info!(%provider.id, "Deleting provider"); + for provider_id in existing_disabled.keys().copied() { + info!(provider.id = %provider_id, "Deleting provider"); if dry_run { continue; } - repo.upstream_oauth_provider().delete(provider).await?; + repo.upstream_oauth_provider() + .delete_by_id(provider_id) + .await?; } } else { - let len = to_delete.count(); + let len = existing_disabled.len(); match len { 0 => {}, - 1 => warn!("A provider in the database is not in the config. Run `mas-cli config sync --prune` to delete it."), - n => warn!("{n} providers in the database are not in the config. Run `mas-cli config sync --prune` to delete them."), + 1 => warn!("A provider is soft-deleted in the database. Run `mas-cli config sync --prune` to delete it."), + n => warn!("{n} providers are soft-deleted in the database. Run `mas-cli config sync --prune` to delete them."), } } for provider in upstream_oauth2_config.providers { + if !provider.enabled { + continue; + } + let _span = info_span!("provider", %provider.id).entered(); - if existing_ids.contains(&provider.id) { + if existing_enabled_ids.contains(&provider.id) { info!("Updating provider"); + } else if existing_disabled.contains_key(&provider.id) { + info!("Enabling and updating provider"); } else { info!("Adding provider"); } @@ -224,10 +275,10 @@ pub async fn config_sync( let config_ids = clients_config .iter() .map(|c| c.client_id) - .collect::>(); + .collect::>(); let existing = repo.oauth2_client().all_static().await?; - let existing_ids = existing.iter().map(|p| p.id).collect::>(); + let existing_ids = existing.iter().map(|p| p.id).collect::>(); let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); if prune { for client in to_delete { diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index 0b1d4d82..bca320c5 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -342,9 +342,24 @@ impl PkceMethod { } } +fn default_true() -> bool { + true +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn is_default_true(value: &bool) -> bool { + *value +} + #[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct Provider { + /// Whether this provider is enabled. + /// + /// Defaults to `true` + #[serde(default = "default_true", skip_serializing_if = "is_default_true")] + pub enabled: bool, + /// An internal unique identifier for this provider #[schemars( with = "String", diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index f3eb76fd..2b8a4010 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -146,6 +146,18 @@ pub struct UpstreamOAuthProvider { pub additional_authorization_parameters: Vec<(String, String)>, } +impl PartialOrd for UpstreamOAuthProvider { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.id.cmp(&other.id)) + } +} + +impl Ord for UpstreamOAuthProvider { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.id.cmp(&other.id) + } +} + impl UpstreamOAuthProvider { /// Returns `true` if the provider is enabled #[must_use] diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 5da62444..4603d8ab 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -515,8 +515,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' async fn disable( &mut self, clock: &dyn Clock, - upstream_oauth_provider: UpstreamOAuthProvider, - ) -> Result<(), Self::Error> { + mut upstream_oauth_provider: UpstreamOAuthProvider, + ) -> Result { let disabled_at = clock.now(); let res = sqlx::query!( r#" @@ -531,7 +531,11 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' .execute(&mut *self.conn) .await?; - DatabaseError::ensure_affected_rows(&res, 1) + DatabaseError::ensure_affected_rows(&res, 1)?; + + upstream_oauth_provider.disabled_at = Some(disabled_at); + + Ok(upstream_oauth_provider) } #[tracing::instrument( diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index af6c0b58..bf8c31b8 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -204,6 +204,8 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { /// Disable an upstream OAuth provider /// + /// Returns the disabled provider + /// /// # Parameters /// /// * `clock`: The clock used to generate timestamps @@ -216,7 +218,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { &mut self, clock: &dyn Clock, provider: UpstreamOAuthProvider, - ) -> Result<(), Self::Error>; + ) -> Result; /// List [`UpstreamOAuthProvider`] with the given filter and pagination /// @@ -281,7 +283,7 @@ repository_impl!(UpstreamOAuthProviderRepository: &mut self, clock: &dyn Clock, provider: UpstreamOAuthProvider - ) -> Result<(), Self::Error>; + ) -> Result; async fn list( &mut self, diff --git a/docs/config.schema.json b/docs/config.schema.json index 0b607a71..4212cb32 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1562,6 +1562,10 @@ "token_endpoint_auth_method" ], "properties": { + "enabled": { + "description": "Whether this provider is enabled.\n\nDefaults to `true`", + "type": "boolean" + }, "id": { "description": "A ULID as per https://github.com/ulid/spec", "type": "string",