From f71f68c926c05b10f4ae7d9c3ccce3aaa2035b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Mon, 28 Nov 2022 11:54:20 +0100 Subject: [PATCH] Add OneOrMany contains claim validator --- crates/jose/src/claims.rs | 66 +++++++++++++++++-- crates/oidc-client/src/error.rs | 4 -- crates/oidc-client/src/requests/jose.rs | 5 +- crates/oidc-client/tests/it/requests/jose.rs | 5 +- .../tests/it/types/client_credentials.rs | 5 +- 5 files changed, 66 insertions(+), 19 deletions(-) diff --git a/crates/jose/src/claims.rs b/crates/jose/src/claims.rs index 04b364ad..d16f923f 100644 --- a/crates/jose/src/claims.rs +++ b/crates/jose/src/claims.rs @@ -325,6 +325,38 @@ impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> { } } +#[derive(Debug, Clone)] +pub struct Contains<'a, T> { + value: &'a T, +} + +impl<'a, T> Contains<'a, T> { + /// Creates a new `Contains` validator for the given value. + #[must_use] + pub fn new(value: &'a T) -> Self { + Self { value } + } +} + +impl<'a, T> Validator> for Contains<'a, T> +where + T: PartialEq, +{ + fn validate(&self, value: &OneOrMany) -> Result<(), anyhow::Error> { + if value.contains(self.value) { + Ok(()) + } else { + Err(anyhow::anyhow!("OneOrMany doesn't contain value")) + } + } +} + +impl<'a, T> From<&'a T> for Contains<'a, T> { + fn from(value: &'a T) -> Self { + Self::new(value) + } +} + #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] #[serde(transparent)] pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime); @@ -381,11 +413,11 @@ impl From for OneOrMany { /// Claims defined in RFC7519 sec. 4.1 /// mod rfc7519 { - use super::{Claim, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp}; + use super::{Claim, Contains, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp}; pub const ISS: Claim> = Claim::new("iss"); pub const SUB: Claim = Claim::new("sub"); - pub const AUD: Claim> = Claim::new("aud"); + pub const AUD: Claim, Contains> = Claim::new("aud"); pub const NBF: Claim = Claim::new("nbf"); pub const EXP: Claim = Claim::new("exp"); pub const IAT: Claim = Claim::new("iat"); @@ -502,7 +534,9 @@ mod tests { .extract_required_with_options(&mut claims, "https://foo.com") .unwrap(); let sub = SUB.extract_optional(&mut claims).unwrap(); - let aud = AUD.extract_optional(&mut claims).unwrap(); + let aud = AUD + .extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned()) + .unwrap(); let nbf = NBF .extract_optional_with_options(&mut claims, &time_options) .unwrap(); @@ -659,7 +693,7 @@ mod tests { Err(ClaimError::InvalidClaim("sub")) )); assert!(matches!( - AUD.extract_required(&mut claims), + AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()), Err(ClaimError::InvalidClaim("aud")) )); assert!(matches!( @@ -694,7 +728,7 @@ mod tests { Err(ClaimError::MissingClaim("sub")) )); assert!(matches!( - AUD.extract_required(&mut claims), + AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()), Err(ClaimError::MissingClaim("aud")) )); @@ -703,7 +737,10 @@ mod tests { Ok(None) )); assert!(matches!(SUB.extract_optional(&mut claims), Ok(None))); - assert!(matches!(AUD.extract_optional(&mut claims), Ok(None))); + assert!(matches!( + AUD.extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned()), + Ok(None) + )); } #[test] @@ -722,4 +759,21 @@ mod tests { Err(ClaimError::ValidationError { claim: "iss", .. }), )); } + + #[test] + fn contains_validation() { + let claims = serde_json::json!({ + "aud": "abcd-efgh", + }); + let mut claims: HashMap = + serde_json::from_value(claims).unwrap(); + + AUD.extract_required_with_options(&mut claims.clone(), &"abcd-efgh".to_owned()) + .unwrap(); + + assert!(matches!( + AUD.extract_required_with_options(&mut claims, &"wxyz".to_owned()), + Err(ClaimError::ValidationError { claim: "aud", .. }), + )); + } } diff --git a/crates/oidc-client/src/error.rs b/crates/oidc-client/src/error.rs index dae64eec..4787b9c1 100644 --- a/crates/oidc-client/src/error.rs +++ b/crates/oidc-client/src/error.rs @@ -602,10 +602,6 @@ pub enum JwtVerificationError { #[error(transparent)] Claim(#[from] ClaimError), - /// The audience of the JWT is not this client. - #[error("wrong aud claim")] - WrongAudience, - /// The algorithm used for signing the JWT is not the one that was /// requested. #[error("wrong signature alg")] diff --git a/crates/oidc-client/src/requests/jose.rs b/crates/oidc-client/src/requests/jose.rs index 0c050b06..081f5b23 100644 --- a/crates/oidc-client/src/requests/jose.rs +++ b/crates/oidc-client/src/requests/jose.rs @@ -130,10 +130,7 @@ pub fn verify_signed_jwt<'a>( claims::ISS.extract_required_with_options(&mut claims, issuer.as_str())?; // Must have the proper audience. - let aud = claims::AUD.extract_required(&mut claims)?; - if !aud.contains(client_id) { - return Err(JwtVerificationError::WrongAudience); - } + claims::AUD.extract_required_with_options(&mut claims, client_id)?; // Must use the proper algorithm. if header.alg() != signing_algorithm { diff --git a/crates/oidc-client/tests/it/requests/jose.rs b/crates/oidc-client/tests/it/requests/jose.rs index fdcb1847..f3950b83 100644 --- a/crates/oidc-client/tests/it/requests/jose.rs +++ b/crates/oidc-client/tests/it/requests/jose.rs @@ -154,7 +154,10 @@ async fn fail_verify_id_token_wrong_audience() { assert_matches!( error, - IdTokenError::Jwt(JwtVerificationError::WrongAudience) + IdTokenError::Jwt(JwtVerificationError::Claim(ClaimError::ValidationError { + claim: "aud", + .. + })) ); } diff --git a/crates/oidc-client/tests/it/types/client_credentials.rs b/crates/oidc-client/tests/it/types/client_credentials.rs index 4b4acea3..55c2cc5c 100644 --- a/crates/oidc-client/tests/it/types/client_credentials.rs +++ b/crates/oidc-client/tests/it/types/client_credentials.rs @@ -474,10 +474,7 @@ fn verify_client_jwt( return Err("Wrong sub".into()); } - let aud = claims::AUD.extract_required(claims)?; - if !aud.contains(token_endpoint) { - return Err("Wrong aud".into()); - } + claims::AUD.extract_required_with_options(claims, token_endpoint)?; claims::EXP.extract_required_with_options(claims, TimeOptions::new(now()))?;