diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 325d1eed..81e918bf 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -48,8 +48,9 @@ pub use self::{ upstream_oauth2::{ UpsreamOAuthProviderSetEmailVerification, UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, - UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportAction, - UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderSubjectPreference, + UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, + UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference, + UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference, }, users::{ Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail, diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index 1e5d9f9b..338f099a 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -20,8 +20,10 @@ pub use self::{ link::UpstreamOAuthLink, provider::{ ClaimsImports as UpstreamOAuthProviderClaimsImports, + DiscoveryMode as UpstreamOAuthProviderDiscoveryMode, ImportAction as UpstreamOAuthProviderImportAction, ImportPreference as UpstreamOAuthProviderImportPreference, + PkceMode as UpstreamOAuthProviderPkceMode, SetEmailVerification as UpsreamOAuthProviderSetEmailVerification, SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider, }, diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 3c905696..79827953 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -17,11 +17,45 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use oauth2_types::scope::Scope; use serde::{Deserialize, Serialize}; use ulid::Ulid; +use url::Url; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum DiscoveryMode { + /// Use OIDC discovery to fetch and verify the provider metadata + #[default] + Oidc, + + /// Use OIDC discovery to fetch the provider metadata, but don't verify it + Insecure, + + /// Don't fetch the provider metadata + Disabled, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum PkceMode { + /// Use PKCE if the provider supports it + #[default] + Auto, + + /// Always use PKCE with the S256 method + S256, + + /// Don't use PKCE + Disabled, +} #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UpstreamOAuthProvider { pub id: Ulid, pub issuer: String, + pub discovery_mode: DiscoveryMode, + pub pkce_mode: PkceMode, + pub jwks_uri_override: Option, + pub authorization_endpoint_override: Option, + pub token_endpoint_override: Option, pub scope: Scope, pub client_id: String, pub encrypted_client_secret: Option, diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 580afcfe..54937e03 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -29,7 +29,7 @@ use mas_storage::{ use thiserror::Error; use ulid::Ulid; -use super::UpstreamSessionsCookie; +use super::{cache::LazyProviderInfos, UpstreamSessionsCookie}; use crate::{ impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, views::shared::OptionalPostAuthAction, @@ -87,23 +87,28 @@ pub(crate) async fn get( let http_service = http_client_factory.http_service("upstream_oauth2.authorize"); // First, discover the provider - let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; + // This is done lazyly according to provider.discovery_mode and the various + // endpoint overrides + let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service); + lazy_metadata.maybe_discover().await?; let redirect_uri = url_builder.upstream_oauth_callback(provider.id); - let mut data = AuthorizationRequestData::new( + let data = AuthorizationRequestData::new( provider.client_id.clone(), provider.scope.clone(), redirect_uri, ); - if let Some(methods) = metadata.code_challenge_methods_supported.clone() { - data = data.with_code_challenge_methods_supported(methods); - } + let data = if let Some(methods) = lazy_metadata.pkce_methods().await? { + data.with_code_challenge_methods_supported(methods) + } else { + data + }; // Build an authorization request for it let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( - metadata.authorization_endpoint().clone(), + lazy_metadata.authorization_endpoint().await?.clone(), data, &mut rng, )?; diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs index cdb40fa8..be14ead2 100644 --- a/crates/handlers/src/upstream_oauth2/cache.rs +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -14,11 +14,128 @@ use std::{collections::HashMap, sync::Arc}; +use mas_data_model::{ + UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode, +}; use mas_http::HttpService; +use mas_iana::oauth::PkceCodeChallengeMethod; use mas_oidc_client::error::DiscoveryError; use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess}; use oauth2_types::oidc::VerifiedProviderMetadata; use tokio::sync::RwLock; +use url::Url; + +/// A high-level layer over metadata cache and provider configuration, which +/// resolves endpoint overrides and discovery modes. +pub struct LazyProviderInfos<'a> { + cache: &'a MetadataCache, + provider: &'a UpstreamOAuthProvider, + http_service: &'a HttpService, + loaded_metadata: Option>, +} + +impl<'a> LazyProviderInfos<'a> { + pub fn new( + cache: &'a MetadataCache, + provider: &'a UpstreamOAuthProvider, + http_service: &'a HttpService, + ) -> Self { + Self { + cache, + provider, + http_service, + loaded_metadata: None, + } + } + + /// Trigger the discovery process and return the metadata if discovery is + /// enabled. + pub async fn maybe_discover<'b>( + &'b mut self, + ) -> Result, DiscoveryError> { + match self.load().await { + Ok(metadata) => Ok(Some(metadata)), + Err(DiscoveryError::Disabled) => Ok(None), + Err(e) => Err(e), + } + } + + async fn load<'b>(&'b mut self) -> Result<&'b VerifiedProviderMetadata, DiscoveryError> { + if self.loaded_metadata.is_none() { + let verify = match self.provider.discovery_mode { + UpstreamOAuthProviderDiscoveryMode::Oidc => true, + UpstreamOAuthProviderDiscoveryMode::Insecure => false, + UpstreamOAuthProviderDiscoveryMode::Disabled => { + return Err(DiscoveryError::Disabled) + } + }; + + let metadata = self + .cache + .get(self.http_service, &self.provider.issuer, verify) + .await?; + + self.loaded_metadata = Some(metadata); + } + + Ok(self.loaded_metadata.as_ref().unwrap()) + } + + /// Get the JWKS URI for the provider. + /// + /// Uses [`UpstreamOAuthProvider.jwks_uri_override`] if set, otherwise uses + /// the one from discovery. + pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> { + if let Some(jwks_uri) = &self.provider.jwks_uri_override { + return Ok(jwks_uri); + } + + Ok(self.load().await?.jwks_uri()) + } + + /// Get the authorization endpoint for the provider. + /// + /// Uses [`UpstreamOAuthProvider.authorization_endpoint_override`] if set, + /// otherwise uses the one from discovery. + pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> { + if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override { + return Ok(authorization_endpoint); + } + + Ok(self.load().await?.authorization_endpoint()) + } + + /// Get the token endpoint for the provider. + /// + /// Uses [`UpstreamOAuthProvider.token_endpoint_override`] if set, otherwise + /// uses the one from discovery. + pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> { + if let Some(token_endpoint) = &self.provider.token_endpoint_override { + return Ok(token_endpoint); + } + + Ok(self.load().await?.token_endpoint()) + } + + /// Get the PKCE methods supported by the provider. + /// + /// If the mode is set to auto, it will use the ones from discovery, + /// defaulting to none if discovery is disabled. + pub async fn pkce_methods( + &mut self, + ) -> Result>, DiscoveryError> { + let methods = match self.provider.pkce_mode { + UpstreamOAuthProviderPkceMode::Auto => self + .maybe_discover() + .await? + .and_then(|metadata| metadata.code_challenge_methods_supported.clone()), + UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]), + UpstreamOAuthProviderPkceMode::Disabled => None, + }; + + Ok(methods) + } +} /// A simple OIDC metadata cache /// @@ -28,7 +145,8 @@ use tokio::sync::RwLock; #[allow(clippy::module_name_repetitions)] #[derive(Debug, Clone, Default)] pub struct MetadataCache { - cache: Arc>>, + cache: Arc>>>, + insecure_cache: Arc>>>, } impl MetadataCache { @@ -52,7 +170,13 @@ impl MetadataCache { let providers = repository.upstream_oauth_provider().all().await?; for provider in providers { - if let Err(e) = self.fetch(&http_service, &provider.issuer).await { + let verify = match provider.discovery_mode { + UpstreamOAuthProviderDiscoveryMode::Oidc => true, + UpstreamOAuthProviderDiscoveryMode::Insecure => false, + UpstreamOAuthProviderDiscoveryMode::Disabled => continue, + }; + + if let Err(e) = self.fetch(&http_service, &provider.issuer, verify).await { tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata"); } } @@ -73,15 +197,32 @@ impl MetadataCache { &self, http_service: &HttpService, issuer: &str, - ) -> Result { - let metadata = mas_oidc_client::requests::discovery::discover(http_service, issuer).await?; + verify: bool, + ) -> Result, DiscoveryError> { + if verify { + let metadata = + mas_oidc_client::requests::discovery::discover(http_service, issuer).await?; + let metadata = Arc::new(metadata); - self.cache - .write() - .await - .insert(issuer.to_owned(), metadata.clone()); + self.cache + .write() + .await + .insert(issuer.to_owned(), metadata.clone()); - Ok(metadata) + Ok(metadata) + } else { + let metadata = + mas_oidc_client::requests::discovery::insecure_discover(http_service, issuer) + .await?; + let metadata = Arc::new(metadata); + + self.insecure_cache + .write() + .await + .insert(issuer.to_owned(), metadata.clone()); + + Ok(metadata) + } } /// Get the metadata for the given issuer. @@ -90,13 +231,21 @@ impl MetadataCache { &self, http_service: &HttpService, issuer: &str, - ) -> Result { - let cache = self.cache.read().await; - if let Some(metadata) = cache.get(issuer) { - return Ok(metadata.clone()); - } + verify: bool, + ) -> Result, DiscoveryError> { + let cache = if verify { + self.cache.read().await + } else { + self.insecure_cache.read().await + }; - let metadata = self.fetch(http_service, issuer).await?; + if let Some(metadata) = cache.get(issuer) { + return Ok(Arc::clone(metadata)); + } + // Drop the cache guard so that we don't deadlock when we try to fetch + drop(cache); + + let metadata = self.fetch(http_service, issuer, verify).await?; Ok(metadata) } @@ -109,9 +258,369 @@ impl MetadataCache { }; for issuer in keys { - if let Err(e) = self.fetch(http_service, &issuer).await { + if let Err(e) = self.fetch(http_service, &issuer, true).await { + tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata"); + } + } + + // Do the same for the insecure cache + let keys: Vec = { + let cache = self.insecure_cache.read().await; + cache.keys().cloned().collect() + }; + + for issuer in keys { + if let Err(e) = self.fetch(http_service, &issuer, false).await { tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata"); } } } } + +#[cfg(test)] +mod tests { + #![allow(clippy::too_many_lines)] + + use std::sync::atomic::{AtomicUsize, Ordering}; + + use hyper::{body::Bytes, Request, Response, StatusCode}; + use mas_data_model::UpstreamOAuthProviderClaimsImports; + use mas_http::BoxCloneSyncService; + use mas_iana::oauth::OAuthClientAuthenticationMethod; + use mas_storage::{clock::MockClock, Clock}; + use oauth2_types::scope::{Scope, OPENID}; + use tower::BoxError; + use ulid::Ulid; + + use crate::test_utils::init_tracing; + + use super::*; + + #[tokio::test] + async fn test_metadata_cache() { + init_tracing(); + let calls = Arc::new(AtomicUsize::new(0)); + let closure_calls = Arc::clone(&calls); + let handler = move |req: Request| { + let calls = Arc::clone(&closure_calls); + async move { + calls.fetch_add(1, Ordering::SeqCst); + + let body = match req.uri().authority().unwrap().as_str() { + "valid.example.com" => Bytes::from_static( + br#"{ + "issuer": "https://valid.example.com/", + "authorization_endpoint": "https://valid.example.com/authorize", + "token_endpoint": "https://valid.example.com/token", + "jwks_uri": "https://valid.example.com/jwks", + "response_types_supported": [ + "code" + ], + "grant_types_supported": [ + "authorization_code" + ], + "subject_types_supported": [ + "public" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "scopes_supported": [ + "openid", + "profile", + "email" + ] + }"#, + ), + "insecure.example.com" => Bytes::from_static( + br#"{ + "issuer": "http://insecure.example.com/", + "authorization_endpoint": "http://insecure.example.com/authorize", + "token_endpoint": "http://insecure.example.com/token", + "jwks_uri": "http://insecure.example.com/jwks", + "response_types_supported": [ + "code" + ], + "grant_types_supported": [ + "authorization_code" + ], + "subject_types_supported": [ + "public" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "scopes_supported": [ + "openid", + "profile", + "email" + ] + }"#, + ), + _ => Bytes::default(), + }; + + let mut response = Response::new(body); + *response.status_mut() = StatusCode::OK; + Ok::<_, BoxError>(response) + } + }; + + let service = BoxCloneSyncService::new(tower::service_fn(handler)); + let cache = MetadataCache::new(); + + // An inexistant issuer should fail + cache + .get(&service, "https://inexistant.example.com/", true) + .await + .unwrap_err(); + assert_eq!(calls.load(Ordering::SeqCst), 1); + + // A valid issuer should succeed + cache + .get(&service, "https://valid.example.com/", true) + .await + .unwrap(); + assert_eq!(calls.load(Ordering::SeqCst), 2); + + // Calling again should not trigger a new fetch + cache + .get(&service, "https://valid.example.com/", true) + .await + .unwrap(); + assert_eq!(calls.load(Ordering::SeqCst), 2); + + // An insecure issuer should work with insecure discovery + cache + .get(&service, "http://insecure.example.com/", false) + .await + .unwrap(); + assert_eq!(calls.load(Ordering::SeqCst), 3); + + // Doing it again shpoild not trigger a new fetch + cache + .get(&service, "http://insecure.example.com/", false) + .await + .unwrap(); + assert_eq!(calls.load(Ordering::SeqCst), 3); + + // But it should fail with secure discovery + // Note that it still fetched because secure and insecure caches are distinct + cache + .get(&service, "http://insecure.example.com/", true) + .await + .unwrap_err(); + assert_eq!(calls.load(Ordering::SeqCst), 4); + + // Calling refresh should refresh all the known valid issuers + cache.refresh_all(&service).await; + assert_eq!(calls.load(Ordering::SeqCst), 6); + } + + #[tokio::test] + async fn test_lazy_provider_infos() { + init_tracing(); + let calls = Arc::new(AtomicUsize::new(0)); + let closure_calls = Arc::clone(&calls); + let handler = move |req: Request| { + let calls = Arc::clone(&closure_calls); + async move { + calls.fetch_add(1, Ordering::SeqCst); + + let body = match req.uri().authority().unwrap().as_str() { + "valid.example.com" => Bytes::from_static( + br#"{ + "issuer": "https://valid.example.com/", + "authorization_endpoint": "https://valid.example.com/authorize", + "token_endpoint": "https://valid.example.com/token", + "jwks_uri": "https://valid.example.com/jwks", + "response_types_supported": [ + "code" + ], + "grant_types_supported": [ + "authorization_code" + ], + "subject_types_supported": [ + "public" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "scopes_supported": [ + "openid", + "profile", + "email" + ] + }"#, + ), + "insecure.example.com" => Bytes::from_static( + br#"{ + "issuer": "http://insecure.example.com/", + "authorization_endpoint": "http://insecure.example.com/authorize", + "token_endpoint": "http://insecure.example.com/token", + "jwks_uri": "http://insecure.example.com/jwks", + "response_types_supported": [ + "code" + ], + "grant_types_supported": [ + "authorization_code" + ], + "subject_types_supported": [ + "public" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "scopes_supported": [ + "openid", + "profile", + "email" + ] + }"#, + ), + _ => Bytes::default(), + }; + + let mut response = Response::new(body); + *response.status_mut() = StatusCode::OK; + Ok::<_, BoxError>(response) + } + }; + + let clock = MockClock::default(); + let service = BoxCloneSyncService::new(tower::service_fn(handler)); + let provider = UpstreamOAuthProvider { + id: Ulid::nil(), + issuer: "https://valid.example.com/".to_owned(), + discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc, + pkce_mode: UpstreamOAuthProviderPkceMode::Auto, + jwks_uri_override: None, + authorization_endpoint_override: None, + token_endpoint_override: None, + scope: Scope::from_iter([OPENID]), + client_id: "client_id".to_owned(), + encrypted_client_secret: None, + token_endpoint_signing_alg: None, + token_endpoint_auth_method: OAuthClientAuthenticationMethod::None, + created_at: clock.now(), + claims_imports: UpstreamOAuthProviderClaimsImports::default(), + }; + + // Without any override, it should just use discovery + { + let cache = MetadataCache::new(); + let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service); + assert_eq!(calls.load(Ordering::SeqCst), 0); + lazy_metadata.maybe_discover().await.unwrap(); + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert_eq!( + lazy_metadata + .authorization_endpoint() + .await + .unwrap() + .as_str(), + "https://valid.example.com/authorize" + ); + } + + // Test overriding endpoints + { + let provider = UpstreamOAuthProvider { + jwks_uri_override: Some("https://valid.example.com/jwks_override".parse().unwrap()), + authorization_endpoint_override: Some( + "https://valid.example.com/authorize_override" + .parse() + .unwrap(), + ), + token_endpoint_override: Some( + "https://valid.example.com/token_override".parse().unwrap(), + ), + ..provider.clone() + }; + let cache = MetadataCache::new(); + let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service); + assert_eq!( + lazy_metadata.jwks_uri().await.unwrap().as_str(), + "https://valid.example.com/jwks_override" + ); + assert_eq!( + lazy_metadata + .authorization_endpoint() + .await + .unwrap() + .as_str(), + "https://valid.example.com/authorize_override" + ); + assert_eq!( + lazy_metadata.token_endpoint().await.unwrap().as_str(), + "https://valid.example.com/token_override" + ); + // This shouldn't trigger a new fetch as the endpoint is overriden + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + // Insecure providers don't work with secure discovery + { + let provider = UpstreamOAuthProvider { + issuer: "http://insecure.example.com/".to_owned(), + ..provider.clone() + }; + let cache = MetadataCache::new(); + let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service); + lazy_metadata.authorization_endpoint().await.unwrap_err(); + // This triggered a fetch, even though it failed + assert_eq!(calls.load(Ordering::SeqCst), 2); + } + + // Insecure providers work with insecure discovery + { + let provider = UpstreamOAuthProvider { + issuer: "http://insecure.example.com/".to_owned(), + discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure, + ..provider.clone() + }; + let cache = MetadataCache::new(); + let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service); + assert_eq!( + lazy_metadata + .authorization_endpoint() + .await + .unwrap() + .as_str(), + "http://insecure.example.com/authorize" + ); + // This triggered a fetch + assert_eq!(calls.load(Ordering::SeqCst), 3); + } + + // Getting endpoints when discovery is disabled only works for overriden ones + { + let provider = UpstreamOAuthProvider { + discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled, + authorization_endpoint_override: Some( + Url::parse("https://valid.example.com/authorize_override").unwrap(), + ), + token_endpoint_override: None, + ..provider.clone() + }; + let cache = MetadataCache::new(); + let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &service); + // This should not fail, but also does nothing + assert!(lazy_metadata.maybe_discover().await.unwrap().is_none()); + assert_eq!( + lazy_metadata + .authorization_endpoint() + .await + .unwrap() + .as_str(), + "https://valid.example.com/authorize_override" + ); + assert!(matches!( + lazy_metadata.token_endpoint().await, + Err(DiscoveryError::Disabled), + )); + // This did not trigger a fetch + assert_eq!(calls.load(Ordering::SeqCst), 3); + } + } +} diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 425809e5..2b695844 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -37,7 +37,7 @@ use serde::Deserialize; use thiserror::Error; use ulid::Ulid; -use super::{client_credentials_for_provider, UpstreamSessionsCookie}; +use super::{cache::LazyProviderInfos, client_credentials_for_provider, UpstreamSessionsCookie}; use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache}; #[derive(Deserialize)] @@ -191,18 +191,17 @@ pub(crate) async fn get( }; let http_service = http_client_factory.http_service("upstream_oauth2.callback"); - - // Discover the provider - let metadata = metadata_cache.get(&http_service, &provider.issuer).await?; + let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service); // Fetch the JWKS let jwks = - mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?; + mas_oidc_client::requests::jose::fetch_jwks(&http_service, lazy_metadata.jwks_uri().await?) + .await?; // Figure out the client credentials let client_credentials = client_credentials_for_provider( &provider, - metadata.token_endpoint(), + lazy_metadata.token_endpoint().await?, &keystore, &encrypter, )?; @@ -229,7 +228,7 @@ pub(crate) async fn get( mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( &http_service, client_credentials, - metadata.token_endpoint(), + lazy_metadata.token_endpoint().await?, code, validation_data, Some(id_token_verification_data), diff --git a/crates/oidc-client/src/error.rs b/crates/oidc-client/src/error.rs index 4787b9c1..2aeb158a 100644 --- a/crates/oidc-client/src/error.rs +++ b/crates/oidc-client/src/error.rs @@ -93,6 +93,10 @@ pub enum DiscoveryError { /// An error occurred sending the request. #[error(transparent)] Service(BoxError), + + /// Discovery is disabled for this provider. + #[error("Discovery is disabled for this provider")] + Disabled, } impl From> for DiscoveryError diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 12641783..95040962 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -14,7 +14,10 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports}; +use mas_data_model::{ + UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, + UpstreamOAuthProviderPkceMode, +}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, @@ -99,6 +102,13 @@ impl TryFrom for UpstreamOAuthProvider { token_endpoint_signing_alg, created_at: value.created_at, claims_imports: value.claims_imports.0, + + // TODO + authorization_endpoint_override: None, + token_endpoint_override: None, + jwks_uri_override: None, + discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(), + pkce_mode: UpstreamOAuthProviderPkceMode::default(), }) } } @@ -213,6 +223,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_auth_method, created_at, claims_imports, + + // TODO + authorization_endpoint_override: None, + token_endpoint_override: None, + jwks_uri_override: None, + discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(), + pkce_mode: UpstreamOAuthProviderPkceMode::default(), }) } @@ -357,6 +374,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_auth_method, created_at, claims_imports, + + // TODO + authorization_endpoint_override: None, + token_endpoint_override: None, + jwks_uri_override: None, + discovery_mode: UpstreamOAuthProviderDiscoveryMode::default(), + pkce_mode: UpstreamOAuthProviderPkceMode::default(), }) }