You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Add a soft-deletion column on upstream OAuth 2.0 providers
This commit is contained in:
@@ -110,7 +110,7 @@ pub async fn config_sync(
|
|||||||
.map(|p| p.id)
|
.map(|p| p.id)
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
let existing = repo.upstream_oauth_provider().all().await?;
|
let existing = repo.upstream_oauth_provider().all_enabled().await?;
|
||||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||||
if prune {
|
if prune {
|
||||||
|
@@ -141,10 +141,19 @@ pub struct UpstreamOAuthProvider {
|
|||||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||||
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||||
pub created_at: DateTime<Utc>,
|
pub created_at: DateTime<Utc>,
|
||||||
|
pub disabled_at: Option<DateTime<Utc>>,
|
||||||
pub claims_imports: ClaimsImports,
|
pub claims_imports: ClaimsImports,
|
||||||
pub additional_authorization_parameters: Vec<(String, String)>,
|
pub additional_authorization_parameters: Vec<(String, String)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UpstreamOAuthProvider {
|
||||||
|
/// Returns `true` if the provider is enabled
|
||||||
|
#[must_use]
|
||||||
|
pub const fn enabled(&self) -> bool {
|
||||||
|
self.disabled_at.is_none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Whether to set the email as verified when importing it from the upstream
|
/// Whether to set the email as verified when importing it from the upstream
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
|
@@ -501,7 +501,9 @@ impl User {
|
|||||||
.transpose()?;
|
.transpose()?;
|
||||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let filter = UpstreamOAuthLinkFilter::new().for_user(&self.0);
|
let filter = UpstreamOAuthLinkFilter::new()
|
||||||
|
.for_user(&self.0)
|
||||||
|
.enabled_providers_only();
|
||||||
|
|
||||||
let page = repo.upstream_oauth_link().list(filter, pagination).await?;
|
let page = repo.upstream_oauth_link().list(filter, pagination).await?;
|
||||||
|
|
||||||
|
@@ -73,6 +73,11 @@ impl UpstreamOAuthQuery {
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We only allow enabled providers to be fetched
|
||||||
|
if !provider.enabled() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(Some(UpstreamOAuth2Provider::new(provider)))
|
Ok(Some(UpstreamOAuth2Provider::new(provider)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +115,9 @@ impl UpstreamOAuthQuery {
|
|||||||
.transpose()?;
|
.transpose()?;
|
||||||
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let filter = UpstreamOAuthProviderFilter::new();
|
// We only want enabled providers
|
||||||
|
// XXX: we may want to let admins see disabled providers
|
||||||
|
let filter = UpstreamOAuthProviderFilter::new().enabled_only();
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
|
@@ -20,6 +20,7 @@ use hyper::StatusCode;
|
|||||||
use mas_axum_utils::{
|
use mas_axum_utils::{
|
||||||
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
|
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
|
||||||
};
|
};
|
||||||
|
use mas_data_model::UpstreamOAuthProvider;
|
||||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
@@ -81,6 +82,7 @@ pub(crate) async fn get(
|
|||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.lookup(provider_id)
|
.lookup(provider_id)
|
||||||
.await?
|
.await?
|
||||||
|
.filter(UpstreamOAuthProvider::enabled)
|
||||||
.ok_or(RouteError::ProviderNotFound)?;
|
.ok_or(RouteError::ProviderNotFound)?;
|
||||||
|
|
||||||
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
|
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
|
||||||
|
@@ -167,7 +167,7 @@ impl MetadataCache {
|
|||||||
interval: std::time::Duration,
|
interval: std::time::Duration,
|
||||||
repository: &mut R,
|
repository: &mut R,
|
||||||
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
|
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
|
||||||
let providers = repository.upstream_oauth_provider().all().await?;
|
let providers = repository.upstream_oauth_provider().all_enabled().await?;
|
||||||
|
|
||||||
for provider in providers {
|
for provider in providers {
|
||||||
let verify = match provider.discovery_mode {
|
let verify = match provider.discovery_mode {
|
||||||
@@ -504,6 +504,7 @@ mod tests {
|
|||||||
token_endpoint_signing_alg: None,
|
token_endpoint_signing_alg: None,
|
||||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
||||||
created_at: clock.now(),
|
created_at: clock.now(),
|
||||||
|
disabled_at: None,
|
||||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||||
additional_authorization_parameters: Vec::new(),
|
additional_authorization_parameters: Vec::new(),
|
||||||
};
|
};
|
||||||
|
@@ -20,6 +20,7 @@ use hyper::StatusCode;
|
|||||||
use mas_axum_utils::{
|
use mas_axum_utils::{
|
||||||
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
|
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
|
||||||
};
|
};
|
||||||
|
use mas_data_model::UpstreamOAuthProvider;
|
||||||
use mas_keystore::{Encrypter, Keystore};
|
use mas_keystore::{Encrypter, Keystore};
|
||||||
use mas_oidc_client::requests::{
|
use mas_oidc_client::requests::{
|
||||||
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
|
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
|
||||||
@@ -146,6 +147,7 @@ pub(crate) async fn get(
|
|||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.lookup(provider_id)
|
.lookup(provider_id)
|
||||||
.await?
|
.await?
|
||||||
|
.filter(UpstreamOAuthProvider::enabled)
|
||||||
.ok_or(RouteError::ProviderNotFound)?;
|
.ok_or(RouteError::ProviderNotFound)?;
|
||||||
|
|
||||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||||
|
@@ -78,7 +78,7 @@ pub(crate) async fn get(
|
|||||||
return Ok((cookie_jar, reply).into_response());
|
return Ok((cookie_jar, reply).into_response());
|
||||||
};
|
};
|
||||||
|
|
||||||
let providers = repo.upstream_oauth_provider().all().await?;
|
let providers = repo.upstream_oauth_provider().all_enabled().await?;
|
||||||
|
|
||||||
// If password-based login is disabled, and there is only one upstream provider,
|
// If password-based login is disabled, and there is only one upstream provider,
|
||||||
// we can directly start an authorization flow
|
// we can directly start an authorization flow
|
||||||
@@ -149,7 +149,7 @@ pub(crate) async fn post(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !state.is_valid() {
|
if !state.is_valid() {
|
||||||
let providers = repo.upstream_oauth_provider().all().await?;
|
let providers = repo.upstream_oauth_provider().all_enabled().await?;
|
||||||
let content = render(
|
let content = render(
|
||||||
locale,
|
locale,
|
||||||
LoginContext::default()
|
LoginContext::default()
|
||||||
|
15
crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json
generated
Normal file
15
crates/storage-pg/.sqlx/query-048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76.json
generated
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"db_name": "PostgreSQL",
|
||||||
|
"query": "\n UPDATE upstream_oauth_providers\n SET disabled_at = $2\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"nullable": []
|
||||||
|
},
|
||||||
|
"hash": "048eec775f4af3ffd805e830e8286c6a5745e523b76e1083d6bfced0035c2f76"
|
||||||
|
}
|
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"db_name": "PostgreSQL",
|
"db_name": "PostgreSQL",
|
||||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
{
|
{
|
||||||
@@ -55,36 +55,41 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 10,
|
"ordinal": 10,
|
||||||
|
"name": "disabled_at",
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ordinal": 11,
|
||||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||||
"type_info": "Jsonb"
|
"type_info": "Jsonb"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 11,
|
"ordinal": 12,
|
||||||
"name": "jwks_uri_override",
|
"name": "jwks_uri_override",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 12,
|
"ordinal": 13,
|
||||||
"name": "authorization_endpoint_override",
|
"name": "authorization_endpoint_override",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 13,
|
"ordinal": 14,
|
||||||
"name": "token_endpoint_override",
|
"name": "token_endpoint_override",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 14,
|
"ordinal": 15,
|
||||||
"name": "discovery_mode",
|
"name": "discovery_mode",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 15,
|
"ordinal": 16,
|
||||||
"name": "pkce_mode",
|
"name": "pkce_mode",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 16,
|
"ordinal": 17,
|
||||||
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
||||||
"type_info": "Jsonb"
|
"type_info": "Jsonb"
|
||||||
}
|
}
|
||||||
@@ -105,6 +110,7 @@
|
|||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
true,
|
||||||
false,
|
false,
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
@@ -114,5 +120,5 @@
|
|||||||
true
|
true
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"hash": "d8d9e49227b7945b4c3bcd842b59a7af6a21f7e9e2d715dc6360c3d691373903"
|
"hash": "51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a"
|
||||||
}
|
}
|
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"db_name": "PostgreSQL",
|
"db_name": "PostgreSQL",
|
||||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n ",
|
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
{
|
{
|
||||||
@@ -55,36 +55,41 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 10,
|
"ordinal": 10,
|
||||||
|
"name": "disabled_at",
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ordinal": 11,
|
||||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||||
"type_info": "Jsonb"
|
"type_info": "Jsonb"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 11,
|
"ordinal": 12,
|
||||||
"name": "jwks_uri_override",
|
"name": "jwks_uri_override",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 12,
|
"ordinal": 13,
|
||||||
"name": "authorization_endpoint_override",
|
"name": "authorization_endpoint_override",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 13,
|
"ordinal": 14,
|
||||||
"name": "token_endpoint_override",
|
"name": "token_endpoint_override",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 14,
|
"ordinal": 15,
|
||||||
"name": "discovery_mode",
|
"name": "discovery_mode",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 15,
|
"ordinal": 16,
|
||||||
"name": "pkce_mode",
|
"name": "pkce_mode",
|
||||||
"type_info": "Text"
|
"type_info": "Text"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ordinal": 16,
|
"ordinal": 17,
|
||||||
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
||||||
"type_info": "Jsonb"
|
"type_info": "Jsonb"
|
||||||
}
|
}
|
||||||
@@ -103,6 +108,7 @@
|
|||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
true,
|
||||||
false,
|
false,
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
@@ -112,5 +118,5 @@
|
|||||||
true
|
true
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"hash": "ccf4965aa84c497ac9759cb31f3ecba59fdf18085791f799dfd398bef4f8eb8c"
|
"hash": "5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d"
|
||||||
}
|
}
|
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"db_name": "PostgreSQL",
|
"db_name": "PostgreSQL",
|
||||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
|
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
{
|
{
|
||||||
@@ -34,5 +34,5 @@
|
|||||||
false
|
false
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"hash": "21132afc29be5394a03680dd27d2aff5e2249a973083c0675935dc658f73b1f4"
|
"hash": "94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae"
|
||||||
}
|
}
|
@@ -0,0 +1,18 @@
|
|||||||
|
-- Copyright 2024 The Matrix.org Foundation C.I.C.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
-- Adds a `disabled_at` column to the `upstream_oauth_providers` table, to soft-delete providers.
|
||||||
|
ALTER TABLE "upstream_oauth_providers"
|
||||||
|
ADD COLUMN "disabled_at" TIMESTAMP WITH TIME ZONE;
|
@@ -107,6 +107,7 @@ pub enum UpstreamOAuthProviders {
|
|||||||
TokenEndpointSigningAlg,
|
TokenEndpointSigningAlg,
|
||||||
TokenEndpointAuthMethod,
|
TokenEndpointAuthMethod,
|
||||||
CreatedAt,
|
CreatedAt,
|
||||||
|
DisabledAt,
|
||||||
ClaimsImports,
|
ClaimsImports,
|
||||||
DiscoveryMode,
|
DiscoveryMode,
|
||||||
PkceMode,
|
PkceMode,
|
||||||
|
@@ -27,7 +27,10 @@ use ulid::Ulid;
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
iden::UpstreamOAuthLinks, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError,
|
iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
|
||||||
|
pagination::QueryBuilderExt,
|
||||||
|
tracing::ExecuteExt,
|
||||||
|
DatabaseError,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
|
/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
|
||||||
@@ -280,6 +283,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
|||||||
))
|
))
|
||||||
.eq(Uuid::from(provider.id))
|
.eq(Uuid::from(provider.id))
|
||||||
}))
|
}))
|
||||||
|
.and_where_option(filter.provider_enabled().map(|enabled| {
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthLinks::Table,
|
||||||
|
UpstreamOAuthLinks::UpstreamOAuthProviderId,
|
||||||
|
))
|
||||||
|
.eq(Expr::any(
|
||||||
|
Query::select()
|
||||||
|
.expr(Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||||
|
)))
|
||||||
|
.from(UpstreamOAuthProviders::Table)
|
||||||
|
.and_where(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::DisabledAt,
|
||||||
|
))
|
||||||
|
.is_null()
|
||||||
|
.eq(enabled),
|
||||||
|
)
|
||||||
|
.take(),
|
||||||
|
))
|
||||||
|
}))
|
||||||
.generate_pagination(
|
.generate_pagination(
|
||||||
(
|
(
|
||||||
UpstreamOAuthLinks::Table,
|
UpstreamOAuthLinks::Table,
|
||||||
@@ -328,6 +354,29 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
|||||||
))
|
))
|
||||||
.eq(Uuid::from(provider.id))
|
.eq(Uuid::from(provider.id))
|
||||||
}))
|
}))
|
||||||
|
.and_where_option(filter.provider_enabled().map(|enabled| {
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthLinks::Table,
|
||||||
|
UpstreamOAuthLinks::UpstreamOAuthProviderId,
|
||||||
|
))
|
||||||
|
.eq(Expr::any(
|
||||||
|
Query::select()
|
||||||
|
.expr(Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::UpstreamOAuthProviderId,
|
||||||
|
)))
|
||||||
|
.from(UpstreamOAuthProviders::Table)
|
||||||
|
.and_where(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::DisabledAt,
|
||||||
|
))
|
||||||
|
.is_null()
|
||||||
|
.eq(enabled),
|
||||||
|
)
|
||||||
|
.take(),
|
||||||
|
))
|
||||||
|
}))
|
||||||
.build_sqlx(PostgresQueryBuilder);
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||||
|
@@ -51,7 +51,7 @@ mod tests {
|
|||||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
||||||
|
|
||||||
// The provider list should be empty at the start
|
// The provider list should be empty at the start
|
||||||
let all_providers = repo.upstream_oauth_provider().all().await.unwrap();
|
let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
|
||||||
assert!(all_providers.is_empty());
|
assert!(all_providers.is_empty());
|
||||||
|
|
||||||
// Let's add a provider
|
// Let's add a provider
|
||||||
@@ -93,7 +93,7 @@ mod tests {
|
|||||||
assert_eq!(provider.client_id, "client-id");
|
assert_eq!(provider.client_id, "client-id");
|
||||||
|
|
||||||
// It should be in the list of all providers
|
// It should be in the list of all providers
|
||||||
let providers = repo.upstream_oauth_provider().all().await.unwrap();
|
let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
|
||||||
assert_eq!(providers.len(), 1);
|
assert_eq!(providers.len(), 1);
|
||||||
assert_eq!(providers[0].issuer, "https://example.com/");
|
assert_eq!(providers[0].issuer, "https://example.com/");
|
||||||
assert_eq!(providers[0].client_id, "client-id");
|
assert_eq!(providers[0].client_id, "client-id");
|
||||||
@@ -192,7 +192,8 @@ mod tests {
|
|||||||
// XXX: we should also try other combinations of the filter
|
// XXX: we should also try other combinations of the filter
|
||||||
let filter = UpstreamOAuthLinkFilter::new()
|
let filter = UpstreamOAuthLinkFilter::new()
|
||||||
.for_user(&user)
|
.for_user(&user)
|
||||||
.for_provider(&provider);
|
.for_provider(&provider)
|
||||||
|
.enabled_providers_only();
|
||||||
|
|
||||||
let links = repo
|
let links = repo
|
||||||
.upstream_oauth_link()
|
.upstream_oauth_link()
|
||||||
@@ -207,13 +208,70 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
|
assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
|
||||||
|
|
||||||
|
// There should be exactly one enabled provider
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new().enabled_only())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new().disabled_only())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
|
// Disable the provider
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.disable(&clock, provider.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// There should be exactly one disabled provider
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new().enabled_only())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
0
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new().disabled_only())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
|
||||||
// Try deleting the provider
|
// Try deleting the provider
|
||||||
repo.upstream_oauth_provider()
|
repo.upstream_oauth_provider()
|
||||||
.delete(provider)
|
.delete(provider)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let providers = repo.upstream_oauth_provider().all().await.unwrap();
|
assert_eq!(
|
||||||
assert!(providers.is_empty());
|
repo.upstream_oauth_provider()
|
||||||
|
.count(UpstreamOAuthProviderFilter::new())
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
0
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test that the pagination works as expected in the upstream OAuth
|
/// Test that the pagination works as expected in the upstream OAuth
|
||||||
@@ -287,6 +345,16 @@ mod tests {
|
|||||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||||
assert_eq!(&edge_ids, &ids[..10]);
|
assert_eq!(&edge_ids, &ids[..10]);
|
||||||
|
|
||||||
|
// Getting the same page with the "enabled only" filter should return the same
|
||||||
|
// results
|
||||||
|
let other_page = repo
|
||||||
|
.upstream_oauth_provider()
|
||||||
|
.list(filter.enabled_only(), Pagination::first(10))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(page, other_page);
|
||||||
|
|
||||||
// Lookup the next 10 items
|
// Lookup the next 10 items
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
@@ -334,5 +402,17 @@ mod tests {
|
|||||||
assert!(!page.has_next_page);
|
assert!(!page.has_next_page);
|
||||||
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
|
||||||
assert_eq!(&edge_ids, &ids[6..8]);
|
assert_eq!(&edge_ids, &ids[6..8]);
|
||||||
|
|
||||||
|
// There should not be any disabled providers
|
||||||
|
assert!(repo
|
||||||
|
.upstream_oauth_provider()
|
||||||
|
.list(
|
||||||
|
UpstreamOAuthProviderFilter::new().disabled_only(),
|
||||||
|
Pagination::first(1)
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.edges
|
||||||
|
.is_empty());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -62,6 +62,7 @@ struct ProviderLookup {
|
|||||||
token_endpoint_signing_alg: Option<String>,
|
token_endpoint_signing_alg: Option<String>,
|
||||||
token_endpoint_auth_method: String,
|
token_endpoint_auth_method: String,
|
||||||
created_at: DateTime<Utc>,
|
created_at: DateTime<Utc>,
|
||||||
|
disabled_at: Option<DateTime<Utc>>,
|
||||||
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
|
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
|
||||||
jwks_uri_override: Option<String>,
|
jwks_uri_override: Option<String>,
|
||||||
authorization_endpoint_override: Option<String>,
|
authorization_endpoint_override: Option<String>,
|
||||||
@@ -161,6 +162,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
|||||||
token_endpoint_auth_method,
|
token_endpoint_auth_method,
|
||||||
token_endpoint_signing_alg,
|
token_endpoint_signing_alg,
|
||||||
created_at: value.created_at,
|
created_at: value.created_at,
|
||||||
|
disabled_at: value.disabled_at,
|
||||||
claims_imports: value.claims_imports.0,
|
claims_imports: value.claims_imports.0,
|
||||||
authorization_endpoint_override,
|
authorization_endpoint_override,
|
||||||
token_endpoint_override,
|
token_endpoint_override,
|
||||||
@@ -200,6 +202,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
token_endpoint_signing_alg,
|
token_endpoint_signing_alg,
|
||||||
token_endpoint_auth_method,
|
token_endpoint_auth_method,
|
||||||
created_at,
|
created_at,
|
||||||
|
disabled_at,
|
||||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||||
jwks_uri_override,
|
jwks_uri_override,
|
||||||
authorization_endpoint_override,
|
authorization_endpoint_override,
|
||||||
@@ -308,6 +311,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||||
created_at,
|
created_at,
|
||||||
|
disabled_at: None,
|
||||||
claims_imports: params.claims_imports,
|
claims_imports: params.claims_imports,
|
||||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||||
token_endpoint_override: params.token_endpoint_override,
|
token_endpoint_override: params.token_endpoint_override,
|
||||||
@@ -434,6 +438,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
scope = EXCLUDED.scope,
|
scope = EXCLUDED.scope,
|
||||||
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
|
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
|
||||||
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
|
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
|
||||||
|
disabled_at = NULL,
|
||||||
client_id = EXCLUDED.client_id,
|
client_id = EXCLUDED.client_id,
|
||||||
encrypted_client_secret = EXCLUDED.encrypted_client_secret,
|
encrypted_client_secret = EXCLUDED.encrypted_client_secret,
|
||||||
claims_imports = EXCLUDED.claims_imports,
|
claims_imports = EXCLUDED.claims_imports,
|
||||||
@@ -487,6 +492,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||||
created_at,
|
created_at,
|
||||||
|
disabled_at: None,
|
||||||
claims_imports: params.claims_imports,
|
claims_imports: params.claims_imports,
|
||||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||||
token_endpoint_override: params.token_endpoint_override,
|
token_endpoint_override: params.token_endpoint_override,
|
||||||
@@ -497,6 +503,37 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "db.upstream_oauth_provider.disable",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
db.statement,
|
||||||
|
%upstream_oauth_provider.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
async fn disable(
|
||||||
|
&mut self,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
upstream_oauth_provider: UpstreamOAuthProvider,
|
||||||
|
) -> Result<(), Self::Error> {
|
||||||
|
let disabled_at = clock.now();
|
||||||
|
let res = sqlx::query!(
|
||||||
|
r#"
|
||||||
|
UPDATE upstream_oauth_providers
|
||||||
|
SET disabled_at = $2
|
||||||
|
WHERE upstream_oauth_provider_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(upstream_oauth_provider.id),
|
||||||
|
disabled_at,
|
||||||
|
)
|
||||||
|
.traced()
|
||||||
|
.execute(&mut *self.conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
DatabaseError::ensure_affected_rows(&res, 1)
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "db.upstream_oauth_provider.list",
|
name = "db.upstream_oauth_provider.list",
|
||||||
skip_all,
|
skip_all,
|
||||||
@@ -507,10 +544,9 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
)]
|
)]
|
||||||
async fn list(
|
async fn list(
|
||||||
&mut self,
|
&mut self,
|
||||||
_filter: UpstreamOAuthProviderFilter<'_>,
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
pagination: Pagination,
|
pagination: Pagination,
|
||||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
||||||
// XXX: the filter is currently ignored, as it does not have any fields
|
|
||||||
let (sql, arguments) = Query::select()
|
let (sql, arguments) = Query::select()
|
||||||
.expr_as(
|
.expr_as(
|
||||||
Expr::col((
|
Expr::col((
|
||||||
@@ -579,6 +615,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
)),
|
)),
|
||||||
ProviderLookupIden::CreatedAt,
|
ProviderLookupIden::CreatedAt,
|
||||||
)
|
)
|
||||||
|
.expr_as(
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::DisabledAt,
|
||||||
|
)),
|
||||||
|
ProviderLookupIden::DisabledAt,
|
||||||
|
)
|
||||||
.expr_as(
|
.expr_as(
|
||||||
Expr::col((
|
Expr::col((
|
||||||
UpstreamOAuthProviders::Table,
|
UpstreamOAuthProviders::Table,
|
||||||
@@ -629,6 +672,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
ProviderLookupIden::AdditionalParameters,
|
ProviderLookupIden::AdditionalParameters,
|
||||||
)
|
)
|
||||||
.from(UpstreamOAuthProviders::Table)
|
.from(UpstreamOAuthProviders::Table)
|
||||||
|
.and_where_option(filter.enabled().map(|enabled| {
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::DisabledAt,
|
||||||
|
))
|
||||||
|
.is_null()
|
||||||
|
.eq(enabled)
|
||||||
|
}))
|
||||||
.generate_pagination(
|
.generate_pagination(
|
||||||
(
|
(
|
||||||
UpstreamOAuthProviders::Table,
|
UpstreamOAuthProviders::Table,
|
||||||
@@ -660,9 +711,8 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
)]
|
)]
|
||||||
async fn count(
|
async fn count(
|
||||||
&mut self,
|
&mut self,
|
||||||
_filter: UpstreamOAuthProviderFilter<'_>,
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
) -> Result<usize, Self::Error> {
|
) -> Result<usize, Self::Error> {
|
||||||
// XXX: the filter is currently ignored, as it does not have any fields
|
|
||||||
let (sql, arguments) = Query::select()
|
let (sql, arguments) = Query::select()
|
||||||
.expr(
|
.expr(
|
||||||
Expr::col((
|
Expr::col((
|
||||||
@@ -672,6 +722,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
.count(),
|
.count(),
|
||||||
)
|
)
|
||||||
.from(UpstreamOAuthProviders::Table)
|
.from(UpstreamOAuthProviders::Table)
|
||||||
|
.and_where_option(filter.enabled().map(|enabled| {
|
||||||
|
Expr::col((
|
||||||
|
UpstreamOAuthProviders::Table,
|
||||||
|
UpstreamOAuthProviders::DisabledAt,
|
||||||
|
))
|
||||||
|
.is_null()
|
||||||
|
.eq(enabled)
|
||||||
|
}))
|
||||||
.build_sqlx(PostgresQueryBuilder);
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||||
@@ -685,14 +743,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "db.upstream_oauth_provider.all",
|
name = "db.upstream_oauth_provider.all_enabled",
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
db.statement,
|
db.statement,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
)]
|
)]
|
||||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
|
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
ProviderLookup,
|
ProviderLookup,
|
||||||
r#"
|
r#"
|
||||||
@@ -707,6 +765,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
token_endpoint_signing_alg,
|
token_endpoint_signing_alg,
|
||||||
token_endpoint_auth_method,
|
token_endpoint_auth_method,
|
||||||
created_at,
|
created_at,
|
||||||
|
disabled_at,
|
||||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||||
jwks_uri_override,
|
jwks_uri_override,
|
||||||
authorization_endpoint_override,
|
authorization_endpoint_override,
|
||||||
@@ -715,6 +774,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
pkce_mode,
|
pkce_mode,
|
||||||
additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
|
additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
|
||||||
FROM upstream_oauth_providers
|
FROM upstream_oauth_providers
|
||||||
|
WHERE disabled_at IS NULL
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.traced()
|
.traced()
|
||||||
|
@@ -137,6 +137,7 @@ impl Pagination {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A page of results returned by a paginated query
|
/// A page of results returned by a paginated query
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct Page<T> {
|
pub struct Page<T> {
|
||||||
/// When paginating forwards, this is true if there are more items after
|
/// When paginating forwards, this is true if there are more items after
|
||||||
pub has_next_page: bool,
|
pub has_next_page: bool,
|
||||||
|
@@ -25,6 +25,7 @@ pub struct UpstreamOAuthLinkFilter<'a> {
|
|||||||
// XXX: we might also want to filter for links without a user linked to them
|
// XXX: we might also want to filter for links without a user linked to them
|
||||||
user: Option<&'a User>,
|
user: Option<&'a User>,
|
||||||
provider: Option<&'a UpstreamOAuthProvider>,
|
provider: Option<&'a UpstreamOAuthProvider>,
|
||||||
|
provider_enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> UpstreamOAuthLinkFilter<'a> {
|
impl<'a> UpstreamOAuthLinkFilter<'a> {
|
||||||
@@ -63,6 +64,26 @@ impl<'a> UpstreamOAuthLinkFilter<'a> {
|
|||||||
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
|
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
|
||||||
self.provider
|
self.provider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set whether to filter for enabled providers
|
||||||
|
#[must_use]
|
||||||
|
pub const fn enabled_providers_only(mut self) -> Self {
|
||||||
|
self.provider_enabled = Some(true);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set whether to filter for disabled providers
|
||||||
|
#[must_use]
|
||||||
|
pub const fn disabled_providers_only(mut self) -> Self {
|
||||||
|
self.provider_enabled = Some(false);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the provider enabled filter
|
||||||
|
#[must_use]
|
||||||
|
pub const fn provider_enabled(&self) -> Option<bool> {
|
||||||
|
self.provider_enabled
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An [`UpstreamOAuthLinkRepository`] helps interacting with
|
/// An [`UpstreamOAuthLinkRepository`] helps interacting with
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -82,6 +82,11 @@ pub struct UpstreamOAuthProviderParams {
|
|||||||
/// Filter parameters for listing upstream OAuth 2.0 providers
|
/// Filter parameters for listing upstream OAuth 2.0 providers
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||||
pub struct UpstreamOAuthProviderFilter<'a> {
|
pub struct UpstreamOAuthProviderFilter<'a> {
|
||||||
|
/// Filter by whether the provider is enabled
|
||||||
|
///
|
||||||
|
/// If `None`, all providers are returned
|
||||||
|
enabled: Option<bool>,
|
||||||
|
|
||||||
_lifetime: PhantomData<&'a ()>,
|
_lifetime: PhantomData<&'a ()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +96,28 @@ impl<'a> UpstreamOAuthProviderFilter<'a> {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return only enabled providers
|
||||||
|
#[must_use]
|
||||||
|
pub const fn enabled_only(mut self) -> Self {
|
||||||
|
self.enabled = Some(true);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return only disabled providers
|
||||||
|
#[must_use]
|
||||||
|
pub const fn disabled_only(mut self) -> Self {
|
||||||
|
self.enabled = Some(false);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the enabled filter
|
||||||
|
///
|
||||||
|
/// Returns `None` if the filter is not set
|
||||||
|
#[must_use]
|
||||||
|
pub const fn enabled(&self) -> Option<bool> {
|
||||||
|
self.enabled
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
|
/// An [`UpstreamOAuthProviderRepository`] helps interacting with
|
||||||
@@ -175,6 +202,22 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
|||||||
params: UpstreamOAuthProviderParams,
|
params: UpstreamOAuthProviderParams,
|
||||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||||
|
|
||||||
|
/// Disable an upstream OAuth provider
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
///
|
||||||
|
/// * `clock`: The clock used to generate timestamps
|
||||||
|
/// * `provider`: The provider to disable
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
|
async fn disable(
|
||||||
|
&mut self,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
provider: UpstreamOAuthProvider,
|
||||||
|
) -> Result<(), Self::Error>;
|
||||||
|
|
||||||
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
|
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
|
||||||
///
|
///
|
||||||
/// # Parameters
|
/// # Parameters
|
||||||
@@ -205,12 +248,12 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
|||||||
filter: UpstreamOAuthProviderFilter<'_>,
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
) -> Result<usize, Self::Error>;
|
) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
/// Get all upstream OAuth providers
|
/// Get all enabled upstream OAuth providers
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns [`Self::Error`] if the underlying repository fails
|
/// Returns [`Self::Error`] if the underlying repository fails
|
||||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
repository_impl!(UpstreamOAuthProviderRepository:
|
repository_impl!(UpstreamOAuthProviderRepository:
|
||||||
@@ -234,6 +277,12 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
|||||||
|
|
||||||
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
|
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
|
||||||
|
|
||||||
|
async fn disable(
|
||||||
|
&mut self,
|
||||||
|
clock: &dyn Clock,
|
||||||
|
provider: UpstreamOAuthProvider
|
||||||
|
) -> Result<(), Self::Error>;
|
||||||
|
|
||||||
async fn list(
|
async fn list(
|
||||||
&mut self,
|
&mut self,
|
||||||
filter: UpstreamOAuthProviderFilter<'_>,
|
filter: UpstreamOAuthProviderFilter<'_>,
|
||||||
@@ -245,5 +294,5 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
|||||||
filter: UpstreamOAuthProviderFilter<'_>
|
filter: UpstreamOAuthProviderFilter<'_>
|
||||||
) -> Result<usize, Self::Error>;
|
) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
async fn all(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
|
||||||
);
|
);
|
||||||
|
Reference in New Issue
Block a user