1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: make the Clock a trait

This commit is contained in:
Quentin Gliech
2023-01-18 12:20:30 +01:00
parent 73a921cc30
commit 142fdbd45a
62 changed files with 261 additions and 212 deletions

129
crates/storage/src/clock.rs Normal file
View File

@ -0,0 +1,129 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A [`Clock`] is a way to get the current date and time.
//!
//! This module defines two implemetation of the [`Clock`] trait:
//! [`SystemClock`] which uses the system time, and a [`MockClock`], which can
//! be used and freely manipulated in tests.
use std::sync::atomic::AtomicI64;
use chrono::{DateTime, TimeZone, Utc};
/// Represents a clock which can give the current date and time
pub trait Clock: Sync {
/// Get the current date and time
fn now(&self) -> DateTime<Utc>;
}
/// A clock which uses the system time
#[derive(Clone, Default)]
pub struct SystemClock {
_private: (),
}
impl Clock for SystemClock {
fn now(&self) -> DateTime<Utc> {
// This is the clock used elsewhere, it's fine to call Utc::now here
#[allow(clippy::disallowed_methods)]
Utc::now()
}
}
/// A fake clock, which uses a fixed timestamp, and can be advanced with the
/// [`MockClock::advance`] method.
///
/// ```rust
/// use mas_storage::clock::{Clock, MockClock};
/// use chrono::Duration;
///
/// let clock = MockClock::default();
/// let t1 = clock.now();
/// let t2 = clock.now();
/// assert_eq!(t1, t2);
///
/// clock.advance(Duration::seconds(10));
/// let t3 = clock.now();
/// assert_eq!(t2 + Duration::seconds(10), t3);
/// ```
pub struct MockClock {
timestamp: AtomicI64,
}
impl Default for MockClock {
fn default() -> Self {
let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap();
Self::new(datetime)
}
}
impl MockClock {
/// Create a new clock which starts at the given datetime
#[must_use]
pub fn new(datetime: DateTime<Utc>) -> Self {
let timestamp = AtomicI64::new(datetime.timestamp());
Self { timestamp }
}
/// Move the clock forward by the given amount of time
pub fn advance(&self, duration: chrono::Duration) {
self.timestamp
.fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed);
}
}
impl Clock for MockClock {
fn now(&self) -> DateTime<Utc> {
let timestamp = self.timestamp.load(std::sync::atomic::Ordering::Relaxed);
chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap()
}
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use super::*;
#[test]
fn test_mocked_clock() {
let clock = MockClock::default();
// Time should be frozen, and give out the same timestamp on each call
let first = clock.now();
std::thread::sleep(std::time::Duration::from_millis(10));
let second = clock.now();
assert_eq!(first, second);
// Clock can be advanced by a fixed duration
clock.advance(Duration::seconds(10));
let third = clock.now();
assert_eq!(first + Duration::seconds(10), third);
}
#[test]
fn test_real_clock() {
let clock = SystemClock::default();
// Time should not be frozen
let first = clock.now();
std::thread::sleep(std::time::Duration::from_millis(10));
let second = clock.now();
assert_ne!(first, second);
assert!(first < second);
}
}

View File

@ -37,7 +37,7 @@ pub trait CompatAccessTokenRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
@ -46,7 +46,7 @@ pub trait CompatAccessTokenRepository: Send + Sync {
/// Set the expiration time of the compat access token to now
async fn expire(
&mut self,
clock: &Clock,
clock: &dyn Clock,
compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error>;
}

View File

@ -36,7 +36,7 @@ pub trait CompatRefreshTokenRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
@ -45,7 +45,7 @@ pub trait CompatRefreshTokenRepository: Send + Sync {
/// Consume a compat refresh token
async fn consume(
&mut self,
clock: &Clock,
clock: &dyn Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error>;
}

View File

@ -30,7 +30,7 @@ pub trait CompatSessionRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error>;
@ -38,7 +38,7 @@ pub trait CompatSessionRepository: Send + Sync {
/// End a compat session
async fn finish(
&mut self,
clock: &Clock,
clock: &dyn Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error>;
}

View File

@ -37,7 +37,7 @@ pub trait CompatSsoLoginRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error>;
@ -45,7 +45,7 @@ pub trait CompatSsoLoginRepository: Send + Sync {
/// Fulfill a compat SSO login by providing a compat session
async fn fulfill(
&mut self,
clock: &Clock,
clock: &dyn Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error>;
@ -53,7 +53,7 @@ pub trait CompatSsoLoginRepository: Send + Sync {
/// Mark a compat SSO login as exchanged
async fn exchange(
&mut self,
clock: &Clock,
clock: &dyn Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error>;

View File

@ -28,92 +28,7 @@
clippy::module_name_repetitions
)]
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Default)]
pub struct Clock {
_private: (),
// #[cfg(test)]
mock: Option<std::sync::Arc<std::sync::atomic::AtomicI64>>,
}
impl Clock {
#[must_use]
pub fn now(&self) -> DateTime<Utc> {
// #[cfg(test)]
if let Some(timestamp) = &self.mock {
let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed);
return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap();
}
// This is the clock used elsewhere, it's fine to call Utc::now here
#[allow(clippy::disallowed_methods)]
Utc::now()
}
// #[cfg(test)]
#[must_use]
pub fn mock() -> Self {
use std::sync::{atomic::AtomicI64, Arc};
use chrono::TimeZone;
let datetime = Utc.with_ymd_and_hms(2022, 1, 16, 14, 40, 0).unwrap();
let timestamp = datetime.timestamp();
Self {
mock: Some(Arc::new(AtomicI64::new(timestamp))),
_private: (),
}
}
// #[cfg(test)]
pub fn advance(&self, duration: chrono::Duration) {
let timestamp = self
.mock
.as_ref()
.expect("Clock::advance should only be called on mocked clocks in tests");
timestamp.fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use super::*;
#[test]
fn test_mocked_clock() {
let clock = Clock::mock();
// Time should be frozen, and give out the same timestamp on each call
let first = clock.now();
std::thread::sleep(std::time::Duration::from_millis(10));
let second = clock.now();
assert_eq!(first, second);
// Clock can be advanced by a fixed duration
clock.advance(Duration::seconds(10));
let third = clock.now();
assert_eq!(first + Duration::seconds(10), third);
}
#[test]
fn test_real_clock() {
let clock = Clock::default();
// Time should not be frozen
let first = clock.now();
std::thread::sleep(std::time::Duration::from_millis(10));
let second = clock.now();
assert_ne!(first, second);
assert!(first < second);
}
}
pub mod clock;
pub mod compat;
pub mod oauth2;
@ -123,6 +38,7 @@ pub mod upstream_oauth2;
pub mod user;
pub use self::{
clock::{Clock, SystemClock},
pagination::{Page, Pagination},
repository::Repository,
};

View File

@ -37,7 +37,7 @@ pub trait OAuth2AccessTokenRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
session: &Session,
access_token: String,
expires_after: Duration,
@ -46,10 +46,10 @@ pub trait OAuth2AccessTokenRepository: Send + Sync {
/// Revoke an access token
async fn revoke(
&mut self,
clock: &Clock,
clock: &dyn Clock,
access_token: AccessToken,
) -> Result<AccessToken, Self::Error>;
/// Cleanup expired access tokens
async fn cleanup_expired(&mut self, clock: &Clock) -> Result<usize, Self::Error>;
async fn cleanup_expired(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error>;
}

View File

@ -31,7 +31,7 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
client: &Client,
redirect_uri: Url,
scope: Scope,
@ -51,14 +51,14 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync {
async fn fulfill(
&mut self,
clock: &Clock,
clock: &dyn Clock,
session: &Session,
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;
async fn exchange(
&mut self,
clock: &Clock,
clock: &dyn Clock,
authorization_grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, Self::Error>;

View File

@ -45,7 +45,7 @@ pub trait OAuth2ClientRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
redirect_uris: Vec<Url>,
encrypted_client_secret: Option<String>,
grant_types: Vec<GrantType>,
@ -68,7 +68,7 @@ pub trait OAuth2ClientRepository: Send + Sync {
async fn add_from_config(
&mut self,
mut rng: impl Rng + Send,
clock: &Clock,
clock: &dyn Clock,
client_id: Ulid,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<String>,
@ -86,7 +86,7 @@ pub trait OAuth2ClientRepository: Send + Sync {
async fn give_consent_for_user(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
client: &Client,
user: &User,
scope: &Scope,

View File

@ -36,7 +36,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
session: &Session,
access_token: &AccessToken,
refresh_token: String,
@ -45,7 +45,7 @@ pub trait OAuth2RefreshTokenRepository: Send + Sync {
/// Consume a refresh token
async fn consume(
&mut self,
clock: &Clock,
clock: &dyn Clock,
refresh_token: RefreshToken,
) -> Result<RefreshToken, Self::Error>;
}

View File

@ -28,12 +28,12 @@ pub trait OAuth2SessionRepository: Send + Sync {
async fn create_from_grant(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
grant: &AuthorizationGrant,
user_session: &BrowserSession,
) -> Result<Session, Self::Error>;
async fn finish(&mut self, clock: &Clock, session: Session) -> Result<Session, Self::Error>;
async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result<Session, Self::Error>;
async fn list_paginated(
&mut self,

View File

@ -37,7 +37,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
) -> Result<UpstreamOAuthLink, Self::Error>;

View File

@ -33,7 +33,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,

View File

@ -33,7 +33,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
state: String,
code_challenge_verifier: Option<String>,
@ -43,7 +43,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
/// Mark a session as completed and associate the given link
async fn complete_with_link(
&mut self,
clock: &Clock,
clock: &dyn Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
@ -52,7 +52,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
/// Mark a session as consumed
async fn consume(
&mut self,
clock: &Clock,
clock: &dyn Clock,
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
}

View File

@ -38,7 +38,7 @@ pub trait UserEmailRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user: &User,
email: String,
) -> Result<UserEmail, Self::Error>;
@ -46,7 +46,7 @@ pub trait UserEmailRepository: Send + Sync {
async fn mark_as_verified(
&mut self,
clock: &Clock,
clock: &dyn Clock,
user_email: UserEmail,
) -> Result<UserEmail, Self::Error>;
@ -55,7 +55,7 @@ pub trait UserEmailRepository: Send + Sync {
async fn add_verification_code(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user_email: &UserEmail,
max_age: chrono::Duration,
code: String,
@ -63,14 +63,14 @@ pub trait UserEmailRepository: Send + Sync {
async fn find_verification_code(
&mut self,
clock: &Clock,
clock: &dyn Clock,
user_email: &UserEmail,
code: &str,
) -> Result<Option<UserEmailVerification>, Self::Error>;
async fn consume_verification_code(
&mut self,
clock: &Clock,
clock: &dyn Clock,
verification: UserEmailVerification,
) -> Result<UserEmailVerification, Self::Error>;
}

View File

@ -36,7 +36,7 @@ pub trait UserRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
username: String,
) -> Result<User, Self::Error>;
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;

View File

@ -26,7 +26,7 @@ pub trait UserPasswordRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user: &User,
version: u16,
hashed_password: String,

View File

@ -27,12 +27,12 @@ pub trait BrowserSessionRepository: Send + Sync {
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user: &User,
) -> Result<BrowserSession, Self::Error>;
async fn finish(
&mut self,
clock: &Clock,
clock: &dyn Clock,
user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error>;
async fn list_active_paginated(
@ -45,7 +45,7 @@ pub trait BrowserSessionRepository: Send + Sync {
async fn authenticate_with_password(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user_session: BrowserSession,
user_password: &Password,
) -> Result<BrowserSession, Self::Error>;
@ -53,7 +53,7 @@ pub trait BrowserSessionRepository: Send + Sync {
async fn authenticate_with_upstream(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
clock: &dyn Clock,
user_session: BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink,
) -> Result<BrowserSession, Self::Error>;