1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Use associated error type in claims validator instead of anyhow.

This commit is contained in:
Quentin Gliech
2022-12-01 12:48:22 +01:00
parent 88f6e0ff28
commit 0ca4366f75
4 changed files with 71 additions and 34 deletions

View File

@ -23,7 +23,7 @@ use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerif
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, hash_token, ClaimError},
claims::{self, hash_token, ClaimError, TokenHashError},
constraints::Constrainable,
jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError},
};
@ -177,6 +177,12 @@ impl From<ClaimError> for RouteError {
}
}
impl From<TokenHashError> for RouteError {
fn from(e: TokenHashError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<JwtSignatureError> for RouteError {
fn from(e: JwtSignatureError) -> Self {
Self::Internal(Box::new(e))

View File

@ -6,7 +6,6 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
anyhow = "1.0.66"
base64ct = { version = "1.5.3", features = ["std"] }
chrono = { version = "0.4.23", features = ["serde"] }
digest = "0.10.6"

View File

@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, marker::PhantomData, ops::Deref};
use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref};
use anyhow::Context;
use base64ct::{Base64UrlUnpadded, Encoding};
use mas_iana::jose::JsonWebSignatureAlg;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -33,16 +32,19 @@ pub enum ClaimError {
ValidationError {
claim: &'static str,
#[source]
source: anyhow::Error,
source: Box<dyn std::error::Error + Send + Sync + 'static>,
},
}
pub trait Validator<T> {
fn validate(&self, value: &T) -> Result<(), anyhow::Error>;
type Error;
fn validate(&self, value: &T) -> Result<(), Self::Error>;
}
impl<T> Validator<T> for () {
fn validate(&self, _value: &T) -> Result<(), anyhow::Error> {
type Error = Infallible;
fn validate(&self, _value: &T) -> Result<(), Self::Error> {
Ok(())
}
}
@ -53,7 +55,10 @@ pub struct Claim<T, V = ()> {
v: PhantomData<V>,
}
impl<T, V> Claim<T, V> {
impl<T, V> Claim<T, V>
where
V: Validator<T>,
{
#[must_use]
pub const fn new(claim: &'static str) -> Self {
Self {
@ -86,7 +91,8 @@ impl<T, V> Claim<T, V> {
) -> Result<T, ClaimError>
where
T: DeserializeOwned,
V: Default + Validator<T>,
V: Default,
V::Error: std::error::Error + Send + Sync + 'static,
{
let validator = V::default();
self.extract_required_with_options(claims, validator)
@ -100,7 +106,7 @@ impl<T, V> Claim<T, V> {
where
T: DeserializeOwned,
I: Into<V>,
V: Validator<T>,
V::Error: std::error::Error + Send + Sync + 'static,
{
let validator: V = validator.into();
let claim = claims
@ -113,7 +119,7 @@ impl<T, V> Claim<T, V> {
.validate(&res)
.map_err(|source| ClaimError::ValidationError {
claim: self.claim,
source,
source: Box::new(source),
})?;
Ok(res)
}
@ -124,7 +130,8 @@ impl<T, V> Claim<T, V> {
) -> Result<Option<T>, ClaimError>
where
T: DeserializeOwned,
V: Default + Validator<T>,
V: Default,
V::Error: std::error::Error + Send + Sync + 'static,
{
let validator = V::default();
self.extract_optional_with_options(claims, validator)
@ -138,7 +145,7 @@ impl<T, V> Claim<T, V> {
where
T: DeserializeOwned,
I: Into<V>,
V: Validator<T>,
V::Error: std::error::Error + Send + Sync + 'static,
{
match self.extract_required_with_options(claims, validator) {
Ok(v) => Ok(Some(v)),
@ -170,15 +177,20 @@ impl TimeOptions {
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Current time is too far away")]
pub struct TimeTooFarError;
#[derive(Debug, Clone)]
pub struct TimeNotAfter(TimeOptions);
impl Validator<Timestamp> for TimeNotAfter {
fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> {
type Error = TimeTooFarError;
fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
if self.0.when <= value.0 + self.0.leeway {
Ok(())
} else {
Err(anyhow::anyhow!("current time is too far away"))
Err(TimeTooFarError)
}
}
}
@ -199,11 +211,13 @@ impl From<&TimeOptions> for TimeNotAfter {
pub struct TimeNotBefore(TimeOptions);
impl Validator<Timestamp> for TimeNotBefore {
fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> {
type Error = TimeTooFarError;
fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
if self.0.when >= value.0 - self.0.leeway {
Ok(())
} else {
Err(anyhow::anyhow!("current time is too far before"))
Err(TimeTooFarError)
}
}
}
@ -229,7 +243,7 @@ impl From<&TimeOptions> for TimeNotBefore {
/// Returns an error if the algorithm is not supported.
///
/// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<String> {
pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result<String, TokenHashError> {
let bits = match alg {
JsonWebSignatureAlg::Hs256
| JsonWebSignatureAlg::Rs256
@ -238,9 +252,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
| JsonWebSignatureAlg::Es256K => {
let mut hasher = Sha256::new();
hasher.update(token);
let hash = hasher.finalize();
let hash: [u8; 32] = hasher.finalize().into();
// Left-most half
hash.get(..16).map(ToOwned::to_owned)
hash[..16].to_owned()
}
JsonWebSignatureAlg::Hs384
| JsonWebSignatureAlg::Rs384
@ -248,9 +262,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
| JsonWebSignatureAlg::Ps384 => {
let mut hasher = Sha384::new();
hasher.update(token);
let hash = hasher.finalize();
let hash: [u8; 48] = hasher.finalize().into();
// Left-most half
hash.get(..24).map(ToOwned::to_owned)
hash[..24].to_owned()
}
JsonWebSignatureAlg::Hs512
| JsonWebSignatureAlg::Rs512
@ -258,17 +272,25 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
| JsonWebSignatureAlg::Ps512 => {
let mut hasher = Sha512::new();
hasher.update(token);
let hash = hasher.finalize();
let hash: [u8; 64] = hasher.finalize().into();
// Left-most half
hash.get(..32).map(ToOwned::to_owned)
hash[..32].to_owned()
}
_ => return Err(anyhow::anyhow!("unsupported algorithm for hashing")),
}
.context("failed to get first half of hash")?;
_ => return Err(TokenHashError::UnsupportedAlgorithm),
};
Ok(Base64UrlUnpadded::encode_string(&bits))
}
#[derive(Debug, Clone, Copy, Error)]
pub enum TokenHashError {
#[error("Hashes don't match")]
HashMismatch,
#[error("Unsupported algorithm for hashing")]
UnsupportedAlgorithm,
}
#[derive(Debug, Clone)]
pub struct TokenHash<'a> {
alg: &'a JsonWebSignatureAlg,
@ -284,15 +306,20 @@ impl<'a> TokenHash<'a> {
}
impl<'a> Validator<String> for TokenHash<'a> {
fn validate(&self, value: &String) -> Result<(), anyhow::Error> {
type Error = TokenHashError;
fn validate(&self, value: &String) -> Result<(), Self::Error> {
if hash_token(self.alg, self.token)? == *value {
Ok(())
} else {
Err(anyhow::anyhow!("hashes don't match"))
Err(TokenHashError::HashMismatch)
}
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Values don't match")]
pub struct EqualityError;
#[derive(Debug, Clone)]
pub struct Equality<'a, T: ?Sized> {
value: &'a T,
@ -310,11 +337,12 @@ impl<'a, T1, T2: ?Sized> Validator<T1> for Equality<'a, T2>
where
T2: PartialEq<T1>,
{
fn validate(&self, value: &T1) -> Result<(), anyhow::Error> {
type Error = EqualityError;
fn validate(&self, value: &T1) -> Result<(), Self::Error> {
if *self.value == *value {
Ok(())
} else {
Err(anyhow::anyhow!("values don't match"))
Err(EqualityError)
}
}
}
@ -338,15 +366,20 @@ impl<'a, T> Contains<'a, T> {
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("OneOrMany doesn't contains value")]
pub struct ContainsError;
impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T>
where
T: PartialEq,
{
fn validate(&self, value: &OneOrMany<T>) -> Result<(), anyhow::Error> {
type Error = ContainsError;
fn validate(&self, value: &OneOrMany<T>) -> Result<(), Self::Error> {
if value.contains(self.value) {
Ok(())
} else {
Err(anyhow::anyhow!("OneOrMany doesn't contain value"))
Err(ContainsError)
}
}
}