1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-06 06:02:40 +03:00

Use associated error type in claims validator instead of anyhow.

This commit is contained in:
Quentin Gliech
2022-12-01 12:48:22 +01:00
parent 88f6e0ff28
commit 0ca4366f75
4 changed files with 71 additions and 34 deletions

1
Cargo.lock generated
View File

@@ -2911,7 +2911,6 @@ dependencies = [
name = "mas-jose" name = "mas-jose"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"base64ct", "base64ct",
"chrono", "chrono",
"digest 0.10.6", "digest 0.10.6",

View File

@@ -23,7 +23,7 @@ use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerif
use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{ use mas_jose::{
claims::{self, hash_token, ClaimError}, claims::{self, hash_token, ClaimError, TokenHashError},
constraints::Constrainable, constraints::Constrainable,
jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError}, jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError},
}; };
@@ -177,6 +177,12 @@ impl From<ClaimError> for RouteError {
} }
} }
impl From<TokenHashError> for RouteError {
fn from(e: TokenHashError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<JwtSignatureError> for RouteError { impl From<JwtSignatureError> for RouteError {
fn from(e: JwtSignatureError) -> Self { fn from(e: JwtSignatureError) -> Self {
Self::Internal(Box::new(e)) Self::Internal(Box::new(e))

View File

@@ -6,7 +6,6 @@ edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
[dependencies] [dependencies]
anyhow = "1.0.66"
base64ct = { version = "1.5.3", features = ["std"] } base64ct = { version = "1.5.3", features = ["std"] }
chrono = { version = "0.4.23", features = ["serde"] } chrono = { version = "0.4.23", features = ["serde"] }
digest = "0.10.6" digest = "0.10.6"

View File

@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, marker::PhantomData, ops::Deref}; use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref};
use anyhow::Context;
use base64ct::{Base64UrlUnpadded, Encoding}; use base64ct::{Base64UrlUnpadded, Encoding};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
@@ -33,16 +32,19 @@ pub enum ClaimError {
ValidationError { ValidationError {
claim: &'static str, claim: &'static str,
#[source] #[source]
source: anyhow::Error, source: Box<dyn std::error::Error + Send + Sync + 'static>,
}, },
} }
pub trait Validator<T> { pub trait Validator<T> {
fn validate(&self, value: &T) -> Result<(), anyhow::Error>; type Error;
fn validate(&self, value: &T) -> Result<(), Self::Error>;
} }
impl<T> Validator<T> for () { impl<T> Validator<T> for () {
fn validate(&self, _value: &T) -> Result<(), anyhow::Error> { type Error = Infallible;
fn validate(&self, _value: &T) -> Result<(), Self::Error> {
Ok(()) Ok(())
} }
} }
@@ -53,7 +55,10 @@ pub struct Claim<T, V = ()> {
v: PhantomData<V>, v: PhantomData<V>,
} }
impl<T, V> Claim<T, V> { impl<T, V> Claim<T, V>
where
V: Validator<T>,
{
#[must_use] #[must_use]
pub const fn new(claim: &'static str) -> Self { pub const fn new(claim: &'static str) -> Self {
Self { Self {
@@ -86,7 +91,8 @@ impl<T, V> Claim<T, V> {
) -> Result<T, ClaimError> ) -> Result<T, ClaimError>
where where
T: DeserializeOwned, T: DeserializeOwned,
V: Default + Validator<T>, V: Default,
V::Error: std::error::Error + Send + Sync + 'static,
{ {
let validator = V::default(); let validator = V::default();
self.extract_required_with_options(claims, validator) self.extract_required_with_options(claims, validator)
@@ -100,7 +106,7 @@ impl<T, V> Claim<T, V> {
where where
T: DeserializeOwned, T: DeserializeOwned,
I: Into<V>, I: Into<V>,
V: Validator<T>, V::Error: std::error::Error + Send + Sync + 'static,
{ {
let validator: V = validator.into(); let validator: V = validator.into();
let claim = claims let claim = claims
@@ -113,7 +119,7 @@ impl<T, V> Claim<T, V> {
.validate(&res) .validate(&res)
.map_err(|source| ClaimError::ValidationError { .map_err(|source| ClaimError::ValidationError {
claim: self.claim, claim: self.claim,
source, source: Box::new(source),
})?; })?;
Ok(res) Ok(res)
} }
@@ -124,7 +130,8 @@ impl<T, V> Claim<T, V> {
) -> Result<Option<T>, ClaimError> ) -> Result<Option<T>, ClaimError>
where where
T: DeserializeOwned, T: DeserializeOwned,
V: Default + Validator<T>, V: Default,
V::Error: std::error::Error + Send + Sync + 'static,
{ {
let validator = V::default(); let validator = V::default();
self.extract_optional_with_options(claims, validator) self.extract_optional_with_options(claims, validator)
@@ -138,7 +145,7 @@ impl<T, V> Claim<T, V> {
where where
T: DeserializeOwned, T: DeserializeOwned,
I: Into<V>, I: Into<V>,
V: Validator<T>, V::Error: std::error::Error + Send + Sync + 'static,
{ {
match self.extract_required_with_options(claims, validator) { match self.extract_required_with_options(claims, validator) {
Ok(v) => Ok(Some(v)), Ok(v) => Ok(Some(v)),
@@ -170,15 +177,20 @@ impl TimeOptions {
} }
} }
#[derive(Debug, Clone, Copy, Error)]
#[error("Current time is too far away")]
pub struct TimeTooFarError;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TimeNotAfter(TimeOptions); pub struct TimeNotAfter(TimeOptions);
impl Validator<Timestamp> for TimeNotAfter { impl Validator<Timestamp> for TimeNotAfter {
fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> { type Error = TimeTooFarError;
fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
if self.0.when <= value.0 + self.0.leeway { if self.0.when <= value.0 + self.0.leeway {
Ok(()) Ok(())
} else { } else {
Err(anyhow::anyhow!("current time is too far away")) Err(TimeTooFarError)
} }
} }
} }
@@ -199,11 +211,13 @@ impl From<&TimeOptions> for TimeNotAfter {
pub struct TimeNotBefore(TimeOptions); pub struct TimeNotBefore(TimeOptions);
impl Validator<Timestamp> for TimeNotBefore { impl Validator<Timestamp> for TimeNotBefore {
fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> { type Error = TimeTooFarError;
fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
if self.0.when >= value.0 - self.0.leeway { if self.0.when >= value.0 - self.0.leeway {
Ok(()) Ok(())
} else { } else {
Err(anyhow::anyhow!("current time is too far before")) Err(TimeTooFarError)
} }
} }
} }
@@ -229,7 +243,7 @@ impl From<&TimeOptions> for TimeNotBefore {
/// Returns an error if the algorithm is not supported. /// Returns an error if the algorithm is not supported.
/// ///
/// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken /// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<String> { pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result<String, TokenHashError> {
let bits = match alg { let bits = match alg {
JsonWebSignatureAlg::Hs256 JsonWebSignatureAlg::Hs256
| JsonWebSignatureAlg::Rs256 | JsonWebSignatureAlg::Rs256
@@ -238,9 +252,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
| JsonWebSignatureAlg::Es256K => { | JsonWebSignatureAlg::Es256K => {
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
hasher.update(token); hasher.update(token);
let hash = hasher.finalize(); let hash: [u8; 32] = hasher.finalize().into();
// Left-most half // Left-most half
hash.get(..16).map(ToOwned::to_owned) hash[..16].to_owned()
} }
JsonWebSignatureAlg::Hs384 JsonWebSignatureAlg::Hs384
| JsonWebSignatureAlg::Rs384 | JsonWebSignatureAlg::Rs384
@@ -248,9 +262,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
| JsonWebSignatureAlg::Ps384 => { | JsonWebSignatureAlg::Ps384 => {
let mut hasher = Sha384::new(); let mut hasher = Sha384::new();
hasher.update(token); hasher.update(token);
let hash = hasher.finalize(); let hash: [u8; 48] = hasher.finalize().into();
// Left-most half // Left-most half
hash.get(..24).map(ToOwned::to_owned) hash[..24].to_owned()
} }
JsonWebSignatureAlg::Hs512 JsonWebSignatureAlg::Hs512
| JsonWebSignatureAlg::Rs512 | JsonWebSignatureAlg::Rs512
@@ -258,17 +272,25 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
| JsonWebSignatureAlg::Ps512 => { | JsonWebSignatureAlg::Ps512 => {
let mut hasher = Sha512::new(); let mut hasher = Sha512::new();
hasher.update(token); hasher.update(token);
let hash = hasher.finalize(); let hash: [u8; 64] = hasher.finalize().into();
// Left-most half // Left-most half
hash.get(..32).map(ToOwned::to_owned) hash[..32].to_owned()
} }
_ => return Err(anyhow::anyhow!("unsupported algorithm for hashing")), _ => return Err(TokenHashError::UnsupportedAlgorithm),
} };
.context("failed to get first half of hash")?;
Ok(Base64UrlUnpadded::encode_string(&bits)) Ok(Base64UrlUnpadded::encode_string(&bits))
} }
#[derive(Debug, Clone, Copy, Error)]
pub enum TokenHashError {
#[error("Hashes don't match")]
HashMismatch,
#[error("Unsupported algorithm for hashing")]
UnsupportedAlgorithm,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TokenHash<'a> { pub struct TokenHash<'a> {
alg: &'a JsonWebSignatureAlg, alg: &'a JsonWebSignatureAlg,
@@ -284,15 +306,20 @@ impl<'a> TokenHash<'a> {
} }
impl<'a> Validator<String> for TokenHash<'a> { impl<'a> Validator<String> for TokenHash<'a> {
fn validate(&self, value: &String) -> Result<(), anyhow::Error> { type Error = TokenHashError;
fn validate(&self, value: &String) -> Result<(), Self::Error> {
if hash_token(self.alg, self.token)? == *value { if hash_token(self.alg, self.token)? == *value {
Ok(()) Ok(())
} else { } else {
Err(anyhow::anyhow!("hashes don't match")) Err(TokenHashError::HashMismatch)
} }
} }
} }
#[derive(Debug, Clone, Copy, Error)]
#[error("Values don't match")]
pub struct EqualityError;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Equality<'a, T: ?Sized> { pub struct Equality<'a, T: ?Sized> {
value: &'a T, value: &'a T,
@@ -310,11 +337,12 @@ impl<'a, T1, T2: ?Sized> Validator<T1> for Equality<'a, T2>
where where
T2: PartialEq<T1>, T2: PartialEq<T1>,
{ {
fn validate(&self, value: &T1) -> Result<(), anyhow::Error> { type Error = EqualityError;
fn validate(&self, value: &T1) -> Result<(), Self::Error> {
if *self.value == *value { if *self.value == *value {
Ok(()) Ok(())
} else { } else {
Err(anyhow::anyhow!("values don't match")) Err(EqualityError)
} }
} }
} }
@@ -338,15 +366,20 @@ impl<'a, T> Contains<'a, T> {
} }
} }
#[derive(Debug, Clone, Copy, Error)]
#[error("OneOrMany doesn't contains value")]
pub struct ContainsError;
impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T> impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T>
where where
T: PartialEq, T: PartialEq,
{ {
fn validate(&self, value: &OneOrMany<T>) -> Result<(), anyhow::Error> { type Error = ContainsError;
fn validate(&self, value: &OneOrMany<T>) -> Result<(), Self::Error> {
if value.contains(self.value) { if value.contains(self.value) {
Ok(()) Ok(())
} else { } else {
Err(anyhow::anyhow!("OneOrMany doesn't contain value")) Err(ContainsError)
} }
} }
} }