diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index f8bf4182..7d8931e4 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -23,7 +23,6 @@ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::SessionInfoExt; use mas_data_model::{AuthorizationCode, Pkce}; -use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; @@ -35,8 +34,8 @@ use mas_templates::Templates; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, pkce, - prelude::*, requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode}, + response_type::ResponseType, }; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use serde::Deserialize; @@ -134,7 +133,7 @@ pub(crate) struct Params { /// figure out what response mode must be used, and emit an error if the /// suggested response mode isn't allowed for the given response types. fn resolve_response_mode( - response_type: OAuthAuthorizationEndpointResponseType, + response_type: &ResponseType, suggested_response_mode: Option, ) -> anyhow::Result { use ResponseMode as M; @@ -172,7 +171,7 @@ pub(crate) async fn get( .resolve_redirect_uri(¶ms.auth.redirect_uri)? .clone(); let response_type = params.auth.response_type; - let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?; + let response_mode = resolve_response_mode(&response_type, params.auth.response_mode)?; // Now we have a proper callback destination to go to on error let callback_destination = CallbackDestination::try_new( diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index f2983ab9..18d47b51 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -66,9 +66,9 @@ pub(crate) async fn get( let scopes_supported = Some(vec![scope::OPENID.to_string(), scope::EMAIL.to_string()]); let response_types_supported = Some(vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::IdToken, - OAuthAuthorizationEndpointResponseType::CodeIdToken, + OAuthAuthorizationEndpointResponseType::Code.into(), + OAuthAuthorizationEndpointResponseType::IdToken.into(), + OAuthAuthorizationEndpointResponseType::CodeIdToken.into(), ]); let response_modes_supported = Some(vec![ diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 9c48922e..9d8cf3a9 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -138,7 +138,7 @@ pub(crate) async fn post( &client_id, metadata.redirect_uris(), None, - metadata.response_types(), + &metadata.response_types(), metadata.grant_types(), contacts, metadata diff --git a/crates/oauth2-types/src/lib.rs b/crates/oauth2-types/src/lib.rs index b4c1a459..d8a3c674 100644 --- a/crates/oauth2-types/src/lib.rs +++ b/crates/oauth2-types/src/lib.rs @@ -17,47 +17,17 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; - -pub trait ResponseTypeExt { - fn has_code(&self) -> bool; - fn has_token(&self) -> bool; - fn has_id_token(&self) -> bool; -} - -impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType { - fn has_code(&self) -> bool { - matches!( - self, - Self::Code | Self::CodeToken | Self::CodeIdToken | Self::CodeIdTokenToken - ) - } - - fn has_token(&self) -> bool { - matches!( - self, - Self::Token | Self::CodeToken | Self::IdTokenToken | Self::CodeIdTokenToken - ) - } - - fn has_id_token(&self) -> bool { - matches!( - self, - Self::IdToken | Self::IdTokenToken | Self::CodeIdToken | Self::CodeIdTokenToken - ) - } -} - pub mod errors; pub mod oidc; pub mod pkce; pub mod registration; pub mod requests; +pub mod response_type; pub mod scope; pub mod webfinger; pub mod prelude { - pub use crate::{pkce::CodeChallengeMethodExt, ResponseTypeExt}; + pub use crate::pkce::CodeChallengeMethodExt; } #[cfg(test)] diff --git a/crates/oauth2-types/src/oidc.rs b/crates/oauth2-types/src/oidc.rs index d74181f3..f20f372c 100644 --- a/crates/oauth2-types/src/oidc.rs +++ b/crates/oauth2-types/src/oidc.rs @@ -17,10 +17,7 @@ use std::ops::Deref; use language_tags::LanguageTag; use mas_iana::{ jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}, - oauth::{ - OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod, - PkceCodeChallengeMethod, - }, + oauth::{OAuthClientAuthenticationMethod, PkceCodeChallengeMethod}, }; use parse_display::{Display, FromStr}; use serde::{Deserialize, Serialize}; @@ -28,7 +25,10 @@ use serde_with::{skip_serializing_none, DeserializeFromStr, SerializeDisplay}; use thiserror::Error; use url::Url; -use crate::requests::{Display, GrantType, Prompt, ResponseMode}; +use crate::{ + requests::{Display, GrantType, Prompt, ResponseMode}, + response_type::ResponseType, +}; #[derive( SerializeDisplay, DeserializeFromStr, Clone, Copy, PartialEq, Eq, Hash, Debug, Display, FromStr, @@ -128,7 +128,7 @@ pub struct ProviderMetadata { /// This field is required. /// /// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9 - pub response_types_supported: Option>, + pub response_types_supported: Option>, /// JSON array containing a list of the [OAuth 2.0 `response_mode` values] /// that this authorization server supports. @@ -707,7 +707,7 @@ impl VerifiedProviderMetadata { /// JSON array containing a list of the OAuth 2.0 `response_type` values /// that this authorization server supports. #[must_use] - pub fn response_types_supported(&self) -> &[OAuthAuthorizationEndpointResponseType] { + pub fn response_types_supported(&self) -> &[ResponseType] { match &self.response_types_supported { Some(u) => u, None => unreachable!(), @@ -934,7 +934,9 @@ mod tests { authorization_endpoint: Some(Url::parse("https://localhost/auth").unwrap()), token_endpoint: Some(Url::parse("https://localhost/token").unwrap()), jwks_uri: Some(Url::parse("https://localhost/jwks").unwrap()), - response_types_supported: Some(vec![OAuthAuthorizationEndpointResponseType::Code]), + response_types_supported: Some(vec![ + OAuthAuthorizationEndpointResponseType::Code.into() + ]), subject_types_supported: Some(vec![SubjectType::Public]), id_token_signing_alg_values_supported: Some(vec![JsonWebSignatureAlg::Rs256]), ..Default::default() @@ -1158,7 +1160,7 @@ mod tests { // Ok - Present metadata.response_types_supported = - Some(vec![OAuthAuthorizationEndpointResponseType::Code]); + Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]); metadata.validate(&issuer).unwrap(); } diff --git a/crates/oauth2-types/src/registration/client_metadata_serde.rs b/crates/oauth2-types/src/registration/client_metadata_serde.rs index da6a7150..892c6fd2 100644 --- a/crates/oauth2-types/src/registration/client_metadata_serde.rs +++ b/crates/oauth2-types/src/registration/client_metadata_serde.rs @@ -18,7 +18,7 @@ use chrono::Duration; use language_tags::LanguageTag; use mas_iana::{ jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}, - oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, + oauth::OAuthClientAuthenticationMethod, }; use mas_jose::jwk::PublicJsonWebKeySet; use serde::{ @@ -34,6 +34,7 @@ use super::{ClientMetadata, Localized, VerifiedClientMetadata}; use crate::{ oidc::{ApplicationType, SubjectType}, requests::GrantType, + response_type::ResponseType, }; impl Localized { @@ -94,7 +95,7 @@ impl Localized { #[derive(Serialize, Deserialize)] pub struct ClientMetadataSerdeHelper { redirect_uris: Option>, - response_types: Option>, + response_types: Option>, grant_types: Option>, application_type: Option, contacts: Option>, diff --git a/crates/oauth2-types/src/registration/mod.rs b/crates/oauth2-types/src/registration/mod.rs index b1011128..895b71e0 100644 --- a/crates/oauth2-types/src/registration/mod.rs +++ b/crates/oauth2-types/src/registration/mod.rs @@ -29,13 +29,14 @@ use url::Url; use crate::{ oidc::{ApplicationType, SubjectType}, requests::GrantType, + response_type::ResponseType, }; mod client_metadata_serde; use client_metadata_serde::ClientMetadataSerdeHelper; -pub const DEFAULT_RESPONSE_TYPES: &[OAuthAuthorizationEndpointResponseType] = - &[OAuthAuthorizationEndpointResponseType::Code]; +pub const DEFAULT_RESPONSE_TYPES: [OAuthAuthorizationEndpointResponseType; 1] = + [OAuthAuthorizationEndpointResponseType::Code]; pub const DEFAULT_GRANT_TYPES: &[GrantType] = &[GrantType::AuthorizationCode]; @@ -134,7 +135,7 @@ pub struct ClientMetadata { /// /// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9 /// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1 - pub response_types: Option>, + pub response_types: Option>, /// Array of [OAuth 2.0 `grant_type` values] that the client can use at the /// [token endpoint]. @@ -431,21 +432,18 @@ impl ClientMetadata { let has_authorization_code = grant_types.contains(&GrantType::AuthorizationCode); let has_both = has_implicit && has_authorization_code; - for response_type in response_types { - let is_ok = match response_type { - OAuthAuthorizationEndpointResponseType::Code => has_authorization_code, - OAuthAuthorizationEndpointResponseType::CodeIdToken - | OAuthAuthorizationEndpointResponseType::CodeIdTokenToken - | OAuthAuthorizationEndpointResponseType::CodeToken => has_both, - OAuthAuthorizationEndpointResponseType::IdToken - | OAuthAuthorizationEndpointResponseType::IdTokenToken - | OAuthAuthorizationEndpointResponseType::Token => has_implicit, - OAuthAuthorizationEndpointResponseType::None => true, - }; + for response_type in &response_types { + let has_code = response_type.has_code(); + let has_id_token = response_type.has_id_token(); + let has_token = response_type.has_token(); + let is_ok = has_code && has_both + || !has_code && has_implicit + || has_authorization_code && !has_id_token && !has_token + || !has_code && !has_id_token && !has_token; if !is_ok { return Err(ClientMetadataVerificationError::IncoherentResponseType( - *response_type, + response_type.clone(), )); } } @@ -489,11 +487,7 @@ impl ClientMetadata { } if self.id_token_signed_response_alg() == JsonWebSignatureAlg::None - && (response_types.contains(&OAuthAuthorizationEndpointResponseType::CodeIdToken) - || response_types - .contains(&OAuthAuthorizationEndpointResponseType::CodeIdTokenToken) - || response_types.contains(&OAuthAuthorizationEndpointResponseType::IdToken) - || response_types.contains(&OAuthAuthorizationEndpointResponseType::IdTokenToken)) + && response_types.iter().any(ResponseType::has_id_token) { return Err(ClientMetadataVerificationError::IdTokenSigningAlgNone); } @@ -547,10 +541,10 @@ impl ClientMetadata { /// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9 /// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1 #[must_use] - pub fn response_types(&self) -> &[OAuthAuthorizationEndpointResponseType] { + pub fn response_types(&self) -> Vec { self.response_types - .as_deref() - .unwrap_or(DEFAULT_RESPONSE_TYPES) + .clone() + .unwrap_or_else(|| DEFAULT_RESPONSE_TYPES.map(ResponseType::from).into()) } /// Array of [OAuth 2.0 `grant_type` values] that the client can use at the @@ -801,7 +795,7 @@ pub enum ClientMetadataVerificationError { /// The given response type is not compatible with the grant types. #[error("'{0}' response type not compatible with grant types")] - IncoherentResponseType(OAuthAuthorizationEndpointResponseType), + IncoherentResponseType(ResponseType), /// Both the `jwks_uri` and `jwks` fields are present but only one is /// allowed. @@ -865,7 +859,7 @@ mod tests { use url::Url; use super::{ClientMetadata, ClientMetadataVerificationError}; - use crate::requests::GrantType; + use crate::{requests::GrantType, response_type::ResponseType}; fn valid_client_metadata() -> ClientMetadata { ClientMetadata { @@ -934,173 +928,192 @@ mod tests { // grant_type = authorization_code // code - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]); metadata.clone().validate().unwrap(); // code id_token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::CodeIdToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code id_token token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // id_token - Err - let response_type = OAuthAuthorizationEndpointResponseType::IdToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::IdToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // id_token token - Err - let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::IdTokenToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // token - Err - let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::IdTokenToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // none - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]); metadata.clone().validate().unwrap(); // grant_type = implicit metadata.grant_types = Some(vec![GrantType::Implicit]); // code - Err - let response_type = OAuthAuthorizationEndpointResponseType::Code; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Code.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code id_token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::CodeIdToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code id_token token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // id_token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]); + metadata.response_types = + Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]); metadata.clone().validate().unwrap(); // id_token token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]); + metadata.response_types = Some(vec![ + OAuthAuthorizationEndpointResponseType::IdTokenToken.into() + ]); metadata.clone().validate().unwrap(); // token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token.into()]); metadata.clone().validate().unwrap(); // none - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]); metadata.clone().validate().unwrap(); // grant_types = [authorization_code, implicit] metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]); // code - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]); metadata.clone().validate().unwrap(); // code id_token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeIdToken]); + metadata.response_types = Some(vec![ + OAuthAuthorizationEndpointResponseType::CodeIdToken.into() + ]); metadata.clone().validate().unwrap(); // code id_token token - Ok metadata.response_types = Some(vec![ - OAuthAuthorizationEndpointResponseType::CodeIdTokenToken, + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(), ]); metadata.clone().validate().unwrap(); // code token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeToken]); + metadata.response_types = Some(vec![ + OAuthAuthorizationEndpointResponseType::CodeToken.into() + ]); metadata.clone().validate().unwrap(); // id_token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]); + metadata.response_types = + Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]); metadata.clone().validate().unwrap(); // id_token token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]); + metadata.response_types = Some(vec![ + OAuthAuthorizationEndpointResponseType::IdTokenToken.into() + ]); metadata.clone().validate().unwrap(); // token - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token.into()]); metadata.clone().validate().unwrap(); // none - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]); metadata.clone().validate().unwrap(); // other grant_types metadata.grant_types = Some(vec![GrantType::RefreshToken, GrantType::ClientCredentials]); // code - Err - let response_type = OAuthAuthorizationEndpointResponseType::Code; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Code.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code id_token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::CodeIdToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code id_token token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // code token - Err - let response_type = OAuthAuthorizationEndpointResponseType::CodeToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // id_token - Err - let response_type = OAuthAuthorizationEndpointResponseType::IdToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::IdToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // id_token token - Err - let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = + OAuthAuthorizationEndpointResponseType::IdTokenToken.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // token - Err - let response_type = OAuthAuthorizationEndpointResponseType::Token; - metadata.response_types = Some(vec![response_type]); + let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Token.into(); + metadata.response_types = Some(vec![response_type.clone()]); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); assert_eq!(res, response_type); // none - Ok - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); + metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]); metadata.validate().unwrap(); } @@ -1206,7 +1219,9 @@ mod tests { metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]); // Err - code id_token - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeIdToken]); + metadata.response_types = Some(vec![ + OAuthAuthorizationEndpointResponseType::CodeIdToken.into() + ]); assert_matches!( metadata.clone().validate(), Err(ClientMetadataVerificationError::IdTokenSigningAlgNone) @@ -1214,7 +1229,7 @@ mod tests { // Err - code id_token token metadata.response_types = Some(vec![ - OAuthAuthorizationEndpointResponseType::CodeIdTokenToken, + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(), ]); assert_matches!( metadata.clone().validate(), @@ -1222,14 +1237,17 @@ mod tests { ); // Err - id_token - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]); + metadata.response_types = + Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]); assert_matches!( metadata.clone().validate(), Err(ClientMetadataVerificationError::IdTokenSigningAlgNone) ); // Err - id_token token - metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]); + metadata.response_types = Some(vec![ + OAuthAuthorizationEndpointResponseType::IdTokenToken.into() + ]); assert_matches!( metadata.clone().validate(), Err(ClientMetadataVerificationError::IdTokenSigningAlgNone) @@ -1237,10 +1255,10 @@ mod tests { // Ok - Other response types metadata.response_types = Some(vec![ - OAuthAuthorizationEndpointResponseType::Code, - OAuthAuthorizationEndpointResponseType::CodeToken, - OAuthAuthorizationEndpointResponseType::Token, - OAuthAuthorizationEndpointResponseType::None, + OAuthAuthorizationEndpointResponseType::Code.into(), + OAuthAuthorizationEndpointResponseType::CodeToken.into(), + OAuthAuthorizationEndpointResponseType::Token.into(), + OAuthAuthorizationEndpointResponseType::None.into(), ]); metadata.validate().unwrap(); } diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index 3e2e8f4f..be3f25bf 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -16,9 +16,7 @@ use std::{collections::HashSet, hash::Hash, num::NonZeroU32}; use chrono::{DateTime, Duration, Utc}; use language_tags::LanguageTag; -use mas_iana::oauth::{ - OAuthAccessTokenType, OAuthAuthorizationEndpointResponseType, OAuthTokenTypeHint, -}; +use mas_iana::oauth::{OAuthAccessTokenType, OAuthTokenTypeHint}; use parse_display::{Display, FromStr}; use serde::{Deserialize, Serialize}; use serde_with::{ @@ -27,7 +25,7 @@ use serde_with::{ }; use url::Url; -use crate::scope::Scope; +use crate::{response_type::ResponseType, scope::Scope}; // ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml @@ -170,7 +168,7 @@ pub enum Prompt { pub struct AuthorizationRequest { /// OAuth 2.0 Response Type value that determines the authorization /// processing flow to be used. - pub response_type: OAuthAuthorizationEndpointResponseType, + pub response_type: ResponseType, /// OAuth 2.0 Client Identifier valid at the Authorization Server. pub client_id: String, @@ -264,11 +262,7 @@ pub struct AuthorizationRequest { impl AuthorizationRequest { /// Creates a basic `AuthorizationRequest`. #[must_use] - pub fn new( - response_type: OAuthAuthorizationEndpointResponseType, - client_id: String, - scope: Scope, - ) -> Self { + pub fn new(response_type: ResponseType, client_id: String, scope: Scope) -> Self { Self { response_type, client_id, diff --git a/crates/oauth2-types/src/response_type.rs b/crates/oauth2-types/src/response_type.rs new file mode 100644 index 00000000..1aab1639 --- /dev/null +++ b/crates/oauth2-types/src/response_type.rs @@ -0,0 +1,489 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(clippy::module_name_repetitions)] + +use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr}; + +use itertools::Itertools; +use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; +use parse_display::{Display, FromStr}; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use thiserror::Error; + +/// An error encountered when trying to parse an invalid [`ResponseType`]. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +#[error("invalid response type")] +pub struct InvalidResponseType; + +/// The accepted tokens in a [`ResponseType`]. +/// +/// `none` is not in this enum because it is represented by an empty +/// [`ResponseType`]. +/// +/// This type also accepts unknown tokens that can be constructed via it's +/// `FromStr` implementation or used via its `Display` implementation. +#[derive( + Debug, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Display, + FromStr, + SerializeDisplay, + DeserializeFromStr, +)] +#[display(style = "snake_case")] +#[non_exhaustive] +pub enum ResponseTypeToken { + /// `code` + Code, + + /// `id_token` + IdToken, + + /// `token` + Token, + + /// Unknown token. + #[display("{0}")] + Unknown(String), +} + +/// An [OAuth 2.0 `response_type` value] that the client can use +/// at the [authorization endpoint]. +/// +/// It is recommended to construct this type from an +/// [`OAuthAuthorizationEndpointResponseType`]. +/// +/// [OAuth 2.0 `response_type` value]: https://www.rfc-editor.org/rfc/rfc7591#page-9 +/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1 +#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr)] +pub struct ResponseType(BTreeSet); + +impl std::ops::Deref for ResponseType { + type Target = BTreeSet; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl ResponseType { + /// Whether this response type requests a code. + #[must_use] + pub fn has_code(&self) -> bool { + self.0.contains(&ResponseTypeToken::Code) + } + + /// Whether this response type requests an ID token. + #[must_use] + pub fn has_id_token(&self) -> bool { + self.0.contains(&ResponseTypeToken::IdToken) + } + + /// Whether this response type requests a token. + #[must_use] + pub fn has_token(&self) -> bool { + self.0.contains(&ResponseTypeToken::Token) + } +} + +impl FromStr for ResponseType { + type Err = InvalidResponseType; + + fn from_str(s: &str) -> Result { + let s = s.trim(); + + if s.is_empty() { + Err(InvalidResponseType) + } else if s == "none" { + Ok(Self(BTreeSet::new())) + } else { + s.split_ascii_whitespace() + .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType))) + .collect::>() + } + } +} + +impl fmt::Display for ResponseType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let res = Itertools::intersperse(self.iter().map(ToString::to_string), ' '.to_string()) + .collect::(); + + if res.is_empty() { + write!(f, "none") + } else { + f.write_str(&res) + } + } +} + +impl FromIterator for ResponseType { + fn from_iter>(iter: T) -> Self { + Self(BTreeSet::from_iter(iter)) + } +} + +impl From for ResponseType { + fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self { + match response_type { + OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()), + OAuthAuthorizationEndpointResponseType::CodeIdToken => { + Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into()) + } + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self( + [ + ResponseTypeToken::Code, + ResponseTypeToken::IdToken, + ResponseTypeToken::Token, + ] + .into(), + ), + OAuthAuthorizationEndpointResponseType::CodeToken => { + Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into()) + } + OAuthAuthorizationEndpointResponseType::IdToken => { + Self([ResponseTypeToken::IdToken].into()) + } + OAuthAuthorizationEndpointResponseType::IdTokenToken => { + Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into()) + } + OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()), + OAuthAuthorizationEndpointResponseType::Token => { + Self([ResponseTypeToken::Token].into()) + } + } + } +} + +impl TryFrom for OAuthAuthorizationEndpointResponseType { + type Error = InvalidResponseType; + + fn try_from(response_type: ResponseType) -> Result { + if response_type + .iter() + .any(|t| matches!(t, ResponseTypeToken::Unknown(_))) + { + return Err(InvalidResponseType); + } + + let tokens = response_type.iter().collect::>(); + let res = match *tokens { + [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code, + [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken, + [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token, + [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => { + OAuthAuthorizationEndpointResponseType::CodeIdToken + } + [ResponseTypeToken::Code, ResponseTypeToken::Token] => { + OAuthAuthorizationEndpointResponseType::CodeToken + } + [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => { + OAuthAuthorizationEndpointResponseType::IdTokenToken + } + [ResponseTypeToken::Code, ResponseTypeToken::IdToken, ResponseTypeToken::Token] => { + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken + } + _ => OAuthAuthorizationEndpointResponseType::None, + }; + + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialize_response_type_token() { + assert_eq!( + serde_json::from_str::("\"code\"").unwrap(), + ResponseTypeToken::Code + ); + assert_eq!( + serde_json::from_str::("\"id_token\"").unwrap(), + ResponseTypeToken::IdToken + ); + assert_eq!( + serde_json::from_str::("\"token\"").unwrap(), + ResponseTypeToken::Token + ); + assert_eq!( + serde_json::from_str::("\"something_unsupported\"").unwrap(), + ResponseTypeToken::Unknown("something_unsupported".to_owned()) + ); + } + + #[test] + fn serialize_response_type_token() { + assert_eq!( + serde_json::to_string(&ResponseTypeToken::Code).unwrap(), + "\"code\"" + ); + assert_eq!( + serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(), + "\"id_token\"" + ); + assert_eq!( + serde_json::to_string(&ResponseTypeToken::Token).unwrap(), + "\"token\"" + ); + assert_eq!( + serde_json::to_string(&ResponseTypeToken::Unknown( + "something_unsupported".to_owned() + )) + .unwrap(), + "\"something_unsupported\"" + ); + } + + #[test] + #[allow(clippy::too_many_lines)] + fn deserialize_response_type() { + serde_json::from_str::("\"\"").unwrap_err(); + + let res_type = serde_json::from_str::("\"none\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::None + ); + + let res_type = serde_json::from_str::("\"code\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::Code + ); + + let res_type = serde_json::from_str::("\"code\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::Code + ); + + let res_type = serde_json::from_str::("\"id_token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::IdToken + ); + + let res_type = serde_json::from_str::("\"token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::Token + ); + + let res_type = serde_json::from_str::("\"something_unsupported\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!( + iter.next(), + Some(&ResponseTypeToken::Unknown( + "something_unsupported".to_owned() + )) + ); + assert_eq!(iter.next(), None); + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err(); + + let res_type = serde_json::from_str::("\"code id_token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::CodeIdToken + ); + + let res_type = serde_json::from_str::("\"code token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::CodeToken + ); + + let res_type = serde_json::from_str::("\"id_token token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::IdTokenToken + ); + + let res_type = serde_json::from_str::("\"code id_token token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken + ); + + let res_type = + serde_json::from_str::("\"code id_token token something_unsupported\"") + .unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!( + iter.next(), + Some(&ResponseTypeToken::Unknown( + "something_unsupported".to_owned() + )) + ); + assert_eq!(iter.next(), None); + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err(); + + // Order doesn't matter + let res_type = serde_json::from_str::("\"token code id_token\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken + ); + + let res_type = + serde_json::from_str::("\"id_token token id_token code\"").unwrap(); + let mut iter = res_type.iter(); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Code)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken)); + assert_eq!(iter.next(), Some(&ResponseTypeToken::Token)); + assert_eq!(iter.next(), None); + assert_eq!( + OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(), + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken + ); + } + + #[test] + fn serialize_response_type() { + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::None + )) + .unwrap(), + "\"none\"" + ); + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::Code + )) + .unwrap(), + "\"code\"" + ); + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::IdToken + )) + .unwrap(), + "\"id_token\"" + ); + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::CodeIdToken + )) + .unwrap(), + "\"code id_token\"" + ); + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::CodeToken + )) + .unwrap(), + "\"code token\"" + ); + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::IdTokenToken + )) + .unwrap(), + "\"id_token token\"" + ); + assert_eq!( + serde_json::to_string(&ResponseType::from( + OAuthAuthorizationEndpointResponseType::CodeIdTokenToken + )) + .unwrap(), + "\"code id_token token\"" + ); + + assert_eq!( + serde_json::to_string( + &[ + ResponseTypeToken::Unknown("something_unsupported".to_owned()), + ResponseTypeToken::Code + ] + .into_iter() + .collect::() + ) + .unwrap(), + "\"code something_unsupported\"" + ); + + // Order doesn't matter. + let res = [ + ResponseTypeToken::IdToken, + ResponseTypeToken::Token, + ResponseTypeToken::Code, + ] + .into_iter() + .collect::(); + assert_eq!( + serde_json::to_string(&res).unwrap(), + "\"code id_token token\"" + ); + + let res = [ + ResponseTypeToken::Code, + ResponseTypeToken::Token, + ResponseTypeToken::IdToken, + ] + .into_iter() + .collect::(); + assert_eq!( + serde_json::to_string(&res).unwrap(), + "\"code id_token token\"" + ); + } +} diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 992bcb46..9357a7ab 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -20,7 +20,7 @@ use mas_iana::{ oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, }; use mas_jose::jwk::PublicJsonWebKeySet; -use oauth2_types::requests::GrantType; +use oauth2_types::{requests::GrantType, response_type::ResponseType}; use sqlx::{PgConnection, PgExecutor}; use thiserror::Error; use url::Url; @@ -322,7 +322,7 @@ pub async fn insert_client( client_id: &str, redirect_uris: &[Url], encrypted_client_secret: Option<&str>, - response_types: &[OAuthAuthorizationEndpointResponseType], + response_types: &[ResponseType], grant_types: &[GrantType], contacts: &[String], client_name: Option<&str>,