1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-28 11:02:02 +03:00

Use ResponseType that doesn't care about tokens order

This commit is contained in:
Kévin Commaille
2022-09-09 14:52:59 +02:00
committed by Quentin Gliech
parent f5715018a6
commit fca6cfa393
10 changed files with 618 additions and 145 deletions

View File

@ -23,7 +23,6 @@ use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt; use mas_axum_utils::SessionInfoExt;
use mas_data_model::{AuthorizationCode, Pkce}; use mas_data_model::{AuthorizationCode, Pkce};
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
@ -35,8 +34,8 @@ use mas_templates::Templates;
use oauth2_types::{ use oauth2_types::{
errors::{ClientError, ClientErrorCode}, errors::{ClientError, ClientErrorCode},
pkce, pkce,
prelude::*,
requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode}, requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
response_type::ResponseType,
}; };
use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::Deserialize; use serde::Deserialize;
@ -134,7 +133,7 @@ pub(crate) struct Params {
/// figure out what response mode must be used, and emit an error if the /// figure out what response mode must be used, and emit an error if the
/// suggested response mode isn't allowed for the given response types. /// suggested response mode isn't allowed for the given response types.
fn resolve_response_mode( fn resolve_response_mode(
response_type: OAuthAuthorizationEndpointResponseType, response_type: &ResponseType,
suggested_response_mode: Option<ResponseMode>, suggested_response_mode: Option<ResponseMode>,
) -> anyhow::Result<ResponseMode> { ) -> anyhow::Result<ResponseMode> {
use ResponseMode as M; use ResponseMode as M;
@ -172,7 +171,7 @@ pub(crate) async fn get(
.resolve_redirect_uri(&params.auth.redirect_uri)? .resolve_redirect_uri(&params.auth.redirect_uri)?
.clone(); .clone();
let response_type = params.auth.response_type; let response_type = params.auth.response_type;
let response_mode = resolve_response_mode(response_type, params.auth.response_mode)?; let response_mode = resolve_response_mode(&response_type, params.auth.response_mode)?;
// Now we have a proper callback destination to go to on error // Now we have a proper callback destination to go to on error
let callback_destination = CallbackDestination::try_new( let callback_destination = CallbackDestination::try_new(

View File

@ -66,9 +66,9 @@ pub(crate) async fn get(
let scopes_supported = Some(vec![scope::OPENID.to_string(), scope::EMAIL.to_string()]); let scopes_supported = Some(vec![scope::OPENID.to_string(), scope::EMAIL.to_string()]);
let response_types_supported = Some(vec![ let response_types_supported = Some(vec![
OAuthAuthorizationEndpointResponseType::Code, OAuthAuthorizationEndpointResponseType::Code.into(),
OAuthAuthorizationEndpointResponseType::IdToken, OAuthAuthorizationEndpointResponseType::IdToken.into(),
OAuthAuthorizationEndpointResponseType::CodeIdToken, OAuthAuthorizationEndpointResponseType::CodeIdToken.into(),
]); ]);
let response_modes_supported = Some(vec![ let response_modes_supported = Some(vec![

View File

@ -138,7 +138,7 @@ pub(crate) async fn post(
&client_id, &client_id,
metadata.redirect_uris(), metadata.redirect_uris(),
None, None,
metadata.response_types(), &metadata.response_types(),
metadata.grant_types(), metadata.grant_types(),
contacts, contacts,
metadata metadata

View File

@ -17,47 +17,17 @@
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)] #![allow(clippy::module_name_repetitions)]
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
pub trait ResponseTypeExt {
fn has_code(&self) -> bool;
fn has_token(&self) -> bool;
fn has_id_token(&self) -> bool;
}
impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType {
fn has_code(&self) -> bool {
matches!(
self,
Self::Code | Self::CodeToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
fn has_token(&self) -> bool {
matches!(
self,
Self::Token | Self::CodeToken | Self::IdTokenToken | Self::CodeIdTokenToken
)
}
fn has_id_token(&self) -> bool {
matches!(
self,
Self::IdToken | Self::IdTokenToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
}
pub mod errors; pub mod errors;
pub mod oidc; pub mod oidc;
pub mod pkce; pub mod pkce;
pub mod registration; pub mod registration;
pub mod requests; pub mod requests;
pub mod response_type;
pub mod scope; pub mod scope;
pub mod webfinger; pub mod webfinger;
pub mod prelude { pub mod prelude {
pub use crate::{pkce::CodeChallengeMethodExt, ResponseTypeExt}; pub use crate::pkce::CodeChallengeMethodExt;
} }
#[cfg(test)] #[cfg(test)]

View File

@ -17,10 +17,7 @@ use std::ops::Deref;
use language_tags::LanguageTag; use language_tags::LanguageTag;
use mas_iana::{ use mas_iana::{
jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}, jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg},
oauth::{ oauth::{OAuthClientAuthenticationMethod, PkceCodeChallengeMethod},
OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod,
PkceCodeChallengeMethod,
},
}; };
use parse_display::{Display, FromStr}; use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -28,7 +25,10 @@ use serde_with::{skip_serializing_none, DeserializeFromStr, SerializeDisplay};
use thiserror::Error; use thiserror::Error;
use url::Url; use url::Url;
use crate::requests::{Display, GrantType, Prompt, ResponseMode}; use crate::{
requests::{Display, GrantType, Prompt, ResponseMode},
response_type::ResponseType,
};
#[derive( #[derive(
SerializeDisplay, DeserializeFromStr, Clone, Copy, PartialEq, Eq, Hash, Debug, Display, FromStr, SerializeDisplay, DeserializeFromStr, Clone, Copy, PartialEq, Eq, Hash, Debug, Display, FromStr,
@ -128,7 +128,7 @@ pub struct ProviderMetadata {
/// This field is required. /// This field is required.
/// ///
/// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9 /// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9
pub response_types_supported: Option<Vec<OAuthAuthorizationEndpointResponseType>>, pub response_types_supported: Option<Vec<ResponseType>>,
/// JSON array containing a list of the [OAuth 2.0 `response_mode` values] /// JSON array containing a list of the [OAuth 2.0 `response_mode` values]
/// that this authorization server supports. /// that this authorization server supports.
@ -707,7 +707,7 @@ impl VerifiedProviderMetadata {
/// JSON array containing a list of the OAuth 2.0 `response_type` values /// JSON array containing a list of the OAuth 2.0 `response_type` values
/// that this authorization server supports. /// that this authorization server supports.
#[must_use] #[must_use]
pub fn response_types_supported(&self) -> &[OAuthAuthorizationEndpointResponseType] { pub fn response_types_supported(&self) -> &[ResponseType] {
match &self.response_types_supported { match &self.response_types_supported {
Some(u) => u, Some(u) => u,
None => unreachable!(), None => unreachable!(),
@ -934,7 +934,9 @@ mod tests {
authorization_endpoint: Some(Url::parse("https://localhost/auth").unwrap()), authorization_endpoint: Some(Url::parse("https://localhost/auth").unwrap()),
token_endpoint: Some(Url::parse("https://localhost/token").unwrap()), token_endpoint: Some(Url::parse("https://localhost/token").unwrap()),
jwks_uri: Some(Url::parse("https://localhost/jwks").unwrap()), jwks_uri: Some(Url::parse("https://localhost/jwks").unwrap()),
response_types_supported: Some(vec![OAuthAuthorizationEndpointResponseType::Code]), response_types_supported: Some(vec![
OAuthAuthorizationEndpointResponseType::Code.into()
]),
subject_types_supported: Some(vec![SubjectType::Public]), subject_types_supported: Some(vec![SubjectType::Public]),
id_token_signing_alg_values_supported: Some(vec![JsonWebSignatureAlg::Rs256]), id_token_signing_alg_values_supported: Some(vec![JsonWebSignatureAlg::Rs256]),
..Default::default() ..Default::default()
@ -1158,7 +1160,7 @@ mod tests {
// Ok - Present // Ok - Present
metadata.response_types_supported = metadata.response_types_supported =
Some(vec![OAuthAuthorizationEndpointResponseType::Code]); Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]);
metadata.validate(&issuer).unwrap(); metadata.validate(&issuer).unwrap();
} }

View File

@ -18,7 +18,7 @@ use chrono::Duration;
use language_tags::LanguageTag; use language_tags::LanguageTag;
use mas_iana::{ use mas_iana::{
jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg}, jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg},
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, oauth::OAuthClientAuthenticationMethod,
}; };
use mas_jose::jwk::PublicJsonWebKeySet; use mas_jose::jwk::PublicJsonWebKeySet;
use serde::{ use serde::{
@ -34,6 +34,7 @@ use super::{ClientMetadata, Localized, VerifiedClientMetadata};
use crate::{ use crate::{
oidc::{ApplicationType, SubjectType}, oidc::{ApplicationType, SubjectType},
requests::GrantType, requests::GrantType,
response_type::ResponseType,
}; };
impl<T> Localized<T> { impl<T> Localized<T> {
@ -94,7 +95,7 @@ impl<T> Localized<T> {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ClientMetadataSerdeHelper { pub struct ClientMetadataSerdeHelper {
redirect_uris: Option<Vec<Url>>, redirect_uris: Option<Vec<Url>>,
response_types: Option<Vec<OAuthAuthorizationEndpointResponseType>>, response_types: Option<Vec<ResponseType>>,
grant_types: Option<Vec<GrantType>>, grant_types: Option<Vec<GrantType>>,
application_type: Option<ApplicationType>, application_type: Option<ApplicationType>,
contacts: Option<Vec<String>>, contacts: Option<Vec<String>>,

View File

@ -29,13 +29,14 @@ use url::Url;
use crate::{ use crate::{
oidc::{ApplicationType, SubjectType}, oidc::{ApplicationType, SubjectType},
requests::GrantType, requests::GrantType,
response_type::ResponseType,
}; };
mod client_metadata_serde; mod client_metadata_serde;
use client_metadata_serde::ClientMetadataSerdeHelper; use client_metadata_serde::ClientMetadataSerdeHelper;
pub const DEFAULT_RESPONSE_TYPES: &[OAuthAuthorizationEndpointResponseType] = pub const DEFAULT_RESPONSE_TYPES: [OAuthAuthorizationEndpointResponseType; 1] =
&[OAuthAuthorizationEndpointResponseType::Code]; [OAuthAuthorizationEndpointResponseType::Code];
pub const DEFAULT_GRANT_TYPES: &[GrantType] = &[GrantType::AuthorizationCode]; pub const DEFAULT_GRANT_TYPES: &[GrantType] = &[GrantType::AuthorizationCode];
@ -134,7 +135,7 @@ pub struct ClientMetadata {
/// ///
/// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9 /// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9
/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1 /// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
pub response_types: Option<Vec<OAuthAuthorizationEndpointResponseType>>, pub response_types: Option<Vec<ResponseType>>,
/// Array of [OAuth 2.0 `grant_type` values] that the client can use at the /// Array of [OAuth 2.0 `grant_type` values] that the client can use at the
/// [token endpoint]. /// [token endpoint].
@ -431,21 +432,18 @@ impl ClientMetadata {
let has_authorization_code = grant_types.contains(&GrantType::AuthorizationCode); let has_authorization_code = grant_types.contains(&GrantType::AuthorizationCode);
let has_both = has_implicit && has_authorization_code; let has_both = has_implicit && has_authorization_code;
for response_type in response_types { for response_type in &response_types {
let is_ok = match response_type { let has_code = response_type.has_code();
OAuthAuthorizationEndpointResponseType::Code => has_authorization_code, let has_id_token = response_type.has_id_token();
OAuthAuthorizationEndpointResponseType::CodeIdToken let has_token = response_type.has_token();
| OAuthAuthorizationEndpointResponseType::CodeIdTokenToken let is_ok = has_code && has_both
| OAuthAuthorizationEndpointResponseType::CodeToken => has_both, || !has_code && has_implicit
OAuthAuthorizationEndpointResponseType::IdToken || has_authorization_code && !has_id_token && !has_token
| OAuthAuthorizationEndpointResponseType::IdTokenToken || !has_code && !has_id_token && !has_token;
| OAuthAuthorizationEndpointResponseType::Token => has_implicit,
OAuthAuthorizationEndpointResponseType::None => true,
};
if !is_ok { if !is_ok {
return Err(ClientMetadataVerificationError::IncoherentResponseType( return Err(ClientMetadataVerificationError::IncoherentResponseType(
*response_type, response_type.clone(),
)); ));
} }
} }
@ -489,11 +487,7 @@ impl ClientMetadata {
} }
if self.id_token_signed_response_alg() == JsonWebSignatureAlg::None if self.id_token_signed_response_alg() == JsonWebSignatureAlg::None
&& (response_types.contains(&OAuthAuthorizationEndpointResponseType::CodeIdToken) && response_types.iter().any(ResponseType::has_id_token)
|| response_types
.contains(&OAuthAuthorizationEndpointResponseType::CodeIdTokenToken)
|| response_types.contains(&OAuthAuthorizationEndpointResponseType::IdToken)
|| response_types.contains(&OAuthAuthorizationEndpointResponseType::IdTokenToken))
{ {
return Err(ClientMetadataVerificationError::IdTokenSigningAlgNone); return Err(ClientMetadataVerificationError::IdTokenSigningAlgNone);
} }
@ -547,10 +541,10 @@ impl ClientMetadata {
/// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9 /// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9
/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1 /// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
#[must_use] #[must_use]
pub fn response_types(&self) -> &[OAuthAuthorizationEndpointResponseType] { pub fn response_types(&self) -> Vec<ResponseType> {
self.response_types self.response_types
.as_deref() .clone()
.unwrap_or(DEFAULT_RESPONSE_TYPES) .unwrap_or_else(|| DEFAULT_RESPONSE_TYPES.map(ResponseType::from).into())
} }
/// Array of [OAuth 2.0 `grant_type` values] that the client can use at the /// Array of [OAuth 2.0 `grant_type` values] that the client can use at the
@ -801,7 +795,7 @@ pub enum ClientMetadataVerificationError {
/// The given response type is not compatible with the grant types. /// The given response type is not compatible with the grant types.
#[error("'{0}' response type not compatible with grant types")] #[error("'{0}' response type not compatible with grant types")]
IncoherentResponseType(OAuthAuthorizationEndpointResponseType), IncoherentResponseType(ResponseType),
/// Both the `jwks_uri` and `jwks` fields are present but only one is /// Both the `jwks_uri` and `jwks` fields are present but only one is
/// allowed. /// allowed.
@ -865,7 +859,7 @@ mod tests {
use url::Url; use url::Url;
use super::{ClientMetadata, ClientMetadataVerificationError}; use super::{ClientMetadata, ClientMetadataVerificationError};
use crate::requests::GrantType; use crate::{requests::GrantType, response_type::ResponseType};
fn valid_client_metadata() -> ClientMetadata { fn valid_client_metadata() -> ClientMetadata {
ClientMetadata { ClientMetadata {
@ -934,173 +928,192 @@ mod tests {
// grant_type = authorization_code // grant_type = authorization_code
// code - Ok // code - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// code id_token - Err // code id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::CodeIdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code id_token token - Err // code id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code token - Err // code token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeToken; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// id_token - Err // id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdToken; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::IdToken.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// id_token token - Err // id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::IdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// token - Err // token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::IdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// none - Ok // none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// grant_type = implicit // grant_type = implicit
metadata.grant_types = Some(vec![GrantType::Implicit]); metadata.grant_types = Some(vec![GrantType::Implicit]);
// code - Err // code - Err
let response_type = OAuthAuthorizationEndpointResponseType::Code; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Code.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code id_token - Err // code id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::CodeIdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code id_token token - Err // code id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code token - Err // code token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeToken; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// id_token - Ok // id_token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]); metadata.response_types =
Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// id_token token - Ok // id_token token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]); metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::IdTokenToken.into()
]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// token - Ok // token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// none - Ok // none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// grant_types = [authorization_code, implicit] // grant_types = [authorization_code, implicit]
metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]); metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]);
// code - Ok // code - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// code id_token - Ok // code id_token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeIdToken]); metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdToken.into()
]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// code id_token token - Ok // code id_token token - Ok
metadata.response_types = Some(vec![ metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken, OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(),
]); ]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// code token - Ok // code token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeToken]); metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeToken.into()
]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// id_token - Ok // id_token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]); metadata.response_types =
Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// id_token token - Ok // id_token token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]); metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::IdTokenToken.into()
]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// token - Ok // token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// none - Ok // none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.clone().validate().unwrap(); metadata.clone().validate().unwrap();
// other grant_types // other grant_types
metadata.grant_types = Some(vec![GrantType::RefreshToken, GrantType::ClientCredentials]); metadata.grant_types = Some(vec![GrantType::RefreshToken, GrantType::ClientCredentials]);
// code - Err // code - Err
let response_type = OAuthAuthorizationEndpointResponseType::Code; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Code.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code id_token - Err // code id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::CodeIdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code id_token token - Err // code id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// code token - Err // code token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeToken; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// id_token - Err // id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdToken; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::IdToken.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// id_token token - Err // id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken; let response_type: ResponseType =
metadata.response_types = Some(vec![response_type]); OAuthAuthorizationEndpointResponseType::IdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// token - Err // token - Err
let response_type = OAuthAuthorizationEndpointResponseType::Token; let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Token.into();
metadata.response_types = Some(vec![response_type]); metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res); let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type); assert_eq!(res, response_type);
// none - Ok // none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]); metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.validate().unwrap(); metadata.validate().unwrap();
} }
@ -1206,7 +1219,9 @@ mod tests {
metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]); metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]);
// Err - code id_token // Err - code id_token
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeIdToken]); metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdToken.into()
]);
assert_matches!( assert_matches!(
metadata.clone().validate(), metadata.clone().validate(),
Err(ClientMetadataVerificationError::IdTokenSigningAlgNone) Err(ClientMetadataVerificationError::IdTokenSigningAlgNone)
@ -1214,7 +1229,7 @@ mod tests {
// Err - code id_token token // Err - code id_token token
metadata.response_types = Some(vec![ metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken, OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(),
]); ]);
assert_matches!( assert_matches!(
metadata.clone().validate(), metadata.clone().validate(),
@ -1222,14 +1237,17 @@ mod tests {
); );
// Err - id_token // Err - id_token
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]); metadata.response_types =
Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]);
assert_matches!( assert_matches!(
metadata.clone().validate(), metadata.clone().validate(),
Err(ClientMetadataVerificationError::IdTokenSigningAlgNone) Err(ClientMetadataVerificationError::IdTokenSigningAlgNone)
); );
// Err - id_token token // Err - id_token token
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]); metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::IdTokenToken.into()
]);
assert_matches!( assert_matches!(
metadata.clone().validate(), metadata.clone().validate(),
Err(ClientMetadataVerificationError::IdTokenSigningAlgNone) Err(ClientMetadataVerificationError::IdTokenSigningAlgNone)
@ -1237,10 +1255,10 @@ mod tests {
// Ok - Other response types // Ok - Other response types
metadata.response_types = Some(vec![ metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::Code, OAuthAuthorizationEndpointResponseType::Code.into(),
OAuthAuthorizationEndpointResponseType::CodeToken, OAuthAuthorizationEndpointResponseType::CodeToken.into(),
OAuthAuthorizationEndpointResponseType::Token, OAuthAuthorizationEndpointResponseType::Token.into(),
OAuthAuthorizationEndpointResponseType::None, OAuthAuthorizationEndpointResponseType::None.into(),
]); ]);
metadata.validate().unwrap(); metadata.validate().unwrap();
} }

View File

@ -16,9 +16,7 @@ use std::{collections::HashSet, hash::Hash, num::NonZeroU32};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag; use language_tags::LanguageTag;
use mas_iana::oauth::{ use mas_iana::oauth::{OAuthAccessTokenType, OAuthTokenTypeHint};
OAuthAccessTokenType, OAuthAuthorizationEndpointResponseType, OAuthTokenTypeHint,
};
use parse_display::{Display, FromStr}; use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{ use serde_with::{
@ -27,7 +25,7 @@ use serde_with::{
}; };
use url::Url; use url::Url;
use crate::scope::Scope; use crate::{response_type::ResponseType, scope::Scope};
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml // ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
@ -170,7 +168,7 @@ pub enum Prompt {
pub struct AuthorizationRequest { pub struct AuthorizationRequest {
/// OAuth 2.0 Response Type value that determines the authorization /// OAuth 2.0 Response Type value that determines the authorization
/// processing flow to be used. /// processing flow to be used.
pub response_type: OAuthAuthorizationEndpointResponseType, pub response_type: ResponseType,
/// OAuth 2.0 Client Identifier valid at the Authorization Server. /// OAuth 2.0 Client Identifier valid at the Authorization Server.
pub client_id: String, pub client_id: String,
@ -264,11 +262,7 @@ pub struct AuthorizationRequest {
impl AuthorizationRequest { impl AuthorizationRequest {
/// Creates a basic `AuthorizationRequest`. /// Creates a basic `AuthorizationRequest`.
#[must_use] #[must_use]
pub fn new( pub fn new(response_type: ResponseType, client_id: String, scope: Scope) -> Self {
response_type: OAuthAuthorizationEndpointResponseType,
client_id: String,
scope: Scope,
) -> Self {
Self { Self {
response_type, response_type,
client_id, client_id,

View File

@ -0,0 +1,489 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::module_name_repetitions)]
use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
use itertools::Itertools;
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use parse_display::{Display, FromStr};
use serde_with::{DeserializeFromStr, SerializeDisplay};
use thiserror::Error;
/// An error encountered when trying to parse an invalid [`ResponseType`].
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[error("invalid response type")]
pub struct InvalidResponseType;
/// The accepted tokens in a [`ResponseType`].
///
/// `none` is not in this enum because it is represented by an empty
/// [`ResponseType`].
///
/// This type also accepts unknown tokens that can be constructed via it's
/// `FromStr` implementation or used via its `Display` implementation.
#[derive(
Debug,
Clone,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Display,
FromStr,
SerializeDisplay,
DeserializeFromStr,
)]
#[display(style = "snake_case")]
#[non_exhaustive]
pub enum ResponseTypeToken {
/// `code`
Code,
/// `id_token`
IdToken,
/// `token`
Token,
/// Unknown token.
#[display("{0}")]
Unknown(String),
}
/// An [OAuth 2.0 `response_type` value] that the client can use
/// at the [authorization endpoint].
///
/// It is recommended to construct this type from an
/// [`OAuthAuthorizationEndpointResponseType`].
///
/// [OAuth 2.0 `response_type` value]: https://www.rfc-editor.org/rfc/rfc7591#page-9
/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr)]
pub struct ResponseType(BTreeSet<ResponseTypeToken>);
impl std::ops::Deref for ResponseType {
type Target = BTreeSet<ResponseTypeToken>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ResponseType {
/// Whether this response type requests a code.
#[must_use]
pub fn has_code(&self) -> bool {
self.0.contains(&ResponseTypeToken::Code)
}
/// Whether this response type requests an ID token.
#[must_use]
pub fn has_id_token(&self) -> bool {
self.0.contains(&ResponseTypeToken::IdToken)
}
/// Whether this response type requests a token.
#[must_use]
pub fn has_token(&self) -> bool {
self.0.contains(&ResponseTypeToken::Token)
}
}
impl FromStr for ResponseType {
type Err = InvalidResponseType;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
if s.is_empty() {
Err(InvalidResponseType)
} else if s == "none" {
Ok(Self(BTreeSet::new()))
} else {
s.split_ascii_whitespace()
.map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
.collect::<Result<_, _>>()
}
}
}
impl fmt::Display for ResponseType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let res = Itertools::intersperse(self.iter().map(ToString::to_string), ' '.to_string())
.collect::<String>();
if res.is_empty() {
write!(f, "none")
} else {
f.write_str(&res)
}
}
}
impl FromIterator<ResponseTypeToken> for ResponseType {
fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
Self(BTreeSet::from_iter(iter))
}
}
impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
match response_type {
OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
OAuthAuthorizationEndpointResponseType::CodeIdToken => {
Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
}
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
[
ResponseTypeToken::Code,
ResponseTypeToken::IdToken,
ResponseTypeToken::Token,
]
.into(),
),
OAuthAuthorizationEndpointResponseType::CodeToken => {
Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
}
OAuthAuthorizationEndpointResponseType::IdToken => {
Self([ResponseTypeToken::IdToken].into())
}
OAuthAuthorizationEndpointResponseType::IdTokenToken => {
Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
}
OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
OAuthAuthorizationEndpointResponseType::Token => {
Self([ResponseTypeToken::Token].into())
}
}
}
}
impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
type Error = InvalidResponseType;
fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
if response_type
.iter()
.any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
{
return Err(InvalidResponseType);
}
let tokens = response_type.iter().collect::<Vec<_>>();
let res = match *tokens {
[ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
[ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
[ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
[ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
OAuthAuthorizationEndpointResponseType::CodeIdToken
}
[ResponseTypeToken::Code, ResponseTypeToken::Token] => {
OAuthAuthorizationEndpointResponseType::CodeToken
}
[ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
OAuthAuthorizationEndpointResponseType::IdTokenToken
}
[ResponseTypeToken::Code, ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
}
_ => OAuthAuthorizationEndpointResponseType::None,
};
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_response_type_token() {
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
ResponseTypeToken::Code
);
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
ResponseTypeToken::IdToken
);
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
ResponseTypeToken::Token
);
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
ResponseTypeToken::Unknown("something_unsupported".to_owned())
);
}
#[test]
fn serialize_response_type_token() {
assert_eq!(
serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
"\"code\""
);
assert_eq!(
serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
"\"id_token\""
);
assert_eq!(
serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
"\"token\""
);
assert_eq!(
serde_json::to_string(&ResponseTypeToken::Unknown(
"something_unsupported".to_owned()
))
.unwrap(),
"\"something_unsupported\""
);
}
#[test]
#[allow(clippy::too_many_lines)]
fn deserialize_response_type() {
serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::None
);
let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::Code
);
let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::Code
);
let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::IdToken
);
let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::Token
);
let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(
iter.next(),
Some(&ResponseTypeToken::Unknown(
"something_unsupported".to_owned()
))
);
assert_eq!(iter.next(), None);
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdToken
);
let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeToken
);
let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::IdTokenToken
);
let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
);
let res_type =
serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
.unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(
iter.next(),
Some(&ResponseTypeToken::Unknown(
"something_unsupported".to_owned()
))
);
assert_eq!(iter.next(), None);
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
// Order doesn't matter
let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
);
let res_type =
serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
);
}
#[test]
fn serialize_response_type() {
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::None
))
.unwrap(),
"\"none\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::Code
))
.unwrap(),
"\"code\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::IdToken
))
.unwrap(),
"\"id_token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::CodeIdToken
))
.unwrap(),
"\"code id_token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::CodeToken
))
.unwrap(),
"\"code token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::IdTokenToken
))
.unwrap(),
"\"id_token token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
))
.unwrap(),
"\"code id_token token\""
);
assert_eq!(
serde_json::to_string(
&[
ResponseTypeToken::Unknown("something_unsupported".to_owned()),
ResponseTypeToken::Code
]
.into_iter()
.collect::<ResponseType>()
)
.unwrap(),
"\"code something_unsupported\""
);
// Order doesn't matter.
let res = [
ResponseTypeToken::IdToken,
ResponseTypeToken::Token,
ResponseTypeToken::Code,
]
.into_iter()
.collect::<ResponseType>();
assert_eq!(
serde_json::to_string(&res).unwrap(),
"\"code id_token token\""
);
let res = [
ResponseTypeToken::Code,
ResponseTypeToken::Token,
ResponseTypeToken::IdToken,
]
.into_iter()
.collect::<ResponseType>();
assert_eq!(
serde_json::to_string(&res).unwrap(),
"\"code id_token token\""
);
}
}

View File

@ -20,7 +20,7 @@ use mas_iana::{
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod}, oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
}; };
use mas_jose::jwk::PublicJsonWebKeySet; use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::requests::GrantType; use oauth2_types::{requests::GrantType, response_type::ResponseType};
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use thiserror::Error; use thiserror::Error;
use url::Url; use url::Url;
@ -322,7 +322,7 @@ pub async fn insert_client(
client_id: &str, client_id: &str,
redirect_uris: &[Url], redirect_uris: &[Url],
encrypted_client_secret: Option<&str>, encrypted_client_secret: Option<&str>,
response_types: &[OAuthAuthorizationEndpointResponseType], response_types: &[ResponseType],
grant_types: &[GrantType], grant_types: &[GrantType],
contacts: &[String], contacts: &[String],
client_name: Option<&str>, client_name: Option<&str>,