1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-09 10:01:45 +03:00

Use ResponseType that doesn't care about tokens order

This commit is contained in:
Kévin Commaille
2022-09-09 14:52:59 +02:00
committed by Quentin Gliech
parent f5715018a6
commit fca6cfa393
10 changed files with 618 additions and 145 deletions

View File

@ -17,47 +17,17 @@
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
pub trait ResponseTypeExt {
fn has_code(&self) -> bool;
fn has_token(&self) -> bool;
fn has_id_token(&self) -> bool;
}
impl ResponseTypeExt for OAuthAuthorizationEndpointResponseType {
fn has_code(&self) -> bool {
matches!(
self,
Self::Code | Self::CodeToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
fn has_token(&self) -> bool {
matches!(
self,
Self::Token | Self::CodeToken | Self::IdTokenToken | Self::CodeIdTokenToken
)
}
fn has_id_token(&self) -> bool {
matches!(
self,
Self::IdToken | Self::IdTokenToken | Self::CodeIdToken | Self::CodeIdTokenToken
)
}
}
pub mod errors;
pub mod oidc;
pub mod pkce;
pub mod registration;
pub mod requests;
pub mod response_type;
pub mod scope;
pub mod webfinger;
pub mod prelude {
pub use crate::{pkce::CodeChallengeMethodExt, ResponseTypeExt};
pub use crate::pkce::CodeChallengeMethodExt;
}
#[cfg(test)]

View File

@ -17,10 +17,7 @@ use std::ops::Deref;
use language_tags::LanguageTag;
use mas_iana::{
jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg},
oauth::{
OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod,
PkceCodeChallengeMethod,
},
oauth::{OAuthClientAuthenticationMethod, PkceCodeChallengeMethod},
};
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
@ -28,7 +25,10 @@ use serde_with::{skip_serializing_none, DeserializeFromStr, SerializeDisplay};
use thiserror::Error;
use url::Url;
use crate::requests::{Display, GrantType, Prompt, ResponseMode};
use crate::{
requests::{Display, GrantType, Prompt, ResponseMode},
response_type::ResponseType,
};
#[derive(
SerializeDisplay, DeserializeFromStr, Clone, Copy, PartialEq, Eq, Hash, Debug, Display, FromStr,
@ -128,7 +128,7 @@ pub struct ProviderMetadata {
/// This field is required.
///
/// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9
pub response_types_supported: Option<Vec<OAuthAuthorizationEndpointResponseType>>,
pub response_types_supported: Option<Vec<ResponseType>>,
/// JSON array containing a list of the [OAuth 2.0 `response_mode` values]
/// that this authorization server supports.
@ -707,7 +707,7 @@ impl VerifiedProviderMetadata {
/// JSON array containing a list of the OAuth 2.0 `response_type` values
/// that this authorization server supports.
#[must_use]
pub fn response_types_supported(&self) -> &[OAuthAuthorizationEndpointResponseType] {
pub fn response_types_supported(&self) -> &[ResponseType] {
match &self.response_types_supported {
Some(u) => u,
None => unreachable!(),
@ -934,7 +934,9 @@ mod tests {
authorization_endpoint: Some(Url::parse("https://localhost/auth").unwrap()),
token_endpoint: Some(Url::parse("https://localhost/token").unwrap()),
jwks_uri: Some(Url::parse("https://localhost/jwks").unwrap()),
response_types_supported: Some(vec![OAuthAuthorizationEndpointResponseType::Code]),
response_types_supported: Some(vec![
OAuthAuthorizationEndpointResponseType::Code.into()
]),
subject_types_supported: Some(vec![SubjectType::Public]),
id_token_signing_alg_values_supported: Some(vec![JsonWebSignatureAlg::Rs256]),
..Default::default()
@ -1158,7 +1160,7 @@ mod tests {
// Ok - Present
metadata.response_types_supported =
Some(vec![OAuthAuthorizationEndpointResponseType::Code]);
Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]);
metadata.validate(&issuer).unwrap();
}

View File

@ -18,7 +18,7 @@ use chrono::Duration;
use language_tags::LanguageTag;
use mas_iana::{
jose::{JsonWebEncryptionAlg, JsonWebEncryptionEnc, JsonWebSignatureAlg},
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
oauth::OAuthClientAuthenticationMethod,
};
use mas_jose::jwk::PublicJsonWebKeySet;
use serde::{
@ -34,6 +34,7 @@ use super::{ClientMetadata, Localized, VerifiedClientMetadata};
use crate::{
oidc::{ApplicationType, SubjectType},
requests::GrantType,
response_type::ResponseType,
};
impl<T> Localized<T> {
@ -94,7 +95,7 @@ impl<T> Localized<T> {
#[derive(Serialize, Deserialize)]
pub struct ClientMetadataSerdeHelper {
redirect_uris: Option<Vec<Url>>,
response_types: Option<Vec<OAuthAuthorizationEndpointResponseType>>,
response_types: Option<Vec<ResponseType>>,
grant_types: Option<Vec<GrantType>>,
application_type: Option<ApplicationType>,
contacts: Option<Vec<String>>,

View File

@ -29,13 +29,14 @@ use url::Url;
use crate::{
oidc::{ApplicationType, SubjectType},
requests::GrantType,
response_type::ResponseType,
};
mod client_metadata_serde;
use client_metadata_serde::ClientMetadataSerdeHelper;
pub const DEFAULT_RESPONSE_TYPES: &[OAuthAuthorizationEndpointResponseType] =
&[OAuthAuthorizationEndpointResponseType::Code];
pub const DEFAULT_RESPONSE_TYPES: [OAuthAuthorizationEndpointResponseType; 1] =
[OAuthAuthorizationEndpointResponseType::Code];
pub const DEFAULT_GRANT_TYPES: &[GrantType] = &[GrantType::AuthorizationCode];
@ -134,7 +135,7 @@ pub struct ClientMetadata {
///
/// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9
/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
pub response_types: Option<Vec<OAuthAuthorizationEndpointResponseType>>,
pub response_types: Option<Vec<ResponseType>>,
/// Array of [OAuth 2.0 `grant_type` values] that the client can use at the
/// [token endpoint].
@ -431,21 +432,18 @@ impl ClientMetadata {
let has_authorization_code = grant_types.contains(&GrantType::AuthorizationCode);
let has_both = has_implicit && has_authorization_code;
for response_type in response_types {
let is_ok = match response_type {
OAuthAuthorizationEndpointResponseType::Code => has_authorization_code,
OAuthAuthorizationEndpointResponseType::CodeIdToken
| OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
| OAuthAuthorizationEndpointResponseType::CodeToken => has_both,
OAuthAuthorizationEndpointResponseType::IdToken
| OAuthAuthorizationEndpointResponseType::IdTokenToken
| OAuthAuthorizationEndpointResponseType::Token => has_implicit,
OAuthAuthorizationEndpointResponseType::None => true,
};
for response_type in &response_types {
let has_code = response_type.has_code();
let has_id_token = response_type.has_id_token();
let has_token = response_type.has_token();
let is_ok = has_code && has_both
|| !has_code && has_implicit
|| has_authorization_code && !has_id_token && !has_token
|| !has_code && !has_id_token && !has_token;
if !is_ok {
return Err(ClientMetadataVerificationError::IncoherentResponseType(
*response_type,
response_type.clone(),
));
}
}
@ -489,11 +487,7 @@ impl ClientMetadata {
}
if self.id_token_signed_response_alg() == JsonWebSignatureAlg::None
&& (response_types.contains(&OAuthAuthorizationEndpointResponseType::CodeIdToken)
|| response_types
.contains(&OAuthAuthorizationEndpointResponseType::CodeIdTokenToken)
|| response_types.contains(&OAuthAuthorizationEndpointResponseType::IdToken)
|| response_types.contains(&OAuthAuthorizationEndpointResponseType::IdTokenToken))
&& response_types.iter().any(ResponseType::has_id_token)
{
return Err(ClientMetadataVerificationError::IdTokenSigningAlgNone);
}
@ -547,10 +541,10 @@ impl ClientMetadata {
/// [OAuth 2.0 `response_type` values]: https://www.rfc-editor.org/rfc/rfc7591#page-9
/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
#[must_use]
pub fn response_types(&self) -> &[OAuthAuthorizationEndpointResponseType] {
pub fn response_types(&self) -> Vec<ResponseType> {
self.response_types
.as_deref()
.unwrap_or(DEFAULT_RESPONSE_TYPES)
.clone()
.unwrap_or_else(|| DEFAULT_RESPONSE_TYPES.map(ResponseType::from).into())
}
/// Array of [OAuth 2.0 `grant_type` values] that the client can use at the
@ -801,7 +795,7 @@ pub enum ClientMetadataVerificationError {
/// The given response type is not compatible with the grant types.
#[error("'{0}' response type not compatible with grant types")]
IncoherentResponseType(OAuthAuthorizationEndpointResponseType),
IncoherentResponseType(ResponseType),
/// Both the `jwks_uri` and `jwks` fields are present but only one is
/// allowed.
@ -865,7 +859,7 @@ mod tests {
use url::Url;
use super::{ClientMetadata, ClientMetadataVerificationError};
use crate::requests::GrantType;
use crate::{requests::GrantType, response_type::ResponseType};
fn valid_client_metadata() -> ClientMetadata {
ClientMetadata {
@ -934,173 +928,192 @@ mod tests {
// grant_type = authorization_code
// code - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]);
metadata.clone().validate().unwrap();
// code id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::CodeIdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::IdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::IdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::IdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.clone().validate().unwrap();
// grant_type = implicit
metadata.grant_types = Some(vec![GrantType::Implicit]);
// code - Err
let response_type = OAuthAuthorizationEndpointResponseType::Code;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Code.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::CodeIdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// id_token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]);
metadata.response_types =
Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]);
metadata.clone().validate().unwrap();
// id_token token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]);
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::IdTokenToken.into()
]);
metadata.clone().validate().unwrap();
// token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token.into()]);
metadata.clone().validate().unwrap();
// none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.clone().validate().unwrap();
// grant_types = [authorization_code, implicit]
metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]);
// code - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]);
metadata.clone().validate().unwrap();
// code id_token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeIdToken]);
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdToken.into()
]);
metadata.clone().validate().unwrap();
// code id_token token - Ok
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken,
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(),
]);
metadata.clone().validate().unwrap();
// code token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeToken]);
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeToken.into()
]);
metadata.clone().validate().unwrap();
// id_token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]);
metadata.response_types =
Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]);
metadata.clone().validate().unwrap();
// id_token token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]);
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::IdTokenToken.into()
]);
metadata.clone().validate().unwrap();
// token - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::Token.into()]);
metadata.clone().validate().unwrap();
// none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.clone().validate().unwrap();
// other grant_types
metadata.grant_types = Some(vec![GrantType::RefreshToken, GrantType::ClientCredentials]);
// code - Err
let response_type = OAuthAuthorizationEndpointResponseType::Code;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Code.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::CodeIdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeIdTokenToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// code token - Err
let response_type = OAuthAuthorizationEndpointResponseType::CodeToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::CodeToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// id_token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::IdToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// id_token token - Err
let response_type = OAuthAuthorizationEndpointResponseType::IdTokenToken;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType =
OAuthAuthorizationEndpointResponseType::IdTokenToken.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// token - Err
let response_type = OAuthAuthorizationEndpointResponseType::Token;
metadata.response_types = Some(vec![response_type]);
let response_type: ResponseType = OAuthAuthorizationEndpointResponseType::Token.into();
metadata.response_types = Some(vec![response_type.clone()]);
let res = assert_matches!(metadata.clone().validate(), Err(ClientMetadataVerificationError::IncoherentResponseType(res)) => res);
assert_eq!(res, response_type);
// none - Ok
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None]);
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::None.into()]);
metadata.validate().unwrap();
}
@ -1206,7 +1219,9 @@ mod tests {
metadata.grant_types = Some(vec![GrantType::AuthorizationCode, GrantType::Implicit]);
// Err - code id_token
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::CodeIdToken]);
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdToken.into()
]);
assert_matches!(
metadata.clone().validate(),
Err(ClientMetadataVerificationError::IdTokenSigningAlgNone)
@ -1214,7 +1229,7 @@ mod tests {
// Err - code id_token token
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken,
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.into(),
]);
assert_matches!(
metadata.clone().validate(),
@ -1222,14 +1237,17 @@ mod tests {
);
// Err - id_token
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdToken]);
metadata.response_types =
Some(vec![OAuthAuthorizationEndpointResponseType::IdToken.into()]);
assert_matches!(
metadata.clone().validate(),
Err(ClientMetadataVerificationError::IdTokenSigningAlgNone)
);
// Err - id_token token
metadata.response_types = Some(vec![OAuthAuthorizationEndpointResponseType::IdTokenToken]);
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::IdTokenToken.into()
]);
assert_matches!(
metadata.clone().validate(),
Err(ClientMetadataVerificationError::IdTokenSigningAlgNone)
@ -1237,10 +1255,10 @@ mod tests {
// Ok - Other response types
metadata.response_types = Some(vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::CodeToken,
OAuthAuthorizationEndpointResponseType::Token,
OAuthAuthorizationEndpointResponseType::None,
OAuthAuthorizationEndpointResponseType::Code.into(),
OAuthAuthorizationEndpointResponseType::CodeToken.into(),
OAuthAuthorizationEndpointResponseType::Token.into(),
OAuthAuthorizationEndpointResponseType::None.into(),
]);
metadata.validate().unwrap();
}

View File

@ -16,9 +16,7 @@ use std::{collections::HashSet, hash::Hash, num::NonZeroU32};
use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag;
use mas_iana::oauth::{
OAuthAccessTokenType, OAuthAuthorizationEndpointResponseType, OAuthTokenTypeHint,
};
use mas_iana::oauth::{OAuthAccessTokenType, OAuthTokenTypeHint};
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use serde_with::{
@ -27,7 +25,7 @@ use serde_with::{
};
use url::Url;
use crate::scope::Scope;
use crate::{response_type::ResponseType, scope::Scope};
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
@ -170,7 +168,7 @@ pub enum Prompt {
pub struct AuthorizationRequest {
/// OAuth 2.0 Response Type value that determines the authorization
/// processing flow to be used.
pub response_type: OAuthAuthorizationEndpointResponseType,
pub response_type: ResponseType,
/// OAuth 2.0 Client Identifier valid at the Authorization Server.
pub client_id: String,
@ -264,11 +262,7 @@ pub struct AuthorizationRequest {
impl AuthorizationRequest {
/// Creates a basic `AuthorizationRequest`.
#[must_use]
pub fn new(
response_type: OAuthAuthorizationEndpointResponseType,
client_id: String,
scope: Scope,
) -> Self {
pub fn new(response_type: ResponseType, client_id: String, scope: Scope) -> Self {
Self {
response_type,
client_id,

View File

@ -0,0 +1,489 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::module_name_repetitions)]
use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
use itertools::Itertools;
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use parse_display::{Display, FromStr};
use serde_with::{DeserializeFromStr, SerializeDisplay};
use thiserror::Error;
/// An error encountered when trying to parse an invalid [`ResponseType`].
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[error("invalid response type")]
pub struct InvalidResponseType;
/// The accepted tokens in a [`ResponseType`].
///
/// `none` is not in this enum because it is represented by an empty
/// [`ResponseType`].
///
/// This type also accepts unknown tokens that can be constructed via it's
/// `FromStr` implementation or used via its `Display` implementation.
#[derive(
Debug,
Clone,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Display,
FromStr,
SerializeDisplay,
DeserializeFromStr,
)]
#[display(style = "snake_case")]
#[non_exhaustive]
pub enum ResponseTypeToken {
/// `code`
Code,
/// `id_token`
IdToken,
/// `token`
Token,
/// Unknown token.
#[display("{0}")]
Unknown(String),
}
/// An [OAuth 2.0 `response_type` value] that the client can use
/// at the [authorization endpoint].
///
/// It is recommended to construct this type from an
/// [`OAuthAuthorizationEndpointResponseType`].
///
/// [OAuth 2.0 `response_type` value]: https://www.rfc-editor.org/rfc/rfc7591#page-9
/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr)]
pub struct ResponseType(BTreeSet<ResponseTypeToken>);
impl std::ops::Deref for ResponseType {
type Target = BTreeSet<ResponseTypeToken>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ResponseType {
/// Whether this response type requests a code.
#[must_use]
pub fn has_code(&self) -> bool {
self.0.contains(&ResponseTypeToken::Code)
}
/// Whether this response type requests an ID token.
#[must_use]
pub fn has_id_token(&self) -> bool {
self.0.contains(&ResponseTypeToken::IdToken)
}
/// Whether this response type requests a token.
#[must_use]
pub fn has_token(&self) -> bool {
self.0.contains(&ResponseTypeToken::Token)
}
}
impl FromStr for ResponseType {
type Err = InvalidResponseType;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
if s.is_empty() {
Err(InvalidResponseType)
} else if s == "none" {
Ok(Self(BTreeSet::new()))
} else {
s.split_ascii_whitespace()
.map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
.collect::<Result<_, _>>()
}
}
}
impl fmt::Display for ResponseType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let res = Itertools::intersperse(self.iter().map(ToString::to_string), ' '.to_string())
.collect::<String>();
if res.is_empty() {
write!(f, "none")
} else {
f.write_str(&res)
}
}
}
impl FromIterator<ResponseTypeToken> for ResponseType {
fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
Self(BTreeSet::from_iter(iter))
}
}
impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
match response_type {
OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
OAuthAuthorizationEndpointResponseType::CodeIdToken => {
Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
}
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
[
ResponseTypeToken::Code,
ResponseTypeToken::IdToken,
ResponseTypeToken::Token,
]
.into(),
),
OAuthAuthorizationEndpointResponseType::CodeToken => {
Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
}
OAuthAuthorizationEndpointResponseType::IdToken => {
Self([ResponseTypeToken::IdToken].into())
}
OAuthAuthorizationEndpointResponseType::IdTokenToken => {
Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
}
OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
OAuthAuthorizationEndpointResponseType::Token => {
Self([ResponseTypeToken::Token].into())
}
}
}
}
impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
type Error = InvalidResponseType;
fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
if response_type
.iter()
.any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
{
return Err(InvalidResponseType);
}
let tokens = response_type.iter().collect::<Vec<_>>();
let res = match *tokens {
[ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
[ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
[ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
[ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
OAuthAuthorizationEndpointResponseType::CodeIdToken
}
[ResponseTypeToken::Code, ResponseTypeToken::Token] => {
OAuthAuthorizationEndpointResponseType::CodeToken
}
[ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
OAuthAuthorizationEndpointResponseType::IdTokenToken
}
[ResponseTypeToken::Code, ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
}
_ => OAuthAuthorizationEndpointResponseType::None,
};
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_response_type_token() {
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
ResponseTypeToken::Code
);
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
ResponseTypeToken::IdToken
);
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
ResponseTypeToken::Token
);
assert_eq!(
serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
ResponseTypeToken::Unknown("something_unsupported".to_owned())
);
}
#[test]
fn serialize_response_type_token() {
assert_eq!(
serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
"\"code\""
);
assert_eq!(
serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
"\"id_token\""
);
assert_eq!(
serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
"\"token\""
);
assert_eq!(
serde_json::to_string(&ResponseTypeToken::Unknown(
"something_unsupported".to_owned()
))
.unwrap(),
"\"something_unsupported\""
);
}
#[test]
#[allow(clippy::too_many_lines)]
fn deserialize_response_type() {
serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::None
);
let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::Code
);
let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::Code
);
let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::IdToken
);
let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::Token
);
let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(
iter.next(),
Some(&ResponseTypeToken::Unknown(
"something_unsupported".to_owned()
))
);
assert_eq!(iter.next(), None);
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdToken
);
let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeToken
);
let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::IdTokenToken
);
let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
);
let res_type =
serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
.unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(
iter.next(),
Some(&ResponseTypeToken::Unknown(
"something_unsupported".to_owned()
))
);
assert_eq!(iter.next(), None);
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
// Order doesn't matter
let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
);
let res_type =
serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
let mut iter = res_type.iter();
assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
assert_eq!(iter.next(), None);
assert_eq!(
OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
);
}
#[test]
fn serialize_response_type() {
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::None
))
.unwrap(),
"\"none\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::Code
))
.unwrap(),
"\"code\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::IdToken
))
.unwrap(),
"\"id_token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::CodeIdToken
))
.unwrap(),
"\"code id_token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::CodeToken
))
.unwrap(),
"\"code token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::IdTokenToken
))
.unwrap(),
"\"id_token token\""
);
assert_eq!(
serde_json::to_string(&ResponseType::from(
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
))
.unwrap(),
"\"code id_token token\""
);
assert_eq!(
serde_json::to_string(
&[
ResponseTypeToken::Unknown("something_unsupported".to_owned()),
ResponseTypeToken::Code
]
.into_iter()
.collect::<ResponseType>()
)
.unwrap(),
"\"code something_unsupported\""
);
// Order doesn't matter.
let res = [
ResponseTypeToken::IdToken,
ResponseTypeToken::Token,
ResponseTypeToken::Code,
]
.into_iter()
.collect::<ResponseType>();
assert_eq!(
serde_json::to_string(&res).unwrap(),
"\"code id_token token\""
);
let res = [
ResponseTypeToken::Code,
ResponseTypeToken::Token,
ResponseTypeToken::IdToken,
]
.into_iter()
.collect::<ResponseType>();
assert_eq!(
serde_json::to_string(&res).unwrap(),
"\"code id_token token\""
);
}
}