diff --git a/crates/handlers/src/health.rs b/crates/handlers/src/health.rs index 10638497..304f494f 100644 --- a/crates/handlers/src/health.rs +++ b/crates/handlers/src/health.rs @@ -31,24 +31,18 @@ pub async fn get(State(pool): State) -> Result Result<(), anyhow::Error> { - let state = crate::test_state(pool).await?; - let app = crate::healthcheck_router().with_state(state); + async fn test_get_health(pool: PgPool) { + let state = TestState::from_pool(pool).await.unwrap(); + let request = Request::get("/health").empty(); - let request = Request::builder().uri("/health").body(Body::empty())?; - - let response = app.oneshot(request).await?; - - assert_eq!(response.status(), StatusCode::OK); - let body = hyper::body::to_bytes(response.into_body()).await?; - assert_eq!(body, "ok"); - - Ok(()) + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + assert_eq!(response.body(), "ok"); } } diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 1ab8af34..de5e0d87 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -59,6 +59,9 @@ pub mod passwords; mod upstream_oauth2; mod views; +#[cfg(test)] +mod test_utils; + /// Implement `From` for `RouteError`, for "internal server error" kind of /// errors. #[macro_export] @@ -363,68 +366,3 @@ where }, )) } - -#[cfg(test)] -async fn test_state(pool: sqlx::PgPool) -> Result { - use mas_email::MailTransport; - use mas_keystore::{JsonWebKey, JsonWebKeySet, PrivateKey}; - - use crate::passwords::Hasher; - - let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) - .join("..") - .join(".."); - - let url_builder = UrlBuilder::new("https://example.com/".parse()?); - - let templates = Templates::load(workspace_root.join("templates"), url_builder.clone()).await?; - - // TODO: add more test keys to the store - let rsa = - PrivateKey::load_pem(include_str!("../../keystore/tests/keys/rsa.pkcs1.pem")).unwrap(); - let rsa = JsonWebKey::new(rsa).with_kid("test-rsa"); - - let jwks = JsonWebKeySet::new(vec![rsa]); - let key_store = Keystore::new(jwks); - - let encrypter = Encrypter::new(&[0x42; 32]); - - let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; - - let transport = MailTransport::blackhole(); - let mailbox: lettre::message::Mailbox = "server@example.com".parse()?; - let mailer = Mailer::new(templates.clone(), transport, mailbox.clone(), mailbox); - - let homeserver = MatrixHomeserver::new("example.com".to_owned()); - - let file = tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?; - - let policy_factory = PolicyFactory::load( - file, - serde_json::json!({}), - "register/violation".to_owned(), - "client_registration/violation".to_owned(), - "authorization_grant/violation".to_owned(), - ) - .await?; - - let policy_factory = Arc::new(policy_factory); - - let graphql_schema = graphql_schema(); - - let http_client_factory = HttpClientFactory::new(10); - - Ok(AppState { - pool, - templates, - key_store, - encrypter, - url_builder, - mailer, - homeserver, - policy_factory, - graphql_schema, - http_client_factory, - password_manager, - }) -} diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index ed9aa0b2..3758dc1f 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -201,24 +201,19 @@ pub(crate) async fn post( #[cfg(test)] mod tests { - use hyper::{ - header::{AUTHORIZATION, CONTENT_TYPE}, - Request, - }; + use hyper::Request; use mas_data_model::AuthorizationCode; use mas_router::SimpleRoute; - use mas_storage::{RepositoryAccess, RepositoryTransaction, SystemClock}; - use mas_storage_pg::PgRepository; + use mas_storage::RepositoryAccess; use oauth2_types::{ registration::ClientRegistrationResponse, requests::{AccessTokenResponse, ResponseMode}, scope::{Scope, OPENID}, }; - use rand::SeedableRng; use sqlx::PgPool; - use tower::{Service, ServiceExt}; use super::*; + use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState}; #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_revoke_access_token(pool: PgPool) { @@ -227,50 +222,40 @@ mod tests { .with_test_writer() .init(); - let clock = SystemClock::default(); - let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + let state = TestState::from_pool(pool).await.unwrap(); - let state = crate::test_state(pool.clone()).await.unwrap(); - let mut app = crate::api_router().with_state(state); + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "https://example.com/", + "redirect_uris": ["https://example.com/callback"], + "contacts": ["contact@example.com"], + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code"], + "grant_types": ["authorization_code"], + })); - let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH) - .header(CONTENT_TYPE, "application/json") - .body( - serde_json::json!({ - "client_uri": "https://example.com/", - "redirect_uris": ["https://example.com/callback"], - "contacts": ["contact@example.com"], - "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code"], - "grant_types": ["authorization_code"], - }) - .to_string(), - ) - .unwrap(); + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); - let response = app.ready().await.unwrap().call(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::CREATED); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let client_registration: ClientRegistrationResponse = - serde_json::from_slice(&body).unwrap(); + serde_json::from_str(response.body()).unwrap(); let client_id = client_registration.client_id; let client_secret = client_registration.client_secret.unwrap(); // Let's provision a user and create a session for them. This part is hard to // test with just HTTP requests, so we'll use the repository directly. - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut repo = state.repository().await.unwrap(); let user = repo .user() - .add(&mut rng, &clock, "alice".to_owned()) + .add(&mut state.rng(), &state.clock, "alice".to_owned()) .await .unwrap(); let browser_session = repo .browser_session() - .add(&mut rng, &clock, &user) + .add(&mut state.rng(), &state.clock, &user) .await .unwrap(); @@ -286,8 +271,8 @@ mod tests { let grant = repo .oauth2_authorization_grant() .add( - &mut rng, - &clock, + &mut state.rng(), + &state.clock, &client, "https://example.com/redirect".parse().unwrap(), Scope::from_iter([OPENID]), @@ -299,7 +284,7 @@ mod tests { Some("nonce".to_owned()), None, ResponseMode::Query, - true, + false, false, ) .await @@ -307,69 +292,169 @@ mod tests { let session = repo .oauth2_session() - .create_from_grant(&mut rng, &clock, &grant, &browser_session) + .create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session) .await .unwrap(); let grant = repo .oauth2_authorization_grant() - .fulfill(&clock, &session, grant) + .fulfill(&state.clock, &session, grant) .await .unwrap(); - Box::new(repo).save().await.unwrap(); + repo.save().await.unwrap(); // Now call the token endpoint to get an access token. - let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH) - .header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .body( - format!( - "grant_type=authorization_code&code={code}&redirect_uri={redirect_uri}&client_id={client_id}&client_secret={client_secret}", - code = grant.code.unwrap().code, - redirect_uri = grant.redirect_uri, - ), - ) - .unwrap(); + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": grant.code.unwrap().code, + "redirect_uri": grant.redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + })); - let response = app.ready().await.unwrap().call(request).await.unwrap(); - let status = response.status(); - assert_eq!(status, StatusCode::OK); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let token: AccessTokenResponse = serde_json::from_slice(&body).unwrap(); + let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap(); - // Let's call the userinfo endpoint to make sure we can access it. - let request = Request::get(mas_router::OidcUserinfo::PATH) - .header(AUTHORIZATION, format!("Bearer {}", token.access_token)) - .body(String::new()) - .unwrap(); - - let response = app.ready().await.unwrap().call(request).await.unwrap(); - let status = response.status(); - assert_eq!(status, StatusCode::OK); + // Check that the token is valid + assert!(state.is_access_token_valid(&token.access_token).await); // Now let's revoke the access token. - let request = Request::post(mas_router::OAuth2Revocation::PATH) - .header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .body(format!( - "token={token}&token_type_hint=access_token&client_id={client_id}&client_secret={client_secret}", - token = token.access_token - )) + let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ + "token": token.access_token, + "token_type_hint": "access_token", + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + // Check that the token is no longer valid + assert!(!state.is_access_token_valid(&token.access_token).await); + + // Try using the refresh token to get a new access token, it should fail. + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": token.refresh_token, + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::BAD_REQUEST); + + // Now try with a new grant, and by revoking the refresh token instead + let mut repo = state.repository().await.unwrap(); + let grant = repo + .oauth2_authorization_grant() + .add( + &mut state.rng(), + &state.clock, + &client, + "https://example.com/redirect".parse().unwrap(), + Scope::from_iter([OPENID]), + Some(AuthorizationCode { + code: "anotherverysecretcode".to_owned(), + pkce: None, + }), + Some("state".to_owned()), + Some("nonce".to_owned()), + None, + ResponseMode::Query, + false, + false, + ) + .await .unwrap(); - let response = app.ready().await.unwrap().call(request).await.unwrap(); - let status = response.status(); - assert_eq!(status, StatusCode::OK); - - // Call the userinfo endpoint again to make sure we can't access it anymore. - let request = Request::get(mas_router::OidcUserinfo::PATH) - .header(AUTHORIZATION, format!("Bearer {}", token.access_token)) - .body(String::new()) + let session = repo + .oauth2_session() + .create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session) + .await .unwrap(); - let response = app.ready().await.unwrap().call(request).await.unwrap(); - let status = response.status(); - assert_eq!(status, StatusCode::UNAUTHORIZED); - // TODO: test refreshing the access token, test refresh token revocation + let grant = repo + .oauth2_authorization_grant() + .fulfill(&state.clock, &session, grant) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // Now call the token endpoint to get an access token. + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "authorization_code", + "code": grant.code.unwrap().code, + "redirect_uri": grant.redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap(); + + // Use the refresh token to get a new access token. + let request = + Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": token.refresh_token, + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let old_token = token; + let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap(); + assert!(state.is_access_token_valid(&token.access_token).await); + assert!(!state.is_access_token_valid(&old_token.access_token).await); + + // Revoking the old access token shouldn't do anything. + let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ + "token": old_token.access_token, + "token_type_hint": "access_token", + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + assert!(state.is_access_token_valid(&token.access_token).await); + + // Revoking the old refresh token shouldn't do anything. + let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ + "token": old_token.refresh_token, + "token_type_hint": "refresh_token", + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + assert!(state.is_access_token_valid(&token.access_token).await); + + // Revoking the new refresh token should invalidate the session + let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({ + "token": token.refresh_token, + "token_type_hint": "refresh_token", + "client_id": client_id, + "client_secret": client_secret, + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + assert!(!state.is_access_token_valid(&token.access_token).await); } } diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs new file mode 100644 index 00000000..3bfbfe7b --- /dev/null +++ b/crates/handlers/src/test_utils.rs @@ -0,0 +1,376 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{convert::Infallible, sync::Arc}; + +use axum::{ + async_trait, + body::HttpBody, + extract::{FromRef, FromRequestParts}, +}; +use headers::{Authorization, ContentType, HeaderMapExt}; +use hyper::{Request, Response, StatusCode}; +use mas_axum_utils::http_client_factory::HttpClientFactory; +use mas_email::{MailTransport, Mailer}; +use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; +use mas_policy::PolicyFactory; +use mas_router::{SimpleRoute, UrlBuilder}; +use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; +use mas_storage_pg::PgRepository; +use mas_templates::Templates; +use rand::SeedableRng; +use rand_chacha::ChaChaRng; +use serde::Serialize; +use sqlx::PgPool; +use tokio::sync::Mutex; +use tower::{Service, ServiceExt}; + +use crate::{ + app_state::RepositoryError, + graphql_schema, + passwords::{Hasher, PasswordManager}, + MatrixHomeserver, +}; + +#[derive(Clone)] +pub(crate) struct TestState { + pub pool: PgPool, + pub templates: Templates, + pub key_store: Keystore, + pub encrypter: Encrypter, + pub url_builder: UrlBuilder, + pub mailer: Mailer, + pub homeserver: MatrixHomeserver, + pub policy_factory: Arc, + pub graphql_schema: mas_graphql::Schema, + pub http_client_factory: HttpClientFactory, + pub password_manager: PasswordManager, + pub clock: Arc, + pub rng: Arc>, +} + +impl TestState { + /// Create a new test state from the given database pool + pub async fn from_pool(pool: PgPool) -> Result { + let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join(".."); + + let url_builder = UrlBuilder::new("https://example.com/".parse()?); + + let templates = + Templates::load(workspace_root.join("templates"), url_builder.clone()).await?; + + // TODO: add more test keys to the store + let rsa = + PrivateKey::load_pem(include_str!("../../keystore/tests/keys/rsa.pkcs1.pem")).unwrap(); + let rsa = JsonWebKey::new(rsa).with_kid("test-rsa"); + + let jwks = JsonWebKeySet::new(vec![rsa]); + let key_store = Keystore::new(jwks); + + let encrypter = Encrypter::new(&[0x42; 32]); + + let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?; + + let transport = MailTransport::blackhole(); + let mailbox: lettre::message::Mailbox = "server@example.com".parse()?; + let mailer = Mailer::new(templates.clone(), transport, mailbox.clone(), mailbox); + + let homeserver = MatrixHomeserver::new("example.com".to_owned()); + + let file = + tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?; + + let policy_factory = PolicyFactory::load( + file, + serde_json::json!({}), + "register/violation".to_owned(), + "client_registration/violation".to_owned(), + "authorization_grant/violation".to_owned(), + ) + .await?; + + let policy_factory = Arc::new(policy_factory); + + let graphql_schema = graphql_schema(); + + let http_client_factory = HttpClientFactory::new(10); + + let clock = Arc::new(MockClock::default()); + let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42))); + + Ok(Self { + pool, + templates, + key_store, + encrypter, + url_builder, + mailer, + homeserver, + policy_factory, + graphql_schema, + http_client_factory, + password_manager, + clock, + rng, + }) + } + + pub async fn request(&self, request: Request) -> Response + where + B: HttpBody + Send + 'static, + B::Error: std::error::Error + Send + Sync, + B::Data: Send, + { + let app = crate::healthcheck_router() + .merge(crate::discovery_router()) + .merge(crate::api_router()) + .merge(crate::compat_router()) + .merge(crate::human_router(self.templates.clone())) + .with_state(self.clone()); + + // Both unwrap are on Infallible, so this is safe + let response = app + .ready_oneshot() + .await + .unwrap() + .call(request) + .await + .unwrap(); + + let (parts, body) = response.into_parts(); + + // This could actually fail, but do we really care about that? + let body = hyper::body::to_bytes(body) + .await + .expect("Failed to read response body"); + let body = std::str::from_utf8(&body) + .expect("Response body is not valid UTF-8") + .to_owned(); + + Response::from_parts(parts, body) + } + + pub async fn repository(&self) -> Result { + let repo = PgRepository::from_pool(&self.pool).await?; + Ok(repo + .map_err(mas_storage::RepositoryError::from_error) + .boxed()) + } + + /// Returns a new random number generator. + /// + /// # Panics + /// + /// Panics if the RNG is already locked. + pub fn rng(&self) -> ChaChaRng { + let mut parent_rng = self.rng.try_lock().expect("Failed to lock RNG"); + ChaChaRng::from_rng(&mut *parent_rng).unwrap() + } + + /// Do a call to the userinfo endpoint to check if the given token is valid. + /// Returns true if the token is valid. + /// + /// # Panics + /// + /// Panics if the response status code is not 200 or 401. + pub async fn is_access_token_valid(&self, token: &str) -> bool { + let request = Request::get(mas_router::OidcUserinfo::PATH) + .bearer(token) + .empty(); + + let response = self.request(request).await; + + match response.status() { + StatusCode::OK => true, + StatusCode::UNAUTHORIZED => false, + _ => panic!("Unexpected status code: {}", response.status()), + } + } +} + +impl FromRef for PgPool { + fn from_ref(input: &TestState) -> Self { + input.pool.clone() + } +} + +impl FromRef for mas_graphql::Schema { + fn from_ref(input: &TestState) -> Self { + input.graphql_schema.clone() + } +} + +impl FromRef for Templates { + fn from_ref(input: &TestState) -> Self { + input.templates.clone() + } +} + +impl FromRef for Keystore { + fn from_ref(input: &TestState) -> Self { + input.key_store.clone() + } +} + +impl FromRef for Encrypter { + fn from_ref(input: &TestState) -> Self { + input.encrypter.clone() + } +} + +impl FromRef for UrlBuilder { + fn from_ref(input: &TestState) -> Self { + input.url_builder.clone() + } +} + +impl FromRef for Mailer { + fn from_ref(input: &TestState) -> Self { + input.mailer.clone() + } +} + +impl FromRef for MatrixHomeserver { + fn from_ref(input: &TestState) -> Self { + input.homeserver.clone() + } +} + +impl FromRef for Arc { + fn from_ref(input: &TestState) -> Self { + input.policy_factory.clone() + } +} + +impl FromRef for HttpClientFactory { + fn from_ref(input: &TestState) -> Self { + input.http_client_factory.clone() + } +} + +impl FromRef for PasswordManager { + fn from_ref(input: &TestState) -> Self { + input.password_manager.clone() + } +} + +#[async_trait] +impl FromRequestParts for BoxClock { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + Ok(Box::new(state.clock.clone())) + } +} + +#[async_trait] +impl FromRequestParts for BoxRng { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + let mut parent_rng = state.rng.lock().await; + let rng = ChaChaRng::from_rng(&mut *parent_rng).expect("Failed to seed RNG"); + Ok(Box::new(rng)) + } +} + +#[async_trait] +impl FromRequestParts for BoxRepository { + type Rejection = RepositoryError; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + let repo = PgRepository::from_pool(&state.pool).await?; + Ok(repo + .map_err(mas_storage::RepositoryError::from_error) + .boxed()) + } +} + +pub(crate) trait RequestBuilderExt { + /// Builds the request with the given JSON value as body. + fn json(self, body: T) -> hyper::Request; + + /// Builds the request with the given form value as body. + fn form(self, body: T) -> hyper::Request; + + /// Sets the request Authorization header to the given bearer token. + fn bearer(self, token: &str) -> Self; + + /// Builds the request with an empty body. + fn empty(self) -> hyper::Request; +} + +impl RequestBuilderExt for hyper::http::request::Builder { + fn json(mut self, body: T) -> hyper::Request { + self.headers_mut() + .unwrap() + .typed_insert(ContentType::json()); + + self.body(serde_json::to_string(&body).unwrap()).unwrap() + } + + fn form(mut self, body: T) -> hyper::Request { + self.headers_mut() + .unwrap() + .typed_insert(ContentType::form_url_encoded()); + + self.body(serde_urlencoded::to_string(&body).unwrap()) + .unwrap() + } + + fn bearer(mut self, token: &str) -> Self { + self.headers_mut() + .unwrap() + .typed_insert(Authorization::bearer(token).unwrap()); + self + } + + fn empty(self) -> hyper::Request { + self.body(String::new()).unwrap() + } +} + +pub(crate) trait ResponseExt { + /// Asserts that the response has the given status code. + /// + /// # Panics + /// + /// Panics if the response has a different status code. + fn assert_status(&self, status: StatusCode); +} + +impl ResponseExt for Response { + #[track_caller] + fn assert_status(&self, status: StatusCode) { + assert_eq!( + self.status(), + status, + "HTTP status code mismatch: got {}, expected {}. Body: {}", + self.status(), + status, + self.body() + ); + } +} diff --git a/crates/storage/src/clock.rs b/crates/storage/src/clock.rs index 04c69f25..a338d1d8 100644 --- a/crates/storage/src/clock.rs +++ b/crates/storage/src/clock.rs @@ -18,7 +18,7 @@ //! [`SystemClock`] which uses the system time, and a [`MockClock`], which can //! be used and freely manipulated in tests. -use std::sync::atomic::AtomicI64; +use std::sync::{atomic::AtomicI64, Arc}; use chrono::{DateTime, TimeZone, Utc}; @@ -28,6 +28,12 @@ pub trait Clock: Sync { fn now(&self) -> DateTime; } +impl Clock for Arc { + fn now(&self) -> DateTime { + (**self).now() + } +} + impl Clock for Box { fn now(&self) -> DateTime { (**self).now()