1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Split the core crate

This commit is contained in:
Quentin Gliech
2021-12-17 18:04:30 +01:00
parent ceb17d3646
commit 2f97ca685d
45 changed files with 418 additions and 408 deletions

View File

@ -10,5 +10,7 @@ chrono = "0.4.19"
thiserror = "1.0.30"
serde = "1.0.131"
url = { version = "2.2.2", features = ["serde"] }
crc = "2.1.0"
rand = "0.8.4"
oauth2-types = { path = "../oauth2-types" }

View File

@ -31,7 +31,7 @@ pub use self::{
oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
},
tokens::{AccessToken, RefreshToken},
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker},
users::{Authentication, BrowserSession, User},
};

View File

@ -13,6 +13,10 @@
// limitations under the License.
use chrono::{DateTime, Duration, Utc};
use crc::{Crc, CRC_32_ISO_HDLC};
use oauth2_types::requests::TokenTypeHint;
use rand::{distributions::Alphanumeric, Rng};
use thiserror::Error;
use crate::traits::{StorageBackend, StorageBackendMarker};
@ -61,3 +65,200 @@ impl<S: StorageBackendMarker> From<RefreshToken<S>> for RefreshToken<()> {
}
}
}
/// Type of token to generate or validate
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType {
/// An access token, used by Relying Parties to authenticate requests
AccessToken,
/// A refresh token, used by the refresh token grant
RefreshToken,
}
impl TokenType {
fn prefix(self) -> &'static str {
match self {
TokenType::AccessToken => "mat",
TokenType::RefreshToken => "mar",
}
}
fn match_prefix(prefix: &str) -> Option<Self> {
match prefix {
"mat" => Some(TokenType::AccessToken),
"mar" => Some(TokenType::RefreshToken),
_ => None,
}
}
/// Generate a token for the given type
///
/// ```rust
/// extern crate rand;
///
/// use rand::thread_rng;
/// use mas_data_model::TokenType::{AccessToken, RefreshToken};
///
/// AccessToken.generate(thread_rng());
/// RefreshToken.generate(thread_rng());
/// ```
pub fn generate(self, rng: impl Rng) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let base = format!("{}_{}", self.prefix(), random_part);
let crc = CRC.checksum(base.as_bytes());
let crc = base62_encode(crc);
format!("{}_{}", base, crc)
}
/// Check the format of a token and determine its type
///
/// ```rust
/// use mas_data_model::TokenType;
///
/// assert_eq!(
/// TokenType::check("mat_kkLSacJDpek22jKWw4AcXG68b7U3W6_0Lg9yb"),
/// Ok(TokenType::AccessToken)
/// );
///
/// assert_eq!(
/// TokenType::check("mar_PkpplxPkfjsqvtdfUlYR1Afg2TpaHF_GaTQd2"),
/// Ok(TokenType::RefreshToken)
/// );
/// ```
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split
.try_into()
.map_err(|_| TokenFormatError::InvalidFormat)?;
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
return Err(TokenFormatError::InvalidFormat);
}
let token_type =
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
prefix: prefix.to_string(),
})?;
let base = format!("{}_{}", token_type.prefix(), random_part);
let expected_crc = CRC.checksum(base.as_bytes());
let expected_crc = base62_encode(expected_crc);
if crc != expected_crc {
return Err(TokenFormatError::InvalidCrc {
expected: expected_crc,
got: crc.to_string(),
});
}
Ok(token_type)
}
}
impl PartialEq<TokenTypeHint> for TokenType {
fn eq(&self, other: &TokenTypeHint) -> bool {
matches!(
(self, other),
(TokenType::AccessToken, TokenTypeHint::AccessToken)
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken)
)
}
}
const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
fn base62_encode(mut num: u32) -> String {
let mut res = String::with_capacity(6);
while num > 0 {
res.push(NUM[(num % 62) as usize] as char);
num /= 62;
}
format!("{:0>6}", res)
}
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
/// Invalid token
#[derive(Debug, Error, PartialEq)]
pub enum TokenFormatError {
/// Overall token format is invalid
#[error("invalid token format")]
InvalidFormat,
/// Token used an unknown prefix
#[error("unknown token prefix {prefix:?}")]
UnknownPrefix {
/// The prefix found in the token
prefix: String,
},
/// The CRC checksum in the token is invalid
#[error("invalid crc {got:?}, expected {expected:?}")]
InvalidCrc {
/// The CRC hash expected to be found in the token
expected: String,
/// The CRC found in the token
got: String,
},
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use rand::thread_rng;
use super::*;
#[test]
fn test_prefix_match() {
use TokenType::{AccessToken, RefreshToken};
assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
assert_eq!(TokenType::match_prefix("matt"), None);
assert_eq!(TokenType::match_prefix("marr"), None);
assert_eq!(TokenType::match_prefix("ma"), None);
assert_eq!(
TokenType::match_prefix(TokenType::AccessToken.prefix()),
Some(TokenType::AccessToken)
);
assert_eq!(
TokenType::match_prefix(TokenType::RefreshToken.prefix()),
Some(TokenType::RefreshToken)
);
}
#[test]
fn test_generate_and_check() {
const COUNT: usize = 500; // Generate 500 of each token type
let mut rng = thread_rng();
// Generate many access tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| TokenType::AccessToken.generate(&mut rng))
.collect();
// Check that they are all different
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
// Check that they are all valid and detected as access tokens
for token in tokens {
assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken);
}
// Same, but for refresh tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| TokenType::RefreshToken.generate(&mut rng))
.collect();
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
for token in tokens {
assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken);
}
}
}