You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-07 17:03:01 +03:00
Refactor token generation a bit
This commit is contained in:
@@ -34,7 +34,7 @@ use crate::{
|
|||||||
storage::oauth2::access_token::{
|
storage::oauth2::access_token::{
|
||||||
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
|
lookup_access_token, AccessTokenLookupError, OAuth2AccessTokenLookup,
|
||||||
},
|
},
|
||||||
tokens::{self, TokenFormatError, TokenType},
|
tokens::{TokenFormatError, TokenType},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Bearer token authentication failed
|
/// Bearer token authentication failed
|
||||||
@@ -89,9 +89,9 @@ async fn authenticate(
|
|||||||
auth: Authorization<Bearer>,
|
auth: Authorization<Bearer>,
|
||||||
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||||
let token = auth.0.token();
|
let token = auth.0.token();
|
||||||
let token_type = tokens::check(token).map_err(AuthenticationError::TokenFormat)?;
|
let token_type = TokenType::check(token).map_err(AuthenticationError::TokenFormat)?;
|
||||||
|
|
||||||
if token_type != tokens::TokenType::AccessToken {
|
if token_type != TokenType::AccessToken {
|
||||||
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
return Err(AuthenticationError::WrongTokenType(token_type).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -62,7 +62,7 @@ use crate::{
|
|||||||
SessionInfo,
|
SessionInfo,
|
||||||
},
|
},
|
||||||
templates::{FormPostContext, Templates},
|
templates::{FormPostContext, Templates},
|
||||||
tokens,
|
tokens::{AccessToken, RefreshToken},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -428,8 +428,8 @@ async fn step(
|
|||||||
let (access_token, refresh_token) = {
|
let (access_token, refresh_token) = {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
(
|
(
|
||||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
AccessToken.generate(&mut rng),
|
||||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
RefreshToken.generate(&mut rng),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -26,7 +26,7 @@ use crate::{
|
|||||||
database::connection,
|
database::connection,
|
||||||
},
|
},
|
||||||
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
||||||
tokens,
|
tokens::{self, TokenType},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn filter(
|
pub fn filter(
|
||||||
@@ -70,7 +70,7 @@ async fn introspect(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let token = ¶ms.token;
|
let token = ¶ms.token;
|
||||||
let token_type = tokens::check(token).wrap_error()?;
|
let token_type = TokenType::check(token).wrap_error()?;
|
||||||
if let Some(hint) = params.token_type_hint {
|
if let Some(hint) = params.token_type_hint {
|
||||||
if token_type != hint {
|
if token_type != hint {
|
||||||
info!("Token type hint did not match");
|
info!("Token type hint did not match");
|
||||||
|
@@ -50,7 +50,7 @@ use crate::{
|
|||||||
authorization_code::{consume_code, lookup_code},
|
authorization_code::{consume_code, lookup_code},
|
||||||
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
|
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
|
||||||
},
|
},
|
||||||
tokens,
|
tokens::{AccessToken, RefreshToken},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
@@ -164,8 +164,8 @@ async fn authorization_code_grant(
|
|||||||
let (access_token, refresh_token) = {
|
let (access_token, refresh_token) = {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
(
|
(
|
||||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
AccessToken.generate(&mut rng),
|
||||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
RefreshToken.generate(&mut rng),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -234,8 +234,8 @@ async fn refresh_token_grant(
|
|||||||
let (access_token, refresh_token) = {
|
let (access_token, refresh_token) = {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
(
|
(
|
||||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
AccessToken.generate(&mut rng),
|
||||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
RefreshToken.generate(&mut rng),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ use warp::{Filter, Rejection, Reply};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::OAuth2Config,
|
config::OAuth2Config,
|
||||||
filters::authenticate::{recover_unauthorized, authentication},
|
filters::authenticate::{authentication, recover_unauthorized},
|
||||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -27,6 +27,8 @@ pub enum TokenType {
|
|||||||
RefreshToken,
|
RefreshToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub use TokenType::*;
|
||||||
|
|
||||||
impl TokenType {
|
impl TokenType {
|
||||||
fn prefix(self) -> &'static str {
|
fn prefix(self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
@@ -42,57 +44,20 @@ impl TokenType {
|
|||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq<TokenTypeHint> for TokenType {
|
pub fn generate(self, rng: impl Rng) -> String {
|
||||||
fn eq(&self, other: &TokenTypeHint) -> bool {
|
|
||||||
matches!(
|
|
||||||
(self, other),
|
|
||||||
(TokenType::AccessToken, TokenTypeHint::AccessToken)
|
|
||||||
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
|
||||||
|
|
||||||
fn base62_encode(mut num: u32) -> String {
|
|
||||||
let mut res = String::with_capacity(6);
|
|
||||||
while num > 0 {
|
|
||||||
res.push(NUM[(num % 62) as usize] as char);
|
|
||||||
num /= 62;
|
|
||||||
}
|
|
||||||
|
|
||||||
format!("{:0>6}", res)
|
|
||||||
}
|
|
||||||
|
|
||||||
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
|
|
||||||
|
|
||||||
pub fn generate(rng: impl Rng, token_type: TokenType) -> String {
|
|
||||||
let random_part: String = rng
|
let random_part: String = rng
|
||||||
.sample_iter(&Alphanumeric)
|
.sample_iter(&Alphanumeric)
|
||||||
.take(30)
|
.take(30)
|
||||||
.map(char::from)
|
.map(char::from)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let base = format!("{}_{}", token_type.prefix(), random_part);
|
let base = format!("{}_{}", self.prefix(), random_part);
|
||||||
let crc = CRC.checksum(base.as_bytes());
|
let crc = CRC.checksum(base.as_bytes());
|
||||||
let crc = base62_encode(crc);
|
let crc = base62_encode(crc);
|
||||||
format!("{}_{}", base, crc)
|
format!("{}_{}", base, crc)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum TokenFormatError {
|
|
||||||
#[error("invalid token format")]
|
|
||||||
InvalidFormat,
|
|
||||||
|
|
||||||
#[error("unknown token prefix {prefix:?}")]
|
|
||||||
UnknownPrefix { prefix: String },
|
|
||||||
|
|
||||||
#[error("invalid crc {got:?}, expected {expected:?}")]
|
|
||||||
InvalidCrc { expected: String, got: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
||||||
let split: Vec<&str> = token.split('_').collect();
|
let split: Vec<&str> = token.split('_').collect();
|
||||||
let [prefix, random_part, crc]: [&str; 3] = split
|
let [prefix, random_part, crc]: [&str; 3] = split
|
||||||
@@ -120,6 +85,43 @@ pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
|||||||
|
|
||||||
Ok(token_type)
|
Ok(token_type)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<TokenTypeHint> for TokenType {
|
||||||
|
fn eq(&self, other: &TokenTypeHint) -> bool {
|
||||||
|
matches!(
|
||||||
|
(self, other),
|
||||||
|
(TokenType::AccessToken, TokenTypeHint::AccessToken)
|
||||||
|
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
||||||
|
|
||||||
|
fn base62_encode(mut num: u32) -> String {
|
||||||
|
let mut res = String::with_capacity(6);
|
||||||
|
while num > 0 {
|
||||||
|
res.push(NUM[(num % 62) as usize] as char);
|
||||||
|
num /= 62;
|
||||||
|
}
|
||||||
|
|
||||||
|
format!("{:0>6}", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum TokenFormatError {
|
||||||
|
#[error("invalid token format")]
|
||||||
|
InvalidFormat,
|
||||||
|
|
||||||
|
#[error("unknown token prefix {prefix:?}")]
|
||||||
|
UnknownPrefix { prefix: String },
|
||||||
|
|
||||||
|
#[error("invalid crc {got:?}, expected {expected:?}")]
|
||||||
|
InvalidCrc { expected: String, got: String },
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@@ -153,7 +155,7 @@ mod tests {
|
|||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
// Generate many access tokens
|
// Generate many access tokens
|
||||||
let tokens: HashSet<String> = (0..COUNT)
|
let tokens: HashSet<String> = (0..COUNT)
|
||||||
.map(|_| generate(&mut rng, TokenType::AccessToken))
|
.map(|_| TokenType::AccessToken.generate(&mut rng))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Check that they are all different
|
// Check that they are all different
|
||||||
@@ -161,18 +163,18 @@ mod tests {
|
|||||||
|
|
||||||
// Check that they are all valid and detected as access tokens
|
// Check that they are all valid and detected as access tokens
|
||||||
for token in tokens {
|
for token in tokens {
|
||||||
assert_eq!(check(&token).unwrap(), TokenType::AccessToken);
|
assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same, but for refresh tokens
|
// Same, but for refresh tokens
|
||||||
let tokens: HashSet<String> = (0..COUNT)
|
let tokens: HashSet<String> = (0..COUNT)
|
||||||
.map(|_| generate(&mut rng, TokenType::RefreshToken))
|
.map(|_| TokenType::RefreshToken.generate(&mut rng))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
|
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
|
||||||
|
|
||||||
for token in tokens {
|
for token in tokens {
|
||||||
assert_eq!(check(&token).unwrap(), TokenType::RefreshToken);
|
assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user