You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-04 18:42:14 +03:00
Soft-delete upstream OAuth 2.0 providers on config sync
This commit is contained in:
@@ -14,11 +14,14 @@
|
||||
|
||||
//! Utilities to synchronize the configuration file with the database.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use mas_config::{ClientsConfig, UpstreamOAuth2Config};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderParams, Clock, RepositoryAccess};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams},
|
||||
Clock, Pagination, RepositoryAccess,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use sqlx::{postgres::PgAdvisoryLock, Connection, PgConnection};
|
||||
use tracing::{error, info, info_span, warn};
|
||||
@@ -107,35 +110,83 @@ pub async fn config_sync(
|
||||
let config_ids = upstream_oauth2_config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|p| p.enabled)
|
||||
.map(|p| p.id)
|
||||
.collect::<HashSet<_>>();
|
||||
.collect::<BTreeSet<_>>();
|
||||
|
||||
// Let's assume we have less than 1000 providers
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list(
|
||||
UpstreamOAuthProviderFilter::default(),
|
||||
Pagination::first(1000),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// A warning is probably enough
|
||||
if page.has_next_page {
|
||||
warn!(
|
||||
"More than 1000 providers in the database, only the first 1000 will be considered"
|
||||
);
|
||||
}
|
||||
|
||||
let mut existing_enabled_ids = BTreeSet::new();
|
||||
let mut existing_disabled = BTreeMap::new();
|
||||
// Process the existing providers
|
||||
for provider in page.edges {
|
||||
if provider.enabled() {
|
||||
if config_ids.contains(&provider.id) {
|
||||
existing_enabled_ids.insert(provider.id);
|
||||
} else {
|
||||
// Provider is enabled in the database but not in the config
|
||||
info!(%provider.id, "Disabling provider");
|
||||
|
||||
let provider = if dry_run {
|
||||
provider
|
||||
} else {
|
||||
repo.upstream_oauth_provider()
|
||||
.disable(clock, provider)
|
||||
.await?
|
||||
};
|
||||
|
||||
existing_disabled.insert(provider.id, provider);
|
||||
}
|
||||
} else {
|
||||
existing_disabled.insert(provider.id, provider);
|
||||
}
|
||||
}
|
||||
|
||||
let existing = repo.upstream_oauth_provider().all_enabled().await?;
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
for provider in to_delete {
|
||||
info!(%provider.id, "Deleting provider");
|
||||
for provider_id in existing_disabled.keys().copied() {
|
||||
info!(provider.id = %provider_id, "Deleting provider");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
repo.upstream_oauth_provider().delete(provider).await?;
|
||||
repo.upstream_oauth_provider()
|
||||
.delete_by_id(provider_id)
|
||||
.await?;
|
||||
}
|
||||
} else {
|
||||
let len = to_delete.count();
|
||||
let len = existing_disabled.len();
|
||||
match len {
|
||||
0 => {},
|
||||
1 => warn!("A provider in the database is not in the config. Run `mas-cli config sync --prune` to delete it."),
|
||||
n => warn!("{n} providers in the database are not in the config. Run `mas-cli config sync --prune` to delete them."),
|
||||
1 => warn!("A provider is soft-deleted in the database. Run `mas-cli config sync --prune` to delete it."),
|
||||
n => warn!("{n} providers are soft-deleted in the database. Run `mas-cli config sync --prune` to delete them."),
|
||||
}
|
||||
}
|
||||
|
||||
for provider in upstream_oauth2_config.providers {
|
||||
if !provider.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
let _span = info_span!("provider", %provider.id).entered();
|
||||
if existing_ids.contains(&provider.id) {
|
||||
if existing_enabled_ids.contains(&provider.id) {
|
||||
info!("Updating provider");
|
||||
} else if existing_disabled.contains_key(&provider.id) {
|
||||
info!("Enabling and updating provider");
|
||||
} else {
|
||||
info!("Adding provider");
|
||||
}
|
||||
@@ -224,10 +275,10 @@ pub async fn config_sync(
|
||||
let config_ids = clients_config
|
||||
.iter()
|
||||
.map(|c| c.client_id)
|
||||
.collect::<HashSet<_>>();
|
||||
.collect::<BTreeSet<_>>();
|
||||
|
||||
let existing = repo.oauth2_client().all_static().await?;
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<BTreeSet<_>>();
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
for client in to_delete {
|
||||
|
@@ -342,9 +342,24 @@ impl PkceMethod {
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||
fn is_default_true(value: &bool) -> bool {
|
||||
*value
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct Provider {
|
||||
/// Whether this provider is enabled.
|
||||
///
|
||||
/// Defaults to `true`
|
||||
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// An internal unique identifier for this provider
|
||||
#[schemars(
|
||||
with = "String",
|
||||
|
@@ -146,6 +146,18 @@ pub struct UpstreamOAuthProvider {
|
||||
pub additional_authorization_parameters: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl PartialOrd for UpstreamOAuthProvider {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.id.cmp(&other.id))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for UpstreamOAuthProvider {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.id.cmp(&other.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl UpstreamOAuthProvider {
|
||||
/// Returns `true` if the provider is enabled
|
||||
#[must_use]
|
||||
|
@@ -515,8 +515,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
async fn disable(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
upstream_oauth_provider: UpstreamOAuthProvider,
|
||||
) -> Result<(), Self::Error> {
|
||||
mut upstream_oauth_provider: UpstreamOAuthProvider,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error> {
|
||||
let disabled_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@@ -531,7 +531,11 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
upstream_oauth_provider.disabled_at = Some(disabled_at);
|
||||
|
||||
Ok(upstream_oauth_provider)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
|
@@ -204,6 +204,8 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
|
||||
/// Disable an upstream OAuth provider
|
||||
///
|
||||
/// Returns the disabled provider
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
@@ -216,7 +218,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
provider: UpstreamOAuthProvider,
|
||||
) -> Result<(), Self::Error>;
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
|
||||
///
|
||||
@@ -281,7 +283,7 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
provider: UpstreamOAuthProvider
|
||||
) -> Result<(), Self::Error>;
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
async fn list(
|
||||
&mut self,
|
||||
|
@@ -1562,6 +1562,10 @@
|
||||
"token_endpoint_auth_method"
|
||||
],
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"description": "Whether this provider is enabled.\n\nDefaults to `true`",
|
||||
"type": "boolean"
|
||||
},
|
||||
"id": {
|
||||
"description": "A ULID as per https://github.com/ulid/spec",
|
||||
"type": "string",
|
||||
|
Reference in New Issue
Block a user