diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index bdbb4e7f..b0d22cdc 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -110,7 +110,7 @@ pub async fn config_sync( .map(|p| p.id) .collect::>(); - let existing = repo.upstream_oauth_provider().all().await?; + let existing = repo.upstream_oauth_provider().all_enabled().await?; let existing_ids = existing.iter().map(|p| p.id).collect::>(); let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); if prune { diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index af0d26bc..f3eb76fd 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -141,10 +141,19 @@ pub struct UpstreamOAuthProvider { pub token_endpoint_signing_alg: Option, pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, pub created_at: DateTime, + pub disabled_at: Option>, pub claims_imports: ClaimsImports, pub additional_authorization_parameters: Vec<(String, String)>, } +impl UpstreamOAuthProvider { + /// Returns `true` if the provider is enabled + #[must_use] + pub const fn enabled(&self) -> bool { + self.disabled_at.is_none() + } +} + /// Whether to set the email as verified when importing it from the upstream #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "lowercase")] diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 49f47664..3b4662d5 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -501,7 +501,9 @@ impl User { .transpose()?; let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let filter = UpstreamOAuthLinkFilter::new().for_user(&self.0); + let filter = UpstreamOAuthLinkFilter::new() + .for_user(&self.0) + .enabled_providers_only(); let page = repo.upstream_oauth_link().list(filter, pagination).await?; diff --git a/crates/graphql/src/query/upstream_oauth.rs b/crates/graphql/src/query/upstream_oauth.rs index 520be2fd..b8a2d750 100644 --- a/crates/graphql/src/query/upstream_oauth.rs +++ b/crates/graphql/src/query/upstream_oauth.rs @@ -73,6 +73,11 @@ impl UpstreamOAuthQuery { return Ok(None); }; + // We only allow enabled providers to be fetched + if !provider.enabled() { + return Ok(None); + } + Ok(Some(UpstreamOAuth2Provider::new(provider))) } @@ -110,7 +115,9 @@ impl UpstreamOAuthQuery { .transpose()?; let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let filter = UpstreamOAuthProviderFilter::new(); + // We only want enabled providers + // XXX: we may want to let admins see disabled providers + let filter = UpstreamOAuthProviderFilter::new().enabled_only(); let page = repo .upstream_oauth_provider() diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index d8171f96..71f5bf1c 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -20,6 +20,7 @@ use hyper::StatusCode; use mas_axum_utils::{ cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID, }; +use mas_data_model::UpstreamOAuthProvider; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; use mas_storage::{ @@ -81,6 +82,7 @@ pub(crate) async fn get( .upstream_oauth_provider() .lookup(provider_id) .await? + .filter(UpstreamOAuthProvider::enabled) .ok_or(RouteError::ProviderNotFound)?; let http_service = http_client_factory.http_service("upstream_oauth2.authorize"); diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs index 69fcc857..cd04260d 100644 --- a/crates/handlers/src/upstream_oauth2/cache.rs +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -167,7 +167,7 @@ impl MetadataCache { interval: std::time::Duration, repository: &mut R, ) -> Result, R::Error> { - let providers = repository.upstream_oauth_provider().all().await?; + let providers = repository.upstream_oauth_provider().all_enabled().await?; for provider in providers { let verify = match provider.discovery_mode { @@ -504,6 +504,7 @@ mod tests { token_endpoint_signing_alg: None, token_endpoint_auth_method: OAuthClientAuthenticationMethod::None, created_at: clock.now(), + disabled_at: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), additional_authorization_parameters: Vec::new(), }; diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index c8ca954a..a49f245b 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -20,6 +20,7 @@ use hyper::StatusCode; use mas_axum_utils::{ cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID, }; +use mas_data_model::UpstreamOAuthProvider; use mas_keystore::{Encrypter, Keystore}; use mas_oidc_client::requests::{ authorization_code::AuthorizationValidationData, jose::JwtVerificationData, @@ -146,6 +147,7 @@ pub(crate) async fn get( .upstream_oauth_provider() .lookup(provider_id) .await? + .filter(UpstreamOAuthProvider::enabled) .ok_or(RouteError::ProviderNotFound)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 90d0a497..a626a53c 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -78,7 +78,7 @@ pub(crate) async fn get( return Ok((cookie_jar, reply).into_response()); }; - let providers = repo.upstream_oauth_provider().all().await?; + let providers = repo.upstream_oauth_provider().all_enabled().await?; // If password-based login is disabled, and there is only one upstream provider, // we can directly start an authorization flow @@ -149,7 +149,7 @@ pub(crate) async fn post( }; if !state.is_valid() { - let providers = repo.upstream_oauth_provider().all().await?; + let providers = repo.upstream_oauth_provider().all_enabled().await?; let content = render( locale, LoginContext::default() diff --git a/crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json b/crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json new file mode 100644 index 00000000..707eead7 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE upstream_oauth_providers\n SET disabled_at = $2\n WHERE upstream_oauth_provider_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76" +} diff --git a/crates/storage-pg/.sqlx/query-d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903.json b/crates/storage-pg/.sqlx/query-51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a.json similarity index 76% rename from crates/storage-pg/.sqlx/query-d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903.json rename to crates/storage-pg/.sqlx/query-51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a.json index e923947b..ac9f681c 100644 --- a/crates/storage-pg/.sqlx/query-d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903.json +++ b/crates/storage-pg/.sqlx/query-51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", "describe": { "columns": [ { @@ -55,36 +55,41 @@ }, { "ordinal": 10, + "name": "disabled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 11, "name": "claims_imports: Json", "type_info": "Jsonb" }, { - "ordinal": 11, + "ordinal": 12, "name": "jwks_uri_override", "type_info": "Text" }, { - "ordinal": 12, + "ordinal": 13, "name": "authorization_endpoint_override", "type_info": "Text" }, { - "ordinal": 13, + "ordinal": 14, "name": "token_endpoint_override", "type_info": "Text" }, { - "ordinal": 14, + "ordinal": 15, "name": "discovery_mode", "type_info": "Text" }, { - "ordinal": 15, + "ordinal": 16, "name": "pkce_mode", "type_info": "Text" }, { - "ordinal": 16, + "ordinal": 17, "name": "additional_parameters: Json>", "type_info": "Jsonb" } @@ -105,6 +110,7 @@ true, false, false, + true, false, true, true, @@ -114,5 +120,5 @@ true ] }, - "hash": "d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903" + "hash": "51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a" } diff --git a/crates/storage-pg/.sqlx/query-ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c.json b/crates/storage-pg/.sqlx/query-5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d.json similarity index 76% rename from crates/storage-pg/.sqlx/query-ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c.json rename to crates/storage-pg/.sqlx/query-5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d.json index a30ed5c7..f6f2b0db 100644 --- a/crates/storage-pg/.sqlx/query-ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c.json +++ b/crates/storage-pg/.sqlx/query-5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n ", + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ", "describe": { "columns": [ { @@ -55,36 +55,41 @@ }, { "ordinal": 10, + "name": "disabled_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 11, "name": "claims_imports: Json", "type_info": "Jsonb" }, { - "ordinal": 11, + "ordinal": 12, "name": "jwks_uri_override", "type_info": "Text" }, { - "ordinal": 12, + "ordinal": 13, "name": "authorization_endpoint_override", "type_info": "Text" }, { - "ordinal": 13, + "ordinal": 14, "name": "token_endpoint_override", "type_info": "Text" }, { - "ordinal": 14, + "ordinal": 15, "name": "discovery_mode", "type_info": "Text" }, { - "ordinal": 15, + "ordinal": 16, "name": "pkce_mode", "type_info": "Text" }, { - "ordinal": 16, + "ordinal": 17, "name": "additional_parameters: Json>", "type_info": "Jsonb" } @@ -103,6 +108,7 @@ true, false, false, + true, false, true, true, @@ -112,5 +118,5 @@ true ] }, - "hash": "ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c" + "hash": "5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d" } diff --git a/crates/storage-pg/.sqlx/query-21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4.json b/crates/storage-pg/.sqlx/query-94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae.json similarity index 70% rename from crates/storage-pg/.sqlx/query-21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4.json rename to crates/storage-pg/.sqlx/query-94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae.json index 403bee8e..64314392 100644 --- a/crates/storage-pg/.sqlx/query-21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4.json +++ b/crates/storage-pg/.sqlx/query-94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ", + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ", "describe": { "columns": [ { @@ -34,5 +34,5 @@ false ] }, - "hash": "21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4" + "hash": "94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae" } diff --git a/crates/storage-pg/migrations/20240402084854_upstream_oauth_disabled_at.sql b/crates/storage-pg/migrations/20240402084854_upstream_oauth_disabled_at.sql new file mode 100644 index 00000000..d1469198 --- /dev/null +++ b/crates/storage-pg/migrations/20240402084854_upstream_oauth_disabled_at.sql @@ -0,0 +1,18 @@ +-- Copyright 2024 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + + +-- Adds a `disabled_at` column to the `upstream_oauth_providers` table, to soft-delete providers. +ALTER TABLE "upstream_oauth_providers" + ADD COLUMN "disabled_at" TIMESTAMP WITH TIME ZONE; diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index 18f745e5..7b58a4a9 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -107,6 +107,7 @@ pub enum UpstreamOAuthProviders { TokenEndpointSigningAlg, TokenEndpointAuthMethod, CreatedAt, + DisabledAt, ClaimsImports, DiscoveryMode, PkceMode, diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs index 1ccdf062..05f9c236 100644 --- a/crates/storage-pg/src/upstream_oauth2/link.rs +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -27,7 +27,10 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - iden::UpstreamOAuthLinks, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, + iden::{UpstreamOAuthLinks, UpstreamOAuthProviders}, + pagination::QueryBuilderExt, + tracing::ExecuteExt, + DatabaseError, }; /// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL @@ -280,6 +283,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { )) .eq(Uuid::from(provider.id)) })) + .and_where_option(filter.provider_enabled().map(|enabled| { + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthProviderId, + )) + .eq(Expr::any( + Query::select() + .expr(Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::UpstreamOAuthProviderId, + ))) + .from(UpstreamOAuthProviders::Table) + .and_where( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::DisabledAt, + )) + .is_null() + .eq(enabled), + ) + .take(), + )) + })) .generate_pagination( ( UpstreamOAuthLinks::Table, @@ -328,6 +354,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { )) .eq(Uuid::from(provider.id)) })) + .and_where_option(filter.provider_enabled().map(|enabled| { + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthProviderId, + )) + .eq(Expr::any( + Query::select() + .expr(Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::UpstreamOAuthProviderId, + ))) + .from(UpstreamOAuthProviders::Table) + .and_where( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::DisabledAt, + )) + .is_null() + .eq(enabled), + ) + .take(), + )) + })) .build_sqlx(PostgresQueryBuilder); let count: i64 = sqlx::query_scalar_with(&sql, arguments) diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 06596ffe..61cf23c4 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -51,7 +51,7 @@ mod tests { let mut repo = PgRepository::from_pool(&pool).await.unwrap(); // The provider list should be empty at the start - let all_providers = repo.upstream_oauth_provider().all().await.unwrap(); + let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap(); assert!(all_providers.is_empty()); // Let's add a provider @@ -93,7 +93,7 @@ mod tests { assert_eq!(provider.client_id, "client-id"); // It should be in the list of all providers - let providers = repo.upstream_oauth_provider().all().await.unwrap(); + let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap(); assert_eq!(providers.len(), 1); assert_eq!(providers[0].issuer, "https://example.com/"); assert_eq!(providers[0].client_id, "client-id"); @@ -192,7 +192,8 @@ mod tests { // XXX: we should also try other combinations of the filter let filter = UpstreamOAuthLinkFilter::new() .for_user(&user) - .for_provider(&provider); + .for_provider(&provider) + .enabled_providers_only(); let links = repo .upstream_oauth_link() @@ -207,13 +208,70 @@ mod tests { assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1); + // There should be exactly one enabled provider + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new()) + .await + .unwrap(), + 1 + ); + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new().enabled_only()) + .await + .unwrap(), + 1 + ); + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new().disabled_only()) + .await + .unwrap(), + 0 + ); + + // Disable the provider + repo.upstream_oauth_provider() + .disable(&clock, provider.clone()) + .await + .unwrap(); + + // There should be exactly one disabled provider + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new()) + .await + .unwrap(), + 1 + ); + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new().enabled_only()) + .await + .unwrap(), + 0 + ); + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new().disabled_only()) + .await + .unwrap(), + 1 + ); + // Try deleting the provider repo.upstream_oauth_provider() .delete(provider) .await .unwrap(); - let providers = repo.upstream_oauth_provider().all().await.unwrap(); - assert!(providers.is_empty()); + assert_eq!( + repo.upstream_oauth_provider() + .count(UpstreamOAuthProviderFilter::new()) + .await + .unwrap(), + 0 + ); } /// Test that the pagination works as expected in the upstream OAuth @@ -287,6 +345,16 @@ mod tests { let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); assert_eq!(&edge_ids, &ids[..10]); + // Getting the same page with the "enabled only" filter should return the same + // results + let other_page = repo + .upstream_oauth_provider() + .list(filter.enabled_only(), Pagination::first(10)) + .await + .unwrap(); + + assert_eq!(page, other_page); + // Lookup the next 10 items let page = repo .upstream_oauth_provider() @@ -334,5 +402,17 @@ mod tests { assert!(!page.has_next_page); let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect(); assert_eq!(&edge_ids, &ids[6..8]); + + // There should not be any disabled providers + assert!(repo + .upstream_oauth_provider() + .list( + UpstreamOAuthProviderFilter::new().disabled_only(), + Pagination::first(1) + ) + .await + .unwrap() + .edges + .is_empty()); } } diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index dcd6e202..5da62444 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -62,6 +62,7 @@ struct ProviderLookup { token_endpoint_signing_alg: Option, token_endpoint_auth_method: String, created_at: DateTime, + disabled_at: Option>, claims_imports: Json, jwks_uri_override: Option, authorization_endpoint_override: Option, @@ -161,6 +162,7 @@ impl TryFrom for UpstreamOAuthProvider { token_endpoint_auth_method, token_endpoint_signing_alg, created_at: value.created_at, + disabled_at: value.disabled_at, claims_imports: value.claims_imports.0, authorization_endpoint_override, token_endpoint_override, @@ -200,6 +202,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_signing_alg, token_endpoint_auth_method, created_at, + disabled_at, claims_imports as "claims_imports: Json", jwks_uri_override, authorization_endpoint_override, @@ -308,6 +311,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_signing_alg: params.token_endpoint_signing_alg, token_endpoint_auth_method: params.token_endpoint_auth_method, created_at, + disabled_at: None, claims_imports: params.claims_imports, authorization_endpoint_override: params.authorization_endpoint_override, token_endpoint_override: params.token_endpoint_override, @@ -434,6 +438,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' scope = EXCLUDED.scope, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method, token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg, + disabled_at = NULL, client_id = EXCLUDED.client_id, encrypted_client_secret = EXCLUDED.encrypted_client_secret, claims_imports = EXCLUDED.claims_imports, @@ -487,6 +492,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_signing_alg: params.token_endpoint_signing_alg, token_endpoint_auth_method: params.token_endpoint_auth_method, created_at, + disabled_at: None, claims_imports: params.claims_imports, authorization_endpoint_override: params.authorization_endpoint_override, token_endpoint_override: params.token_endpoint_override, @@ -497,6 +503,37 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' }) } + #[tracing::instrument( + name = "db.upstream_oauth_provider.disable", + skip_all, + fields( + db.statement, + %upstream_oauth_provider.id, + ), + err, + )] + async fn disable( + &mut self, + clock: &dyn Clock, + upstream_oauth_provider: UpstreamOAuthProvider, + ) -> Result<(), Self::Error> { + let disabled_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE upstream_oauth_providers + SET disabled_at = $2 + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(upstream_oauth_provider.id), + disabled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1) + } + #[tracing::instrument( name = "db.upstream_oauth_provider.list", skip_all, @@ -507,10 +544,9 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )] async fn list( &mut self, - _filter: UpstreamOAuthProviderFilter<'_>, + filter: UpstreamOAuthProviderFilter<'_>, pagination: Pagination, ) -> Result, Self::Error> { - // XXX: the filter is currently ignored, as it does not have any fields let (sql, arguments) = Query::select() .expr_as( Expr::col(( @@ -579,6 +615,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )), ProviderLookupIden::CreatedAt, ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::DisabledAt, + )), + ProviderLookupIden::DisabledAt, + ) .expr_as( Expr::col(( UpstreamOAuthProviders::Table, @@ -629,6 +672,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' ProviderLookupIden::AdditionalParameters, ) .from(UpstreamOAuthProviders::Table) + .and_where_option(filter.enabled().map(|enabled| { + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::DisabledAt, + )) + .is_null() + .eq(enabled) + })) .generate_pagination( ( UpstreamOAuthProviders::Table, @@ -660,9 +711,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )] async fn count( &mut self, - _filter: UpstreamOAuthProviderFilter<'_>, + filter: UpstreamOAuthProviderFilter<'_>, ) -> Result { - // XXX: the filter is currently ignored, as it does not have any fields let (sql, arguments) = Query::select() .expr( Expr::col(( @@ -672,6 +722,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' .count(), ) .from(UpstreamOAuthProviders::Table) + .and_where_option(filter.enabled().map(|enabled| { + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::DisabledAt, + )) + .is_null() + .eq(enabled) + })) .build_sqlx(PostgresQueryBuilder); let count: i64 = sqlx::query_scalar_with(&sql, arguments) @@ -685,14 +743,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' } #[tracing::instrument( - name = "db.upstream_oauth_provider.all", + name = "db.upstream_oauth_provider.all_enabled", skip_all, fields( db.statement, ), err, )] - async fn all(&mut self) -> Result, Self::Error> { + async fn all_enabled(&mut self) -> Result, Self::Error> { let res = sqlx::query_as!( ProviderLookup, r#" @@ -707,6 +765,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_signing_alg, token_endpoint_auth_method, created_at, + disabled_at, claims_imports as "claims_imports: Json", jwks_uri_override, authorization_endpoint_override, @@ -715,6 +774,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' pkce_mode, additional_parameters as "additional_parameters: Json>" FROM upstream_oauth_providers + WHERE disabled_at IS NULL "#, ) .traced() diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index d8d8bc1c..e3efcf0b 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -137,6 +137,7 @@ impl Pagination { } /// A page of results returned by a paginated query +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Page { /// When paginating forwards, this is true if there are more items after pub has_next_page: bool, diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index b22d760e..a54350c5 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -25,6 +25,7 @@ pub struct UpstreamOAuthLinkFilter<'a> { // XXX: we might also want to filter for links without a user linked to them user: Option<&'a User>, provider: Option<&'a UpstreamOAuthProvider>, + provider_enabled: Option, } impl<'a> UpstreamOAuthLinkFilter<'a> { @@ -63,6 +64,26 @@ impl<'a> UpstreamOAuthLinkFilter<'a> { pub fn provider(&self) -> Option<&UpstreamOAuthProvider> { self.provider } + + /// Set whether to filter for enabled providers + #[must_use] + pub const fn enabled_providers_only(mut self) -> Self { + self.provider_enabled = Some(true); + self + } + + /// Set whether to filter for disabled providers + #[must_use] + pub const fn disabled_providers_only(mut self) -> Self { + self.provider_enabled = Some(false); + self + } + + /// Get the provider enabled filter + #[must_use] + pub const fn provider_enabled(&self) -> Option { + self.provider_enabled + } } /// An [`UpstreamOAuthLinkRepository`] helps interacting with diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 3458bfa8..af6c0b58 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -1,4 +1,4 @@ -// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// Copyright 2022-2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -82,6 +82,11 @@ pub struct UpstreamOAuthProviderParams { /// Filter parameters for listing upstream OAuth 2.0 providers #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] pub struct UpstreamOAuthProviderFilter<'a> { + /// Filter by whether the provider is enabled + /// + /// If `None`, all providers are returned + enabled: Option, + _lifetime: PhantomData<&'a ()>, } @@ -91,6 +96,28 @@ impl<'a> UpstreamOAuthProviderFilter<'a> { pub fn new() -> Self { Self::default() } + + /// Return only enabled providers + #[must_use] + pub const fn enabled_only(mut self) -> Self { + self.enabled = Some(true); + self + } + + /// Return only disabled providers + #[must_use] + pub const fn disabled_only(mut self) -> Self { + self.enabled = Some(false); + self + } + + /// Get the enabled filter + /// + /// Returns `None` if the filter is not set + #[must_use] + pub const fn enabled(&self) -> Option { + self.enabled + } } /// An [`UpstreamOAuthProviderRepository`] helps interacting with @@ -175,6 +202,22 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { params: UpstreamOAuthProviderParams, ) -> Result; + /// Disable an upstream OAuth provider + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `provider`: The provider to disable + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn disable( + &mut self, + clock: &dyn Clock, + provider: UpstreamOAuthProvider, + ) -> Result<(), Self::Error>; + /// List [`UpstreamOAuthProvider`] with the given filter and pagination /// /// # Parameters @@ -205,12 +248,12 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { filter: UpstreamOAuthProviderFilter<'_>, ) -> Result; - /// Get all upstream OAuth providers + /// Get all enabled upstream OAuth providers /// /// # Errors /// /// Returns [`Self::Error`] if the underlying repository fails - async fn all(&mut self) -> Result, Self::Error>; + async fn all_enabled(&mut self) -> Result, Self::Error>; } repository_impl!(UpstreamOAuthProviderRepository: @@ -234,6 +277,12 @@ repository_impl!(UpstreamOAuthProviderRepository: async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>; + async fn disable( + &mut self, + clock: &dyn Clock, + provider: UpstreamOAuthProvider + ) -> Result<(), Self::Error>; + async fn list( &mut self, filter: UpstreamOAuthProviderFilter<'_>, @@ -245,5 +294,5 @@ repository_impl!(UpstreamOAuthProviderRepository: filter: UpstreamOAuthProviderFilter<'_> ) -> Result; - async fn all(&mut self) -> Result, Self::Error>; + async fn all_enabled(&mut self) -> Result, Self::Error>; );