You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Use associated error type in claims validator instead of anyhow.
This commit is contained in:
@ -23,7 +23,7 @@ use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerif
|
||||
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_jose::{
|
||||
claims::{self, hash_token, ClaimError},
|
||||
claims::{self, hash_token, ClaimError, TokenHashError},
|
||||
constraints::Constrainable,
|
||||
jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError},
|
||||
};
|
||||
@ -177,6 +177,12 @@ impl From<ClaimError> for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TokenHashError> for RouteError {
|
||||
fn from(e: TokenHashError) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JwtSignatureError> for RouteError {
|
||||
fn from(e: JwtSignatureError) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
|
@ -6,7 +6,6 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.66"
|
||||
base64ct = { version = "1.5.3", features = ["std"] }
|
||||
chrono = { version = "0.4.23", features = ["serde"] }
|
||||
digest = "0.10.6"
|
||||
|
@ -12,9 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, marker::PhantomData, ops::Deref};
|
||||
use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref};
|
||||
|
||||
use anyhow::Context;
|
||||
use base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
@ -33,16 +32,19 @@ pub enum ClaimError {
|
||||
ValidationError {
|
||||
claim: &'static str,
|
||||
#[source]
|
||||
source: anyhow::Error,
|
||||
source: Box<dyn std::error::Error + Send + Sync + 'static>,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Validator<T> {
|
||||
fn validate(&self, value: &T) -> Result<(), anyhow::Error>;
|
||||
type Error;
|
||||
fn validate(&self, value: &T) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
impl<T> Validator<T> for () {
|
||||
fn validate(&self, _value: &T) -> Result<(), anyhow::Error> {
|
||||
type Error = Infallible;
|
||||
|
||||
fn validate(&self, _value: &T) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -53,7 +55,10 @@ pub struct Claim<T, V = ()> {
|
||||
v: PhantomData<V>,
|
||||
}
|
||||
|
||||
impl<T, V> Claim<T, V> {
|
||||
impl<T, V> Claim<T, V>
|
||||
where
|
||||
V: Validator<T>,
|
||||
{
|
||||
#[must_use]
|
||||
pub const fn new(claim: &'static str) -> Self {
|
||||
Self {
|
||||
@ -86,7 +91,8 @@ impl<T, V> Claim<T, V> {
|
||||
) -> Result<T, ClaimError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
V: Default + Validator<T>,
|
||||
V: Default,
|
||||
V::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let validator = V::default();
|
||||
self.extract_required_with_options(claims, validator)
|
||||
@ -100,7 +106,7 @@ impl<T, V> Claim<T, V> {
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
I: Into<V>,
|
||||
V: Validator<T>,
|
||||
V::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let validator: V = validator.into();
|
||||
let claim = claims
|
||||
@ -113,7 +119,7 @@ impl<T, V> Claim<T, V> {
|
||||
.validate(&res)
|
||||
.map_err(|source| ClaimError::ValidationError {
|
||||
claim: self.claim,
|
||||
source,
|
||||
source: Box::new(source),
|
||||
})?;
|
||||
Ok(res)
|
||||
}
|
||||
@ -124,7 +130,8 @@ impl<T, V> Claim<T, V> {
|
||||
) -> Result<Option<T>, ClaimError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
V: Default + Validator<T>,
|
||||
V: Default,
|
||||
V::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let validator = V::default();
|
||||
self.extract_optional_with_options(claims, validator)
|
||||
@ -138,7 +145,7 @@ impl<T, V> Claim<T, V> {
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
I: Into<V>,
|
||||
V: Validator<T>,
|
||||
V::Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
match self.extract_required_with_options(claims, validator) {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
@ -170,15 +177,20 @@ impl TimeOptions {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Error)]
|
||||
#[error("Current time is too far away")]
|
||||
pub struct TimeTooFarError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TimeNotAfter(TimeOptions);
|
||||
|
||||
impl Validator<Timestamp> for TimeNotAfter {
|
||||
fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> {
|
||||
type Error = TimeTooFarError;
|
||||
fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
|
||||
if self.0.when <= value.0 + self.0.leeway {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("current time is too far away"))
|
||||
Err(TimeTooFarError)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -199,11 +211,13 @@ impl From<&TimeOptions> for TimeNotAfter {
|
||||
pub struct TimeNotBefore(TimeOptions);
|
||||
|
||||
impl Validator<Timestamp> for TimeNotBefore {
|
||||
fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> {
|
||||
type Error = TimeTooFarError;
|
||||
|
||||
fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
|
||||
if self.0.when >= value.0 - self.0.leeway {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("current time is too far before"))
|
||||
Err(TimeTooFarError)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -229,7 +243,7 @@ impl From<&TimeOptions> for TimeNotBefore {
|
||||
/// Returns an error if the algorithm is not supported.
|
||||
///
|
||||
/// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
|
||||
pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<String> {
|
||||
pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result<String, TokenHashError> {
|
||||
let bits = match alg {
|
||||
JsonWebSignatureAlg::Hs256
|
||||
| JsonWebSignatureAlg::Rs256
|
||||
@ -238,9 +252,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
|
||||
| JsonWebSignatureAlg::Es256K => {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token);
|
||||
let hash = hasher.finalize();
|
||||
let hash: [u8; 32] = hasher.finalize().into();
|
||||
// Left-most half
|
||||
hash.get(..16).map(ToOwned::to_owned)
|
||||
hash[..16].to_owned()
|
||||
}
|
||||
JsonWebSignatureAlg::Hs384
|
||||
| JsonWebSignatureAlg::Rs384
|
||||
@ -248,9 +262,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
|
||||
| JsonWebSignatureAlg::Ps384 => {
|
||||
let mut hasher = Sha384::new();
|
||||
hasher.update(token);
|
||||
let hash = hasher.finalize();
|
||||
let hash: [u8; 48] = hasher.finalize().into();
|
||||
// Left-most half
|
||||
hash.get(..24).map(ToOwned::to_owned)
|
||||
hash[..24].to_owned()
|
||||
}
|
||||
JsonWebSignatureAlg::Hs512
|
||||
| JsonWebSignatureAlg::Rs512
|
||||
@ -258,17 +272,25 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result<Stri
|
||||
| JsonWebSignatureAlg::Ps512 => {
|
||||
let mut hasher = Sha512::new();
|
||||
hasher.update(token);
|
||||
let hash = hasher.finalize();
|
||||
let hash: [u8; 64] = hasher.finalize().into();
|
||||
// Left-most half
|
||||
hash.get(..32).map(ToOwned::to_owned)
|
||||
hash[..32].to_owned()
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("unsupported algorithm for hashing")),
|
||||
}
|
||||
.context("failed to get first half of hash")?;
|
||||
_ => return Err(TokenHashError::UnsupportedAlgorithm),
|
||||
};
|
||||
|
||||
Ok(Base64UrlUnpadded::encode_string(&bits))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Error)]
|
||||
pub enum TokenHashError {
|
||||
#[error("Hashes don't match")]
|
||||
HashMismatch,
|
||||
|
||||
#[error("Unsupported algorithm for hashing")]
|
||||
UnsupportedAlgorithm,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokenHash<'a> {
|
||||
alg: &'a JsonWebSignatureAlg,
|
||||
@ -284,15 +306,20 @@ impl<'a> TokenHash<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Validator<String> for TokenHash<'a> {
|
||||
fn validate(&self, value: &String) -> Result<(), anyhow::Error> {
|
||||
type Error = TokenHashError;
|
||||
fn validate(&self, value: &String) -> Result<(), Self::Error> {
|
||||
if hash_token(self.alg, self.token)? == *value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("hashes don't match"))
|
||||
Err(TokenHashError::HashMismatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Error)]
|
||||
#[error("Values don't match")]
|
||||
pub struct EqualityError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Equality<'a, T: ?Sized> {
|
||||
value: &'a T,
|
||||
@ -310,11 +337,12 @@ impl<'a, T1, T2: ?Sized> Validator<T1> for Equality<'a, T2>
|
||||
where
|
||||
T2: PartialEq<T1>,
|
||||
{
|
||||
fn validate(&self, value: &T1) -> Result<(), anyhow::Error> {
|
||||
type Error = EqualityError;
|
||||
fn validate(&self, value: &T1) -> Result<(), Self::Error> {
|
||||
if *self.value == *value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("values don't match"))
|
||||
Err(EqualityError)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -338,15 +366,20 @@ impl<'a, T> Contains<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Error)]
|
||||
#[error("OneOrMany doesn't contains value")]
|
||||
pub struct ContainsError;
|
||||
|
||||
impl<'a, T> Validator<OneOrMany<T>> for Contains<'a, T>
|
||||
where
|
||||
T: PartialEq,
|
||||
{
|
||||
fn validate(&self, value: &OneOrMany<T>) -> Result<(), anyhow::Error> {
|
||||
type Error = ContainsError;
|
||||
fn validate(&self, value: &OneOrMany<T>) -> Result<(), Self::Error> {
|
||||
if value.contains(self.value) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("OneOrMany doesn't contain value"))
|
||||
Err(ContainsError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user