You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Refactor token generation a bit
This commit is contained in:
@ -34,7 +34,7 @@ use crate::{
|
||||
storage::oauth2::access_token::{
|
||||
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
|
||||
},
|
||||
tokens::{self, TokenFormatError, TokenType},
|
||||
tokens::{TokenFormatError, TokenType},
|
||||
};
|
||||
|
||||
/// Bearer token authentication failed
|
||||
@ -89,9 +89,9 @@ async fn authenticate(
|
||||
auth: Authorization<Bearer>,
|
||||
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||
let token = auth.0.token();
|
||||
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||
|
||||
if token_type != tokens::TokenType::AccessToken {
|
||||
if token_type != TokenType::AccessToken {
|
||||
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ use crate::{
|
||||
SessionInfo,
|
||||
},
|
||||
templates::{FormPostContext, Templates},
|
||||
tokens,
|
||||
tokens::{AccessToken, RefreshToken},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -428,8 +428,8 @@ async fn step(
|
||||
let (access_token, refresh_token) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
||||
AccessToken.generate(&mut rng),
|
||||
RefreshToken.generate(&mut rng),
|
||||
)
|
||||
};
|
||||
|
||||
|
@ -26,7 +26,7 @@ use crate::{
|
||||
database::connection,
|
||||
},
|
||||
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
||||
tokens,
|
||||
tokens::{self, TokenType},
|
||||
};
|
||||
|
||||
pub fn filter(
|
||||
@ -70,7 +70,7 @@ async fn introspect(
|
||||
}
|
||||
|
||||
let token = ¶ms.token;
|
||||
let token_type = tokens::check(token).wrap_error()?;
|
||||
let token_type = TokenType::check(token).wrap_error()?;
|
||||
if let Some(hint) = params.token_type_hint {
|
||||
if token_type != hint {
|
||||
info!("Token type hint did not match");
|
||||
|
@ -50,7 +50,7 @@ use crate::{
|
||||
authorization_code::{consume_code, lookup_code},
|
||||
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
|
||||
},
|
||||
tokens,
|
||||
tokens::{AccessToken, RefreshToken},
|
||||
};
|
||||
|
||||
#[skip_serializing_none]
|
||||
@ -164,8 +164,8 @@ async fn authorization_code_grant(
|
||||
let (access_token, refresh_token) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
||||
AccessToken.generate(&mut rng),
|
||||
RefreshToken.generate(&mut rng),
|
||||
)
|
||||
};
|
||||
|
||||
@ -234,8 +234,8 @@ async fn refresh_token_grant(
|
||||
let (access_token, refresh_token) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
||||
AccessToken.generate(&mut rng),
|
||||
RefreshToken.generate(&mut rng),
|
||||
)
|
||||
};
|
||||
|
||||
|
@ -18,7 +18,7 @@ use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::OAuth2Config,
|
||||
filters::authenticate::{recover_unauthorized, authentication},
|
||||
filters::authenticate::{authentication, recover_unauthorized},
|
||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||
};
|
||||
|
||||
|
@ -27,6 +27,8 @@ pub enum TokenType {
|
||||
RefreshToken,
|
||||
}
|
||||
|
||||
pub use TokenType::*;
|
||||
|
||||
impl TokenType {
|
||||
fn prefix(self) -> &'static str {
|
||||
match self {
|
||||
@ -42,6 +44,47 @@ impl TokenType {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate(self, rng: impl Rng) -> String {
|
||||
let random_part: String = rng
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(30)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let base = format!("{}_{}", self.prefix(), random_part);
|
||||
let crc = CRC.checksum(base.as_bytes());
|
||||
let crc = base62_encode(crc);
|
||||
format!("{}_{}", base, crc)
|
||||
}
|
||||
|
||||
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
||||
let split: Vec<&str> = token.split('_').collect();
|
||||
let [prefix, random_part, crc]: [&str; 3] = split
|
||||
.try_into()
|
||||
.map_err(|_| TokenFormatError::InvalidFormat)?;
|
||||
|
||||
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
|
||||
return Err(TokenFormatError::InvalidFormat);
|
||||
}
|
||||
|
||||
let token_type =
|
||||
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
|
||||
prefix: prefix.to_string(),
|
||||
})?;
|
||||
|
||||
let base = format!("{}_{}", token_type.prefix(), random_part);
|
||||
let expected_crc = CRC.checksum(base.as_bytes());
|
||||
let expected_crc = base62_encode(expected_crc);
|
||||
if crc != expected_crc {
|
||||
return Err(TokenFormatError::InvalidCrc {
|
||||
expected: expected_crc,
|
||||
got: crc.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(token_type)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<TokenTypeHint> for TokenType {
|
||||
@ -68,19 +111,6 @@ fn base62_encode(mut num: u32) -> String {
|
||||
|
||||
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
|
||||
|
||||
pub fn generate(rng: impl Rng, token_type: TokenType) -> String {
|
||||
let random_part: String = rng
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(30)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let base = format!("{}_{}", token_type.prefix(), random_part);
|
||||
let crc = CRC.checksum(base.as_bytes());
|
||||
let crc = base62_encode(crc);
|
||||
format!("{}_{}", base, crc)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenFormatError {
|
||||
#[error("invalid token format")]
|
||||
@ -93,34 +123,6 @@ pub enum TokenFormatError {
|
||||
InvalidCrc { expected: String, got: String },
|
||||
}
|
||||
|
||||
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
||||
let split: Vec<&str> = token.split('_').collect();
|
||||
let [prefix, random_part, crc]: [&str; 3] = split
|
||||
.try_into()
|
||||
.map_err(|_| TokenFormatError::InvalidFormat)?;
|
||||
|
||||
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
|
||||
return Err(TokenFormatError::InvalidFormat);
|
||||
}
|
||||
|
||||
let token_type =
|
||||
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
|
||||
prefix: prefix.to_string(),
|
||||
})?;
|
||||
|
||||
let base = format!("{}_{}", token_type.prefix(), random_part);
|
||||
let expected_crc = CRC.checksum(base.as_bytes());
|
||||
let expected_crc = base62_encode(expected_crc);
|
||||
if crc != expected_crc {
|
||||
return Err(TokenFormatError::InvalidCrc {
|
||||
expected: expected_crc,
|
||||
got: crc.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(token_type)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
@ -153,7 +155,7 @@ mod tests {
|
||||
let mut rng = thread_rng();
|
||||
// Generate many access tokens
|
||||
let tokens: HashSet<String> = (0..COUNT)
|
||||
.map(|_| generate(&mut rng, TokenType::AccessToken))
|
||||
.map(|_| TokenType::AccessToken.generate(&mut rng))
|
||||
.collect();
|
||||
|
||||
// Check that they are all different
|
||||
@ -161,18 +163,18 @@ mod tests {
|
||||
|
||||
// Check that they are all valid and detected as access tokens
|
||||
for token in tokens {
|
||||
assert_eq!(check(&token).unwrap(), TokenType::AccessToken);
|
||||
assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken);
|
||||
}
|
||||
|
||||
// Same, but for refresh tokens
|
||||
let tokens: HashSet<String> = (0..COUNT)
|
||||
.map(|_| generate(&mut rng, TokenType::RefreshToken))
|
||||
.map(|_| TokenType::RefreshToken.generate(&mut rng))
|
||||
.collect();
|
||||
|
||||
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
|
||||
|
||||
for token in tokens {
|
||||
assert_eq!(check(&token).unwrap(), TokenType::RefreshToken);
|
||||
assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user