1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Add equality claim validator

This commit is contained in:
Kévin Commaille
2022-11-28 11:39:38 +01:00
committed by Quentin Gliech
parent db25574a96
commit a2a3b3954e
7 changed files with 82 additions and 33 deletions

View File

@ -293,6 +293,38 @@ impl<'a> Validator<String> for TokenHash<'a> {
} }
} }
#[derive(Debug, Clone)]
pub struct Equality<'a, T: ?Sized> {
value: &'a T,
}
impl<'a, T: ?Sized> Equality<'a, T> {
/// Creates a new `Equality` validator for the given value.
#[must_use]
pub fn new(value: &'a T) -> Self {
Self { value }
}
}
impl<'a, T1, T2: ?Sized> Validator<T1> for Equality<'a, T2>
where
T2: PartialEq<T1>,
{
fn validate(&self, value: &T1) -> Result<(), anyhow::Error> {
if *self.value == *value {
Ok(())
} else {
Err(anyhow::anyhow!("values don't match"))
}
}
}
impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> {
fn from(value: &'a T) -> Self {
Self::new(value)
}
}
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
#[serde(transparent)] #[serde(transparent)]
pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>); pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>);
@ -349,9 +381,9 @@ impl<T> From<T> for OneOrMany<T> {
/// Claims defined in RFC7519 sec. 4.1 /// Claims defined in RFC7519 sec. 4.1
/// <https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1> /// <https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1>
mod rfc7519 { mod rfc7519 {
use super::{Claim, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp}; use super::{Claim, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp};
pub const ISS: Claim<String> = Claim::new("iss"); pub const ISS: Claim<String, Equality<str>> = Claim::new("iss");
pub const SUB: Claim<String> = Claim::new("sub"); pub const SUB: Claim<String> = Claim::new("sub");
pub const AUD: Claim<OneOrMany<String>> = Claim::new("aud"); pub const AUD: Claim<OneOrMany<String>> = Claim::new("aud");
pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf"); pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf");
@ -366,10 +398,10 @@ mod rfc7519 {
mod oidc_core { mod oidc_core {
use url::Url; use url::Url;
use super::{Claim, Timestamp, TokenHash}; use super::{Claim, Equality, Timestamp, TokenHash};
pub const AUTH_TIME: Claim<Timestamp> = Claim::new("auth_time"); pub const AUTH_TIME: Claim<Timestamp> = Claim::new("auth_time");
pub const NONCE: Claim<String> = Claim::new("nonce"); pub const NONCE: Claim<String, Equality<str>> = Claim::new("nonce");
pub const AT_HASH: Claim<String, TokenHash> = Claim::new("at_hash"); pub const AT_HASH: Claim<String, TokenHash> = Claim::new("at_hash");
pub const C_HASH: Claim<String, TokenHash> = Claim::new("c_hash"); pub const C_HASH: Claim<String, TokenHash> = Claim::new("c_hash");
@ -466,7 +498,9 @@ mod tests {
}); });
let mut claims = serde_json::from_value(claims).unwrap(); let mut claims = serde_json::from_value(claims).unwrap();
let iss = ISS.extract_required(&mut claims).unwrap(); let iss = ISS
.extract_required_with_options(&mut claims, "https://foo.com")
.unwrap();
let sub = SUB.extract_optional(&mut claims).unwrap(); let sub = SUB.extract_optional(&mut claims).unwrap();
let aud = AUD.extract_optional(&mut claims).unwrap(); let aud = AUD.extract_optional(&mut claims).unwrap();
let nbf = NBF let nbf = NBF
@ -617,7 +651,7 @@ mod tests {
let mut claims = serde_json::from_value(claims).unwrap(); let mut claims = serde_json::from_value(claims).unwrap();
assert!(matches!( assert!(matches!(
ISS.extract_required(&mut claims), ISS.extract_required_with_options(&mut claims, "https://foo.com"),
Err(ClaimError::InvalidClaim("iss")) Err(ClaimError::InvalidClaim("iss"))
)); ));
assert!(matches!( assert!(matches!(
@ -652,7 +686,7 @@ mod tests {
let mut claims = HashMap::new(); let mut claims = HashMap::new();
assert!(matches!( assert!(matches!(
ISS.extract_required(&mut claims), ISS.extract_required_with_options(&mut claims, "https://foo.com"),
Err(ClaimError::MissingClaim("iss")) Err(ClaimError::MissingClaim("iss"))
)); ));
assert!(matches!( assert!(matches!(
@ -664,8 +698,28 @@ mod tests {
Err(ClaimError::MissingClaim("aud")) Err(ClaimError::MissingClaim("aud"))
)); ));
assert!(matches!(ISS.extract_optional(&mut claims), Ok(None))); assert!(matches!(
ISS.extract_optional_with_options(&mut claims, "https://foo.com"),
Ok(None)
));
assert!(matches!(SUB.extract_optional(&mut claims), Ok(None))); assert!(matches!(SUB.extract_optional(&mut claims), Ok(None)));
assert!(matches!(AUD.extract_optional(&mut claims), Ok(None))); assert!(matches!(AUD.extract_optional(&mut claims), Ok(None)));
} }
#[test]
fn string_eq_validation() {
let claims = serde_json::json!({
"iss": "https://foo.com",
});
let mut claims: HashMap<String, serde_json::Value> =
serde_json::from_value(claims).unwrap();
ISS.extract_required_with_options(&mut claims.clone(), "https://foo.com")
.unwrap();
assert!(matches!(
ISS.extract_required_with_options(&mut claims, "https://bar.com"),
Err(ClaimError::ValidationError { claim: "iss", .. }),
));
}
} }

View File

@ -335,10 +335,6 @@ where
/// All possible errors when exchanging a code for an access token. /// All possible errors when exchanging a code for an access token.
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum TokenAuthorizationCodeError { pub enum TokenAuthorizationCodeError {
/// The nonce doesn't match the one that was sent.
#[error("wrong nonce")]
WrongNonce,
/// An error occurred requesting the access token. /// An error occurred requesting the access token.
#[error(transparent)] #[error(transparent)]
Token(#[from] TokenRequestError), Token(#[from] TokenRequestError),
@ -606,10 +602,6 @@ pub enum JwtVerificationError {
#[error(transparent)] #[error(transparent)]
Claim(#[from] ClaimError), Claim(#[from] ClaimError),
/// The issuer is not the one that sent the JWT.
#[error("wrong issuer claim")]
WrongIssuer,
/// The audience of the JWT is not this client. /// The audience of the JWT is not this client.
#[error("wrong aud claim")] #[error("wrong aud claim")]
WrongAudience, WrongAudience,

View File

@ -444,12 +444,9 @@ pub async fn access_token_with_authorization_code(
.map_err(IdTokenError::from)?; .map_err(IdTokenError::from)?;
// Nonce must match. // Nonce must match.
let token_nonce = claims::NONCE claims::NONCE
.extract_required(&mut claims) .extract_required_with_options(&mut claims, validation_data.nonce.as_str())
.map_err(IdTokenError::from)?; .map_err(IdTokenError::from)?;
if token_nonce != validation_data.nonce {
return Err(TokenAuthorizationCodeError::WrongNonce);
}
Some(id_token.into_owned()) Some(id_token.into_owned())
} else { } else {

View File

@ -127,10 +127,7 @@ pub fn verify_signed_jwt<'a>(
let (header, mut claims) = jwt.clone().into_parts(); let (header, mut claims) = jwt.clone().into_parts();
// Must have the proper issuer. // Must have the proper issuer.
let iss = claims::ISS.extract_required(&mut claims)?; claims::ISS.extract_required_with_options(&mut claims, issuer.as_str())?;
if iss != issuer.as_str() {
return Err(JwtVerificationError::WrongIssuer);
}
// Must have the proper audience. // Must have the proper audience.
let aud = claims::AUD.extract_required(&mut claims)?; let aud = claims::AUD.extract_required(&mut claims)?;

View File

@ -22,7 +22,7 @@ use chrono::Duration;
use mas_iana::oauth::{ use mas_iana::oauth::{
OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod, OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod,
}; };
use mas_jose::jwk::PublicJsonWebKeySet; use mas_jose::{claims::ClaimError, jwk::PublicJsonWebKeySet};
use mas_oidc_client::{ use mas_oidc_client::{
error::{ error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError, AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
@ -358,7 +358,13 @@ async fn fail_access_token_with_authorization_code_wrong_nonce() {
.await .await
.unwrap_err(); .unwrap_err();
assert_matches!(error, TokenAuthorizationCodeError::WrongNonce); assert_matches!(
error,
TokenAuthorizationCodeError::IdToken(IdTokenError::Claim(ClaimError::ValidationError {
claim: "nonce",
..
}))
);
} }
#[tokio::test] #[tokio::test]

View File

@ -18,7 +18,7 @@ use assert_matches::assert_matches;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{ use mas_jose::{
claims, claims::{self, ClaimError},
constraints::Constrainable, constraints::Constrainable,
jwk::PublicJsonWebKeySet, jwk::PublicJsonWebKeySet,
jwt::{JsonWebSignatureHeader, Jwt}, jwt::{JsonWebSignatureHeader, Jwt},
@ -128,7 +128,13 @@ async fn fail_verify_id_token_wrong_issuer() {
let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err(); let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err();
assert_matches!(error, IdTokenError::Jwt(JwtVerificationError::WrongIssuer)); assert_matches!(
error,
IdTokenError::Jwt(JwtVerificationError::Claim(ClaimError::ValidationError {
claim: "iss",
..
}))
);
} }
#[tokio::test] #[tokio::test]

View File

@ -467,10 +467,7 @@ fn verify_client_jwt(
claims: &mut HashMap<String, Value>, claims: &mut HashMap<String, Value>,
token_endpoint: &String, token_endpoint: &String,
) -> Result<(), BoxError> { ) -> Result<(), BoxError> {
let iss = claims::ISS.extract_required(claims)?; claims::ISS.extract_required_with_options(claims, CLIENT_ID)?;
if iss != CLIENT_ID {
return Err("Wrong iss".into());
}
let sub = claims::SUB.extract_required(claims)?; let sub = claims::SUB.extract_required(claims)?;
if sub != CLIENT_ID { if sub != CLIENT_ID {