1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Refactor token generation a bit

This commit is contained in:
Quentin Gliech
2021-09-23 14:24:44 +02:00
parent 29126e336e
commit a9f1f8bb71
6 changed files with 61 additions and 59 deletions

View File

@ -34,7 +34,7 @@ use crate::{
storage::oauth2::access_token::{
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
},
tokens::{self, TokenFormatError, TokenType},
tokens::{TokenFormatError, TokenType},
};
/// Bearer token authentication failed
@ -89,9 +89,9 @@ async fn authenticate(
auth: Authorization<Bearer>,
) -> Result<OAuth2AccessTokenLookup, Rejection> {
let token = auth.0.token();
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
if token_type != tokens::TokenType::AccessToken {
if token_type != TokenType::AccessToken {
return Err(AuthenticationError::WrongTokenType(token_type).into());
}

View File

@ -62,7 +62,7 @@ use crate::{
SessionInfo,
},
templates::{FormPostContext, Templates},
tokens,
tokens::{AccessToken, RefreshToken},
};
#[derive(Deserialize)]
@ -428,8 +428,8 @@ async fn step(
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
AccessToken.generate(&mut rng),
RefreshToken.generate(&mut rng),
)
};

View File

@ -26,7 +26,7 @@ use crate::{
database::connection,
},
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
tokens,
tokens::{self, TokenType},
};
pub fn filter(
@ -70,7 +70,7 @@ async fn introspect(
}
let token = &params.token;
let token_type = tokens::check(token).wrap_error()?;
let token_type = TokenType::check(token).wrap_error()?;
if let Some(hint) = params.token_type_hint {
if token_type != hint {
info!("Token type hint did not match");

View File

@ -50,7 +50,7 @@ use crate::{
authorization_code::{consume_code, lookup_code},
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
},
tokens,
tokens::{AccessToken, RefreshToken},
};
#[skip_serializing_none]
@ -164,8 +164,8 @@ async fn authorization_code_grant(
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
AccessToken.generate(&mut rng),
RefreshToken.generate(&mut rng),
)
};
@ -234,8 +234,8 @@ async fn refresh_token_grant(
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
AccessToken.generate(&mut rng),
RefreshToken.generate(&mut rng),
)
};

View File

@ -18,7 +18,7 @@ use warp::{Filter, Rejection, Reply};
use crate::{
config::OAuth2Config,
filters::authenticate::{recover_unauthorized, authentication},
filters::authenticate::{authentication, recover_unauthorized},
storage::oauth2::access_token::OAuth2AccessTokenLookup,
};

View File

@ -27,6 +27,8 @@ pub enum TokenType {
RefreshToken,
}
pub use TokenType::*;
impl TokenType {
fn prefix(self) -> &'static str {
match self {
@ -42,6 +44,47 @@ impl TokenType {
_ => None,
}
}
pub fn generate(self, rng: impl Rng) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let base = format!("{}_{}", self.prefix(), random_part);
let crc = CRC.checksum(base.as_bytes());
let crc = base62_encode(crc);
format!("{}_{}", base, crc)
}
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split
.try_into()
.map_err(|_| TokenFormatError::InvalidFormat)?;
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
return Err(TokenFormatError::InvalidFormat);
}
let token_type =
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
prefix: prefix.to_string(),
})?;
let base = format!("{}_{}", token_type.prefix(), random_part);
let expected_crc = CRC.checksum(base.as_bytes());
let expected_crc = base62_encode(expected_crc);
if crc != expected_crc {
return Err(TokenFormatError::InvalidCrc {
expected: expected_crc,
got: crc.to_string(),
});
}
Ok(token_type)
}
}
impl PartialEq<TokenTypeHint> for TokenType {
@ -68,19 +111,6 @@ fn base62_encode(mut num: u32) -> String {
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
pub fn generate(rng: impl Rng, token_type: TokenType) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let base = format!("{}_{}", token_type.prefix(), random_part);
let crc = CRC.checksum(base.as_bytes());
let crc = base62_encode(crc);
format!("{}_{}", base, crc)
}
#[derive(Debug, Error)]
pub enum TokenFormatError {
#[error("invalid token format")]
@ -93,34 +123,6 @@ pub enum TokenFormatError {
InvalidCrc { expected: String, got: String },
}
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split
.try_into()
.map_err(|_| TokenFormatError::InvalidFormat)?;
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
return Err(TokenFormatError::InvalidFormat);
}
let token_type =
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
prefix: prefix.to_string(),
})?;
let base = format!("{}_{}", token_type.prefix(), random_part);
let expected_crc = CRC.checksum(base.as_bytes());
let expected_crc = base62_encode(expected_crc);
if crc != expected_crc {
return Err(TokenFormatError::InvalidCrc {
expected: expected_crc,
got: crc.to_string(),
});
}
Ok(token_type)
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
@ -153,7 +155,7 @@ mod tests {
let mut rng = thread_rng();
// Generate many access tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| generate(&mut rng, TokenType::AccessToken))
.map(|_| TokenType::AccessToken.generate(&mut rng))
.collect();
// Check that they are all different
@ -161,18 +163,18 @@ mod tests {
// Check that they are all valid and detected as access tokens
for token in tokens {
assert_eq!(check(&token).unwrap(), TokenType::AccessToken);
assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken);
}
// Same, but for refresh tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| generate(&mut rng, TokenType::RefreshToken))
.map(|_| TokenType::RefreshToken.generate(&mut rng))
.collect();
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
for token in tokens {
assert_eq!(check(&token).unwrap(), TokenType::RefreshToken);
assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken);
}
}
}