diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 878a9e09..2c5853f7 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -75,7 +75,8 @@ pub(crate) async fn get( let token_endpoint_auth_signing_alg_values_supported = client_auth_signing_alg_values_supported.clone(); - let introspection_endpoint_auth_methods_supported = client_auth_methods_supported; + let introspection_endpoint_auth_methods_supported = + client_auth_methods_supported.map(|v| v.into_iter().map(Into::into).collect()); let introspection_endpoint_auth_signing_alg_values_supported = client_auth_signing_alg_values_supported; diff --git a/crates/oauth2-types/src/oidc.rs b/crates/oauth2-types/src/oidc.rs index 245f5c1d..3a43f84f 100644 --- a/crates/oauth2-types/src/oidc.rs +++ b/crates/oauth2-types/src/oidc.rs @@ -17,7 +17,7 @@ use std::ops::Deref; use language_tags::LanguageTag; use mas_iana::{ jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}, - oauth::{OAuthClientAuthenticationMethod, PkceCodeChallengeMethod}, + oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod}, }; use parse_display::{Display, FromStr}; use serde::{Deserialize, Serialize}; @@ -30,6 +30,55 @@ use crate::{ response_type::ResponseType, }; +/// An enum for types that accept either an [`OAuthClientAuthenticationMethod`] +/// or an [`OAuthAccessTokenType`]. +#[derive( + SerializeDisplay, DeserializeFromStr, Clone, Copy, PartialEq, Eq, Hash, Debug, Display, FromStr, +)] +pub enum AuthenticationMethodOrAccessTokenType { + /// An authentication method. + #[display("{0}")] + AuthenticationMethod(OAuthClientAuthenticationMethod), + + /// An access token type. + #[display("{0}")] + AccessTokenType(OAuthAccessTokenType), +} + +impl AuthenticationMethodOrAccessTokenType { + /// Get the authentication method of this + /// `AuthenticationMethodOrAccessTokenType`. + #[must_use] + pub fn authentication_method(&self) -> Option { + match self { + Self::AuthenticationMethod(m) => Some(*m), + Self::AccessTokenType(_) => None, + } + } + + /// Get the access token type of this + /// `AuthenticationMethodOrAccessTokenType`. + #[must_use] + pub fn access_token_type(&self) -> Option { + match self { + Self::AuthenticationMethod(_) => None, + Self::AccessTokenType(t) => Some(*t), + } + } +} + +impl From for AuthenticationMethodOrAccessTokenType { + fn from(t: OAuthClientAuthenticationMethod) -> Self { + Self::AuthenticationMethod(t) + } +} + +impl From for AuthenticationMethodOrAccessTokenType { + fn from(t: OAuthAccessTokenType) -> Self { + Self::AccessTokenType(t) + } +} + #[derive( SerializeDisplay, DeserializeFromStr, Clone, Copy, PartialEq, Eq, Hash, Debug, Display, FromStr, )] @@ -214,9 +263,10 @@ pub struct ProviderMetadata { /// [OAuth 2.0 introspection endpoint]: https://www.rfc-editor.org/rfc/rfc7662 pub introspection_endpoint: Option, - /// JSON array containing a list of client authentication methods supported - /// by this introspection endpoint. - pub introspection_endpoint_auth_methods_supported: Option>, + /// JSON array containing a list of client authentication methods or token + /// types supported by this introspection endpoint. + pub introspection_endpoint_auth_methods_supported: + Option>, /// JSON array containing a list of the JWS signing algorithms supported by /// the introspection endpoint for the signature on the JWT used to @@ -438,10 +488,20 @@ impl ProviderMetadata { validate_url("introspection_endpoint", url, ExtraUrlRestrictions::None)?; } + // The list can also contain token types so remove them as we don't need to + // check them. + let introspection_methods = metadata + .introspection_endpoint_auth_methods_supported + .as_ref() + .map(|v| { + v.iter() + .filter_map(AuthenticationMethodOrAccessTokenType::authentication_method) + .collect::>() + }); validate_signing_alg_values_supported( "introspection_endpoint", &metadata.introspection_endpoint_auth_signing_alg_values_supported, - &metadata.introspection_endpoint_auth_methods_supported, + &introspection_methods, )?; if let Some(url) = &metadata.userinfo_endpoint {