diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 0475878e..a7539293 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -36,7 +36,7 @@ pub use self::{ }, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, - InvalidRedirectUriError, JwksOrJwksUri, Pkce, PkceVerificationError, Session, + InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, }, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, traits::{StorageBackend, StorageBackendMarker}, diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index 4b73cb51..cf69f07b 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -16,7 +16,10 @@ use std::num::NonZeroU32; use chrono::{DateTime, Duration, Utc}; use mas_iana::oauth::PkceCodeChallengeMethod; -use oauth2_types::{pkce::CodeChallengeMethodExt, requests::ResponseMode}; +use oauth2_types::{ + pkce::{CodeChallengeError, CodeChallengeMethodExt}, + requests::ResponseMode, +}; use serde::Serialize; use thiserror::Error; use url::Url; @@ -24,21 +27,6 @@ use url::Url; use super::{client::Client, session::Session}; use crate::{traits::StorageBackend, StorageBackendMarker}; -#[derive(Debug, Error, PartialEq)] -pub enum PkceVerificationError { - #[error("code_verifier should be at least 43 characters long")] - TooShort, - - #[error("code_verifier should be at most 128 characters long")] - TooLong, - - #[error("code_verifier contains invalid characters")] - InvalidCharacters, - - #[error("challenge verification failed")] - VerificationFailed, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Pkce { pub challenge_method: PkceCodeChallengeMethod, @@ -54,27 +42,8 @@ impl Pkce { } } - pub fn verify(&self, verifier: &str) -> Result<(), PkceVerificationError> { - if verifier.len() < 43 { - return Err(PkceVerificationError::TooShort); - } - - if verifier.len() > 128 { - return Err(PkceVerificationError::TooLong); - } - - if !verifier - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~') - { - return Err(PkceVerificationError::InvalidCharacters); - } - - if !self.challenge_method.verify(&self.challenge, verifier) { - return Err(PkceVerificationError::VerificationFailed); - } - - Ok(()) + pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> { + self.challenge_method.verify(&self.challenge, verifier) } } @@ -238,42 +207,3 @@ impl AuthorizationGrant { self.created_at - Duration::seconds(max_age.unwrap_or(3600 * 24 * 365)) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_pkce_verification() { - // This challenge is taken from the RFC7636 appendices - let pkce = Pkce::new( - PkceCodeChallengeMethod::S256, - "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM".to_string(), - ); - - assert_eq!( - pkce.verify("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"), - Ok(()), - ); - - assert_eq!( - pkce.verify("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"), - Err(PkceVerificationError::VerificationFailed), - ); - - assert_eq!( - pkce.verify("tooshort"), - Err(PkceVerificationError::TooShort), - ); - - assert_eq!( - pkce.verify("toolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolong"), - Err(PkceVerificationError::TooLong), - ); - - assert_eq!( - pkce.verify("this is long enough but has invalid characters in it"), - Err(PkceVerificationError::InvalidCharacters), - ); - } -} diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index fdf11254..ef512260 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -17,9 +17,7 @@ pub(self) mod client; pub(self) mod session; pub use self::{ - authorization_grant::{ - AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce, PkceVerificationError, - }, + authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, session::Session, }; diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 2d6d2fc4..8cc8d965 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -22,7 +22,7 @@ use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; use hyper::StatusCode; use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError}; use mas_config::Encrypter; -use mas_data_model::{AuthorizationGrantStage, Client, PkceVerificationError, TokenType}; +use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ claims::{self, ClaimError}, @@ -44,6 +44,7 @@ use mas_storage::{ }; use oauth2_types::{ errors::{INVALID_CLIENT, INVALID_GRANT, INVALID_REQUEST, SERVER_ERROR, UNAUTHORIZED_CLIENT}, + pkce::CodeChallengeError, requests::{ AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, }, @@ -87,7 +88,7 @@ pub(crate) enum RouteError { BadRequest, #[error("pkce verification failed")] - PkceVerification(#[from] PkceVerificationError), + PkceVerification(#[from] CodeChallengeError), #[error("client not found")] ClientNotFound, diff --git a/crates/oauth2-types/src/pkce.rs b/crates/oauth2-types/src/pkce.rs index 980f4a4f..0dc29ad0 100644 --- a/crates/oauth2-types/src/pkce.rs +++ b/crates/oauth2-types/src/pkce.rs @@ -18,18 +18,75 @@ use data_encoding::BASE64URL_NOPAD; use mas_iana::oauth::PkceCodeChallengeMethod; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; +use thiserror::Error; + +#[derive(Debug, Error, PartialEq)] +pub enum CodeChallengeError { + #[error("code_verifier should be at least 43 characters long")] + TooShort, + + #[error("code_verifier should be at most 128 characters long")] + TooLong, + + #[error("code_verifier contains invalid characters")] + InvalidCharacters, + + #[error("challenge verification failed")] + VerificationFailed, +} + +fn validate_verifier(verifier: &str) -> Result<(), CodeChallengeError> { + if verifier.len() < 43 { + return Err(CodeChallengeError::TooShort); + } + + if verifier.len() > 128 { + return Err(CodeChallengeError::TooLong); + } + + if !verifier + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~') + { + return Err(CodeChallengeError::InvalidCharacters); + } + + Ok(()) +} pub trait CodeChallengeMethodExt { - #[must_use] - fn compute_challenge(self, verifier: &str) -> Cow<'_, str>; + /// Compute the challenge for a given verifier + /// + /// # Errors + /// + /// Returns an error if the verifier did not adhere to the rules defined by + /// the RFC in terms of length and allowed characters + fn compute_challenge(self, verifier: &str) -> Result, CodeChallengeError>; - #[must_use] - fn verify(self, challenge: &str, verifier: &str) -> bool; + /// Verify that a given verifier is valid for the given challenge + /// + /// # Errors + /// + /// Returns an error if the verifier did not match the challenge, or if the + /// verifier did not adhere to the rules defined by the RFC in terms of + /// length and allowed characters + fn verify(self, challenge: &str, verifier: &str) -> Result<(), CodeChallengeError> + where + Self: Sized, + { + if self.compute_challenge(verifier)? == challenge { + Ok(()) + } else { + Err(CodeChallengeError::VerificationFailed) + } + } } impl CodeChallengeMethodExt for PkceCodeChallengeMethod { - fn compute_challenge(self, verifier: &str) -> Cow<'_, str> { - match self { + fn compute_challenge(self, verifier: &str) -> Result, CodeChallengeError> { + validate_verifier(verifier)?; + + let challenge = match self { Self::Plain => verifier.into(), Self::S256 => { let mut hasher = Sha256::new(); @@ -38,11 +95,9 @@ impl CodeChallengeMethodExt for PkceCodeChallengeMethod { let verifier = BASE64URL_NOPAD.encode(&hash); verifier.into() } - } - } + }; - fn verify(self, challenge: &str, verifier: &str) -> bool { - self.compute_challenge(verifier) == challenge + Ok(challenge) } } @@ -56,3 +111,44 @@ pub struct AuthorizationRequest { pub struct TokenRequest { pub code_challenge_verifier: String, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pkce_verification() { + use PkceCodeChallengeMethod::{Plain, S256}; + // This challenge comes from the RFC7636 appendices + let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + + assert!(S256 + .verify(challenge, "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk") + .is_ok()); + + assert!(Plain.verify(challenge, challenge).is_ok()); + + assert_eq!( + S256.verify(challenge, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"), + Err(CodeChallengeError::VerificationFailed), + ); + + assert_eq!( + S256.verify(challenge, "tooshort"), + Err(CodeChallengeError::TooShort), + ); + + assert_eq!( + S256.verify(challenge, "toolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolong"), + Err(CodeChallengeError::TooLong), + ); + + assert_eq!( + S256.verify( + challenge, + "this is long enough but has invalid characters in it" + ), + Err(CodeChallengeError::InvalidCharacters), + ); + } +}