1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Add OneOrMany contains claim validator

This commit is contained in:
Kévin Commaille
2022-11-28 11:54:20 +01:00
committed by Quentin Gliech
parent a2a3b3954e
commit f71f68c926
5 changed files with 66 additions and 19 deletions

View File

@ -325,6 +325,38 @@ impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> {
} }
} }
#[derive(Debug, Clone)]
pub struct Contains<'a, T> {
value: &'a T,
}
impl<'a, T> Contains<'a, T> {
/// Creates a new `Contains` validator for the given value.
#[must_use]
pub fn new(value: &'a T) -> Self {
Self { value }
}
}
impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T>
where
T: PartialEq,
{
fn validate(&self, value: &OneOrMany<T>) -> Result<(), anyhow::Error> {
if value.contains(self.value) {
Ok(())
} else {
Err(anyhow::anyhow!("OneOrMany doesn't contain value"))
}
}
}
impl<'a, T> From<&'a T> for Contains<'a, T> {
fn from(value: &'a T) -> Self {
Self::new(value)
}
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
#[serde(transparent)] #[serde(transparent)]
pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>); pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>);
@ -381,11 +413,11 @@ impl<T> From<T> for OneOrMany<T> {
/// Claims defined in RFC7519 sec. 4.1 /// Claims defined in RFC7519 sec. 4.1
/// <https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1> /// <https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1>
mod rfc7519 { mod rfc7519 {
use super::{Claim, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp}; use super::{Claim, Contains, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp};
pub const ISS: Claim<String, Equality<str>> = Claim::new("iss"); pub const ISS: Claim<String, Equality<str>> = Claim::new("iss");
pub const SUB: Claim<String> = Claim::new("sub"); pub const SUB: Claim<String> = Claim::new("sub");
pub const AUD: Claim<OneOrMany<String>> = Claim::new("aud"); pub const AUD: Claim<OneOrMany<String>, Contains<String>> = Claim::new("aud");
pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf"); pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf");
pub const EXP: Claim<Timestamp, TimeNotAfter> = Claim::new("exp"); pub const EXP: Claim<Timestamp, TimeNotAfter> = Claim::new("exp");
pub const IAT: Claim<Timestamp, TimeNotBefore> = Claim::new("iat"); pub const IAT: Claim<Timestamp, TimeNotBefore> = Claim::new("iat");
@ -502,7 +534,9 @@ mod tests {
.extract_required_with_options(&mut claims, "https://foo.com") .extract_required_with_options(&mut claims, "https://foo.com")
.unwrap(); .unwrap();
let sub = SUB.extract_optional(&mut claims).unwrap(); let sub = SUB.extract_optional(&mut claims).unwrap();
let aud = AUD.extract_optional(&mut claims).unwrap(); let aud = AUD
.extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned())
.unwrap();
let nbf = NBF let nbf = NBF
.extract_optional_with_options(&mut claims, &time_options) .extract_optional_with_options(&mut claims, &time_options)
.unwrap(); .unwrap();
@ -659,7 +693,7 @@ mod tests {
Err(ClaimError::InvalidClaim("sub")) Err(ClaimError::InvalidClaim("sub"))
)); ));
assert!(matches!( assert!(matches!(
AUD.extract_required(&mut claims), AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
Err(ClaimError::InvalidClaim("aud")) Err(ClaimError::InvalidClaim("aud"))
)); ));
assert!(matches!( assert!(matches!(
@ -694,7 +728,7 @@ mod tests {
Err(ClaimError::MissingClaim("sub")) Err(ClaimError::MissingClaim("sub"))
)); ));
assert!(matches!( assert!(matches!(
AUD.extract_required(&mut claims), AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
Err(ClaimError::MissingClaim("aud")) Err(ClaimError::MissingClaim("aud"))
)); ));
@ -703,7 +737,10 @@ mod tests {
Ok(None) Ok(None)
)); ));
assert!(matches!(SUB.extract_optional(&mut claims), Ok(None))); assert!(matches!(SUB.extract_optional(&mut claims), Ok(None)));
assert!(matches!(AUD.extract_optional(&mut claims), Ok(None))); assert!(matches!(
AUD.extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned()),
Ok(None)
));
} }
#[test] #[test]
@ -722,4 +759,21 @@ mod tests {
Err(ClaimError::ValidationError { claim: "iss", .. }), Err(ClaimError::ValidationError { claim: "iss", .. }),
)); ));
} }
#[test]
fn contains_validation() {
let claims = serde_json::json!({
"aud": "abcd-efgh",
});
let mut claims: HashMap<String, serde_json::Value> =
serde_json::from_value(claims).unwrap();
AUD.extract_required_with_options(&mut claims.clone(), &"abcd-efgh".to_owned())
.unwrap();
assert!(matches!(
AUD.extract_required_with_options(&mut claims, &"wxyz".to_owned()),
Err(ClaimError::ValidationError { claim: "aud", .. }),
));
}
} }

View File

@ -602,10 +602,6 @@ pub enum JwtVerificationError {
#[error(transparent)] #[error(transparent)]
Claim(#[from] ClaimError), Claim(#[from] ClaimError),
/// The audience of the JWT is not this client.
#[error("wrong aud claim")]
WrongAudience,
/// The algorithm used for signing the JWT is not the one that was /// The algorithm used for signing the JWT is not the one that was
/// requested. /// requested.
#[error("wrong signature alg")] #[error("wrong signature alg")]

View File

@ -130,10 +130,7 @@ pub fn verify_signed_jwt<'a>(
claims::ISS.extract_required_with_options(&mut claims, issuer.as_str())?; claims::ISS.extract_required_with_options(&mut claims, issuer.as_str())?;
// Must have the proper audience. // Must have the proper audience.
let aud = claims::AUD.extract_required(&mut claims)?; claims::AUD.extract_required_with_options(&mut claims, client_id)?;
if !aud.contains(client_id) {
return Err(JwtVerificationError::WrongAudience);
}
// Must use the proper algorithm. // Must use the proper algorithm.
if header.alg() != signing_algorithm { if header.alg() != signing_algorithm {

View File

@ -154,7 +154,10 @@ async fn fail_verify_id_token_wrong_audience() {
assert_matches!( assert_matches!(
error, error,
IdTokenError::Jwt(JwtVerificationError::WrongAudience) IdTokenError::Jwt(JwtVerificationError::Claim(ClaimError::ValidationError {
claim: "aud",
..
}))
); );
} }

View File

@ -474,10 +474,7 @@ fn verify_client_jwt(
return Err("Wrong sub".into()); return Err("Wrong sub".into());
} }
let aud = claims::AUD.extract_required(claims)?; claims::AUD.extract_required_with_options(claims, token_endpoint)?;
if !aud.contains(token_endpoint) {
return Err("Wrong aud".into());
}
claims::EXP.extract_required_with_options(claims, TimeOptions::new(now()))?; claims::EXP.extract_required_with_options(claims, TimeOptions::new(now()))?;