You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Allow endpoints and discovery mode override for upstream oauth2 providers
This time, at the configuration and database level
This commit is contained in:
@@ -17,12 +17,13 @@ use std::collections::HashSet;
|
||||
use clap::Parser;
|
||||
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock,
|
||||
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
|
||||
RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::SeedableRng;
|
||||
use sqlx::{postgres::PgAdvisoryLock, Acquire};
|
||||
use tracing::{info, info_span, warn};
|
||||
use tracing::{error, info, info_span, warn};
|
||||
|
||||
use crate::util::database_connection_from_config;
|
||||
|
||||
@@ -204,10 +205,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
|
||||
}
|
||||
|
||||
for provider in config.upstream_oauth2.providers {
|
||||
let _span = info_span!("provider", %provider.id).entered();
|
||||
if existing_ids.contains(&provider.id) {
|
||||
info!(%provider.id, "Updating provider");
|
||||
info!("Updating provider");
|
||||
} else {
|
||||
info!(%provider.id, "Adding provider");
|
||||
info!("Adding provider");
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
@@ -218,20 +220,65 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
|
||||
.client_secret()
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
let client_auth_method = provider.client_auth_method();
|
||||
let client_auth_signing_alg = provider.client_auth_signing_alg();
|
||||
let token_endpoint_auth_method = provider.client_auth_method();
|
||||
let token_endpoint_signing_alg = provider.client_auth_signing_alg();
|
||||
|
||||
let discovery_mode = match provider.discovery_mode {
|
||||
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
|
||||
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
|
||||
}
|
||||
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
|
||||
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
|
||||
}
|
||||
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
|
||||
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
|
||||
}
|
||||
};
|
||||
|
||||
if discovery_mode.is_disabled() {
|
||||
if provider.authorization_endpoint.is_none() {
|
||||
error!("Provider has discovery disabled but no authorization endpoint set");
|
||||
}
|
||||
|
||||
if provider.token_endpoint.is_none() {
|
||||
error!("Provider has discovery disabled but no token endpoint set");
|
||||
}
|
||||
|
||||
if provider.jwks_uri.is_none() {
|
||||
error!("Provider has discovery disabled but no JWKS URI set");
|
||||
}
|
||||
}
|
||||
|
||||
let pkce_mode = match provider.pkce_method {
|
||||
mas_config::UpstreamOAuth2PkceMethod::Auto => {
|
||||
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
|
||||
}
|
||||
mas_config::UpstreamOAuth2PkceMethod::Always => {
|
||||
mas_data_model::UpstreamOAuthProviderPkceMode::S256
|
||||
}
|
||||
mas_config::UpstreamOAuth2PkceMethod::Never => {
|
||||
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
|
||||
}
|
||||
};
|
||||
|
||||
repo.upstream_oauth_provider()
|
||||
.upsert(
|
||||
&clock,
|
||||
provider.id,
|
||||
provider.issuer,
|
||||
provider.scope.parse()?,
|
||||
client_auth_method,
|
||||
client_auth_signing_alg,
|
||||
provider.client_id,
|
||||
encrypted_client_secret,
|
||||
map_claims_imports(&provider.claims_imports),
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: provider.issuer,
|
||||
scope: provider.scope.parse()?,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id: provider.client_id,
|
||||
encrypted_client_secret,
|
||||
claims_imports: map_claims_imports(&provider.claims_imports),
|
||||
token_endpoint_override: provider.token_endpoint,
|
||||
authorization_endpoint_override: provider.authorization_endpoint,
|
||||
jwks_uri_override: provider.jwks_uri,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -268,10 +315,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
|
||||
}
|
||||
|
||||
for client in config.clients.iter() {
|
||||
let _span = info_span!("client", client.id = %client.client_id).entered();
|
||||
if existing_ids.contains(&client.client_id) {
|
||||
info!(client.id = %client.client_id, "Updating client");
|
||||
info!("Updating client");
|
||||
} else {
|
||||
info!(client.id = %client.client_id, "Adding client");
|
||||
info!("Adding client");
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
|
@@ -51,10 +51,10 @@ pub use self::{
|
||||
},
|
||||
templates::TemplatesConfig,
|
||||
upstream_oauth2::{
|
||||
ClaimsImports as UpstreamOAuth2ClaimsImports,
|
||||
ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
|
||||
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
|
||||
ImportAction as UpstreamOAuth2ImportAction,
|
||||
ImportPreference as UpstreamOAuth2ImportPreference,
|
||||
ImportPreference as UpstreamOAuth2ImportPreference, PkceMethod as UpstreamOAuth2PkceMethod,
|
||||
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
|
||||
},
|
||||
};
|
||||
|
@@ -21,6 +21,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
use crate::ConfigurationSection;
|
||||
|
||||
@@ -197,6 +198,39 @@ pub struct ClaimsImports {
|
||||
pub email: EmailImportPreference,
|
||||
}
|
||||
|
||||
/// How to discover the provider's configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DiscoveryMode {
|
||||
/// Use OIDC discovery with strict metadata verification
|
||||
#[default]
|
||||
Oidc,
|
||||
|
||||
/// Use OIDC discovery with relaxed metadata verification
|
||||
Insecure,
|
||||
|
||||
/// Use a static configuration
|
||||
Disabled,
|
||||
}
|
||||
|
||||
/// Whether to use proof key for code exchange (PKCE) when requesting and
|
||||
/// exchanging the token.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PkceMethod {
|
||||
/// Use PKCE if the provider supports it
|
||||
///
|
||||
/// Defaults to no PKCE if provider discovery is disabled
|
||||
#[default]
|
||||
Auto,
|
||||
|
||||
/// Always use PKCE with the S256 challenge method
|
||||
Always,
|
||||
|
||||
/// Never use PKCE
|
||||
Never,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct Provider {
|
||||
@@ -220,6 +254,34 @@ pub struct Provider {
|
||||
#[serde(flatten)]
|
||||
pub token_auth_method: TokenAuthMethod,
|
||||
|
||||
/// How to discover the provider's configuration
|
||||
///
|
||||
/// Defaults to use OIDC discovery with strict metadata verification
|
||||
#[serde(default)]
|
||||
pub discovery_mode: DiscoveryMode,
|
||||
|
||||
/// Whether to use proof key for code exchange (PKCE) when requesting and
|
||||
/// exchanging the token.
|
||||
///
|
||||
/// Defaults to `auto`, which uses PKCE if the provider supports it.
|
||||
#[serde(default)]
|
||||
pub pkce_method: PkceMethod,
|
||||
|
||||
/// The URL to use for the provider's authorization endpoint
|
||||
///
|
||||
/// Defaults to the `authorization_endpoint` provided through discovery
|
||||
pub authorization_endpoint: Option<Url>,
|
||||
|
||||
/// The URL to use for the provider's token endpoint
|
||||
///
|
||||
/// Defaults to the `token_endpoint` provided through discovery
|
||||
pub token_endpoint: Option<Url>,
|
||||
|
||||
/// The URL to use for getting the provider's public keys
|
||||
///
|
||||
/// Defaults to the `jwks_uri` provided through discovery
|
||||
pub jwks_uri: Option<Url>,
|
||||
|
||||
/// How claims should be imported from the `id_token` provided by the
|
||||
/// provider
|
||||
pub claims_imports: ClaimsImports,
|
||||
|
@@ -16,6 +16,7 @@ use chrono::{DateTime, Utc};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use oauth2_types::scope::Scope;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
@@ -33,6 +34,48 @@ pub enum DiscoveryMode {
|
||||
Disabled,
|
||||
}
|
||||
|
||||
impl DiscoveryMode {
|
||||
/// Returns `true` if discovery is disabled
|
||||
#[must_use]
|
||||
pub fn is_disabled(&self) -> bool {
|
||||
matches!(self, DiscoveryMode::Disabled)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Error)]
|
||||
#[error("Invalid discovery mode {0:?}")]
|
||||
pub struct InvalidDiscoveryModeError(String);
|
||||
|
||||
impl std::str::FromStr for DiscoveryMode {
|
||||
type Err = InvalidDiscoveryModeError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"oidc" => Ok(Self::Oidc),
|
||||
"insecure" => Ok(Self::Insecure),
|
||||
"disabled" => Ok(Self::Disabled),
|
||||
s => Err(InvalidDiscoveryModeError(s.to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DiscoveryMode {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Oidc => "oidc",
|
||||
Self::Insecure => "insecure",
|
||||
Self::Disabled => "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DiscoveryMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PkceMode {
|
||||
@@ -47,6 +90,40 @@ pub enum PkceMode {
|
||||
Disabled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Error)]
|
||||
#[error("Invalid PKCE mode {0:?}")]
|
||||
pub struct InvalidPkceModeError(String);
|
||||
|
||||
impl std::str::FromStr for PkceMode {
|
||||
type Err = InvalidPkceModeError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"auto" => Ok(Self::Auto),
|
||||
"s256" => Ok(Self::S256),
|
||||
"disabled" => Ok(Self::Disabled),
|
||||
s => Err(InvalidPkceModeError(s.to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PkceMode {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Auto => "auto",
|
||||
Self::S256 => "s256",
|
||||
Self::Disabled => "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PkceMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct UpstreamOAuthProvider {
|
||||
pub id: Ulid,
|
||||
|
@@ -292,9 +292,8 @@ mod tests {
|
||||
use tower::BoxError;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::test_utils::init_tracing;
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::init_tracing;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_cache() {
|
||||
|
@@ -803,6 +803,7 @@ mod tests {
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
|
||||
use mas_router::Route;
|
||||
use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
|
||||
use oauth2_types::scope::{Scope, OPENID};
|
||||
use sqlx::PgPool;
|
||||
|
||||
@@ -858,13 +859,20 @@ mod tests {
|
||||
.add(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
"https://example.com/".to_owned(),
|
||||
Scope::from_iter([OPENID]),
|
||||
OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"client".to_owned(),
|
||||
None,
|
||||
claims_imports,
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: "https://example.com/".to_owned(),
|
||||
scope: Scope::from_iter([OPENID]),
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id: "client".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports,
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
@@ -311,7 +311,10 @@ mod test {
|
||||
use mas_data_model::UpstreamOAuthProviderClaimsImports;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_router::Route;
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
|
||||
RepositoryAccess,
|
||||
};
|
||||
use mas_templates::escape_html;
|
||||
use oauth2_types::scope::OPENID;
|
||||
use sqlx::PgPool;
|
||||
@@ -346,13 +349,20 @@ mod test {
|
||||
.add(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
"https://first.com/".into(),
|
||||
[OPENID].into_iter().collect(),
|
||||
OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"first_client".into(),
|
||||
None,
|
||||
UpstreamOAuthProviderClaimsImports::default(),
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: "https://first.com/".to_owned(),
|
||||
scope: [OPENID].into_iter().collect(),
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id: "client".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -371,13 +381,20 @@ mod test {
|
||||
.add(
|
||||
&mut rng,
|
||||
&state.clock,
|
||||
"https://second.com/".into(),
|
||||
[OPENID].into_iter().collect(),
|
||||
OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"second_client".into(),
|
||||
None,
|
||||
UpstreamOAuthProviderClaimsImports::default(),
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: "https://second.com/".to_owned(),
|
||||
scope: [OPENID].into_iter().collect(),
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id: "client".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n ",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
@@ -12,11 +12,16 @@
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz",
|
||||
"Jsonb"
|
||||
"Jsonb",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "6021c1b9e17b0b2e8b511888f8c6be00683ba0635a13eb7fcd403d3d4a3f90db"
|
||||
"hash": "311957a0b745660aa2a21b1bd211376739318efa1e84670e04189e1257d4a8ed"
|
||||
}
|
35
crates/storage-pg/.sqlx/query-75b58c1b7f4e26997e961ad64418938938f09b3215a9b14f7edb3dd91cdf2dd5.json
generated
Normal file
35
crates/storage-pg/.sqlx/query-75b58c1b7f4e26997e961ad64418938938f09b3215a9b14f7edb3dd91cdf2dd5.json
generated
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "75b58c1b7f4e26997e961ad64418938938f09b3215a9b14f7edb3dd91cdf2dd5"
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\"\n FROM upstream_oauth_providers\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode\n FROM upstream_oauth_providers\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -47,6 +47,31 @@
|
||||
"ordinal": 8,
|
||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "jwks_uri_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "authorization_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "token_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 13,
|
||||
"name": "pkce_mode",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -61,8 +86,13 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "af65441068530b68826561d4308e15923ba6c6882ded4860ebde4a7641359abb"
|
||||
"hash": "b44e77ba737c9ec9af3838f148e2e882c90c0118ff77a92d2d93fe97dbd33233"
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -47,6 +47,31 @@
|
||||
"ordinal": 8,
|
||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "jwks_uri_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "authorization_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "token_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 13,
|
||||
"name": "pkce_mode",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -63,8 +88,13 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "6733c54a8d9ed93a760f365a9362fdb0f77340d7a4df642a2942174aba2c6502"
|
||||
"hash": "e1759a6bda20a09a423e9dcb3a7544dbf259fea54e7cdaa714455f05814f39f6"
|
||||
}
|
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz",
|
||||
"Jsonb"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "e7ce95415bb6b57cd601393c6abe5febfec2a963ce6eac7b099b761594b1dfaf"
|
||||
}
|
@@ -0,0 +1,21 @@
|
||||
-- Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- Adds various endpoint overrides for oauth providers
|
||||
ALTER TABLE upstream_oauth_providers
|
||||
ADD COLUMN "jwks_uri_override" TEXT,
|
||||
ADD COLUMN "authorization_endpoint_override" TEXT,
|
||||
ADD COLUMN "token_endpoint_override" TEXT,
|
||||
ADD COLUMN "discovery_mode" TEXT NOT NULL DEFAULT 'oidc',
|
||||
ADD COLUMN "pkce_mode" TEXT NOT NULL DEFAULT 'auto';
|
@@ -103,6 +103,11 @@ pub enum UpstreamOAuthProviders {
|
||||
TokenEndpointAuthMethod,
|
||||
CreatedAt,
|
||||
ClaimsImports,
|
||||
DiscoveryMode,
|
||||
PkceMode,
|
||||
JwksUriOverride,
|
||||
TokenEndpointOverride,
|
||||
AuthorizationEndpointOverride,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
|
@@ -32,7 +32,8 @@ mod tests {
|
||||
clock::MockClock,
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
|
||||
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
|
||||
UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
|
||||
UpstreamOAuthSessionRepository,
|
||||
},
|
||||
user::UserRepository,
|
||||
Pagination, RepositoryAccess,
|
||||
@@ -59,13 +60,21 @@ mod tests {
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
"https://example.com/".to_owned(),
|
||||
Scope::from_iter([OPENID]),
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
"client-id".to_owned(),
|
||||
None,
|
||||
UpstreamOAuthProviderClaimsImports::default(),
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: "https://example.com/".to_owned(),
|
||||
scope: Scope::from_iter([OPENID]),
|
||||
token_endpoint_auth_method:
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id: "client-id".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
token_endpoint_override: None,
|
||||
authorization_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -232,13 +241,21 @@ mod tests {
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
ISSUER.to_owned(),
|
||||
scope.clone(),
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
None,
|
||||
client_id,
|
||||
None,
|
||||
UpstreamOAuthProviderClaimsImports::default(),
|
||||
UpstreamOAuthProviderParams {
|
||||
issuer: ISSUER.to_owned(),
|
||||
scope: scope.clone(),
|
||||
token_endpoint_auth_method:
|
||||
mas_iana::oauth::OAuthClientAuthenticationMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id,
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
token_endpoint_override: None,
|
||||
authorization_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
@@ -14,16 +14,13 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
|
||||
UpstreamOAuthProviderPkceMode,
|
||||
};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
|
||||
upstream_oauth2::{
|
||||
UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
|
||||
},
|
||||
Clock, Page, Pagination,
|
||||
};
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::RngCore;
|
||||
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
@@ -63,6 +60,11 @@ struct ProviderLookup {
|
||||
token_endpoint_auth_method: String,
|
||||
created_at: DateTime<Utc>,
|
||||
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
|
||||
jwks_uri_override: Option<String>,
|
||||
authorization_endpoint_override: Option<String>,
|
||||
token_endpoint_override: Option<String>,
|
||||
discovery_mode: String,
|
||||
pkce_mode: String,
|
||||
}
|
||||
|
||||
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
@@ -92,6 +94,53 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let authorization_endpoint_override = value
|
||||
.authorization_endpoint_override
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("authorization_endpoint_override")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let token_endpoint_override = value
|
||||
.token_endpoint_override
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_override")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let jwks_uri_override = value
|
||||
.jwks_uri_override
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("jwks_uri_override")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let discovery_mode = value.discovery_mode.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("discovery_mode")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let pkce_mode = value.pkce_mode.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("pkce_mode")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer: value.issuer,
|
||||
@@ -102,13 +151,11 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
token_endpoint_signing_alg,
|
||||
created_at: value.created_at,
|
||||
claims_imports: value.claims_imports.0,
|
||||
|
||||
// TODO
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
|
||||
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
jwks_uri_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -139,7 +186,12 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at,
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>"
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
jwks_uri_override,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
discovery_mode,
|
||||
pkce_mode
|
||||
FROM upstream_oauth_providers
|
||||
WHERE upstream_oauth_provider_id = $1
|
||||
"#,
|
||||
@@ -163,23 +215,16 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id,
|
||||
upstream_oauth_provider.issuer = %issuer,
|
||||
upstream_oauth_provider.client_id = %client_id,
|
||||
upstream_oauth_provider.issuer = %params.issuer,
|
||||
upstream_oauth_provider.client_id = %params.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||
params: UpstreamOAuthProviderParams,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
@@ -195,19 +240,39 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
created_at,
|
||||
claims_imports
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
claims_imports,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
jwks_uri_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&issuer,
|
||||
scope.to_string(),
|
||||
token_endpoint_auth_method.to_string(),
|
||||
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
|
||||
&client_id,
|
||||
encrypted_client_secret.as_deref(),
|
||||
¶ms.issuer,
|
||||
params.scope.to_string(),
|
||||
params.token_endpoint_auth_method.to_string(),
|
||||
params
|
||||
.token_endpoint_signing_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
¶ms.client_id,
|
||||
params.encrypted_client_secret.as_deref(),
|
||||
Json(¶ms.claims_imports) as _,
|
||||
params
|
||||
.authorization_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params
|
||||
.token_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params.jwks_uri_override.as_ref().map(ToString::to_string),
|
||||
params.discovery_mode.as_str(),
|
||||
params.pkce_mode.as_str(),
|
||||
created_at,
|
||||
Json(&claims_imports) as _,
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
@@ -215,21 +280,19 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
issuer: params.issuer,
|
||||
scope: params.scope,
|
||||
client_id: params.client_id,
|
||||
encrypted_client_secret: params.encrypted_client_secret,
|
||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||
created_at,
|
||||
claims_imports,
|
||||
|
||||
// TODO
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
|
||||
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
|
||||
claims_imports: params.claims_imports,
|
||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||
token_endpoint_override: params.token_endpoint_override,
|
||||
jwks_uri_override: params.jwks_uri_override,
|
||||
discovery_mode: params.discovery_mode,
|
||||
pkce_mode: params.pkce_mode,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -305,23 +368,16 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
fields(
|
||||
db.statement,
|
||||
upstream_oauth_provider.id = %id,
|
||||
upstream_oauth_provider.issuer = %issuer,
|
||||
upstream_oauth_provider.client_id = %client_id,
|
||||
upstream_oauth_provider.issuer = %params.issuer,
|
||||
upstream_oauth_provider.client_id = %params.client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn upsert(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
id: Ulid,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||
params: UpstreamOAuthProviderParams,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error> {
|
||||
let created_at = clock.now();
|
||||
|
||||
@@ -335,9 +391,15 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
created_at,
|
||||
claims_imports
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
claims_imports,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
jwks_uri_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14)
|
||||
ON CONFLICT (upstream_oauth_provider_id)
|
||||
DO UPDATE
|
||||
SET
|
||||
@@ -347,18 +409,37 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
|
||||
client_id = EXCLUDED.client_id,
|
||||
encrypted_client_secret = EXCLUDED.encrypted_client_secret,
|
||||
claims_imports = EXCLUDED.claims_imports
|
||||
claims_imports = EXCLUDED.claims_imports,
|
||||
authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
|
||||
token_endpoint_override = EXCLUDED.token_endpoint_override,
|
||||
jwks_uri_override = EXCLUDED.jwks_uri_override,
|
||||
discovery_mode = EXCLUDED.discovery_mode,
|
||||
pkce_mode = EXCLUDED.pkce_mode
|
||||
RETURNING created_at
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
&issuer,
|
||||
scope.to_string(),
|
||||
token_endpoint_auth_method.to_string(),
|
||||
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
|
||||
&client_id,
|
||||
encrypted_client_secret.as_deref(),
|
||||
¶ms.issuer,
|
||||
params.scope.to_string(),
|
||||
params.token_endpoint_auth_method.to_string(),
|
||||
params
|
||||
.token_endpoint_signing_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
¶ms.client_id,
|
||||
params.encrypted_client_secret.as_deref(),
|
||||
Json(¶ms.claims_imports) as _,
|
||||
params
|
||||
.authorization_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params
|
||||
.token_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params.jwks_uri_override.as_ref().map(ToString::to_string),
|
||||
params.discovery_mode.as_str(),
|
||||
params.pkce_mode.as_str(),
|
||||
created_at,
|
||||
Json(&claims_imports) as _,
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
@@ -366,21 +447,19 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
issuer,
|
||||
scope,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
issuer: params.issuer,
|
||||
scope: params.scope,
|
||||
client_id: params.client_id,
|
||||
encrypted_client_secret: params.encrypted_client_secret,
|
||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||
created_at,
|
||||
claims_imports,
|
||||
|
||||
// TODO
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
|
||||
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
|
||||
claims_imports: params.claims_imports,
|
||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||
token_endpoint_override: params.token_endpoint_override,
|
||||
jwks_uri_override: params.jwks_uri_override,
|
||||
discovery_mode: params.discovery_mode,
|
||||
pkce_mode: params.pkce_mode,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -459,6 +538,41 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
)),
|
||||
ProviderLookupIden::ClaimsImports,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::JwksUriOverride,
|
||||
)),
|
||||
ProviderLookupIden::JwksUriOverride,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::TokenEndpointOverride,
|
||||
)),
|
||||
ProviderLookupIden::TokenEndpointOverride,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::AuthorizationEndpointOverride,
|
||||
)),
|
||||
ProviderLookupIden::AuthorizationEndpointOverride,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::DiscoveryMode,
|
||||
)),
|
||||
ProviderLookupIden::DiscoveryMode,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::PkceMode,
|
||||
)),
|
||||
ProviderLookupIden::PkceMode,
|
||||
)
|
||||
.from(UpstreamOAuthProviders::Table)
|
||||
.generate_pagination(
|
||||
(
|
||||
@@ -536,7 +650,12 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
created_at,
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>"
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
jwks_uri_override,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
discovery_mode,
|
||||
pkce_mode
|
||||
FROM upstream_oauth_providers
|
||||
"#,
|
||||
)
|
||||
|
@@ -21,6 +21,8 @@ mod session;
|
||||
|
||||
pub use self::{
|
||||
link::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
|
||||
provider::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
|
||||
provider::{
|
||||
UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
|
||||
},
|
||||
session::UpstreamOAuthSessionRepository,
|
||||
};
|
||||
|
@@ -15,14 +15,61 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
|
||||
use mas_data_model::{
|
||||
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
|
||||
UpstreamOAuthProviderPkceMode,
|
||||
};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand_core::RngCore;
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
|
||||
use crate::{pagination::Page, repository_impl, Clock, Pagination};
|
||||
|
||||
/// Structure which holds parameters when inserting or updating an upstream
|
||||
/// OAuth 2.0 provider
|
||||
pub struct UpstreamOAuthProviderParams {
|
||||
/// The OIDC issuer of the provider
|
||||
pub issuer: String,
|
||||
|
||||
/// The scope to request during the authorization flow
|
||||
pub scope: Scope,
|
||||
|
||||
/// The token endpoint authentication method
|
||||
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
|
||||
/// The JWT signing algorithm to use when then `client_secret_jwt` or
|
||||
/// `private_key_jwt` authentication methods are used
|
||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
/// The client ID to use when authenticating to the upstream
|
||||
pub client_id: String,
|
||||
|
||||
/// The encrypted client secret to use when authenticating to the upstream
|
||||
pub encrypted_client_secret: Option<String>,
|
||||
|
||||
/// How claims should be imported from the upstream provider
|
||||
pub claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||
|
||||
/// The URL to use as the authorization endpoint. If `None`, the URL will be
|
||||
/// discovered
|
||||
pub authorization_endpoint_override: Option<Url>,
|
||||
|
||||
/// The URL to use as the token endpoint. If `None`, the URL will be
|
||||
/// discovered
|
||||
pub token_endpoint_override: Option<Url>,
|
||||
|
||||
/// The URL to use when fetching JWKS. If `None`, the URL will be discovered
|
||||
pub jwks_uri_override: Option<Url>,
|
||||
|
||||
/// How the provider metadata should be discovered
|
||||
pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
|
||||
|
||||
/// How should PKCE be used
|
||||
pub pkce_mode: UpstreamOAuthProviderPkceMode,
|
||||
}
|
||||
|
||||
/// Filter parameters for listing upstream OAuth 2.0 providers
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
|
||||
pub struct UpstreamOAuthProviderFilter<'a> {
|
||||
@@ -65,33 +112,16 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
///
|
||||
/// * `rng`: A random number generator
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
/// * `issuer`: The OIDC issuer of the provider
|
||||
/// * `scope`: The scope to request during the authorization flow
|
||||
/// * `token_endpoint_auth_method`: The token endpoint authentication method
|
||||
/// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use
|
||||
/// when then `client_secret_jwt` or `private_key_jwt` authentication
|
||||
/// methods are used
|
||||
/// * `client_id`: The client ID to use when authenticating to the upstream
|
||||
/// * `encrypted_client_secret`: The encrypted client secret to use when
|
||||
/// authenticating to the upstream
|
||||
/// * `claims_imports`: How claims should be imported from the upstream
|
||||
/// provider
|
||||
/// * `params`: The parameters of the provider to add
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||
params: UpstreamOAuthProviderParams,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
/// Delete an upstream OAuth provider
|
||||
@@ -124,33 +154,16 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
||||
///
|
||||
/// * `clock`: The clock used to generate timestamps
|
||||
/// * `id`: The ID of the provider to update
|
||||
/// * `issuer`: The OIDC issuer of the provider
|
||||
/// * `scope`: The scope to request during the authorization flow
|
||||
/// * `token_endpoint_auth_method`: The token endpoint authentication method
|
||||
/// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use
|
||||
/// when then `client_secret_jwt` or `private_key_jwt` authentication
|
||||
/// methods are used
|
||||
/// * `client_id`: The client ID to use when authenticating to the upstream
|
||||
/// * `encrypted_client_secret`: The encrypted client secret to use when
|
||||
/// authenticating to the upstream
|
||||
/// * `claims_imports`: How claims should be imported from the upstream
|
||||
/// provider
|
||||
/// * `params`: The parameters of the provider to update
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Self::Error`] if the underlying repository fails
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn upsert(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
id: Ulid,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||
params: UpstreamOAuthProviderParams,
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
/// List [`UpstreamOAuthProvider`] with the given filter and pagination
|
||||
@@ -198,26 +211,14 @@ repository_impl!(UpstreamOAuthProviderRepository:
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports
|
||||
params: UpstreamOAuthProviderParams
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
async fn upsert(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
id: Ulid,
|
||||
issuer: String,
|
||||
scope: Scope,
|
||||
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports,
|
||||
params: UpstreamOAuthProviderParams
|
||||
) -> Result<UpstreamOAuthProvider, Self::Error>;
|
||||
|
||||
async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
|
||||
|
@@ -633,6 +633,32 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"DiscoveryMode": {
|
||||
"description": "How to discover the provider's configuration",
|
||||
"oneOf": [
|
||||
{
|
||||
"description": "Use OIDC discovery with strict metadata verification",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"oidc"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Use OIDC discovery with relaxed metadata verification",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"insecure"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Use a static configuration",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"disabled"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"DisplaynameImportPreference": {
|
||||
"description": "What should be done for the displayname attribute",
|
||||
"type": "object",
|
||||
@@ -1520,6 +1546,32 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"PkceMethod": {
|
||||
"description": "Whether to use proof key for code exchange (PKCE) when requesting and exchanging the token.",
|
||||
"oneOf": [
|
||||
{
|
||||
"description": "Use PKCE if the provider supports it\n\nDefaults to no PKCE if provider discovery is disabled",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Always use PKCE with the S256 challenge method",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"always"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Never use PKCE",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"never"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"PolicyConfig": {
|
||||
"description": "Application secrets",
|
||||
"type": "object",
|
||||
@@ -1706,6 +1758,11 @@
|
||||
"scope"
|
||||
],
|
||||
"properties": {
|
||||
"authorization_endpoint": {
|
||||
"description": "The URL to use for the provider's authorization endpoint\n\nDefaults to the `authorization_endpoint` provided through discovery",
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
},
|
||||
"claims_imports": {
|
||||
"description": "How claims should be imported from the `id_token` provided by the provider",
|
||||
"allOf": [
|
||||
@@ -1718,6 +1775,15 @@
|
||||
"description": "The client ID to use when authenticating with the provider",
|
||||
"type": "string"
|
||||
},
|
||||
"discovery_mode": {
|
||||
"description": "How to discover the provider's configuration\n\nDefaults to use OIDC discovery with strict metadata verification",
|
||||
"default": "oidc",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/DiscoveryMode"
|
||||
}
|
||||
]
|
||||
},
|
||||
"id": {
|
||||
"description": "A ULID as per https://github.com/ulid/spec",
|
||||
"type": "string",
|
||||
@@ -1727,9 +1793,28 @@
|
||||
"description": "The OIDC issuer URL",
|
||||
"type": "string"
|
||||
},
|
||||
"jwks_uri": {
|
||||
"description": "The URL to use for getting the provider's public keys\n\nDefaults to the `jwks_uri` provided through discovery",
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
},
|
||||
"pkce_method": {
|
||||
"description": "Whether to use proof key for code exchange (PKCE) when requesting and exchanging the token.\n\nDefaults to `auto`, which uses PKCE if the provider supports it.",
|
||||
"default": "auto",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/PkceMethod"
|
||||
}
|
||||
]
|
||||
},
|
||||
"scope": {
|
||||
"description": "The scopes to request from the provider",
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint": {
|
||||
"description": "The URL to use for the provider's token endpoint\n\nDefaults to the `token_endpoint` provided through discovery",
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
Reference in New Issue
Block a user