diff --git a/Cargo.lock b/Cargo.lock index 6cf2ba73..de8dd489 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2911,7 +2911,6 @@ dependencies = [ name = "mas-jose" version = "0.1.0" dependencies = [ - "anyhow", "base64ct", "chrono", "digest 0.10.6", diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index f15416b0..bf37121d 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -23,7 +23,7 @@ use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerif use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ - claims::{self, hash_token, ClaimError}, + claims::{self, hash_token, ClaimError, TokenHashError}, constraints::Constrainable, jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError}, }; @@ -177,6 +177,12 @@ impl From for RouteError { } } +impl From for RouteError { + fn from(e: TokenHashError) -> Self { + Self::Internal(Box::new(e)) + } +} + impl From for RouteError { fn from(e: JwtSignatureError) -> Self { Self::Internal(Box::new(e)) diff --git a/crates/jose/Cargo.toml b/crates/jose/Cargo.toml index b4a48b9e..62f09612 100644 --- a/crates/jose/Cargo.toml +++ b/crates/jose/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" license = "Apache-2.0" [dependencies] -anyhow = "1.0.66" base64ct = { version = "1.5.3", features = ["std"] } chrono = { version = "0.4.23", features = ["serde"] } digest = "0.10.6" diff --git a/crates/jose/src/claims.rs b/crates/jose/src/claims.rs index d16f923f..d4a2a8ff 100644 --- a/crates/jose/src/claims.rs +++ b/crates/jose/src/claims.rs @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, marker::PhantomData, ops::Deref}; +use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref}; -use anyhow::Context; use base64ct::{Base64UrlUnpadded, Encoding}; use mas_iana::jose::JsonWebSignatureAlg; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -33,16 +32,19 @@ pub enum ClaimError { ValidationError { claim: &'static str, #[source] - source: anyhow::Error, + source: Box, }, } pub trait Validator { - fn validate(&self, value: &T) -> Result<(), anyhow::Error>; + type Error; + fn validate(&self, value: &T) -> Result<(), Self::Error>; } impl Validator for () { - fn validate(&self, _value: &T) -> Result<(), anyhow::Error> { + type Error = Infallible; + + fn validate(&self, _value: &T) -> Result<(), Self::Error> { Ok(()) } } @@ -53,7 +55,10 @@ pub struct Claim { v: PhantomData, } -impl Claim { +impl Claim +where + V: Validator, +{ #[must_use] pub const fn new(claim: &'static str) -> Self { Self { @@ -86,7 +91,8 @@ impl Claim { ) -> Result where T: DeserializeOwned, - V: Default + Validator, + V: Default, + V::Error: std::error::Error + Send + Sync + 'static, { let validator = V::default(); self.extract_required_with_options(claims, validator) @@ -100,7 +106,7 @@ impl Claim { where T: DeserializeOwned, I: Into, - V: Validator, + V::Error: std::error::Error + Send + Sync + 'static, { let validator: V = validator.into(); let claim = claims @@ -113,7 +119,7 @@ impl Claim { .validate(&res) .map_err(|source| ClaimError::ValidationError { claim: self.claim, - source, + source: Box::new(source), })?; Ok(res) } @@ -124,7 +130,8 @@ impl Claim { ) -> Result, ClaimError> where T: DeserializeOwned, - V: Default + Validator, + V: Default, + V::Error: std::error::Error + Send + Sync + 'static, { let validator = V::default(); self.extract_optional_with_options(claims, validator) @@ -138,7 +145,7 @@ impl Claim { where T: DeserializeOwned, I: Into, - V: Validator, + V::Error: std::error::Error + Send + Sync + 'static, { match self.extract_required_with_options(claims, validator) { Ok(v) => Ok(Some(v)), @@ -170,15 +177,20 @@ impl TimeOptions { } } +#[derive(Debug, Clone, Copy, Error)] +#[error("Current time is too far away")] +pub struct TimeTooFarError; + #[derive(Debug, Clone)] pub struct TimeNotAfter(TimeOptions); impl Validator for TimeNotAfter { - fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> { + type Error = TimeTooFarError; + fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> { if self.0.when <= value.0 + self.0.leeway { Ok(()) } else { - Err(anyhow::anyhow!("current time is too far away")) + Err(TimeTooFarError) } } } @@ -199,11 +211,13 @@ impl From<&TimeOptions> for TimeNotAfter { pub struct TimeNotBefore(TimeOptions); impl Validator for TimeNotBefore { - fn validate(&self, value: &Timestamp) -> Result<(), anyhow::Error> { + type Error = TimeTooFarError; + + fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> { if self.0.when >= value.0 - self.0.leeway { Ok(()) } else { - Err(anyhow::anyhow!("current time is too far before")) + Err(TimeTooFarError) } } } @@ -229,7 +243,7 @@ impl From<&TimeOptions> for TimeNotBefore { /// Returns an error if the algorithm is not supported. /// /// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken -pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result { +pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result { let bits = match alg { JsonWebSignatureAlg::Hs256 | JsonWebSignatureAlg::Rs256 @@ -238,9 +252,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result { let mut hasher = Sha256::new(); hasher.update(token); - let hash = hasher.finalize(); + let hash: [u8; 32] = hasher.finalize().into(); // Left-most half - hash.get(..16).map(ToOwned::to_owned) + hash[..16].to_owned() } JsonWebSignatureAlg::Hs384 | JsonWebSignatureAlg::Rs384 @@ -248,9 +262,9 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result { let mut hasher = Sha384::new(); hasher.update(token); - let hash = hasher.finalize(); + let hash: [u8; 48] = hasher.finalize().into(); // Left-most half - hash.get(..24).map(ToOwned::to_owned) + hash[..24].to_owned() } JsonWebSignatureAlg::Hs512 | JsonWebSignatureAlg::Rs512 @@ -258,17 +272,25 @@ pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> anyhow::Result { let mut hasher = Sha512::new(); hasher.update(token); - let hash = hasher.finalize(); + let hash: [u8; 64] = hasher.finalize().into(); // Left-most half - hash.get(..32).map(ToOwned::to_owned) + hash[..32].to_owned() } - _ => return Err(anyhow::anyhow!("unsupported algorithm for hashing")), - } - .context("failed to get first half of hash")?; + _ => return Err(TokenHashError::UnsupportedAlgorithm), + }; Ok(Base64UrlUnpadded::encode_string(&bits)) } +#[derive(Debug, Clone, Copy, Error)] +pub enum TokenHashError { + #[error("Hashes don't match")] + HashMismatch, + + #[error("Unsupported algorithm for hashing")] + UnsupportedAlgorithm, +} + #[derive(Debug, Clone)] pub struct TokenHash<'a> { alg: &'a JsonWebSignatureAlg, @@ -284,15 +306,20 @@ impl<'a> TokenHash<'a> { } impl<'a> Validator for TokenHash<'a> { - fn validate(&self, value: &String) -> Result<(), anyhow::Error> { + type Error = TokenHashError; + fn validate(&self, value: &String) -> Result<(), Self::Error> { if hash_token(self.alg, self.token)? == *value { Ok(()) } else { - Err(anyhow::anyhow!("hashes don't match")) + Err(TokenHashError::HashMismatch) } } } +#[derive(Debug, Clone, Copy, Error)] +#[error("Values don't match")] +pub struct EqualityError; + #[derive(Debug, Clone)] pub struct Equality<'a, T: ?Sized> { value: &'a T, @@ -310,11 +337,12 @@ impl<'a, T1, T2: ?Sized> Validator for Equality<'a, T2> where T2: PartialEq, { - fn validate(&self, value: &T1) -> Result<(), anyhow::Error> { + type Error = EqualityError; + fn validate(&self, value: &T1) -> Result<(), Self::Error> { if *self.value == *value { Ok(()) } else { - Err(anyhow::anyhow!("values don't match")) + Err(EqualityError) } } } @@ -338,15 +366,20 @@ impl<'a, T> Contains<'a, T> { } } +#[derive(Debug, Clone, Copy, Error)] +#[error("OneOrMany doesn't contains value")] +pub struct ContainsError; + impl<'a, T> Validator> for Contains<'a, T> where T: PartialEq, { - fn validate(&self, value: &OneOrMany) -> Result<(), anyhow::Error> { + type Error = ContainsError; + fn validate(&self, value: &OneOrMany) -> Result<(), Self::Error> { if value.contains(self.value) { Ok(()) } else { - Err(anyhow::anyhow!("OneOrMany doesn't contain value")) + Err(ContainsError) } } }