1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Allow overriding usptream OAuth2 providers endpoints

Also have a way to disable OIDC discovery when all the endpoints are known.
This commit is contained in:
Quentin Gliech
2023-11-16 11:42:23 +01:00
parent 08d46a79a4
commit 364093f12f
8 changed files with 611 additions and 33 deletions

View File

@ -48,8 +48,9 @@ pub use self::{
upstream_oauth2::{ upstream_oauth2::{
UpsreamOAuthProviderSetEmailVerification, UpstreamOAuthAuthorizationSession, UpsreamOAuthProviderSetEmailVerification, UpstreamOAuthAuthorizationSession,
UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportAction, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
}, },
users::{ users::{
Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail, Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail,

View File

@ -20,8 +20,10 @@ pub use self::{
link::UpstreamOAuthLink, link::UpstreamOAuthLink,
provider::{ provider::{
ClaimsImports as UpstreamOAuthProviderClaimsImports, ClaimsImports as UpstreamOAuthProviderClaimsImports,
DiscoveryMode as UpstreamOAuthProviderDiscoveryMode,
ImportAction as UpstreamOAuthProviderImportAction, ImportAction as UpstreamOAuthProviderImportAction,
ImportPreference as UpstreamOAuthProviderImportPreference, ImportPreference as UpstreamOAuthProviderImportPreference,
PkceMode as UpstreamOAuthProviderPkceMode,
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification, SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider, SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
}, },

View File

@ -17,11 +17,45 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ulid::Ulid; use ulid::Ulid;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum DiscoveryMode {
/// Use OIDC discovery to fetch and verify the provider metadata
#[default]
Oidc,
/// Use OIDC discovery to fetch the provider metadata, but don't verify it
Insecure,
/// Don't fetch the provider metadata
Disabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PkceMode {
/// Use PKCE if the provider supports it
#[default]
Auto,
/// Always use PKCE with the S256 method
S256,
/// Don't use PKCE
Disabled,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider { pub struct UpstreamOAuthProvider {
pub id: Ulid, pub id: Ulid,
pub issuer: String, pub issuer: String,
pub discovery_mode: DiscoveryMode,
pub pkce_mode: PkceMode,
pub jwks_uri_override: Option<Url>,
pub authorization_endpoint_override: Option<Url>,
pub token_endpoint_override: Option<Url>,
pub scope: Scope, pub scope: Scope,
pub client_id: String, pub client_id: String,
pub encrypted_client_secret: Option<String>, pub encrypted_client_secret: Option<String>,

View File

@ -29,7 +29,7 @@ use mas_storage::{
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::UpstreamSessionsCookie; use super::{cache::LazyProviderInfos, UpstreamSessionsCookie};
use crate::{ use crate::{
impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
views::shared::OptionalPostAuthAction, views::shared::OptionalPostAuthAction,
@ -87,23 +87,28 @@ pub(crate) async fn get(
let http_service = http_client_factory.http_service("upstream_oauth2.authorize"); let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
// First, discover the provider // First, discover the provider
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; // This is done lazyly according to provider.discovery_mode and the various
// endpoint overrides
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
lazy_metadata.maybe_discover().await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
let mut data = AuthorizationRequestData::new( let data = AuthorizationRequestData::new(
provider.client_id.clone(), provider.client_id.clone(),
provider.scope.clone(), provider.scope.clone(),
redirect_uri, redirect_uri,
); );
if let Some(methods) = metadata.code_challenge_methods_supported.clone() { let data = if let Some(methods) = lazy_metadata.pkce_methods().await? {
data = data.with_code_challenge_methods_supported(methods); data.with_code_challenge_methods_supported(methods)
} } else {
data
};
// Build an authorization request for it // Build an authorization request for it
let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
metadata.authorization_endpoint().clone(), lazy_metadata.authorization_endpoint().await?.clone(),
data, data,
&mut rng, &mut rng,
)?; )?;

View File

@ -14,11 +14,128 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
};
use mas_http::HttpService; use mas_http::HttpService;
use mas_iana::oauth::PkceCodeChallengeMethod;
use mas_oidc_client::error::DiscoveryError; use mas_oidc_client::error::DiscoveryError;
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess}; use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
use oauth2_types::oidc::VerifiedProviderMetadata; use oauth2_types::oidc::VerifiedProviderMetadata;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use url::Url;
/// A high-level layer over metadata cache and provider configuration, which
/// resolves endpoint overrides and discovery modes.
pub struct LazyProviderInfos<'a> {
cache: &'a MetadataCache,
provider: &'a UpstreamOAuthProvider,
http_service: &'a HttpService,
loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
}
impl<'a> LazyProviderInfos<'a> {
pub fn new(
cache: &'a MetadataCache,
provider: &'a UpstreamOAuthProvider,
http_service: &'a HttpService,
) -> Self {
Self {
cache,
provider,
http_service,
loaded_metadata: None,
}
}
/// Trigger the discovery process and return the metadata if discovery is
/// enabled.
pub async fn maybe_discover<'b>(
&'b mut self,
) -> Result<Option<&'b VerifiedProviderMetadata>, DiscoveryError> {
match self.load().await {
Ok(metadata) => Ok(Some(metadata)),
Err(DiscoveryError::Disabled) => Ok(None),
Err(e) => Err(e),
}
}
async fn load<'b>(&'b mut self) -> Result<&'b VerifiedProviderMetadata, DiscoveryError> {
if self.loaded_metadata.is_none() {
let verify = match self.provider.discovery_mode {
UpstreamOAuthProviderDiscoveryMode::Oidc => true,
UpstreamOAuthProviderDiscoveryMode::Insecure => false,
UpstreamOAuthProviderDiscoveryMode::Disabled => {
return Err(DiscoveryError::Disabled)
}
};
let metadata = self
.cache
.get(self.http_service, &self.provider.issuer, verify)
.await?;
self.loaded_metadata = Some(metadata);
}
Ok(self.loaded_metadata.as_ref().unwrap())
}
/// Get the JWKS URI for the provider.
///
/// Uses [`UpstreamOAuthProvider.jwks_uri_override`] if set, otherwise uses
/// the one from discovery.
pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(jwks_uri) = &self.provider.jwks_uri_override {
return Ok(jwks_uri);
}
Ok(self.load().await?.jwks_uri())
}
/// Get the authorization endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.authorization_endpoint_override`] if set,
/// otherwise uses the one from discovery.
pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
return Ok(authorization_endpoint);
}
Ok(self.load().await?.authorization_endpoint())
}
/// Get the token endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.token_endpoint_override`] if set, otherwise
/// uses the one from discovery.
pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(token_endpoint) = &self.provider.token_endpoint_override {
return Ok(token_endpoint);
}
Ok(self.load().await?.token_endpoint())
}
/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
/// defaulting to none if discovery is disabled.
pub async fn pkce_methods(
&mut self,
) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
let methods = match self.provider.pkce_mode {
UpstreamOAuthProviderPkceMode::Auto => self
.maybe_discover()
.await?
.and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
UpstreamOAuthProviderPkceMode::Disabled => None,
};
Ok(methods)
}
}
/// A simple OIDC metadata cache /// A simple OIDC metadata cache
/// ///
@ -28,7 +145,8 @@ use tokio::sync::RwLock;
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct MetadataCache { pub struct MetadataCache {
cache: Arc<RwLock<HashMap<String, VerifiedProviderMetadata>>>, cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
} }
impl MetadataCache { impl MetadataCache {
@ -52,7 +170,13 @@ impl MetadataCache {
let providers = repository.upstream_oauth_provider().all().await?; let providers = repository.upstream_oauth_provider().all().await?;
for provider in providers { for provider in providers {
if let Err(e) = self.fetch(&http_service, &provider.issuer).await { let verify = match provider.discovery_mode {
UpstreamOAuthProviderDiscoveryMode::Oidc => true,
UpstreamOAuthProviderDiscoveryMode::Insecure => false,
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
};
if let Err(e) = self.fetch(&http_service, &provider.issuer, verify).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata"); tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
} }
} }
@ -73,15 +197,32 @@ impl MetadataCache {
&self, &self,
http_service: &HttpService, http_service: &HttpService,
issuer: &str, issuer: &str,
) -> Result<VerifiedProviderMetadata, DiscoveryError> { verify: bool,
let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?; ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
if verify {
let metadata =
mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
let metadata = Arc::new(metadata);
self.cache self.cache
.write() .write()
.await .await
.insert(issuer.to_owned(), metadata.clone()); .insert(issuer.to_owned(), metadata.clone());
Ok(metadata) Ok(metadata)
} else {
let metadata =
mas_oidc_client::requests::discovery::insecure_discover(http_service, issuer)
.await?;
let metadata = Arc::new(metadata);
self.insecure_cache
.write()
.await
.insert(issuer.to_owned(), metadata.clone());
Ok(metadata)
}
} }
/// Get the metadata for the given issuer. /// Get the metadata for the given issuer.
@ -90,13 +231,21 @@ impl MetadataCache {
&self, &self,
http_service: &HttpService, http_service: &HttpService,
issuer: &str, issuer: &str,
) -> Result<VerifiedProviderMetadata, DiscoveryError> { verify: bool,
let cache = self.cache.read().await; ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
if let Some(metadata) = cache.get(issuer) { let cache = if verify {
return Ok(metadata.clone()); self.cache.read().await
} } else {
self.insecure_cache.read().await
};
let metadata = self.fetch(http_service, issuer).await?; if let Some(metadata) = cache.get(issuer) {
return Ok(Arc::clone(metadata));
}
// Drop the cache guard so that we don't deadlock when we try to fetch
drop(cache);
let metadata = self.fetch(http_service, issuer, verify).await?;
Ok(metadata) Ok(metadata)
} }
@ -109,9 +258,369 @@ impl MetadataCache {
}; };
for issuer in keys { for issuer in keys {
if let Err(e) = self.fetch(http_service, &issuer).await { if let Err(e) = self.fetch(http_service, &issuer, true).await {
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
}
}
// Do the same for the insecure cache
let keys: Vec<String> = {
let cache = self.insecure_cache.read().await;
cache.keys().cloned().collect()
};
for issuer in keys {
if let Err(e) = self.fetch(http_service, &issuer, false).await {
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata"); tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
} }
} }
} }
} }
#[cfg(test)]
mod tests {
#![allow(clippy::too_many_lines)]
use std::sync::atomic::{AtomicUsize, Ordering};
use hyper::{body::Bytes, Request, Response, StatusCode};
use mas_data_model::UpstreamOAuthProviderClaimsImports;
use mas_http::BoxCloneSyncService;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_storage::{clock::MockClock, Clock};
use oauth2_types::scope::{Scope, OPENID};
use tower::BoxError;
use ulid::Ulid;
use crate::test_utils::init_tracing;
use super::*;
#[tokio::test]
async fn test_metadata_cache() {
init_tracing();
let calls = Arc::new(AtomicUsize::new(0));
let closure_calls = Arc::clone(&calls);
let handler = move |req: Request<Bytes>| {
let calls = Arc::clone(&closure_calls);
async move {
calls.fetch_add(1, Ordering::SeqCst);
let body = match req.uri().authority().unwrap().as_str() {
"valid.example.com" => Bytes::from_static(
br#"{
"issuer": "https://valid.example.com/",
"authorization_endpoint": "https://valid.example.com/authorize",
"token_endpoint": "https://valid.example.com/token",
"jwks_uri": "https://valid.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
"insecure.example.com" => Bytes::from_static(
br#"{
"issuer": "http://insecure.example.com/",
"authorization_endpoint": "http://insecure.example.com/authorize",
"token_endpoint": "http://insecure.example.com/token",
"jwks_uri": "http://insecure.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
_ => Bytes::default(),
};
let mut response = Response::new(body);
*response.status_mut() = StatusCode::OK;
Ok::<_, BoxError>(response)
}
};
let service = BoxCloneSyncService::new(tower::service_fn(handler));
let cache = MetadataCache::new();
// An inexistant issuer should fail
cache
.get(&service, "https://inexistant.example.com/", true)
.await
.unwrap_err();
assert_eq!(calls.load(Ordering::SeqCst), 1);
// A valid issuer should succeed
cache
.get(&service, "https://valid.example.com/", true)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
// Calling again should not trigger a new fetch
cache
.get(&service, "https://valid.example.com/", true)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
// An insecure issuer should work with insecure discovery
cache
.get(&service, "http://insecure.example.com/", false)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 3);
// Doing it again shpoild not trigger a new fetch
cache
.get(&service, "http://insecure.example.com/", false)
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 3);
// But it should fail with secure discovery
// Note that it still fetched because secure and insecure caches are distinct
cache
.get(&service, "http://insecure.example.com/", true)
.await
.unwrap_err();
assert_eq!(calls.load(Ordering::SeqCst), 4);
// Calling refresh should refresh all the known valid issuers
cache.refresh_all(&service).await;
assert_eq!(calls.load(Ordering::SeqCst), 6);
}
#[tokio::test]
async fn test_lazy_provider_infos() {
init_tracing();
let calls = Arc::new(AtomicUsize::new(0));
let closure_calls = Arc::clone(&calls);
let handler = move |req: Request<Bytes>| {
let calls = Arc::clone(&closure_calls);
async move {
calls.fetch_add(1, Ordering::SeqCst);
let body = match req.uri().authority().unwrap().as_str() {
"valid.example.com" => Bytes::from_static(
br#"{
"issuer": "https://valid.example.com/",
"authorization_endpoint": "https://valid.example.com/authorize",
"token_endpoint": "https://valid.example.com/token",
"jwks_uri": "https://valid.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
"insecure.example.com" => Bytes::from_static(
br#"{
"issuer": "http://insecure.example.com/",
"authorization_endpoint": "http://insecure.example.com/authorize",
"token_endpoint": "http://insecure.example.com/token",
"jwks_uri": "http://insecure.example.com/jwks",
"response_types_supported": [
"code"
],
"grant_types_supported": [
"authorization_code"
],
"subject_types_supported": [
"public"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"scopes_supported": [
"openid",
"profile",
"email"
]
}"#,
),
_ => Bytes::default(),
};
let mut response = Response::new(body);
*response.status_mut() = StatusCode::OK;
Ok::<_, BoxError>(response)
}
};
let clock = MockClock::default();
let service = BoxCloneSyncService::new(tower::service_fn(handler));
let provider = UpstreamOAuthProvider {
id: Ulid::nil(),
issuer: "https://valid.example.com/".to_owned(),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
token_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
client_id: "client_id".to_owned(),
encrypted_client_secret: None,
token_endpoint_signing_alg: None,
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
created_at: clock.now(),
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
};
// Without any override, it should just use discovery
{
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert_eq!(calls.load(Ordering::SeqCst), 0);
lazy_metadata.maybe_discover().await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"https://valid.example.com/authorize"
);
}
// Test overriding endpoints
{
let provider = UpstreamOAuthProvider {
jwks_uri_override: Some("https://valid.example.com/jwks_override".parse().unwrap()),
authorization_endpoint_override: Some(
"https://valid.example.com/authorize_override"
.parse()
.unwrap(),
),
token_endpoint_override: Some(
"https://valid.example.com/token_override".parse().unwrap(),
),
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert_eq!(
lazy_metadata.jwks_uri().await.unwrap().as_str(),
"https://valid.example.com/jwks_override"
);
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"https://valid.example.com/authorize_override"
);
assert_eq!(
lazy_metadata.token_endpoint().await.unwrap().as_str(),
"https://valid.example.com/token_override"
);
// This shouldn't trigger a new fetch as the endpoint is overriden
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
// Insecure providers don't work with secure discovery
{
let provider = UpstreamOAuthProvider {
issuer: "http://insecure.example.com/".to_owned(),
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
lazy_metadata.authorization_endpoint().await.unwrap_err();
// This triggered a fetch, even though it failed
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
// Insecure providers work with insecure discovery
{
let provider = UpstreamOAuthProvider {
issuer: "http://insecure.example.com/".to_owned(),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"http://insecure.example.com/authorize"
);
// This triggered a fetch
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
// Getting endpoints when discovery is disabled only works for overriden ones
{
let provider = UpstreamOAuthProvider {
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
authorization_endpoint_override: Some(
Url::parse("https://valid.example.com/authorize_override").unwrap(),
),
token_endpoint_override: None,
..provider.clone()
};
let cache = MetadataCache::new();
let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service);
// This should not fail, but also does nothing
assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
assert_eq!(
lazy_metadata
.authorization_endpoint()
.await
.unwrap()
.as_str(),
"https://valid.example.com/authorize_override"
);
assert!(matches!(
lazy_metadata.token_endpoint().await,
Err(DiscoveryError::Disabled),
));
// This did not trigger a fetch
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
}
}

View File

@ -37,7 +37,7 @@ use serde::Deserialize;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::{client_credentials_for_provider, UpstreamSessionsCookie}; use super::{cache::LazyProviderInfos, client_credentials_for_provider, UpstreamSessionsCookie};
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache}; use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
#[derive(Deserialize)] #[derive(Deserialize)]
@ -191,18 +191,17 @@ pub(crate) async fn get(
}; };
let http_service = http_client_factory.http_service("upstream_oauth2.callback"); let http_service = http_client_factory.http_service("upstream_oauth2.callback");
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
// Discover the provider
let metadata = metadata_cache.get(&http_service, &provider.issuer).await?;
// Fetch the JWKS // Fetch the JWKS
let jwks = let jwks =
mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?; mas_oidc_client::requests::jose::fetch_jwks(&http_service, lazy_metadata.jwks_uri().await?)
.await?;
// Figure out the client credentials // Figure out the client credentials
let client_credentials = client_credentials_for_provider( let client_credentials = client_credentials_for_provider(
&provider, &provider,
metadata.token_endpoint(), lazy_metadata.token_endpoint().await?,
&keystore, &keystore,
&encrypter, &encrypter,
)?; )?;
@ -229,7 +228,7 @@ pub(crate) async fn get(
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&http_service, &http_service,
client_credentials, client_credentials,
metadata.token_endpoint(), lazy_metadata.token_endpoint().await?,
code, code,
validation_data, validation_data,
Some(id_token_verification_data), Some(id_token_verification_data),

View File

@ -93,6 +93,10 @@ pub enum DiscoveryError {
/// An error occurred sending the request. /// An error occurred sending the request.
#[error(transparent)] #[error(transparent)]
Service(BoxError), Service(BoxError),
/// Discovery is disabled for this provider.
#[error("Discovery is disabled for this provider")]
Disabled,
} }
impl<S> From<json_response::Error<S>> for DiscoveryError impl<S> From<json_response::Error<S>> for DiscoveryError

View File

@ -14,7 +14,10 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports}; use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderPkceMode,
};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_storage::{ use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
@ -99,6 +102,13 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
token_endpoint_signing_alg, token_endpoint_signing_alg,
created_at: value.created_at, created_at: value.created_at,
claims_imports: value.claims_imports.0, claims_imports: value.claims_imports.0,
// TODO
authorization_endpoint_override: None,
token_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
}) })
} }
} }
@ -213,6 +223,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_auth_method, token_endpoint_auth_method,
created_at, created_at,
claims_imports, claims_imports,
// TODO
authorization_endpoint_override: None,
token_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
}) })
} }
@ -357,6 +374,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
token_endpoint_auth_method, token_endpoint_auth_method,
created_at, created_at,
claims_imports, claims_imports,
// TODO
authorization_endpoint_override: None,
token_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(),
pkce_mode: UpstreamOAuthProviderPkceMode::default(),
}) })
} }