1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Soft-delete upstream OAuth 2.0 providers on config sync

This commit is contained in:
Quentin Gliech
2024-04-03 09:18:22 +02:00
parent 4e3823fe4f
commit cd0ec35d2f
6 changed files with 108 additions and 20 deletions

View File

@@ -14,11 +14,14 @@
//! Utilities to synchronize the configuration file with the database. //! Utilities to synchronize the configuration file with the database.
use std::collections::HashSet; use std::collections::{BTreeMap, BTreeSet};
use mas_config::{ClientsConfig, UpstreamOAuth2Config}; use mas_config::{ClientsConfig, UpstreamOAuth2Config};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderParams, Clock, RepositoryAccess}; use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams},
Clock, Pagination, RepositoryAccess,
};
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection}; use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection};
use tracing::{error, info, info_span, warn}; use tracing::{error, info, info_span, warn};
@@ -107,35 +110,83 @@ pub async fn config_sync(
let config_ids = upstream_oauth2_config let config_ids = upstream_oauth2_config
.providers .providers
.iter() .iter()
.filter(|p| p.enabled)
.map(|p| p.id) .map(|p| p.id)
.collect::<HashSet<_>>(); .collect::<BTreeSet<_>>();
// Let's assume we have less than 1000 providers
let page = repo
.upstream_oauth_provider()
.list(
UpstreamOAuthProviderFilter::default(),
Pagination::first(1000),
)
.await?;
// A warning is probably enough
if page.has_next_page {
warn!(
"More than 1000 providers in the database, only the first 1000 will be considered"
);
}
let mut existing_enabled_ids = BTreeSet::new();
let mut existing_disabled = BTreeMap::new();
// Process the existing providers
for provider in page.edges {
if provider.enabled() {
if config_ids.contains(&provider.id) {
existing_enabled_ids.insert(provider.id);
} else {
// Provider is enabled in the database but not in the config
info!(%provider.id, "Disabling provider");
let provider = if dry_run {
provider
} else {
repo.upstream_oauth_provider()
.disable(clock, provider)
.await?
};
existing_disabled.insert(provider.id, provider);
}
} else {
existing_disabled.insert(provider.id, provider);
}
}
let existing = repo.upstream_oauth_provider().all_enabled().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune { if prune {
for provider in to_delete { for provider_id in existing_disabled.keys().copied() {
info!(%provider.id, "Deleting provider"); info!(provider.id = %provider_id, "Deleting provider");
if dry_run { if dry_run {
continue; continue;
} }
repo.upstream_oauth_provider().delete(provider).await?; repo.upstream_oauth_provider()
.delete_by_id(provider_id)
.await?;
} }
} else { } else {
let len = to_delete.count(); let len = existing_disabled.len();
match len { match len {
0 => {}, 0 => {},
1 => warn!("A provider in the database is not in the config. Run `mas-cli config sync --prune` to delete it."), 1 => warn!("A provider is soft-deleted in the database. Run `mas-cli config sync --prune` to delete it."),
n => warn!("{n} providers in the database are not in the config. Run `mas-cli config sync --prune` to delete them."), n => warn!("{n} providers are soft-deleted in the database. Run `mas-cli config sync --prune` to delete them."),
} }
} }
for provider in upstream_oauth2_config.providers { for provider in upstream_oauth2_config.providers {
if !provider.enabled {
continue;
}
let _span = info_span!("provider", %provider.id).entered(); let _span = info_span!("provider", %provider.id).entered();
if existing_ids.contains(&provider.id) { if existing_enabled_ids.contains(&provider.id) {
info!("Updating provider"); info!("Updating provider");
} else if existing_disabled.contains_key(&provider.id) {
info!("Enabling and updating provider");
} else { } else {
info!("Adding provider"); info!("Adding provider");
} }
@@ -224,10 +275,10 @@ pub async fn config_sync(
let config_ids = clients_config let config_ids = clients_config
.iter() .iter()
.map(|c| c.client_id) .map(|c| c.client_id)
.collect::<HashSet<_>>(); .collect::<BTreeSet<_>>();
let existing = repo.oauth2_client().all_static().await?; let existing = repo.oauth2_client().all_static().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>(); let existing_ids = existing.iter().map(|p| p.id).collect::<BTreeSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune { if prune {
for client in to_delete { for client in to_delete {

View File

@@ -342,9 +342,24 @@ impl PkceMethod {
} }
} }
fn default_true() -> bool {
true
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_default_true(value: &bool) -> bool {
*value
}
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider { pub struct Provider {
/// Whether this provider is enabled.
///
/// Defaults to `true`
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub enabled: bool,
/// An internal unique identifier for this provider /// An internal unique identifier for this provider
#[schemars( #[schemars(
with = "String", with = "String",

View File

@@ -146,6 +146,18 @@ pub struct UpstreamOAuthProvider {
pub additional_authorization_parameters: Vec<(String, String)>, pub additional_authorization_parameters: Vec<(String, String)>,
} }
impl PartialOrd for UpstreamOAuthProvider {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.id.cmp(&other.id))
}
}
impl Ord for UpstreamOAuthProvider {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.id.cmp(&other.id)
}
}
impl UpstreamOAuthProvider { impl UpstreamOAuthProvider {
/// Returns `true` if the provider is enabled /// Returns `true` if the provider is enabled
#[must_use] #[must_use]

View File

@@ -515,8 +515,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
async fn disable( async fn disable(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
upstream_oauth_provider: UpstreamOAuthProvider, mut upstream_oauth_provider: UpstreamOAuthProvider,
) -> Result<(), Self::Error> { ) -> Result<UpstreamOAuthProvider, Self::Error> {
let disabled_at = clock.now(); let disabled_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@@ -531,7 +531,11 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
.execute(&mut *self.conn) .execute(&mut *self.conn)
.await?; .await?;
DatabaseError::ensure_affected_rows(&res, 1) DatabaseError::ensure_affected_rows(&res, 1)?;
upstream_oauth_provider.disabled_at = Some(disabled_at);
Ok(upstream_oauth_provider)
} }
#[tracing::instrument( #[tracing::instrument(

View File

@@ -204,6 +204,8 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
/// Disable an upstream OAuth provider /// Disable an upstream OAuth provider
/// ///
/// Returns the disabled provider
///
/// # Parameters /// # Parameters
/// ///
/// * `clock`: The clock used to generate timestamps /// * `clock`: The clock used to generate timestamps
@@ -216,7 +218,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
provider: UpstreamOAuthProvider, provider: UpstreamOAuthProvider,
) -> Result<(), Self::Error>; ) -> Result<UpstreamOAuthProvider, Self::Error>;
/// List [`UpstreamOAuthProvider`] with the given filter and pagination /// List [`UpstreamOAuthProvider`] with the given filter and pagination
/// ///
@@ -281,7 +283,7 @@ repository_impl!(UpstreamOAuthProviderRepository:
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
provider: UpstreamOAuthProvider provider: UpstreamOAuthProvider
) -> Result<(), Self::Error>; ) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn list( async fn list(
&mut self, &mut self,

View File

@@ -1562,6 +1562,10 @@
"token_endpoint_auth_method" "token_endpoint_auth_method"
], ],
"properties": { "properties": {
"enabled": {
"description": "Whether this provider is enabled.\n\nDefaults to `true`",
"type": "boolean"
},
"id": { "id": {
"description": "A ULID as per https://github.com/ulid/spec", "description": "A ULID as per https://github.com/ulid/spec",
"type": "string", "type": "string",