diff --git a/crates/jose/src/claims.rs b/crates/jose/src/claims.rs index 4e7edff4..04b364ad 100644 --- a/crates/jose/src/claims.rs +++ b/crates/jose/src/claims.rs @@ -293,6 +293,38 @@ impl<'a> Validator for TokenHash<'a> { } } +#[derive(Debug, Clone)] +pub struct Equality<'a, T: ?Sized> { + value: &'a T, +} + +impl<'a, T: ?Sized> Equality<'a, T> { + /// Creates a new `Equality` validator for the given value. + #[must_use] + pub fn new(value: &'a T) -> Self { + Self { value } + } +} + +impl<'a, T1, T2: ?Sized> Validator for Equality<'a, T2> +where + T2: PartialEq, +{ + fn validate(&self, value: &T1) -> Result<(), anyhow::Error> { + if *self.value == *value { + Ok(()) + } else { + Err(anyhow::anyhow!("values don't match")) + } + } +} + +impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> { + fn from(value: &'a T) -> Self { + Self::new(value) + } +} + #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] #[serde(transparent)] pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime); @@ -349,9 +381,9 @@ impl From for OneOrMany { /// Claims defined in RFC7519 sec. 4.1 /// mod rfc7519 { - use super::{Claim, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp}; + use super::{Claim, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp}; - pub const ISS: Claim = Claim::new("iss"); + pub const ISS: Claim> = Claim::new("iss"); pub const SUB: Claim = Claim::new("sub"); pub const AUD: Claim> = Claim::new("aud"); pub const NBF: Claim = Claim::new("nbf"); @@ -366,10 +398,10 @@ mod rfc7519 { mod oidc_core { use url::Url; - use super::{Claim, Timestamp, TokenHash}; + use super::{Claim, Equality, Timestamp, TokenHash}; pub const AUTH_TIME: Claim = Claim::new("auth_time"); - pub const NONCE: Claim = Claim::new("nonce"); + pub const NONCE: Claim> = Claim::new("nonce"); pub const AT_HASH: Claim = Claim::new("at_hash"); pub const C_HASH: Claim = Claim::new("c_hash"); @@ -466,7 +498,9 @@ mod tests { }); let mut claims = serde_json::from_value(claims).unwrap(); - let iss = ISS.extract_required(&mut claims).unwrap(); + let iss = ISS + .extract_required_with_options(&mut claims, "https://foo.com") + .unwrap(); let sub = SUB.extract_optional(&mut claims).unwrap(); let aud = AUD.extract_optional(&mut claims).unwrap(); let nbf = NBF @@ -617,7 +651,7 @@ mod tests { let mut claims = serde_json::from_value(claims).unwrap(); assert!(matches!( - ISS.extract_required(&mut claims), + ISS.extract_required_with_options(&mut claims, "https://foo.com"), Err(ClaimError::InvalidClaim("iss")) )); assert!(matches!( @@ -652,7 +686,7 @@ mod tests { let mut claims = HashMap::new(); assert!(matches!( - ISS.extract_required(&mut claims), + ISS.extract_required_with_options(&mut claims, "https://foo.com"), Err(ClaimError::MissingClaim("iss")) )); assert!(matches!( @@ -664,8 +698,28 @@ mod tests { Err(ClaimError::MissingClaim("aud")) )); - assert!(matches!(ISS.extract_optional(&mut claims), Ok(None))); + assert!(matches!( + ISS.extract_optional_with_options(&mut claims, "https://foo.com"), + Ok(None) + )); assert!(matches!(SUB.extract_optional(&mut claims), Ok(None))); assert!(matches!(AUD.extract_optional(&mut claims), Ok(None))); } + + #[test] + fn string_eq_validation() { + let claims = serde_json::json!({ + "iss": "https://foo.com", + }); + let mut claims: HashMap = + serde_json::from_value(claims).unwrap(); + + ISS.extract_required_with_options(&mut claims.clone(), "https://foo.com") + .unwrap(); + + assert!(matches!( + ISS.extract_required_with_options(&mut claims, "https://bar.com"), + Err(ClaimError::ValidationError { claim: "iss", .. }), + )); + } } diff --git a/crates/oidc-client/src/error.rs b/crates/oidc-client/src/error.rs index 62dcffe0..dae64eec 100644 --- a/crates/oidc-client/src/error.rs +++ b/crates/oidc-client/src/error.rs @@ -335,10 +335,6 @@ where /// All possible errors when exchanging a code for an access token. #[derive(Debug, Error)] pub enum TokenAuthorizationCodeError { - /// The nonce doesn't match the one that was sent. - #[error("wrong nonce")] - WrongNonce, - /// An error occurred requesting the access token. #[error(transparent)] Token(#[from] TokenRequestError), @@ -606,10 +602,6 @@ pub enum JwtVerificationError { #[error(transparent)] Claim(#[from] ClaimError), - /// The issuer is not the one that sent the JWT. - #[error("wrong issuer claim")] - WrongIssuer, - /// The audience of the JWT is not this client. #[error("wrong aud claim")] WrongAudience, diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index 1fcde540..066a61b9 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -444,12 +444,9 @@ pub async fn access_token_with_authorization_code( .map_err(IdTokenError::from)?; // Nonce must match. - let token_nonce = claims::NONCE - .extract_required(&mut claims) + claims::NONCE + .extract_required_with_options(&mut claims, validation_data.nonce.as_str()) .map_err(IdTokenError::from)?; - if token_nonce != validation_data.nonce { - return Err(TokenAuthorizationCodeError::WrongNonce); - } Some(id_token.into_owned()) } else { diff --git a/crates/oidc-client/src/requests/jose.rs b/crates/oidc-client/src/requests/jose.rs index 2f671434..0c050b06 100644 --- a/crates/oidc-client/src/requests/jose.rs +++ b/crates/oidc-client/src/requests/jose.rs @@ -127,10 +127,7 @@ pub fn verify_signed_jwt<'a>( let (header, mut claims) = jwt.clone().into_parts(); // Must have the proper issuer. - let iss = claims::ISS.extract_required(&mut claims)?; - if iss != issuer.as_str() { - return Err(JwtVerificationError::WrongIssuer); - } + claims::ISS.extract_required_with_options(&mut claims, issuer.as_str())?; // Must have the proper audience. let aud = claims::AUD.extract_required(&mut claims)?; diff --git a/crates/oidc-client/tests/it/requests/authorization_code.rs b/crates/oidc-client/tests/it/requests/authorization_code.rs index 5ae8b5c8..caa1bc88 100644 --- a/crates/oidc-client/tests/it/requests/authorization_code.rs +++ b/crates/oidc-client/tests/it/requests/authorization_code.rs @@ -22,7 +22,7 @@ use chrono::Duration; use mas_iana::oauth::{ OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod, }; -use mas_jose::jwk::PublicJsonWebKeySet; +use mas_jose::{claims::ClaimError, jwk::PublicJsonWebKeySet}; use mas_oidc_client::{ error::{ AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError, @@ -358,7 +358,13 @@ async fn fail_access_token_with_authorization_code_wrong_nonce() { .await .unwrap_err(); - assert_matches!(error, TokenAuthorizationCodeError::WrongNonce); + assert_matches!( + error, + TokenAuthorizationCodeError::IdToken(IdTokenError::Claim(ClaimError::ValidationError { + claim: "nonce", + .. + })) + ); } #[tokio::test] diff --git a/crates/oidc-client/tests/it/requests/jose.rs b/crates/oidc-client/tests/it/requests/jose.rs index 833fb17a..fdcb1847 100644 --- a/crates/oidc-client/tests/it/requests/jose.rs +++ b/crates/oidc-client/tests/it/requests/jose.rs @@ -18,7 +18,7 @@ use assert_matches::assert_matches; use chrono::{DateTime, Duration, Utc}; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ - claims, + claims::{self, ClaimError}, constraints::Constrainable, jwk::PublicJsonWebKeySet, jwt::{JsonWebSignatureHeader, Jwt}, @@ -128,7 +128,13 @@ async fn fail_verify_id_token_wrong_issuer() { let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err(); - assert_matches!(error, IdTokenError::Jwt(JwtVerificationError::WrongIssuer)); + assert_matches!( + error, + IdTokenError::Jwt(JwtVerificationError::Claim(ClaimError::ValidationError { + claim: "iss", + .. + })) + ); } #[tokio::test] diff --git a/crates/oidc-client/tests/it/types/client_credentials.rs b/crates/oidc-client/tests/it/types/client_credentials.rs index c82db0be..4b4acea3 100644 --- a/crates/oidc-client/tests/it/types/client_credentials.rs +++ b/crates/oidc-client/tests/it/types/client_credentials.rs @@ -467,10 +467,7 @@ fn verify_client_jwt( claims: &mut HashMap, token_endpoint: &String, ) -> Result<(), BoxError> { - let iss = claims::ISS.extract_required(claims)?; - if iss != CLIENT_ID { - return Err("Wrong iss".into()); - } + claims::ISS.extract_required_with_options(claims, CLIENT_ID)?; let sub = claims::SUB.extract_required(claims)?; if sub != CLIENT_ID {