1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Refactor token generation a bit

This commit is contained in:
Quentin Gliech
2021-09-23 14:24:44 +02:00
parent 29126e336e
commit a9f1f8bb71
6 changed files with 61 additions and 59 deletions

View File

@@ -34,7 +34,7 @@ use crate::{
storage::oauth2::access_token::{ storage::oauth2::access_token::{
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup, lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
}, },
tokens::{self, TokenFormatError, TokenType}, tokens::{TokenFormatError, TokenType},
}; };
/// Bearer token authentication failed /// Bearer token authentication failed
@@ -89,9 +89,9 @@ async fn authenticate(
auth: Authorization<Bearer>, auth: Authorization<Bearer>,
) -> Result<OAuth2AccessTokenLookup, Rejection> { ) -> Result<OAuth2AccessTokenLookup, Rejection> {
let token = auth.0.token(); let token = auth.0.token();
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?; let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
if token_type != tokens::TokenType::AccessToken { if token_type != TokenType::AccessToken {
return Err(AuthenticationError::WrongTokenType(token_type).into()); return Err(AuthenticationError::WrongTokenType(token_type).into());
} }

View File

@@ -62,7 +62,7 @@ use crate::{
SessionInfo, SessionInfo,
}, },
templates::{FormPostContext, Templates}, templates::{FormPostContext, Templates},
tokens, tokens::{AccessToken, RefreshToken},
}; };
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -428,8 +428,8 @@ async fn step(
let (access_token, refresh_token) = { let (access_token, refresh_token) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
( (
tokens::generate(&mut rng, tokens::TokenType::AccessToken), AccessToken.generate(&mut rng),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken), RefreshToken.generate(&mut rng),
) )
}; };

View File

@@ -26,7 +26,7 @@ use crate::{
database::connection, database::connection,
}, },
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token}, storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
tokens, tokens::{self, TokenType},
}; };
pub fn filter( pub fn filter(
@@ -70,7 +70,7 @@ async fn introspect(
} }
let token = &params.token; let token = &params.token;
let token_type = tokens::check(token).wrap_error()?; let token_type = TokenType::check(token).wrap_error()?;
if let Some(hint) = params.token_type_hint { if let Some(hint) = params.token_type_hint {
if token_type != hint { if token_type != hint {
info!("Token type hint did not match"); info!("Token type hint did not match");

View File

@@ -50,7 +50,7 @@ use crate::{
authorization_code::{consume_code, lookup_code}, authorization_code::{consume_code, lookup_code},
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token}, refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
}, },
tokens, tokens::{AccessToken, RefreshToken},
}; };
#[skip_serializing_none] #[skip_serializing_none]
@@ -164,8 +164,8 @@ async fn authorization_code_grant(
let (access_token, refresh_token) = { let (access_token, refresh_token) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
( (
tokens::generate(&mut rng, tokens::TokenType::AccessToken), AccessToken.generate(&mut rng),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken), RefreshToken.generate(&mut rng),
) )
}; };
@@ -234,8 +234,8 @@ async fn refresh_token_grant(
let (access_token, refresh_token) = { let (access_token, refresh_token) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
( (
tokens::generate(&mut rng, tokens::TokenType::AccessToken), AccessToken.generate(&mut rng),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken), RefreshToken.generate(&mut rng),
) )
}; };

View File

@@ -18,7 +18,7 @@ use warp::{Filter, Rejection, Reply};
use crate::{ use crate::{
config::OAuth2Config, config::OAuth2Config,
filters::authenticate::{recover_unauthorized, authentication}, filters::authenticate::{authentication, recover_unauthorized},
storage::oauth2::access_token::OAuth2AccessTokenLookup, storage::oauth2::access_token::OAuth2AccessTokenLookup,
}; };

View File

@@ -27,6 +27,8 @@ pub enum TokenType {
RefreshToken, RefreshToken,
} }
pub use TokenType::*;
impl TokenType { impl TokenType {
fn prefix(self) -> &'static str { fn prefix(self) -> &'static str {
match self { match self {
@@ -42,6 +44,47 @@ impl TokenType {
_ => None, _ => None,
} }
} }
pub fn generate(self, rng: impl Rng) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let base = format!("{}_{}", self.prefix(), random_part);
let crc = CRC.checksum(base.as_bytes());
let crc = base62_encode(crc);
format!("{}_{}", base, crc)
}
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split
.try_into()
.map_err(|_| TokenFormatError::InvalidFormat)?;
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
return Err(TokenFormatError::InvalidFormat);
}
let token_type =
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
prefix: prefix.to_string(),
})?;
let base = format!("{}_{}", token_type.prefix(), random_part);
let expected_crc = CRC.checksum(base.as_bytes());
let expected_crc = base62_encode(expected_crc);
if crc != expected_crc {
return Err(TokenFormatError::InvalidCrc {
expected: expected_crc,
got: crc.to_string(),
});
}
Ok(token_type)
}
} }
impl PartialEq<TokenTypeHint> for TokenType { impl PartialEq<TokenTypeHint> for TokenType {
@@ -68,19 +111,6 @@ fn base62_encode(mut num: u32) -> String {
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC); const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
pub fn generate(rng: impl Rng, token_type: TokenType) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let base = format!("{}_{}", token_type.prefix(), random_part);
let crc = CRC.checksum(base.as_bytes());
let crc = base62_encode(crc);
format!("{}_{}", base, crc)
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum TokenFormatError { pub enum TokenFormatError {
#[error("invalid token format")] #[error("invalid token format")]
@@ -93,34 +123,6 @@ pub enum TokenFormatError {
InvalidCrc { expected: String, got: String }, InvalidCrc { expected: String, got: String },
} }
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split
.try_into()
.map_err(|_| TokenFormatError::InvalidFormat)?;
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
return Err(TokenFormatError::InvalidFormat);
}
let token_type =
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
prefix: prefix.to_string(),
})?;
let base = format!("{}_{}", token_type.prefix(), random_part);
let expected_crc = CRC.checksum(base.as_bytes());
let expected_crc = base62_encode(expected_crc);
if crc != expected_crc {
return Err(TokenFormatError::InvalidCrc {
expected: expected_crc,
got: crc.to_string(),
});
}
Ok(token_type)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashSet; use std::collections::HashSet;
@@ -153,7 +155,7 @@ mod tests {
let mut rng = thread_rng(); let mut rng = thread_rng();
// Generate many access tokens // Generate many access tokens
let tokens: HashSet<String> = (0..COUNT) let tokens: HashSet<String> = (0..COUNT)
.map(|_| generate(&mut rng, TokenType::AccessToken)) .map(|_| TokenType::AccessToken.generate(&mut rng))
.collect(); .collect();
// Check that they are all different // Check that they are all different
@@ -161,18 +163,18 @@ mod tests {
// Check that they are all valid and detected as access tokens // Check that they are all valid and detected as access tokens
for token in tokens { for token in tokens {
assert_eq!(check(&token).unwrap(), TokenType::AccessToken); assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken);
} }
// Same, but for refresh tokens // Same, but for refresh tokens
let tokens: HashSet<String> = (0..COUNT) let tokens: HashSet<String> = (0..COUNT)
.map(|_| generate(&mut rng, TokenType::RefreshToken)) .map(|_| TokenType::RefreshToken.generate(&mut rng))
.collect(); .collect();
assert_eq!(tokens.len(), COUNT, "All tokens are unique"); assert_eq!(tokens.len(), COUNT, "All tokens are unique");
for token in tokens { for token in tokens {
assert_eq!(check(&token).unwrap(), TokenType::RefreshToken); assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken);
} }
} }
} }