1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Add a soft-deletion column on upstream OAuth 2.0 providers

This commit is contained in:
Quentin Gliech
2024-04-02 14:38:54 +02:00
parent 58fd6ab4c1
commit 4e3823fe4f
20 changed files with 369 additions and 40 deletions

View File

@@ -110,7 +110,7 @@ pub async fn config_sync(
.map(|p| p.id)
.collect::<HashSet<_>>();
let existing = repo.upstream_oauth_provider().all().await?;
let existing = repo.upstream_oauth_provider().all_enabled().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {

View File

@@ -141,10 +141,19 @@ pub struct UpstreamOAuthProvider {
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
pub created_at: DateTime<Utc>,
pub disabled_at: Option<DateTime<Utc>>,
pub claims_imports: ClaimsImports,
pub additional_authorization_parameters: Vec<(String, String)>,
}
impl UpstreamOAuthProvider {
/// Returns `true` if the provider is enabled
#[must_use]
pub const fn enabled(&self) -> bool {
self.disabled_at.is_none()
}
}
/// Whether to set the email as verified when importing it from the upstream
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]

View File

@@ -501,7 +501,9 @@ impl User {
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let filter = UpstreamOAuthLinkFilter::new().for_user(&self.0);
let filter = UpstreamOAuthLinkFilter::new()
.for_user(&self.0)
.enabled_providers_only();
let page = repo.upstream_oauth_link().list(filter, pagination).await?;

View File

@@ -73,6 +73,11 @@ impl UpstreamOAuthQuery {
return Ok(None);
};
// We only allow enabled providers to be fetched
if !provider.enabled() {
return Ok(None);
}
Ok(Some(UpstreamOAuth2Provider::new(provider)))
}
@@ -110,7 +115,9 @@ impl UpstreamOAuthQuery {
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let filter = UpstreamOAuthProviderFilter::new();
// We only want enabled providers
// XXX: we may want to let admins see disabled providers
let filter = UpstreamOAuthProviderFilter::new().enabled_only();
let page = repo
.upstream_oauth_provider()

View File

@@ -20,6 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
};
use mas_data_model::UpstreamOAuthProvider;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder;
use mas_storage::{
@@ -81,6 +82,7 @@ pub(crate) async fn get(
.upstream_oauth_provider()
.lookup(provider_id)
.await?
.filter(UpstreamOAuthProvider::enabled)
.ok_or(RouteError::ProviderNotFound)?;
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");

View File

@@ -167,7 +167,7 @@ impl MetadataCache {
interval: std::time::Duration,
repository: &mut R,
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
let providers = repository.upstream_oauth_provider().all().await?;
let providers = repository.upstream_oauth_provider().all_enabled().await?;
for provider in providers {
let verify = match provider.discovery_mode {
@@ -504,6 +504,7 @@ mod tests {
token_endpoint_signing_alg: None,
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
created_at: clock.now(),
disabled_at: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
additional_authorization_parameters: Vec::new(),
};

View File

@@ -20,6 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
};
use mas_data_model::UpstreamOAuthProvider;
use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::requests::{
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
@@ -146,6 +147,7 @@ pub(crate) async fn get(
.upstream_oauth_provider()
.lookup(provider_id)
.await?
.filter(UpstreamOAuthProvider::enabled)
.ok_or(RouteError::ProviderNotFound)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);

View File

@@ -78,7 +78,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, reply).into_response());
};
let providers = repo.upstream_oauth_provider().all().await?;
let providers = repo.upstream_oauth_provider().all_enabled().await?;
// If password-based login is disabled, and there is only one upstream provider,
// we can directly start an authorization flow
@@ -149,7 +149,7 @@ pub(crate) async fn post(
};
if !state.is_valid() {
let providers = repo.upstream_oauth_provider().all().await?;
let providers = repo.upstream_oauth_provider().all_enabled().await?;
let content = render(
locale,
LoginContext::default()

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE upstream_oauth_providers\n SET disabled_at = $2\n WHERE upstream_oauth_provider_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
},
"nullable": []
},
"hash": "048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
"describe": {
"columns": [
{
@@ -55,36 +55,41 @@
},
{
"ordinal": 10,
"name": "disabled_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"type_info": "Jsonb"
},
{
"ordinal": 11,
"ordinal": 12,
"name": "jwks_uri_override",
"type_info": "Text"
},
{
"ordinal": 12,
"ordinal": 13,
"name": "authorization_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 13,
"ordinal": 14,
"name": "token_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 15,
"ordinal": 16,
"name": "pkce_mode",
"type_info": "Text"
},
{
"ordinal": 16,
"ordinal": 17,
"name": "additional_parameters: Json<Vec<(String, String)>>",
"type_info": "Jsonb"
}
@@ -105,6 +110,7 @@
true,
false,
false,
true,
false,
true,
true,
@@ -114,5 +120,5 @@
true
]
},
"hash": "d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903"
"hash": "51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n ",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
"describe": {
"columns": [
{
@@ -55,36 +55,41 @@
},
{
"ordinal": 10,
"name": "disabled_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"type_info": "Jsonb"
},
{
"ordinal": 11,
"ordinal": 12,
"name": "jwks_uri_override",
"type_info": "Text"
},
{
"ordinal": 12,
"ordinal": 13,
"name": "authorization_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 13,
"ordinal": 14,
"name": "token_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 15,
"ordinal": 16,
"name": "pkce_mode",
"type_info": "Text"
},
{
"ordinal": 16,
"ordinal": 17,
"name": "additional_parameters: Json<Vec<(String, String)>>",
"type_info": "Jsonb"
}
@@ -103,6 +108,7 @@
true,
false,
false,
true,
false,
true,
true,
@@ -112,5 +118,5 @@
true
]
},
"hash": "ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c"
"hash": "5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
"describe": {
"columns": [
{
@@ -34,5 +34,5 @@
false
]
},
"hash": "21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4"
"hash": "94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae"
}

View File

@@ -0,0 +1,18 @@
-- Copyright 2024 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- Adds a `disabled_at` column to the `upstream_oauth_providers` table, to soft-delete providers.
ALTER TABLE "upstream_oauth_providers"
ADD COLUMN "disabled_at" TIMESTAMP WITH TIME ZONE;

View File

@@ -107,6 +107,7 @@ pub enum UpstreamOAuthProviders {
TokenEndpointSigningAlg,
TokenEndpointAuthMethod,
CreatedAt,
DisabledAt,
ClaimsImports,
DiscoveryMode,
PkceMode,

View File

@@ -27,7 +27,10 @@ use ulid::Ulid;
use uuid::Uuid;
use crate::{
iden::UpstreamOAuthLinks, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError,
iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
pagination::QueryBuilderExt,
tracing::ExecuteExt,
DatabaseError,
};
/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
@@ -280,6 +283,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
))
.eq(Uuid::from(provider.id))
}))
.and_where_option(filter.provider_enabled().map(|enabled| {
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
))
.eq(Expr::any(
Query::select()
.expr(Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::UpstreamOAuthProviderId,
)))
.from(UpstreamOAuthProviders::Table)
.and_where(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled),
)
.take(),
))
}))
.generate_pagination(
(
UpstreamOAuthLinks::Table,
@@ -328,6 +354,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
))
.eq(Uuid::from(provider.id))
}))
.and_where_option(filter.provider_enabled().map(|enabled| {
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
))
.eq(Expr::any(
Query::select()
.expr(Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::UpstreamOAuthProviderId,
)))
.from(UpstreamOAuthProviders::Table)
.and_where(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled),
)
.take(),
))
}))
.build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)

View File

@@ -51,7 +51,7 @@ mod tests {
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
// The provider list should be empty at the start
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
assert!(all_providers.is_empty());
// Let's add a provider
@@ -93,7 +93,7 @@ mod tests {
assert_eq!(provider.client_id, "client-id");
// It should be in the list of all providers
let providers = repo.upstream_oauth_provider().all().await.unwrap();
let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].issuer, "https://example.com/");
assert_eq!(providers[0].client_id, "client-id");
@@ -192,7 +192,8 @@ mod tests {
// XXX: we should also try other combinations of the filter
let filter = UpstreamOAuthLinkFilter::new()
.for_user(&user)
.for_provider(&provider);
.for_provider(&provider)
.enabled_providers_only();
let links = repo
.upstream_oauth_link()
@@ -207,13 +208,70 @@ mod tests {
assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
// There should be exactly one enabled provider
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new())
.await
.unwrap(),
1
);
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new().enabled_only())
.await
.unwrap(),
1
);
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new().disabled_only())
.await
.unwrap(),
0
);
// Disable the provider
repo.upstream_oauth_provider()
.disable(&clock, provider.clone())
.await
.unwrap();
// There should be exactly one disabled provider
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new())
.await
.unwrap(),
1
);
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new().enabled_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new().disabled_only())
.await
.unwrap(),
1
);
// Try deleting the provider
repo.upstream_oauth_provider()
.delete(provider)
.await
.unwrap();
let providers = repo.upstream_oauth_provider().all().await.unwrap();
assert!(providers.is_empty());
assert_eq!(
repo.upstream_oauth_provider()
.count(UpstreamOAuthProviderFilter::new())
.await
.unwrap(),
0
);
}
/// Test that the pagination works as expected in the upstream OAuth
@@ -287,6 +345,16 @@ mod tests {
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[..10]);
// Getting the same page with the "enabled only" filter should return the same
// results
let other_page = repo
.upstream_oauth_provider()
.list(filter.enabled_only(), Pagination::first(10))
.await
.unwrap();
assert_eq!(page, other_page);
// Lookup the next 10 items
let page = repo
.upstream_oauth_provider()
@@ -334,5 +402,17 @@ mod tests {
assert!(!page.has_next_page);
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
assert_eq!(&edge_ids, &ids[6..8]);
// There should not be any disabled providers
assert!(repo
.upstream_oauth_provider()
.list(
UpstreamOAuthProviderFilter::new().disabled_only(),
Pagination::first(1)
)
.await
.unwrap()
.edges
.is_empty());
}
}

View File

@@ -62,6 +62,7 @@ struct ProviderLookup {
token_endpoint_signing_alg: Option<String>,
token_endpoint_auth_method: String,
created_at: DateTime<Utc>,
disabled_at: Option<DateTime<Utc>>,
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
jwks_uri_override: Option<String>,
authorization_endpoint_override: Option<String>,
@@ -161,6 +162,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
token_endpoint_auth_method,
token_endpoint_signing_alg,
created_at: value.created_at,
disabled_at: value.disabled_at,
claims_imports: value.claims_imports.0,
authorization_endpoint_override,
token_endpoint_override,
@@ -200,6 +202,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
disabled_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
jwks_uri_override,
authorization_endpoint_override,
@@ -308,6 +311,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
token_endpoint_auth_method: params.token_endpoint_auth_method,
created_at,
disabled_at: None,
claims_imports: params.claims_imports,
authorization_endpoint_override: params.authorization_endpoint_override,
token_endpoint_override: params.token_endpoint_override,
@@ -434,6 +438,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
scope = EXCLUDED.scope,
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
disabled_at = NULL,
client_id = EXCLUDED.client_id,
encrypted_client_secret = EXCLUDED.encrypted_client_secret,
claims_imports = EXCLUDED.claims_imports,
@@ -487,6 +492,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
token_endpoint_auth_method: params.token_endpoint_auth_method,
created_at,
disabled_at: None,
claims_imports: params.claims_imports,
authorization_endpoint_override: params.authorization_endpoint_override,
token_endpoint_override: params.token_endpoint_override,
@@ -497,6 +503,37 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
})
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.disable",
skip_all,
fields(
db.statement,
%upstream_oauth_provider.id,
),
err,
)]
async fn disable(
&mut self,
clock: &dyn Clock,
upstream_oauth_provider: UpstreamOAuthProvider,
) -> Result<(), Self::Error> {
let disabled_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE upstream_oauth_providers
SET disabled_at = $2
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(upstream_oauth_provider.id),
disabled_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.list",
skip_all,
@@ -507,10 +544,9 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
)]
async fn list(
&mut self,
_filter: UpstreamOAuthProviderFilter<'_>,
filter: UpstreamOAuthProviderFilter<'_>,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
// XXX: the filter is currently ignored, as it does not have any fields
let (sql, arguments) = Query::select()
.expr_as(
Expr::col((
@@ -579,6 +615,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
)),
ProviderLookupIden::CreatedAt,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
)),
ProviderLookupIden::DisabledAt,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
@@ -629,6 +672,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
ProviderLookupIden::AdditionalParameters,
)
.from(UpstreamOAuthProviders::Table)
.and_where_option(filter.enabled().map(|enabled| {
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled)
}))
.generate_pagination(
(
UpstreamOAuthProviders::Table,
@@ -660,9 +711,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
)]
async fn count(
&mut self,
_filter: UpstreamOAuthProviderFilter<'_>,
filter: UpstreamOAuthProviderFilter<'_>,
) -> Result<usize, Self::Error> {
// XXX: the filter is currently ignored, as it does not have any fields
let (sql, arguments) = Query::select()
.expr(
Expr::col((
@@ -672,6 +722,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
.count(),
)
.from(UpstreamOAuthProviders::Table)
.and_where_option(filter.enabled().map(|enabled| {
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled)
}))
.build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
@@ -685,14 +743,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.all",
name = "db.upstream_oauth_provider.all_enabled",
skip_all,
fields(
db.statement,
),
err,
)]
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
let res = sqlx::query_as!(
ProviderLookup,
r#"
@@ -707,6 +765,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
disabled_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
jwks_uri_override,
authorization_endpoint_override,
@@ -715,6 +774,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
pkce_mode,
additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
FROM upstream_oauth_providers
WHERE disabled_at IS NULL
"#,
)
.traced()

View File

@@ -137,6 +137,7 @@ impl Pagination {
}
/// A page of results returned by a paginated query
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Page<T> {
/// When paginating forwards, this is true if there are more items after
pub has_next_page: bool,

View File

@@ -25,6 +25,7 @@ pub struct UpstreamOAuthLinkFilter<'a> {
// XXX: we might also want to filter for links without a user linked to them
user: Option<&'a User>,
provider: Option<&'a UpstreamOAuthProvider>,
provider_enabled: Option<bool>,
}
impl<'a> UpstreamOAuthLinkFilter<'a> {
@@ -63,6 +64,26 @@ impl<'a> UpstreamOAuthLinkFilter<'a> {
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
self.provider
}
/// Set whether to filter for enabled providers
#[must_use]
pub const fn enabled_providers_only(mut self) -> Self {
self.provider_enabled = Some(true);
self
}
/// Set whether to filter for disabled providers
#[must_use]
pub const fn disabled_providers_only(mut self) -> Self {
self.provider_enabled = Some(false);
self
}
/// Get the provider enabled filter
#[must_use]
pub const fn provider_enabled(&self) -> Option<bool> {
self.provider_enabled
}
}
/// An [`UpstreamOAuthLinkRepository`] helps interacting with

View File

@@ -1,4 +1,4 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -82,6 +82,11 @@ pub struct UpstreamOAuthProviderParams {
/// Filter parameters for listing upstream OAuth 2.0 providers
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UpstreamOAuthProviderFilter<'a> {
/// Filter by whether the provider is enabled
///
/// If `None`, all providers are returned
enabled: Option<bool>,
_lifetime: PhantomData<&'a ()>,
}
@@ -91,6 +96,28 @@ impl<'a> UpstreamOAuthProviderFilter<'a> {
pub fn new() -> Self {
Self::default()
}
/// Return only enabled providers
#[must_use]
pub const fn enabled_only(mut self) -> Self {
self.enabled = Some(true);
self
}
/// Return only disabled providers
#[must_use]
pub const fn disabled_only(mut self) -> Self {
self.enabled = Some(false);
self
}
/// Get the enabled filter
///
/// Returns `None` if the filter is not set
#[must_use]
pub const fn enabled(&self) -> Option<bool> {
self.enabled
}
}
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
@@ -175,6 +202,22 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
params: UpstreamOAuthProviderParams,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Disable an upstream OAuth provider
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `provider`: The provider to disable
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn disable(
&mut self,
clock: &dyn Clock,
provider: UpstreamOAuthProvider,
) -> Result<(), Self::Error>;
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
///
/// # Parameters
@@ -205,12 +248,12 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
filter: UpstreamOAuthProviderFilter<'_>,
) -> Result<usize, Self::Error>;
/// Get all upstream OAuth providers
/// Get all enabled upstream OAuth providers
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
}
repository_impl!(UpstreamOAuthProviderRepository:
@@ -234,6 +277,12 @@ repository_impl!(UpstreamOAuthProviderRepository:
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
async fn disable(
&mut self,
clock: &dyn Clock,
provider: UpstreamOAuthProvider
) -> Result<(), Self::Error>;
async fn list(
&mut self,
filter: UpstreamOAuthProviderFilter<'_>,
@@ -245,5 +294,5 @@ repository_impl!(UpstreamOAuthProviderRepository:
filter: UpstreamOAuthProviderFilter<'_>
) -> Result<usize, Self::Error>;
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
);