diff --git a/crates/core/src/filters/authenticate.rs b/crates/core/src/filters/authenticate.rs index 1af5ebe4..680b8790 100644 --- a/crates/core/src/filters/authenticate.rs +++ b/crates/core/src/filters/authenticate.rs @@ -34,7 +34,7 @@ use crate::{ storage::oauth2::access_token::{ lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup, }, - tokens::{self, TokenFormatError, TokenType}, + tokens::{TokenFormatError, TokenType}, }; /// Bearer token authentication failed @@ -89,9 +89,9 @@ async fn authenticate( auth: Authorization, ) -> Result { let token = auth.0.token(); - let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?; + let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?; - if token_type != tokens::TokenType::AccessToken { + if token_type != TokenType::AccessToken { return Err(AuthenticationError::WrongTokenType(token_type).into()); } diff --git a/crates/core/src/handlers/oauth2/authorization.rs b/crates/core/src/handlers/oauth2/authorization.rs index 03a4db4c..a3b27130 100644 --- a/crates/core/src/handlers/oauth2/authorization.rs +++ b/crates/core/src/handlers/oauth2/authorization.rs @@ -62,7 +62,7 @@ use crate::{ SessionInfo, }, templates::{FormPostContext, Templates}, - tokens, + tokens::{AccessToken, RefreshToken}, }; #[derive(Deserialize)] @@ -428,8 +428,8 @@ async fn step( let (access_token, refresh_token) = { let mut rng = thread_rng(); ( - tokens::generate(&mut rng, tokens::TokenType::AccessToken), - tokens::generate(&mut rng, tokens::TokenType::RefreshToken), + AccessToken.generate(&mut rng), + RefreshToken.generate(&mut rng), ) }; diff --git a/crates/core/src/handlers/oauth2/introspection.rs b/crates/core/src/handlers/oauth2/introspection.rs index 673ca49a..ea2f56e6 100644 --- a/crates/core/src/handlers/oauth2/introspection.rs +++ b/crates/core/src/handlers/oauth2/introspection.rs @@ -26,7 +26,7 @@ use crate::{ database::connection, }, storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token}, - tokens, + tokens::{self, TokenType}, }; pub fn filter( @@ -70,7 +70,7 @@ async fn introspect( } let token = ¶ms.token; - let token_type = tokens::check(token).wrap_error()?; + let token_type = TokenType::check(token).wrap_error()?; if let Some(hint) = params.token_type_hint { if token_type != hint { info!("Token type hint did not match"); diff --git a/crates/core/src/handlers/oauth2/token.rs b/crates/core/src/handlers/oauth2/token.rs index f3526dc7..9f84463d 100644 --- a/crates/core/src/handlers/oauth2/token.rs +++ b/crates/core/src/handlers/oauth2/token.rs @@ -50,7 +50,7 @@ use crate::{ authorization_code::{consume_code, lookup_code}, refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token}, }, - tokens, + tokens::{AccessToken, RefreshToken}, }; #[skip_serializing_none] @@ -164,8 +164,8 @@ async fn authorization_code_grant( let (access_token, refresh_token) = { let mut rng = thread_rng(); ( - tokens::generate(&mut rng, tokens::TokenType::AccessToken), - tokens::generate(&mut rng, tokens::TokenType::RefreshToken), + AccessToken.generate(&mut rng), + RefreshToken.generate(&mut rng), ) }; @@ -234,8 +234,8 @@ async fn refresh_token_grant( let (access_token, refresh_token) = { let mut rng = thread_rng(); ( - tokens::generate(&mut rng, tokens::TokenType::AccessToken), - tokens::generate(&mut rng, tokens::TokenType::RefreshToken), + AccessToken.generate(&mut rng), + RefreshToken.generate(&mut rng), ) }; diff --git a/crates/core/src/handlers/oauth2/userinfo.rs b/crates/core/src/handlers/oauth2/userinfo.rs index fbd08a6d..fdac9ed4 100644 --- a/crates/core/src/handlers/oauth2/userinfo.rs +++ b/crates/core/src/handlers/oauth2/userinfo.rs @@ -18,7 +18,7 @@ use warp::{Filter, Rejection, Reply}; use crate::{ config::OAuth2Config, - filters::authenticate::{recover_unauthorized, authentication}, + filters::authenticate::{authentication, recover_unauthorized}, storage::oauth2::access_token::OAuth2AccessTokenLookup, }; diff --git a/crates/core/src/tokens.rs b/crates/core/src/tokens.rs index 69fb8bb7..d52a5a1d 100644 --- a/crates/core/src/tokens.rs +++ b/crates/core/src/tokens.rs @@ -27,6 +27,8 @@ pub enum TokenType { RefreshToken, } +pub use TokenType::*; + impl TokenType { fn prefix(self) -> &'static str { match self { @@ -42,6 +44,47 @@ impl TokenType { _ => None, } } + + pub fn generate(self, rng: impl Rng) -> String { + let random_part: String = rng + .sample_iter(&Alphanumeric) + .take(30) + .map(char::from) + .collect(); + + let base = format!("{}_{}", self.prefix(), random_part); + let crc = CRC.checksum(base.as_bytes()); + let crc = base62_encode(crc); + format!("{}_{}", base, crc) + } + + pub fn check(token: &str) -> Result { + let split: Vec<&str> = token.split('_').collect(); + let [prefix, random_part, crc]: [&str; 3] = split + .try_into() + .map_err(|_| TokenFormatError::InvalidFormat)?; + + if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 { + return Err(TokenFormatError::InvalidFormat); + } + + let token_type = + TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix { + prefix: prefix.to_string(), + })?; + + let base = format!("{}_{}", token_type.prefix(), random_part); + let expected_crc = CRC.checksum(base.as_bytes()); + let expected_crc = base62_encode(expected_crc); + if crc != expected_crc { + return Err(TokenFormatError::InvalidCrc { + expected: expected_crc, + got: crc.to_string(), + }); + } + + Ok(token_type) + } } impl PartialEq for TokenType { @@ -68,19 +111,6 @@ fn base62_encode(mut num: u32) -> String { const CRC: Crc = Crc::::new(&CRC_32_ISO_HDLC); -pub fn generate(rng: impl Rng, token_type: TokenType) -> String { - let random_part: String = rng - .sample_iter(&Alphanumeric) - .take(30) - .map(char::from) - .collect(); - - let base = format!("{}_{}", token_type.prefix(), random_part); - let crc = CRC.checksum(base.as_bytes()); - let crc = base62_encode(crc); - format!("{}_{}", base, crc) -} - #[derive(Debug, Error)] pub enum TokenFormatError { #[error("invalid token format")] @@ -93,34 +123,6 @@ pub enum TokenFormatError { InvalidCrc { expected: String, got: String }, } -pub fn check(token: &str) -> Result { - let split: Vec<&str> = token.split('_').collect(); - let [prefix, random_part, crc]: [&str; 3] = split - .try_into() - .map_err(|_| TokenFormatError::InvalidFormat)?; - - if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 { - return Err(TokenFormatError::InvalidFormat); - } - - let token_type = - TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix { - prefix: prefix.to_string(), - })?; - - let base = format!("{}_{}", token_type.prefix(), random_part); - let expected_crc = CRC.checksum(base.as_bytes()); - let expected_crc = base62_encode(expected_crc); - if crc != expected_crc { - return Err(TokenFormatError::InvalidCrc { - expected: expected_crc, - got: crc.to_string(), - }); - } - - Ok(token_type) -} - #[cfg(test)] mod tests { use std::collections::HashSet; @@ -153,7 +155,7 @@ mod tests { let mut rng = thread_rng(); // Generate many access tokens let tokens: HashSet = (0..COUNT) - .map(|_| generate(&mut rng, TokenType::AccessToken)) + .map(|_| TokenType::AccessToken.generate(&mut rng)) .collect(); // Check that they are all different @@ -161,18 +163,18 @@ mod tests { // Check that they are all valid and detected as access tokens for token in tokens { - assert_eq!(check(&token).unwrap(), TokenType::AccessToken); + assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken); } // Same, but for refresh tokens let tokens: HashSet = (0..COUNT) - .map(|_| generate(&mut rng, TokenType::RefreshToken)) + .map(|_| TokenType::RefreshToken.generate(&mut rng)) .collect(); assert_eq!(tokens.len(), COUNT, "All tokens are unique"); for token in tokens { - assert_eq!(check(&token).unwrap(), TokenType::RefreshToken); + assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken); } } }