1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Allow endpoints and discovery mode override for upstream oauth2 providers

This time, at the configuration and database level
This commit is contained in:
Quentin Gliech
2023-11-17 14:22:57 +01:00
parent 364093f12f
commit 7315dd9a7a
19 changed files with 764 additions and 233 deletions

View File

@@ -17,12 +17,13 @@ use std::collections::HashSet;
use clap::Parser; use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig, SyncConfig}; use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
use mas_storage::{ use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock, upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
RepositoryAccess, SystemClock,
}; };
use mas_storage_pg::PgRepository; use mas_storage_pg::PgRepository;
use rand::SeedableRng; use rand::SeedableRng;
use sqlx::{postgres::PgAdvisoryLock, Acquire}; use sqlx::{postgres::PgAdvisoryLock, Acquire};
use tracing::{info, info_span, warn}; use tracing::{error, info, info_span, warn};
use crate::util::database_connection_from_config; use crate::util::database_connection_from_config;
@@ -204,10 +205,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
} }
for provider in config.upstream_oauth2.providers { for provider in config.upstream_oauth2.providers {
let _span = info_span!("provider", %provider.id).entered();
if existing_ids.contains(&provider.id) { if existing_ids.contains(&provider.id) {
info!(%provider.id, "Updating provider"); info!("Updating provider");
} else { } else {
info!(%provider.id, "Adding provider"); info!("Adding provider");
} }
if dry_run { if dry_run {
@@ -218,20 +220,65 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
.client_secret() .client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?; .transpose()?;
let client_auth_method = provider.client_auth_method(); let token_endpoint_auth_method = provider.client_auth_method();
let client_auth_signing_alg = provider.client_auth_signing_alg(); let token_endpoint_signing_alg = provider.client_auth_signing_alg();
let discovery_mode = match provider.discovery_mode {
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
}
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
}
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
}
};
if discovery_mode.is_disabled() {
if provider.authorization_endpoint.is_none() {
error!("Provider has discovery disabled but no authorization endpoint set");
}
if provider.token_endpoint.is_none() {
error!("Provider has discovery disabled but no token endpoint set");
}
if provider.jwks_uri.is_none() {
error!("Provider has discovery disabled but no JWKS URI set");
}
}
let pkce_mode = match provider.pkce_method {
mas_config::UpstreamOAuth2PkceMethod::Auto => {
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
}
mas_config::UpstreamOAuth2PkceMethod::Always => {
mas_data_model::UpstreamOAuthProviderPkceMode::S256
}
mas_config::UpstreamOAuth2PkceMethod::Never => {
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
}
};
repo.upstream_oauth_provider() repo.upstream_oauth_provider()
.upsert( .upsert(
&clock, &clock,
provider.id, provider.id,
provider.issuer, UpstreamOAuthProviderParams {
provider.scope.parse()?, issuer: provider.issuer,
client_auth_method, scope: provider.scope.parse()?,
client_auth_signing_alg, token_endpoint_auth_method,
provider.client_id, token_endpoint_signing_alg,
client_id: provider.client_id,
encrypted_client_secret, encrypted_client_secret,
map_claims_imports(&provider.claims_imports), claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
},
) )
.await?; .await?;
} }
@@ -268,10 +315,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
} }
for client in config.clients.iter() { for client in config.clients.iter() {
let _span = info_span!("client", client.id = %client.client_id).entered();
if existing_ids.contains(&client.client_id) { if existing_ids.contains(&client.client_id) {
info!(client.id = %client.client_id, "Updating client"); info!("Updating client");
} else { } else {
info!(client.id = %client.client_id, "Adding client"); info!("Adding client");
} }
if dry_run { if dry_run {

View File

@@ -51,10 +51,10 @@ pub use self::{
}, },
templates::TemplatesConfig, templates::TemplatesConfig,
upstream_oauth2::{ upstream_oauth2::{
ClaimsImports as UpstreamOAuth2ClaimsImports, ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
EmailImportPreference as UpstreamOAuth2EmailImportPreference, EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction, ImportAction as UpstreamOAuth2ImportAction,
ImportPreference as UpstreamOAuth2ImportPreference, ImportPreference as UpstreamOAuth2ImportPreference, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config, SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
}, },
}; };

View File

@@ -21,6 +21,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use ulid::Ulid; use ulid::Ulid;
use url::Url;
use crate::ConfigurationSection; use crate::ConfigurationSection;
@@ -197,6 +198,39 @@ pub struct ClaimsImports {
pub email: EmailImportPreference, pub email: EmailImportPreference,
} }
/// How to discover the provider's configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum DiscoveryMode {
/// Use OIDC discovery with strict metadata verification
#[default]
Oidc,
/// Use OIDC discovery with relaxed metadata verification
Insecure,
/// Use a static configuration
Disabled,
}
/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum PkceMethod {
/// Use PKCE if the provider supports it
///
/// Defaults to no PKCE if provider discovery is disabled
#[default]
Auto,
/// Always use PKCE with the S256 challenge method
Always,
/// Never use PKCE
Never,
}
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider { pub struct Provider {
@@ -220,6 +254,34 @@ pub struct Provider {
#[serde(flatten)] #[serde(flatten)]
pub token_auth_method: TokenAuthMethod, pub token_auth_method: TokenAuthMethod,
/// How to discover the provider's configuration
///
/// Defaults to use OIDC discovery with strict metadata verification
#[serde(default)]
pub discovery_mode: DiscoveryMode,
/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
///
/// Defaults to `auto`, which uses PKCE if the provider supports it.
#[serde(default)]
pub pkce_method: PkceMethod,
/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
pub authorization_endpoint: Option<Url>,
/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
pub token_endpoint: Option<Url>,
/// The URL to use for getting the provider's public keys
///
/// Defaults to the `jwks_uri` provided through discovery
pub jwks_uri: Option<Url>,
/// How claims should be imported from the `id_token` provided by the /// How claims should be imported from the `id_token` provided by the
/// provider /// provider
pub claims_imports: ClaimsImports, pub claims_imports: ClaimsImports,

View File

@@ -16,6 +16,7 @@ use chrono::{DateTime, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
@@ -33,6 +34,48 @@ pub enum DiscoveryMode {
Disabled, Disabled,
} }
impl DiscoveryMode {
/// Returns `true` if discovery is disabled
#[must_use]
pub fn is_disabled(&self) -> bool {
matches!(self, DiscoveryMode::Disabled)
}
}
#[derive(Debug, Clone, Error)]
#[error("Invalid discovery mode {0:?}")]
pub struct InvalidDiscoveryModeError(String);
impl std::str::FromStr for DiscoveryMode {
type Err = InvalidDiscoveryModeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"oidc" => Ok(Self::Oidc),
"insecure" => Ok(Self::Insecure),
"disabled" => Ok(Self::Disabled),
s => Err(InvalidDiscoveryModeError(s.to_owned())),
}
}
}
impl DiscoveryMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Oidc => "oidc",
Self::Insecure => "insecure",
Self::Disabled => "disabled",
}
}
}
impl std::fmt::Display for DiscoveryMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum PkceMode { pub enum PkceMode {
@@ -47,6 +90,40 @@ pub enum PkceMode {
Disabled, Disabled,
} }
#[derive(Debug, Clone, Error)]
#[error("Invalid PKCE mode {0:?}")]
pub struct InvalidPkceModeError(String);
impl std::str::FromStr for PkceMode {
type Err = InvalidPkceModeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auto" => Ok(Self::Auto),
"s256" => Ok(Self::S256),
"disabled" => Ok(Self::Disabled),
s => Err(InvalidPkceModeError(s.to_owned())),
}
}
}
impl PkceMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::S256 => "s256",
Self::Disabled => "disabled",
}
}
}
impl std::fmt::Display for PkceMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider { pub struct UpstreamOAuthProvider {
pub id: Ulid, pub id: Ulid,

View File

@@ -292,9 +292,8 @@ mod tests {
use tower::BoxError; use tower::BoxError;
use ulid::Ulid; use ulid::Ulid;
use crate::test_utils::init_tracing;
use super::*; use super::*;
use crate::test_utils::init_tracing;
#[tokio::test] #[tokio::test]
async fn test_metadata_cache() { async fn test_metadata_cache() {

View File

@@ -803,6 +803,7 @@ mod tests {
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt}; use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
use mas_router::Route; use mas_router::Route;
use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
use oauth2_types::scope::{Scope, OPENID}; use oauth2_types::scope::{Scope, OPENID};
use sqlx::PgPool; use sqlx::PgPool;
@@ -858,13 +859,20 @@ mod tests {
.add( .add(
&mut rng, &mut rng,
&state.clock, &state.clock,
"https://example.com/".to_owned(), UpstreamOAuthProviderParams {
Scope::from_iter([OPENID]), issuer: "https://example.com/".to_owned(),
OAuthClientAuthenticationMethod::None, scope: Scope::from_iter([OPENID]),
None, token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
"client".to_owned(), token_endpoint_signing_alg: None,
None, client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports, claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
},
) )
.await .await
.unwrap(); .unwrap();

View File

@@ -311,7 +311,10 @@ mod test {
use mas_data_model::UpstreamOAuthProviderClaimsImports; use mas_data_model::UpstreamOAuthProviderClaimsImports;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_router::Route; use mas_router::Route;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess}; use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
RepositoryAccess,
};
use mas_templates::escape_html; use mas_templates::escape_html;
use oauth2_types::scope::OPENID; use oauth2_types::scope::OPENID;
use sqlx::PgPool; use sqlx::PgPool;
@@ -346,13 +349,20 @@ mod test {
.add( .add(
&mut rng, &mut rng,
&state.clock, &state.clock,
"https://first.com/".into(), UpstreamOAuthProviderParams {
[OPENID].into_iter().collect(), issuer: "https://first.com/".to_owned(),
OAuthClientAuthenticationMethod::None, scope: [OPENID].into_iter().collect(),
None, token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
"first_client".into(), token_endpoint_signing_alg: None,
None, client_id: "client".to_owned(),
UpstreamOAuthProviderClaimsImports::default(), encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
},
) )
.await .await
.unwrap(); .unwrap();
@@ -371,13 +381,20 @@ mod test {
.add( .add(
&mut rng, &mut rng,
&state.clock, &state.clock,
"https://second.com/".into(), UpstreamOAuthProviderParams {
[OPENID].into_iter().collect(), issuer: "https://second.com/".to_owned(),
OAuthClientAuthenticationMethod::None, scope: [OPENID].into_iter().collect(),
None, token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
"second_client".into(), token_endpoint_signing_alg: None,
None, client_id: "client".to_owned(),
UpstreamOAuthProviderClaimsImports::default(), encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
},
) )
.await .await
.unwrap(); .unwrap();

View File

@@ -1,6 +1,6 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n ", "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14)\n ",
"describe": { "describe": {
"columns": [], "columns": [],
"parameters": { "parameters": {
@@ -12,11 +12,16 @@
"Text", "Text",
"Text", "Text",
"Text", "Text",
"Timestamptz", "Jsonb",
"Jsonb" "Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz"
] ]
}, },
"nullable": [] "nullable": []
}, },
"hash": "6021c1b9e17b0b2e8b511888f8c6be00683ba0635a13eb7fcd403d3d4a3f90db" "hash": "311957a0b745660aa2a21b1bd211376739318efa1e84670e04189e1257d4a8ed"
} }

View File

@@ -0,0 +1,35 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz"
]
},
"nullable": [
false
]
},
"hash": "75b58c1b7f4e26997e961ad64418938938f09b3215a9b14f7edb3dd91cdf2dd5"
}

View File

@@ -1,6 +1,6 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\"\n FROM upstream_oauth_providers\n ", "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode\n FROM upstream_oauth_providers\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
@@ -47,6 +47,31 @@
"ordinal": 8, "ordinal": 8,
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>", "name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"type_info": "Jsonb" "type_info": "Jsonb"
},
{
"ordinal": 9,
"name": "jwks_uri_override",
"type_info": "Text"
},
{
"ordinal": 10,
"name": "authorization_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 11,
"name": "token_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 12,
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 13,
"name": "pkce_mode",
"type_info": "Text"
} }
], ],
"parameters": { "parameters": {
@@ -61,8 +86,13 @@
true, true,
false, false,
false, false,
false,
true,
true,
true,
false,
false false
] ]
}, },
"hash": "af65441068530b68826561d4308e15923ba6c6882ded4860ebde4a7641359abb" "hash": "b44e77ba737c9ec9af3838f148e2e882c90c0118ff77a92d2d93fe97dbd33233"
} }

View File

@@ -1,6 +1,6 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
@@ -47,6 +47,31 @@
"ordinal": 8, "ordinal": 8,
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>", "name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"type_info": "Jsonb" "type_info": "Jsonb"
},
{
"ordinal": 9,
"name": "jwks_uri_override",
"type_info": "Text"
},
{
"ordinal": 10,
"name": "authorization_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 11,
"name": "token_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 12,
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 13,
"name": "pkce_mode",
"type_info": "Text"
} }
], ],
"parameters": { "parameters": {
@@ -63,8 +88,13 @@
true, true,
false, false,
false, false,
false,
true,
true,
true,
false,
false false
] ]
}, },
"hash": "6733c54a8d9ed93a760f365a9362fdb0f77340d7a4df642a2942174aba2c6502" "hash": "e1759a6bda20a09a423e9dcb3a7544dbf259fea54e7cdaa714455f05814f39f6"
} }

View File

@@ -1,30 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz",
"Jsonb"
]
},
"nullable": [
false
]
},
"hash": "e7ce95415bb6b57cd601393c6abe5febfec2a963ce6eac7b099b761594b1dfaf"
}

View File

@@ -0,0 +1,21 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- Adds various endpoint overrides for oauth providers
ALTER TABLE upstream_oauth_providers
ADD COLUMN "jwks_uri_override" TEXT,
ADD COLUMN "authorization_endpoint_override" TEXT,
ADD COLUMN "token_endpoint_override" TEXT,
ADD COLUMN "discovery_mode" TEXT NOT NULL DEFAULT 'oidc',
ADD COLUMN "pkce_mode" TEXT NOT NULL DEFAULT 'auto';

View File

@@ -103,6 +103,11 @@ pub enum UpstreamOAuthProviders {
TokenEndpointAuthMethod, TokenEndpointAuthMethod,
CreatedAt, CreatedAt,
ClaimsImports, ClaimsImports,
DiscoveryMode,
PkceMode,
JwksUriOverride,
TokenEndpointOverride,
AuthorizationEndpointOverride,
} }
#[derive(sea_query::Iden)] #[derive(sea_query::Iden)]

View File

@@ -32,7 +32,8 @@ mod tests {
clock::MockClock, clock::MockClock,
upstream_oauth2::{ upstream_oauth2::{
UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter, UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
UpstreamOAuthSessionRepository,
}, },
user::UserRepository, user::UserRepository,
Pagination, RepositoryAccess, Pagination, RepositoryAccess,
@@ -59,13 +60,21 @@ mod tests {
.add( .add(
&mut rng, &mut rng,
&clock, &clock,
"https://example.com/".to_owned(), UpstreamOAuthProviderParams {
Scope::from_iter([OPENID]), issuer: "https://example.com/".to_owned(),
scope: Scope::from_iter([OPENID]),
token_endpoint_auth_method:
mas_iana::oauth::OAuthClientAuthenticationMethod::None, mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None, token_endpoint_signing_alg: None,
"client-id".to_owned(), client_id: "client-id".to_owned(),
None, encrypted_client_secret: None,
UpstreamOAuthProviderClaimsImports::default(), claims_imports: UpstreamOAuthProviderClaimsImports::default(),
token_endpoint_override: None,
authorization_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
},
) )
.await .await
.unwrap(); .unwrap();
@@ -232,13 +241,21 @@ mod tests {
.add( .add(
&mut rng, &mut rng,
&clock, &clock,
ISSUER.to_owned(), UpstreamOAuthProviderParams {
scope.clone(), issuer: ISSUER.to_owned(),
scope: scope.clone(),
token_endpoint_auth_method:
mas_iana::oauth::OAuthClientAuthenticationMethod::None, mas_iana::oauth::OAuthClientAuthenticationMethod::None,
None, token_endpoint_signing_alg: None,
client_id, client_id,
None, encrypted_client_secret: None,
UpstreamOAuthProviderClaimsImports::default(), claims_imports: UpstreamOAuthProviderClaimsImports::default(),
token_endpoint_override: None,
authorization_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
},
) )
.await .await
.unwrap(); .unwrap();

View File

@@ -14,16 +14,13 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderPkceMode,
};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_storage::{ use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, upstream_oauth2::{
UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
},
Clock, Page, Pagination, Clock, Page, Pagination,
}; };
use oauth2_types::scope::Scope;
use rand::RngCore; use rand::RngCore;
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder; use sea_query_binder::SqlxBinder;
@@ -63,6 +60,11 @@ struct ProviderLookup {
token_endpoint_auth_method: String, token_endpoint_auth_method: String,
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
claims_imports: Json<UpstreamOAuthProviderClaimsImports>, claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
jwks_uri_override: Option<String>,
authorization_endpoint_override: Option<String>,
token_endpoint_override: Option<String>,
discovery_mode: String,
pkce_mode: String,
} }
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider { impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
@@ -92,6 +94,53 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
.source(e) .source(e)
})?; })?;
let authorization_endpoint_override = value
.authorization_endpoint_override
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("authorization_endpoint_override")
.row(id)
.source(e)
})?;
let token_endpoint_override = value
.token_endpoint_override
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("token_endpoint_override")
.row(id)
.source(e)
})?;
let jwks_uri_override = value
.jwks_uri_override
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("jwks_uri_override")
.row(id)
.source(e)
})?;
let discovery_mode = value.discovery_mode.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("discovery_mode")
.row(id)
.source(e)
})?;
let pkce_mode = value.pkce_mode.parse().map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("pkce_mode")
.row(id)
.source(e)
})?;
Ok(UpstreamOAuthProvider { Ok(UpstreamOAuthProvider {
id, id,
issuer: value.issuer, issuer: value.issuer,
@@ -102,13 +151,11 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
token_endpoint_signing_alg, token_endpoint_signing_alg,
created_at: value.created_at, created_at: value.created_at,
claims_imports: value.claims_imports.0, claims_imports: value.claims_imports.0,
authorization_endpoint_override,
// TODO token_endpoint_override,
authorization_endpoint_override: None, jwks_uri_override,
token_endpoint_override: None, discovery_mode,
jwks_uri_override: None, pkce_mode,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
}) })
} }
} }
@@ -139,7 +186,12 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg, token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method,
created_at, created_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>" claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
jwks_uri_override,
authorization_endpoint_override,
token_endpoint_override,
discovery_mode,
pkce_mode
FROM upstream_oauth_providers FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1 WHERE upstream_oauth_provider_id = $1
"#, "#,
@@ -163,23 +215,16 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
fields( fields(
db.statement, db.statement,
upstream_oauth_provider.id, upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %issuer, upstream_oauth_provider.issuer = %params.issuer,
upstream_oauth_provider.client_id = %client_id, upstream_oauth_provider.client_id = %params.client_id,
), ),
err, err,
)] )]
#[allow(clippy::too_many_arguments)]
async fn add( async fn add(
&mut self, &mut self,
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
issuer: String, params: UpstreamOAuthProviderParams,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error> { ) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng); let id = Ulid::from_datetime_with_source(created_at.into(), rng);
@@ -195,19 +240,39 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg, token_endpoint_signing_alg,
client_id, client_id,
encrypted_client_secret, encrypted_client_secret,
created_at, claims_imports,
claims_imports authorization_endpoint_override,
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) token_endpoint_override,
jwks_uri_override,
discovery_mode,
pkce_mode,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
$10, $11, $12, $13, $14)
"#, "#,
Uuid::from(id), Uuid::from(id),
&issuer, &params.issuer,
scope.to_string(), params.scope.to_string(),
token_endpoint_auth_method.to_string(), params.token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string), params
&client_id, .token_endpoint_signing_alg
encrypted_client_secret.as_deref(), .as_ref()
.map(ToString::to_string),
&params.client_id,
params.encrypted_client_secret.as_deref(),
Json(&params.claims_imports) as _,
params
.authorization_endpoint_override
.as_ref()
.map(ToString::to_string),
params
.token_endpoint_override
.as_ref()
.map(ToString::to_string),
params.jwks_uri_override.as_ref().map(ToString::to_string),
params.discovery_mode.as_str(),
params.pkce_mode.as_str(),
created_at, created_at,
Json(&claims_imports) as _,
) )
.traced() .traced()
.execute(&mut *self.conn) .execute(&mut *self.conn)
@@ -215,21 +280,19 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
Ok(UpstreamOAuthProvider { Ok(UpstreamOAuthProvider {
id, id,
issuer, issuer: params.issuer,
scope, scope: params.scope,
client_id, client_id: params.client_id,
encrypted_client_secret, encrypted_client_secret: params.encrypted_client_secret,
token_endpoint_signing_alg, token_endpoint_signing_alg: params.token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method: params.token_endpoint_auth_method,
created_at, created_at,
claims_imports, claims_imports: params.claims_imports,
authorization_endpoint_override: params.authorization_endpoint_override,
// TODO token_endpoint_override: params.token_endpoint_override,
authorization_endpoint_override: None, jwks_uri_override: params.jwks_uri_override,
token_endpoint_override: None, discovery_mode: params.discovery_mode,
jwks_uri_override: None, pkce_mode: params.pkce_mode,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
}) })
} }
@@ -305,23 +368,16 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
fields( fields(
db.statement, db.statement,
upstream_oauth_provider.id = %id, upstream_oauth_provider.id = %id,
upstream_oauth_provider.issuer = %issuer, upstream_oauth_provider.issuer = %params.issuer,
upstream_oauth_provider.client_id = %client_id, upstream_oauth_provider.client_id = %params.client_id,
), ),
err, err,
)] )]
#[allow(clippy::too_many_arguments)]
async fn upsert( async fn upsert(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
id: Ulid, id: Ulid,
issuer: String, params: UpstreamOAuthProviderParams,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error> { ) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now(); let created_at = clock.now();
@@ -335,9 +391,15 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg, token_endpoint_signing_alg,
client_id, client_id,
encrypted_client_secret, encrypted_client_secret,
created_at, claims_imports,
claims_imports authorization_endpoint_override,
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) token_endpoint_override,
jwks_uri_override,
discovery_mode,
pkce_mode,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
$10, $11, $12, $13, $14)
ON CONFLICT (upstream_oauth_provider_id) ON CONFLICT (upstream_oauth_provider_id)
DO UPDATE DO UPDATE
SET SET
@@ -347,18 +409,37 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg, token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
client_id = EXCLUDED.client_id, client_id = EXCLUDED.client_id,
encrypted_client_secret = EXCLUDED.encrypted_client_secret, encrypted_client_secret = EXCLUDED.encrypted_client_secret,
claims_imports = EXCLUDED.claims_imports claims_imports = EXCLUDED.claims_imports,
authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
token_endpoint_override = EXCLUDED.token_endpoint_override,
jwks_uri_override = EXCLUDED.jwks_uri_override,
discovery_mode = EXCLUDED.discovery_mode,
pkce_mode = EXCLUDED.pkce_mode
RETURNING created_at RETURNING created_at
"#, "#,
Uuid::from(id), Uuid::from(id),
&issuer, &params.issuer,
scope.to_string(), params.scope.to_string(),
token_endpoint_auth_method.to_string(), params.token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string), params
&client_id, .token_endpoint_signing_alg
encrypted_client_secret.as_deref(), .as_ref()
.map(ToString::to_string),
&params.client_id,
params.encrypted_client_secret.as_deref(),
Json(&params.claims_imports) as _,
params
.authorization_endpoint_override
.as_ref()
.map(ToString::to_string),
params
.token_endpoint_override
.as_ref()
.map(ToString::to_string),
params.jwks_uri_override.as_ref().map(ToString::to_string),
params.discovery_mode.as_str(),
params.pkce_mode.as_str(),
created_at, created_at,
Json(&claims_imports) as _,
) )
.traced() .traced()
.fetch_one(&mut *self.conn) .fetch_one(&mut *self.conn)
@@ -366,21 +447,19 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
Ok(UpstreamOAuthProvider { Ok(UpstreamOAuthProvider {
id, id,
issuer, issuer: params.issuer,
scope, scope: params.scope,
client_id, client_id: params.client_id,
encrypted_client_secret, encrypted_client_secret: params.encrypted_client_secret,
token_endpoint_signing_alg, token_endpoint_signing_alg: params.token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method: params.token_endpoint_auth_method,
created_at, created_at,
claims_imports, claims_imports: params.claims_imports,
authorization_endpoint_override: params.authorization_endpoint_override,
// TODO token_endpoint_override: params.token_endpoint_override,
authorization_endpoint_override: None, jwks_uri_override: params.jwks_uri_override,
token_endpoint_override: None, discovery_mode: params.discovery_mode,
jwks_uri_override: None, pkce_mode: params.pkce_mode,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
}) })
} }
@@ -459,6 +538,41 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
)), )),
ProviderLookupIden::ClaimsImports, ProviderLookupIden::ClaimsImports,
) )
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::JwksUriOverride,
)),
ProviderLookupIden::JwksUriOverride,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::TokenEndpointOverride,
)),
ProviderLookupIden::TokenEndpointOverride,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::AuthorizationEndpointOverride,
)),
ProviderLookupIden::AuthorizationEndpointOverride,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DiscoveryMode,
)),
ProviderLookupIden::DiscoveryMode,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::PkceMode,
)),
ProviderLookupIden::PkceMode,
)
.from(UpstreamOAuthProviders::Table) .from(UpstreamOAuthProviders::Table)
.generate_pagination( .generate_pagination(
( (
@@ -536,7 +650,12 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_signing_alg, token_endpoint_signing_alg,
token_endpoint_auth_method, token_endpoint_auth_method,
created_at, created_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>" claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
jwks_uri_override,
authorization_endpoint_override,
token_endpoint_override,
discovery_mode,
pkce_mode
FROM upstream_oauth_providers FROM upstream_oauth_providers
"#, "#,
) )

View File

@@ -21,6 +21,8 @@ mod session;
pub use self::{ pub use self::{
link::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository}, link::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
provider::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, provider::{
UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
},
session::UpstreamOAuthSessionRepository, session::UpstreamOAuthSessionRepository,
}; };

View File

@@ -15,14 +15,61 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use async_trait::async_trait; use async_trait::async_trait;
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports}; use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderPkceMode,
};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use rand_core::RngCore; use rand_core::RngCore;
use ulid::Ulid; use ulid::Ulid;
use url::Url;
use crate::{pagination::Page, repository_impl, Clock, Pagination}; use crate::{pagination::Page, repository_impl, Clock, Pagination};
/// Structure which holds parameters when inserting or updating an upstream
/// OAuth 2.0 provider
pub struct UpstreamOAuthProviderParams {
/// The OIDC issuer of the provider
pub issuer: String,
/// The scope to request during the authorization flow
pub scope: Scope,
/// The token endpoint authentication method
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
/// The JWT signing algorithm to use when then `client_secret_jwt` or
/// `private_key_jwt` authentication methods are used
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
/// The client ID to use when authenticating to the upstream
pub client_id: String,
/// The encrypted client secret to use when authenticating to the upstream
pub encrypted_client_secret: Option<String>,
/// How claims should be imported from the upstream provider
pub claims_imports: UpstreamOAuthProviderClaimsImports,
/// The URL to use as the authorization endpoint. If `None`, the URL will be
/// discovered
pub authorization_endpoint_override: Option<Url>,
/// The URL to use as the token endpoint. If `None`, the URL will be
/// discovered
pub token_endpoint_override: Option<Url>,
/// The URL to use when fetching JWKS. If `None`, the URL will be discovered
pub jwks_uri_override: Option<Url>,
/// How the provider metadata should be discovered
pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
/// How should PKCE be used
pub pkce_mode: UpstreamOAuthProviderPkceMode,
}
/// Filter parameters for listing upstream OAuth 2.0 providers /// Filter parameters for listing upstream OAuth 2.0 providers
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UpstreamOAuthProviderFilter<'a> { pub struct UpstreamOAuthProviderFilter<'a> {
@@ -65,33 +112,16 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
/// ///
/// * `rng`: A random number generator /// * `rng`: A random number generator
/// * `clock`: The clock used to generate timestamps /// * `clock`: The clock used to generate timestamps
/// * `issuer`: The OIDC issuer of the provider /// * `params`: The parameters of the provider to add
/// * `scope`: The scope to request during the authorization flow
/// * `token_endpoint_auth_method`: The token endpoint authentication method
/// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use
/// when then `client_secret_jwt` or `private_key_jwt` authentication
/// methods are used
/// * `client_id`: The client ID to use when authenticating to the upstream
/// * `encrypted_client_secret`: The encrypted client secret to use when
/// authenticating to the upstream
/// * `claims_imports`: How claims should be imported from the upstream
/// provider
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
#[allow(clippy::too_many_arguments)]
async fn add( async fn add(
&mut self, &mut self,
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
issuer: String, params: UpstreamOAuthProviderParams,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error>; ) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Delete an upstream OAuth provider /// Delete an upstream OAuth provider
@@ -124,33 +154,16 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
/// ///
/// * `clock`: The clock used to generate timestamps /// * `clock`: The clock used to generate timestamps
/// * `id`: The ID of the provider to update /// * `id`: The ID of the provider to update
/// * `issuer`: The OIDC issuer of the provider /// * `params`: The parameters of the provider to update
/// * `scope`: The scope to request during the authorization flow
/// * `token_endpoint_auth_method`: The token endpoint authentication method
/// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use
/// when then `client_secret_jwt` or `private_key_jwt` authentication
/// methods are used
/// * `client_id`: The client ID to use when authenticating to the upstream
/// * `encrypted_client_secret`: The encrypted client secret to use when
/// authenticating to the upstream
/// * `claims_imports`: How claims should be imported from the upstream
/// provider
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
#[allow(clippy::too_many_arguments)]
async fn upsert( async fn upsert(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
id: Ulid, id: Ulid,
issuer: String, params: UpstreamOAuthProviderParams,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error>; ) -> Result<UpstreamOAuthProvider, Self::Error>;
/// List [`UpstreamOAuthProvider`] with the given filter and pagination /// List [`UpstreamOAuthProvider`] with the given filter and pagination
@@ -198,26 +211,14 @@ repository_impl!(UpstreamOAuthProviderRepository:
&mut self, &mut self,
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
issuer: String, params: UpstreamOAuthProviderParams
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports
) -> Result<UpstreamOAuthProvider, Self::Error>; ) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn upsert( async fn upsert(
&mut self, &mut self,
clock: &dyn Clock, clock: &dyn Clock,
id: Ulid, id: Ulid,
issuer: String, params: UpstreamOAuthProviderParams
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error>; ) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>; async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;

View File

@@ -633,6 +633,32 @@
} }
} }
}, },
"DiscoveryMode": {
"description": "How to discover the provider's configuration",
"oneOf": [
{
"description": "Use OIDC discovery with strict metadata verification",
"type": "string",
"enum": [
"oidc"
]
},
{
"description": "Use OIDC discovery with relaxed metadata verification",
"type": "string",
"enum": [
"insecure"
]
},
{
"description": "Use a static configuration",
"type": "string",
"enum": [
"disabled"
]
}
]
},
"DisplaynameImportPreference": { "DisplaynameImportPreference": {
"description": "What should be done for the displayname attribute", "description": "What should be done for the displayname attribute",
"type": "object", "type": "object",
@@ -1520,6 +1546,32 @@
} }
} }
}, },
"PkceMethod": {
"description": "Whether to use proof key for code exchange (PKCE) when requesting and exchanging the token.",
"oneOf": [
{
"description": "Use PKCE if the provider supports it\n\nDefaults to no PKCE if provider discovery is disabled",
"type": "string",
"enum": [
"auto"
]
},
{
"description": "Always use PKCE with the S256 challenge method",
"type": "string",
"enum": [
"always"
]
},
{
"description": "Never use PKCE",
"type": "string",
"enum": [
"never"
]
}
]
},
"PolicyConfig": { "PolicyConfig": {
"description": "Application secrets", "description": "Application secrets",
"type": "object", "type": "object",
@@ -1706,6 +1758,11 @@
"scope" "scope"
], ],
"properties": { "properties": {
"authorization_endpoint": {
"description": "The URL to use for the provider's authorization endpoint\n\nDefaults to the `authorization_endpoint` provided through discovery",
"type": "string",
"format": "uri"
},
"claims_imports": { "claims_imports": {
"description": "How claims should be imported from the `id_token` provided by the provider", "description": "How claims should be imported from the `id_token` provided by the provider",
"allOf": [ "allOf": [
@@ -1718,6 +1775,15 @@
"description": "The client ID to use when authenticating with the provider", "description": "The client ID to use when authenticating with the provider",
"type": "string" "type": "string"
}, },
"discovery_mode": {
"description": "How to discover the provider's configuration\n\nDefaults to use OIDC discovery with strict metadata verification",
"default": "oidc",
"allOf": [
{
"$ref": "#/definitions/DiscoveryMode"
}
]
},
"id": { "id": {
"description": "A ULID as per https://github.com/ulid/spec", "description": "A ULID as per https://github.com/ulid/spec",
"type": "string", "type": "string",
@@ -1727,9 +1793,28 @@
"description": "The OIDC issuer URL", "description": "The OIDC issuer URL",
"type": "string" "type": "string"
}, },
"jwks_uri": {
"description": "The URL to use for getting the provider's public keys\n\nDefaults to the `jwks_uri` provided through discovery",
"type": "string",
"format": "uri"
},
"pkce_method": {
"description": "Whether to use proof key for code exchange (PKCE) when requesting and exchanging the token.\n\nDefaults to `auto`, which uses PKCE if the provider supports it.",
"default": "auto",
"allOf": [
{
"$ref": "#/definitions/PkceMethod"
}
]
},
"scope": { "scope": {
"description": "The scopes to request from the provider", "description": "The scopes to request from the provider",
"type": "string" "type": "string"
},
"token_endpoint": {
"description": "The URL to use for the provider's token endpoint\n\nDefaults to the `token_endpoint` provided through discovery",
"type": "string",
"format": "uri"
} }
} }
}, },