You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Add a soft-deletion column on upstream OAuth 2.0 providers
This commit is contained in:
@@ -110,7 +110,7 @@ pub async fn config_sync(
|
||||
.map(|p| p.id)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let existing = repo.upstream_oauth_provider().all().await?;
|
||||
let existing = repo.upstream_oauth_provider().all_enabled().await?;
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
|
@@ -141,10 +141,19 @@ pub struct UpstreamOAuthProvider {
|
||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub disabled_at: Option<DateTime<Utc>>,
|
||||
pub claims_imports: ClaimsImports,
|
||||
pub additional_authorization_parameters: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl UpstreamOAuthProvider {
|
||||
/// Returns `true` if the provider is enabled
|
||||
#[must_use]
|
||||
pub const fn enabled(&self) -> bool {
|
||||
self.disabled_at.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether to set the email as verified when importing it from the upstream
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
@@ -501,7 +501,9 @@ impl User {
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let filter = UpstreamOAuthLinkFilter::new().for_user(&self.0);
|
||||
let filter = UpstreamOAuthLinkFilter::new()
|
||||
.for_user(&self.0)
|
||||
.enabled_providers_only();
|
||||
|
||||
let page = repo.upstream_oauth_link().list(filter, pagination).await?;
|
||||
|
||||
|
@@ -73,6 +73,11 @@ impl UpstreamOAuthQuery {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// We only allow enabled providers to be fetched
|
||||
if !provider.enabled() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(UpstreamOAuth2Provider::new(provider)))
|
||||
}
|
||||
|
||||
@@ -110,7 +115,9 @@ impl UpstreamOAuthQuery {
|
||||
.transpose()?;
|
||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||
|
||||
let filter = UpstreamOAuthProviderFilter::new();
|
||||
// We only want enabled providers
|
||||
// XXX: we may want to let admins see disabled providers
|
||||
let filter = UpstreamOAuthProviderFilter::new().enabled_only();
|
||||
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
|
@@ -20,6 +20,7 @@ use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
|
||||
};
|
||||
use mas_data_model::UpstreamOAuthProvider;
|
||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{
|
||||
@@ -81,6 +82,7 @@ pub(crate) async fn get(
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider_id)
|
||||
.await?
|
||||
.filter(UpstreamOAuthProvider::enabled)
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
|
||||
|
@@ -167,7 +167,7 @@ impl MetadataCache {
|
||||
interval: std::time::Duration,
|
||||
repository: &mut R,
|
||||
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
|
||||
let providers = repository.upstream_oauth_provider().all().await?;
|
||||
let providers = repository.upstream_oauth_provider().all_enabled().await?;
|
||||
|
||||
for provider in providers {
|
||||
let verify = match provider.discovery_mode {
|
||||
@@ -504,6 +504,7 @@ mod tests {
|
||||
token_endpoint_signing_alg: None,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
||||
created_at: clock.now(),
|
||||
disabled_at: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
additional_authorization_parameters: Vec::new(),
|
||||
};
|
||||
|
@@ -20,6 +20,7 @@ use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
|
||||
};
|
||||
use mas_data_model::UpstreamOAuthProvider;
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_oidc_client::requests::{
|
||||
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
|
||||
@@ -146,6 +147,7 @@ pub(crate) async fn get(
|
||||
.upstream_oauth_provider()
|
||||
.lookup(provider_id)
|
||||
.await?
|
||||
.filter(UpstreamOAuthProvider::enabled)
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
|
@@ -78,7 +78,7 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, reply).into_response());
|
||||
};
|
||||
|
||||
let providers = repo.upstream_oauth_provider().all().await?;
|
||||
let providers = repo.upstream_oauth_provider().all_enabled().await?;
|
||||
|
||||
// If password-based login is disabled, and there is only one upstream provider,
|
||||
// we can directly start an authorization flow
|
||||
@@ -149,7 +149,7 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
if !state.is_valid() {
|
||||
let providers = repo.upstream_oauth_provider().all().await?;
|
||||
let providers = repo.upstream_oauth_provider().all_enabled().await?;
|
||||
let content = render(
|
||||
locale,
|
||||
LoginContext::default()
|
||||
|
15
crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json
generated
Normal file
15
crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE upstream_oauth_providers\n SET disabled_at = $2\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76"
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -55,36 +55,41 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "disabled_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"ordinal": 12,
|
||||
"name": "jwks_uri_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"ordinal": 13,
|
||||
"name": "authorization_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 13,
|
||||
"ordinal": 14,
|
||||
"name": "token_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 14,
|
||||
"ordinal": 15,
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 15,
|
||||
"ordinal": 16,
|
||||
"name": "pkce_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 16,
|
||||
"ordinal": 17,
|
||||
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
@@ -105,6 +110,7 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
@@ -114,5 +120,5 @@
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903"
|
||||
"hash": "51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a"
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -55,36 +55,41 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "disabled_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"ordinal": 12,
|
||||
"name": "jwks_uri_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"ordinal": 13,
|
||||
"name": "authorization_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 13,
|
||||
"ordinal": 14,
|
||||
"name": "token_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 14,
|
||||
"ordinal": 15,
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 15,
|
||||
"ordinal": 16,
|
||||
"name": "pkce_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 16,
|
||||
"ordinal": 17,
|
||||
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
@@ -103,6 +108,7 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
@@ -112,5 +118,5 @@
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c"
|
||||
"hash": "5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d"
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -34,5 +34,5 @@
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4"
|
||||
"hash": "94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae"
|
||||
}
|
@@ -0,0 +1,18 @@
|
||||
-- Copyright 2024 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
|
||||
-- Adds a `disabled_at` column to the `upstream_oauth_providers` table, to soft-delete providers.
|
||||
ALTER TABLE "upstream_oauth_providers"
|
||||
ADD COLUMN "disabled_at" TIMESTAMP WITH TIME ZONE;
|
@@ -107,6 +107,7 @@ pub enum UpstreamOAuthProviders {
|
||||
TokenEndpointSigningAlg,
|
||||
TokenEndpointAuthMethod,
|
||||
CreatedAt,
|
||||
DisabledAt,
|
||||
ClaimsImports,
|
||||
DiscoveryMode,
|
||||
PkceMode,
|
||||
|
@@ -27,7 +27,10 @@ use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
iden::UpstreamOAuthLinks, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError,
|
||||
iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
|
||||
pagination::QueryBuilderExt,
|
||||
tracing::ExecuteExt,
|
||||
DatabaseError,
|
||||
};
|
||||
|
||||
/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
|
||||
@@ -280,6 +283,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
||||
))
|
||||
.eq(Uuid::from(provider.id))
|
||||
}))
|
||||
.and_where_option(filter.provider_enabled().map(|enabled| {
|
||||
Expr::col((
|
||||
UpstreamOAuthLinks::Table,
|
||||
UpstreamOAuthLinks::UpstreamOAuthProviderId,
|
||||
))
|
||||
.eq(Expr::any(
|
||||
Query::select()
|
||||
.expr(Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||
)))
|
||||
.from(UpstreamOAuthProviders::Table)
|
||||
.and_where(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::DisabledAt,
|
||||
))
|
||||
.is_null()
|
||||
.eq(enabled),
|
||||
)
|
||||
.take(),
|
||||
))
|
||||
}))
|
||||
.generate_pagination(
|
||||
(
|
||||
UpstreamOAuthLinks::Table,
|
||||
@@ -328,6 +354,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
||||
))
|
||||
.eq(Uuid::from(provider.id))
|
||||
}))
|
||||
.and_where_option(filter.provider_enabled().map(|enabled| {
|
||||
Expr::col((
|
||||
UpstreamOAuthLinks::Table,
|
||||
UpstreamOAuthLinks::UpstreamOAuthProviderId,
|
||||
))
|
||||
.eq(Expr::any(
|
||||
Query::select()
|
||||
.expr(Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||
)))
|
||||
.from(UpstreamOAuthProviders::Table)
|
||||
.and_where(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::DisabledAt,
|
||||
))
|
||||
.is_null()
|
||||
.eq(enabled),
|
||||
)
|
||||
.take(),
|
||||
))
|
||||
}))
|
||||
.build_sqlx(PostgresQueryBuilder);
|
||||
|
||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||
|
@@ -51,7 +51,7 @@ mod tests {
|
||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||
|
||||
// The provider list should be empty at the start
|
||||
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
|
||||
let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
|
||||
assert!(all_providers.is_empty());
|
||||
|
||||
// Let's add a provider
|
||||
@@ -93,7 +93,7 @@ mod tests {
|
||||
assert_eq!(provider.client_id, "client-id");
|
||||
|
||||
// It should be in the list of all providers
|
||||
let providers = repo.upstream_oauth_provider().all().await.unwrap();
|
||||
let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
|
||||
assert_eq!(providers.len(), 1);
|
||||
assert_eq!(providers[0].issuer, "https://example.com/");
|
||||
assert_eq!(providers[0].client_id, "client-id");
|
||||
@@ -192,7 +192,8 @@ mod tests {
|
||||
// XXX: we should also try other combinations of the filter
|
||||
let filter = UpstreamOAuthLinkFilter::new()
|
||||
.for_user(&user)
|
||||
.for_provider(&provider);
|
||||
.for_provider(&provider)
|
||||
.enabled_providers_only();
|
||||
|
||||
let links = repo
|
||||
.upstream_oauth_link()
|
||||
@@ -207,13 +208,70 @@ mod tests {
|
||||
|
||||
assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
|
||||
|
||||
// There should be exactly one enabled provider
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new())
|
||||
.await
|
||||
.unwrap(),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new().enabled_only())
|
||||
.await
|
||||
.unwrap(),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new().disabled_only())
|
||||
.await
|
||||
.unwrap(),
|
||||
0
|
||||
);
|
||||
|
||||
// Disable the provider
|
||||
repo.upstream_oauth_provider()
|
||||
.disable(&clock, provider.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// There should be exactly one disabled provider
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new())
|
||||
.await
|
||||
.unwrap(),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new().enabled_only())
|
||||
.await
|
||||
.unwrap(),
|
||||
0
|
||||
);
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new().disabled_only())
|
||||
.await
|
||||
.unwrap(),
|
||||
1
|
||||
);
|
||||
|
||||
// Try deleting the provider
|
||||
repo.upstream_oauth_provider()
|
||||
.delete(provider)
|
||||
.await
|
||||
.unwrap();
|
||||
let providers = repo.upstream_oauth_provider().all().await.unwrap();
|
||||
assert!(providers.is_empty());
|
||||
assert_eq!(
|
||||
repo.upstream_oauth_provider()
|
||||
.count(UpstreamOAuthProviderFilter::new())
|
||||
.await
|
||||
.unwrap(),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that the pagination works as expected in the upstream OAuth
|
||||
@@ -287,6 +345,16 @@ mod tests {
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[..10]);
|
||||
|
||||
// Getting the same page with the "enabled only" filter should return the same
|
||||
// results
|
||||
let other_page = repo
|
||||
.upstream_oauth_provider()
|
||||
.list(filter.enabled_only(), Pagination::first(10))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(page, other_page);
|
||||
|
||||
// Lookup the next 10 items
|
||||
let page = repo
|
||||
.upstream_oauth_provider()
|
||||
@@ -334,5 +402,17 @@ mod tests {
|
||||
assert!(!page.has_next_page);
|
||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||
assert_eq!(&edge_ids, &ids[6..8]);
|
||||
|
||||
// There should not be any disabled providers
|
||||
assert!(repo
|
||||
.upstream_oauth_provider()
|
||||
.list(
|
||||
UpstreamOAuthProviderFilter::new().disabled_only(),
|
||||
Pagination::first(1)
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.edges
|
||||
.is_empty());
|
||||
}
|
||||
}
|
||||
|
@@ -62,6 +62,7 @@ struct ProviderLookup {
|
||||
token_endpoint_signing_alg: Option<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
created_at: DateTime<Utc>,
|
||||
disabled_at: Option<DateTime<Utc>>,
|
||||
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
|
||||
jwks_uri_override: Option<String>,
|
||||
authorization_endpoint_override: Option<String>,
|
||||
@@ -161,6 +162,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
created_at: value.created_at,
|
||||
disabled_at: value.disabled_at,
|
||||
claims_imports: value.claims_imports.0,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
@@ -200,6 +202,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at,
|
||||
disabled_at,
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
jwks_uri_override,
|
||||
authorization_endpoint_override,
|
||||
@@ -308,6 +311,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||
created_at,
|
||||
disabled_at: None,
|
||||
claims_imports: params.claims_imports,
|
||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||
token_endpoint_override: params.token_endpoint_override,
|
||||
@@ -434,6 +438,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
scope = EXCLUDED.scope,
|
||||
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
|
||||
disabled_at = NULL,
|
||||
client_id = EXCLUDED.client_id,
|
||||
encrypted_client_secret = EXCLUDED.encrypted_client_secret,
|
||||
claims_imports = EXCLUDED.claims_imports,
|
||||
@@ -487,6 +492,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||
created_at,
|
||||
disabled_at: None,
|
||||
claims_imports: params.claims_imports,
|
||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||
token_endpoint_override: params.token_endpoint_override,
|
||||
@@ -497,6 +503,37 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.disable",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%upstream_oauth_provider.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn disable(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
upstream_oauth_provider: UpstreamOAuthProvider,
|
||||
) -> Result<(), Self::Error> {
|
||||
let disabled_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE upstream_oauth_providers
|
||||
SET disabled_at = $2
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
disabled_at,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.list",
|
||||
skip_all,
|
||||
@@ -507,10 +544,9 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
)]
|
||||
async fn list(
|
||||
&mut self,
|
||||
_filter: UpstreamOAuthProviderFilter<'_>,
|
||||
filter: UpstreamOAuthProviderFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
||||
// XXX: the filter is currently ignored, as it does not have any fields
|
||||
let (sql, arguments) = Query::select()
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
@@ -579,6 +615,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
)),
|
||||
ProviderLookupIden::CreatedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::DisabledAt,
|
||||
)),
|
||||
ProviderLookupIden::DisabledAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
@@ -629,6 +672,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
ProviderLookupIden::AdditionalParameters,
|
||||
)
|
||||
.from(UpstreamOAuthProviders::Table)
|
||||
.and_where_option(filter.enabled().map(|enabled| {
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::DisabledAt,
|
||||
))
|
||||
.is_null()
|
||||
.eq(enabled)
|
||||
}))
|
||||
.generate_pagination(
|
||||
(
|
||||
UpstreamOAuthProviders::Table,
|
||||
@@ -660,9 +711,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
)]
|
||||
async fn count(
|
||||
&mut self,
|
||||
_filter: UpstreamOAuthProviderFilter<'_>,
|
||||
filter: UpstreamOAuthProviderFilter<'_>,
|
||||
) -> Result<usize, Self::Error> {
|
||||
// XXX: the filter is currently ignored, as it does not have any fields
|
||||
let (sql, arguments) = Query::select()
|
||||
.expr(
|
||||
Expr::col((
|
||||
@@ -672,6 +722,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
.count(),
|
||||
)
|
||||
.from(UpstreamOAuthProviders::Table)
|
||||
.and_where_option(filter.enabled().map(|enabled| {
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::DisabledAt,
|
||||
))
|
||||
.is_null()
|
||||
.eq(enabled)
|
||||
}))
|
||||
.build_sqlx(PostgresQueryBuilder);
|
||||
|
||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||
@@ -685,14 +743,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.upstream_oauth_provider.all",
|
||||
name = "db.upstream_oauth_provider.all_enabled",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
|
||||
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
ProviderLookup,
|
||||
r#"
|
||||
@@ -707,6 +765,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at,
|
||||
disabled_at,
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
jwks_uri_override,
|
||||
authorization_endpoint_override,
|
||||
@@ -715,6 +774,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
pkce_mode,
|
||||
additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
|
||||
FROM upstream_oauth_providers
|
||||
WHERE disabled_at IS NULL
|
||||
"#,
|
||||
)
|
||||
.traced()
|
||||
|
@@ -137,6 +137,7 @@ impl Pagination {
|
||||
}
|
||||
|
||||
/// A page of results returned by a paginated query
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Page<T> {
|
||||
/// When paginating forwards, this is true if there are more items after
|
||||
pub has_next_page: bool,
|
||||
|
@@ -25,6 +25,7 @@ pub struct UpstreamOAuthLinkFilter<'a> {
|
||||
// XXX: we might also want to filter for links without a user linked to them
|
||||
user: Option<&'a User>,
|
||||
provider: Option<&'a UpstreamOAuthProvider>,
|
||||
provider_enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl<'a> UpstreamOAuthLinkFilter<'a> {
|
||||
@@ -63,6 +64,26 @@ impl<'a> UpstreamOAuthLinkFilter<'a> {
|
||||
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
|
||||
self.provider
|
||||
}
|
||||
|
||||
/// Set whether to filter for enabled providers
|
||||
#[must_use]
|
||||
pub const fn enabled_providers_only(mut self) -> Self {
|
||||
self.provider_enabled = Some(true);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to filter for disabled providers
|
||||
#[must_use]
|
||||
pub const fn disabled_providers_only(mut self) -> Self {
|
||||
self.provider_enabled = Some(false);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the provider enabled filter
|
||||
#[must_use]
|
||||
pub const fn provider_enabled(&self) -> Option<bool> {
|
||||
self.provider_enabled
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`UpstreamOAuthLinkRepository`] helps interacting with
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -82,6 +82,11 @@ pub struct UpstreamOAuthProviderParams {
|
||||
/// Filter parameters for listing upstream OAuth 2.0 providers
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub struct UpstreamOAuthProviderFilter<'a> {
|
||||
/// Filter by whether the provider is enabled
|
||||
///
|
||||
/// If `None`, all providers are returned
|
||||
enabled: Option<bool>,
|
||||
|
||||
_lifetime: PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
@@ -91,6 +96,28 @@ impl<'a> UpstreamOAuthProviderFilter<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Return only enabled providers
|
||||
#[must_use]
|
||||
pub const fn enabled_only(mut self) -> Self {
|
||||
self.enabled = Some(true);
|
||||
self
|
||||
}
|
||||
|
||||
/// Return only disabled providers
|
||||
#[must_use]
|
||||
pub const fn disabled_only(mut self) -> Self {
|
||||
self.enabled = Some(false);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the enabled filter
|
||||
///
|
||||
/// Returns `None` if the filter is not set
|
||||
#[must_use]
|
||||
pub const fn enabled(&self) -> Option<bool> {
|
||||
self.enabled
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
|
||||
@@ -175,6 +202,22 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
params: UpstreamOAuthProviderParams,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
/// Disable an upstream OAuth provider
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
/// * `provider`: The provider to disable
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn disable(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
provider: UpstreamOAuthProvider,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
|
||||
///
|
||||
/// # Parameters
|
||||
@@ -205,12 +248,12 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
filter: UpstreamOAuthProviderFilter<'_>,
|
||||
) -> Result<usize, Self::Error>;
|
||||
|
||||
/// Get all upstream OAuth providers
|
||||
/// Get all enabled upstream OAuth providers
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(UpstreamOAuthProviderRepository:
|
||||
@@ -234,6 +277,12 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
||||
|
||||
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
|
||||
|
||||
async fn disable(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
provider: UpstreamOAuthProvider
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
async fn list(
|
||||
&mut self,
|
||||
filter: UpstreamOAuthProviderFilter<'_>,
|
||||
@@ -245,5 +294,5 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
||||
filter: UpstreamOAuthProviderFilter<'_>
|
||||
) -> Result<usize, Self::Error>;
|
||||
|
||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||
);
|
||||
|
Reference in New Issue
Block a user