diff --git a/Cargo.lock b/Cargo.lock index 95b5c584..b5c2ed41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2468,6 +2468,8 @@ dependencies = [ "opentelemetry-semantic-conventions", "opentelemetry-zipkin", "prometheus", + "rand", + "rand_chacha", "rustls", "serde_json", "serde_yaml", @@ -2497,6 +2499,7 @@ dependencies = [ "mas-keystore", "pem-rfc7468", "rand", + "rand_chacha", "rustls-pemfile", "schemars", "serde", @@ -2567,6 +2570,7 @@ dependencies = [ "mime", "oauth2-types", "rand", + "rand_chacha", "serde", "serde_json", "serde_urlencoded", @@ -2779,6 +2783,7 @@ dependencies = [ "oauth2-types", "password-hash", "rand", + "rand_chacha", "serde", "serde_json", "sqlx", @@ -3068,7 +3073,7 @@ checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" [[package]] name = "opa-wasm" version = "0.1.0" -source = "git+https://github.com/matrix-org/rust-opa-wasm.git#325071ee8a2a7d18cc611365edd3235945a8cdf4" +source = "git+https://github.com/matrix-org/rust-opa-wasm.git#f838595670747b0644b6bfd9829fca5d63bbee66" dependencies = [ "anyhow", "base64", @@ -3079,6 +3084,7 @@ dependencies = [ "md-5", "parse-size", "rand", + "rayon-core", "semver", "serde", "serde_json", diff --git a/clippy.toml b/clippy.toml index 87582113..93fce8d5 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,2 +1,11 @@ msrv = "1.61.0" doc-valid-idents = ["OpenID", "OAuth", ".."] + +disallowed-methods = [ + { path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" }, + { path = "chrono::Utc::now", reason = "source the current time from the clock instead" }, +] + +disallowed-types = [ + "rand::OsRng", +] diff --git a/crates/axum-utils/src/csrf.rs b/crates/axum-utils/src/csrf.rs index b69f935d..c12a2ed8 100644 --- a/crates/axum-utils/src/csrf.rs +++ b/crates/axum-utils/src/csrf.rs @@ -15,6 +15,7 @@ use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use chrono::{DateTime, Duration, Utc}; use data_encoding::{DecodeError, BASE64URL_NOPAD}; +use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, TimestampSeconds}; use thiserror::Error; @@ -56,20 +57,20 @@ pub struct CsrfToken { impl CsrfToken { /// Create a new token from a defined value valid for a specified duration - fn new(token: [u8; 32], ttl: Duration) -> Self { - let expiration = Utc::now() + ttl; + fn new(token: [u8; 32], now: DateTime, ttl: Duration) -> Self { + let expiration = now + ttl; Self { expiration, token } } /// Generate a new random token valid for a specified duration - fn generate(ttl: Duration) -> Self { - let token = rand::random(); - Self::new(token, ttl) + fn generate(now: DateTime, mut rng: impl Rng, ttl: Duration) -> Self { + let token = rng.gen(); + Self::new(token, now, ttl) } /// Generate a new token with the same value but an up to date expiration - fn refresh(self, ttl: Duration) -> Self { - Self::new(self.token, ttl) + fn refresh(self, now: DateTime, ttl: Duration) -> Self { + Self::new(self.token, now, ttl) } /// Get the value to include in HTML forms @@ -88,8 +89,8 @@ impl CsrfToken { } } - fn verify_expiration(self) -> Result { - if Utc::now() < self.expiration { + fn verify_expiration(self, now: DateTime) -> Result { + if now < self.expiration { Ok(self) } else { Err(CsrfError::Expired) @@ -118,12 +119,18 @@ impl CsrfExt for PrivateCookieJar { cookie.set_path("/"); cookie.set_http_only(true); + // XXX: the rng source and clock should come from somewhere else + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + #[allow(clippy::disallowed_methods)] + let rng = thread_rng(); + let new_token = cookie .decode() .ok() - .and_then(|token: CsrfToken| token.verify_expiration().ok()) - .unwrap_or_else(|| CsrfToken::generate(Duration::hours(1))) - .refresh(Duration::hours(1)); + .and_then(|token: CsrfToken| token.verify_expiration(now).ok()) + .unwrap_or_else(|| CsrfToken::generate(now, rng, Duration::hours(1))) + .refresh(now, Duration::hours(1)); let cookie = cookie.encode(&new_token); let jar = jar.add(cookie); @@ -131,9 +138,13 @@ impl CsrfExt for PrivateCookieJar { } fn verify_form(&self, form: ProtectedForm) -> Result { + // XXX: the clock should come from somewhere else + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + let cookie = self.get("csrf").ok_or(CsrfError::Missing)?; let token: CsrfToken = cookie.decode()?; - let token = token.verify_expiration()?; + let token = token.verify_expiration(now)?; token.verify_form_value(&form.csrf)?; Ok(form.inner) } diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index b3484ef0..3a588157 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -6,23 +6,25 @@ edition = "2021" license = "Apache-2.0" [dependencies] -axum = "0.6.0-rc.2" -tokio = { version = "1.21.2", features = ["full"] } -futures-util = "0.3.25" anyhow = "1.0.66" +argon2 = { version = "0.4.1", features = ["password-hash"] } +atty = "0.2.14" +axum = "0.6.0-rc.2" clap = { version = "4.0.18", features = ["derive"] } dotenv = "0.15.0" -tower = { version = "0.4.13", features = ["full"] } +futures-util = "0.3.25" hyper = { version = "0.14.22", features = ["full"] } -serde_yaml = "0.9.14" -serde_json = "1.0.87" -url = "2.3.1" -argon2 = { version = "0.4.1", features = ["password-hash"] } -watchman_client = "0.8.0" -atty = "0.2.14" -listenfd = "1.0.0" -rustls = "0.20.7" itertools = "0.10.5" +listenfd = "1.0.0" +rand = "0.8.5" +rand_chacha = "0.3.1" +rustls = "0.20.7" +serde_json = "1.0.87" +serde_yaml = "0.9.14" +tokio = { version = "1.21.2", features = ["full"] } +tower = { version = "0.4.13", features = ["full"] } +url = "2.3.1" +watchman_client = "0.8.0" tracing = "0.1.37" tracing-appender = "0.2.2" diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 3cd5f38d..8afd3b5a 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -20,7 +20,9 @@ use mas_storage::{ user::{ lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user, }, + Clock, }; +use rand::SeedableRng; use tracing::{info, warn}; #[derive(Parser, Debug)] @@ -51,14 +53,17 @@ enum Subcommand { impl Options { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; + let clock = Clock::default(); + match &self.subcommand { SC::Register { username, password } => { let config: DatabaseConfig = root.load_config()?; let pool = config.connect().await?; let mut txn = pool.begin().await?; let hasher = Argon2::default(); + let rng = rand_chacha::ChaChaRng::from_entropy(); - let user = register_user(&mut txn, hasher, username, password).await?; + let user = register_user(&mut txn, rng, &clock, hasher, username, password).await?; txn.commit().await?; info!(?user, "User registered"); @@ -76,7 +81,7 @@ impl Options { let user = lookup_user_by_username(&mut txn, username).await?; let email = lookup_user_email(&mut txn, &user, email).await?; - let email = mark_user_email_as_verified(&mut txn, email).await?; + let email = mark_user_email_as_verified(&mut txn, &clock, email).await?; txn.commit().await?; info!(?email, "Email marked as verified"); diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml index b78d401b..ddaf3db9 100644 --- a/crates/config/Cargo.toml +++ b/crates/config/Cargo.toml @@ -28,6 +28,7 @@ lettre = { version = "0.10.1", default-features = false, features = ["serde", "b pem-rfc7468 = "0.6.0" rustls-pemfile = "1.0.1" rand = "0.8.5" +rand_chacha = "0.3.1" indoc = "1.0.7" diff --git a/crates/config/src/sections/secrets.rs b/crates/config/src/sections/secrets.rs index f9fc9905..8ba5e667 100644 --- a/crates/config/src/sections/secrets.rs +++ b/crates/config/src/sections/secrets.rs @@ -20,7 +20,7 @@ use mas_jose::jwk::{JsonWebKey, JsonWebKeySet}; use mas_keystore::{Encrypter, Keystore, PrivateKey}; use rand::{ distributions::{Alphanumeric, DistString}, - thread_rng, + thread_rng, SeedableRng, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -139,64 +139,72 @@ impl ConfigurationSection<'_> for SecretsConfig { #[tracing::instrument] async fn generate() -> anyhow::Result { + // XXX: that RNG should come from somewhere else + #[allow(clippy::disallowed_methods)] + let mut rng = rand_chacha::ChaChaRng::from_rng(thread_rng())?; + info!("Generating keys..."); let span = tracing::info_span!("rsa"); + let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?; let rsa_key = task::spawn_blocking(move || { let _entered = span.enter(); - let ret = PrivateKey::generate_rsa(thread_rng()).unwrap(); + let ret = PrivateKey::generate_rsa(key_rng).unwrap(); info!("Done generating RSA key"); ret }) .await .context("could not join blocking task")?; let rsa_key = KeyConfig { - kid: Alphanumeric.sample_string(&mut thread_rng(), 10), + kid: Alphanumeric.sample_string(&mut rng, 10), password: None, key: KeyOrFile::Key(rsa_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), }; let span = tracing::info_span!("ec_p256"); + let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?; let ec_p256_key = task::spawn_blocking(move || { let _entered = span.enter(); - let ret = PrivateKey::generate_ec_p256(thread_rng()); + let ret = PrivateKey::generate_ec_p256(key_rng); info!("Done generating EC P-256 key"); ret }) .await .context("could not join blocking task")?; let ec_p256_key = KeyConfig { - kid: Alphanumeric.sample_string(&mut thread_rng(), 10), + kid: Alphanumeric.sample_string(&mut rng, 10), password: None, key: KeyOrFile::Key(ec_p256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), }; let span = tracing::info_span!("ec_p384"); + let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?; let ec_p384_key = task::spawn_blocking(move || { let _entered = span.enter(); - let ret = PrivateKey::generate_ec_p384(thread_rng()); + let ret = PrivateKey::generate_ec_p384(key_rng); info!("Done generating EC P-256 key"); ret }) .await .context("could not join blocking task")?; let ec_p384_key = KeyConfig { - kid: Alphanumeric.sample_string(&mut thread_rng(), 10), + kid: Alphanumeric.sample_string(&mut rng, 10), password: None, key: KeyOrFile::Key(ec_p384_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), }; let span = tracing::info_span!("ec_k256"); + let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?; let ec_k256_key = task::spawn_blocking(move || { let _entered = span.enter(); - let ret = PrivateKey::generate_ec_k256(thread_rng()); + let ret = PrivateKey::generate_ec_k256(key_rng); info!("Done generating EC secp256k1 key"); ret }) .await .context("could not join blocking task")?; let ec_k256_key = KeyConfig { - kid: Alphanumeric.sample_string(&mut thread_rng(), 10), + kid: Alphanumeric.sample_string(&mut rng, 10), password: None, key: KeyOrFile::Key(ec_k256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), }; diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index f11ccb0a..6e1c7658 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -263,6 +263,8 @@ mod tests { #[test] fn test_generate_and_check() { const COUNT: usize = 500; // Generate 500 of each token type + + #[allow(clippy::disallowed_methods)] let mut rng = thread_rng(); for t in [ diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index e64cd379..43c736f2 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -44,6 +44,7 @@ chrono = { version = "0.4.22", features = ["serde"] } url = { version = "2.3.1", features = ["serde"] } mime = "0.3.16" rand = "0.8.5" +rand_chacha = "0.3.1" headers = "0.3.8" ulid = "1.0.0" diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 490e097d..d3af3402 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -13,7 +13,7 @@ // limitations under the License. use axum::{extract::State, response::IntoResponse, Json}; -use chrono::{Duration, Utc}; +use chrono::Duration; use hyper::StatusCode; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; use mas_storage::{ @@ -22,9 +22,8 @@ use mas_storage::{ get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged, CompatSsoLoginLookupError, }, - PostgresqlBackend, + Clock, PostgresqlBackend, }; -use rand::thread_rng; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use sqlx::{PgPool, Postgres, Transaction}; @@ -201,6 +200,7 @@ pub(crate) async fn post( State(homeserver): State, Json(input): Json, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let session = match input.credentials { Credentials::Password { @@ -208,7 +208,7 @@ pub(crate) async fn post( password, } => user_password_login(&mut txn, user, password).await?, - Credentials::Token { token } => token_login(&mut txn, &token).await?, + Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?, _ => { return Err(RouteError::Unsupported); @@ -225,14 +225,28 @@ pub(crate) async fn post( None }; - let access_token = TokenType::CompatAccessToken.generate(&mut thread_rng()); - let access_token = - add_compat_access_token(&mut txn, &session, access_token, expires_in).await?; + let access_token = TokenType::CompatAccessToken.generate(&mut rng); + let access_token = add_compat_access_token( + &mut txn, + &mut rng, + &clock, + &session, + access_token, + expires_in, + ) + .await?; let refresh_token = if input.refresh_token { - let refresh_token = TokenType::CompatRefreshToken.generate(&mut thread_rng()); - let refresh_token = - add_compat_refresh_token(&mut txn, &session, &access_token, refresh_token).await?; + let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); + let refresh_token = add_compat_refresh_token( + &mut txn, + &mut rng, + &clock, + &session, + &access_token, + refresh_token, + ) + .await?; Some(refresh_token.token) } else { None @@ -251,11 +265,12 @@ pub(crate) async fn post( async fn token_login( txn: &mut Transaction<'_, Postgres>, + clock: &Clock, token: &str, ) -> Result, RouteError> { let login = get_compat_sso_login_by_token(&mut *txn, token).await?; - let now = Utc::now(); + let now = clock.now(); match login.state { CompatSsoLoginState::Pending => { tracing::error!( @@ -285,7 +300,7 @@ async fn token_login( } } - let login = mark_compat_sso_login_as_exchanged(&mut *txn, login).await?; + let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; match login.state { CompatSsoLoginState::Exchanged { session, .. } => Ok(session), @@ -298,8 +313,10 @@ async fn user_password_login( username: String, password: String, ) -> Result, RouteError> { - let device = Device::generate(&mut thread_rng()); - let session = compat_login(txn, &username, &password, device) + let (clock, mut rng) = crate::rng_and_clock()?; + + let device = Device::generate(&mut rng); + let session = compat_login(txn, &mut rng, &clock, &username, &password, device) .await .map_err(|_| RouteError::LoginFailed)?; diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 6083d944..97dd0e00 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -20,7 +20,7 @@ use axum::{ response::{Html, IntoResponse, Redirect, Response}, }; use axum_extra::extract::PrivateCookieJar; -use chrono::{Duration, Utc}; +use chrono::Duration; use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, FancyError, SessionInfoExt, @@ -28,9 +28,11 @@ use mas_axum_utils::{ use mas_data_model::Device; use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; -use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; +use mas_storage::{ + compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}, + Clock, +}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; -use rand::thread_rng; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use ulid::Ulid; @@ -56,6 +58,7 @@ pub async fn get( Path(id): Path, Query(params): Query, ) -> Result { + let clock = Clock::default(); let mut conn = pool.acquire().await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -95,7 +98,7 @@ pub async fn get( let login = get_compat_sso_login_by_id(&mut conn, id).await?; // Bail out if that login session is more than 30min old - if Utc::now() > login.created_at + Duration::minutes(30) { + if clock.now() > login.created_at + Duration::minutes(30) { let ctx = ErrorContext::new() .with_code("compat_sso_login_expired") .with_description("This login session expired.".to_owned()); @@ -121,6 +124,7 @@ pub async fn post( Query(params): Query, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -160,7 +164,7 @@ pub async fn post( let login = get_compat_sso_login_by_id(&mut txn, id).await?; // Bail out if that login session is more than 30min old - if Utc::now() > login.created_at + Duration::minutes(30) { + if clock.now() > login.created_at + Duration::minutes(30) { let ctx = ErrorContext::new() .with_code("compat_sso_login_expired") .with_description("This login session expired.".to_owned()); @@ -186,8 +190,9 @@ pub async fn post( redirect_uri }; - let device = Device::generate(&mut thread_rng()); - let _login = fullfill_compat_sso_login(&mut txn, session.user, login, device).await?; + let device = Device::generate(&mut rng); + let _login = + fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?; txn.commit().await?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 2a4c4676..9a146cd4 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -20,10 +20,7 @@ use axum::{ use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_storage::compat::insert_compat_sso_login; -use rand::{ - distributions::{Alphanumeric, DistString}, - thread_rng, -}; +use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; use sqlx::PgPool; @@ -70,6 +67,8 @@ pub async fn get( State(url_builder): State, Query(params): Query, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; + // Check the redirectUrl parameter let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?; let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?; @@ -84,9 +83,9 @@ pub async fn get( return Err(RouteError::InvalidRedirectUrl); } - let token = Alphanumeric.sample_string(&mut thread_rng(), 32); + let token = Alphanumeric.sample_string(&mut rng, 32); let mut conn = pool.acquire().await?; - let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?; + let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?; Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action))) } diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 36e64c47..e613376c 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -16,7 +16,7 @@ use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; -use mas_storage::compat::compat_logout; +use mas_storage::{compat::compat_logout, Clock}; use sqlx::PgPool; use super::MatrixError; @@ -67,6 +67,7 @@ pub(crate) async fn post( State(pool): State, maybe_authorization: Option>>, ) -> Result { + let clock = Clock::default(); let mut conn = pool.acquire().await?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; @@ -78,7 +79,7 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - compat_logout(&mut conn, token) + compat_logout(&mut conn, &clock, token) .await .map_err(|_| RouteError::LogoutFailed)?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index e07a5768..fb1297bb 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -20,7 +20,6 @@ use mas_storage::compat::{ add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, expire_compat_access_token, lookup_active_compat_refresh_token, CompatRefreshTokenLookupError, }; -use rand::thread_rng; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; use sqlx::PgPool; @@ -98,6 +97,7 @@ pub(crate) async fn post( State(pool): State, Json(input): Json, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let token_type = TokenType::check(&input.refresh_token)?; @@ -109,23 +109,31 @@ pub(crate) async fn post( let (refresh_token, access_token, session) = lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?; - let (new_refresh_token_str, new_access_token_str) = { - let mut rng = thread_rng(); - ( - TokenType::CompatRefreshToken.generate(&mut rng), - TokenType::CompatAccessToken.generate(&mut rng), - ) - }; + let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng); + let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let expires_in = Duration::minutes(5); - let new_access_token = - add_compat_access_token(&mut txn, &session, new_access_token_str, Some(expires_in)).await?; - let new_refresh_token = - add_compat_refresh_token(&mut txn, &session, &new_access_token, new_refresh_token_str) - .await?; + let new_access_token = add_compat_access_token( + &mut txn, + &mut rng, + &clock, + &session, + new_access_token_str, + Some(expires_in), + ) + .await?; + let new_refresh_token = add_compat_refresh_token( + &mut txn, + &mut rng, + &clock, + &session, + &new_access_token, + new_refresh_token_str, + ) + .await?; - consume_compat_refresh_token(&mut txn, refresh_token).await?; - expire_compat_access_token(&mut txn, access_token).await?; + consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?; + expire_compat_access_token(&mut txn, &clock, access_token).await?; txn.commit().await?; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 18aca49e..730055f6 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -21,6 +21,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration}; +use anyhow::Context; use axum::{ body::HttpBody, extract::FromRef, @@ -36,6 +37,7 @@ use mas_keystore::{Encrypter, Keystore}; use mas_policy::PolicyFactory; use mas_router::{Route, UrlBuilder}; use mas_templates::{ErrorContext, Templates}; +use rand::SeedableRng; use sqlx::PgPool; use tower::util::AndThenLayer; use tower_http::cors::{Any, CorsLayer}; @@ -356,3 +358,15 @@ async fn test_state(pool: PgPool) -> Result, anyhow::Error> { policy_factory, })) } + +// XXX: that should be moved somewhere else +fn rng_and_clock() -> Result<(mas_storage::Clock, rand_chacha::ChaChaRng), anyhow::Error> { + let clock = mas_storage::Clock::default(); + + // This rng is used to source the local rng + #[allow(clippy::disallowed_methods)] + let rng = rand::thread_rng(); + + let rng = rand_chacha::ChaChaRng::from_rng(rng).context("Failed to seed RNG")?; + Ok((clock, rng)) +} diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 23709c2f..a592dc40 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -190,6 +190,8 @@ pub(crate) async fn complete( policy_factory: &PolicyFactory, mut txn: Transaction<'_, Postgres>, ) -> Result>, GrantCompletionError> { + let (clock, mut rng) = crate::rng_and_clock()?; + // Verify that the grant is in a pending stage if !grant.stage.is_pending() { return Err(GrantCompletionError::NotPending); @@ -226,7 +228,7 @@ pub(crate) async fn complete( } // All good, let's start the session - let session = derive_session(&mut txn, &grant, browser_session).await?; + let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?; let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index ad5a347b..69149206 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -37,7 +37,7 @@ use oauth2_types::{ requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode}, response_type::ResponseType, }; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use rand::{distributions::Alphanumeric, Rng}; use serde::Deserialize; use sqlx::PgPool; use thiserror::Error; @@ -159,6 +159,7 @@ pub(crate) async fn get( cookie_jar: PrivateCookieJar, Form(params): Form, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; // First, figure out what client it is @@ -265,7 +266,7 @@ pub(crate) async fn get( } // 32 random alphanumeric characters, about 190bit of entropy - let code: String = thread_rng() + let code: String = (&mut rng) .sample_iter(&Alphanumeric) .take(32) .map(char::from) @@ -296,6 +297,8 @@ pub(crate) async fn get( let grant = new_authorization_grant( &mut txn, + &mut rng, + &clock, client, redirect_uri.clone(), params.auth.scope, diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 1fa88ad6..e3f9f0bf 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -119,6 +119,7 @@ pub(crate) async fn post( Path(grant_id): Path, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool .begin() .await @@ -163,6 +164,8 @@ pub(crate) async fn post( .collect(); insert_client_consent( &mut txn, + &mut rng, + &clock, &session.user, &grant.client, &scope_without_device, diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index edbd01d9..c99ad063 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -28,6 +28,7 @@ use mas_storage::{ client::ClientFetchError, refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError}, }, + Clock, }; use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; use sqlx::PgPool; @@ -158,6 +159,7 @@ pub(crate) async fn post( State(encrypter): State, client_authorization: ClientAuthorization, ) -> Result { + let clock = Clock::default(); let mut conn = pool.acquire().await?; let client = client_authorization.credentials.fetch(&mut conn).await?; @@ -227,7 +229,8 @@ pub(crate) async fn post( } } TokenType::CompatAccessToken => { - let (token, session) = lookup_active_compat_access_token(&mut conn, token).await?; + let (token, session) = + lookup_active_compat_access_token(&mut conn, &clock, token).await?; let device_scope = session.device.to_scope_token(); let scope = [device_scope].into_iter().collect(); diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index fb70e7fd..0b14cedf 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -50,7 +50,6 @@ use oauth2_types::{ }, scope, }; -use rand::thread_rng; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; use sqlx::{PgPool, Postgres, Transaction}; @@ -235,12 +234,13 @@ async fn authorization_code_grant( url_builder: &UrlBuilder, mut txn: Transaction<'_, Postgres>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; + // TODO: there is a bunch of unnecessary cloning here // TODO: handle "not found" cases let authz_grant = lookup_grant_by_code(&mut txn, &grant.code).await?; - // TODO: that's not a timestamp from the DB. Let's assume they are in sync - let now = Utc::now(); + let now = clock.now(); let session = match authz_grant.stage { AuthorizationGrantStage::Cancelled { cancelled_at } => { @@ -257,7 +257,7 @@ async fn authorization_code_grant( // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - end_oauth_session(&mut txn, session).await?; + end_oauth_session(&mut txn, &clock, session).await?; txn.commit().await?; } @@ -303,22 +303,32 @@ async fn authorization_code_grant( let browser_session = &session.browser_session; let ttl = Duration::minutes(5); - let (access_token_str, refresh_token_str) = { - let mut rng = thread_rng(); - ( - TokenType::AccessToken.generate(&mut rng), - TokenType::RefreshToken.generate(&mut rng), - ) - }; + let access_token_str = TokenType::AccessToken.generate(&mut rng); + let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let access_token = add_access_token(&mut txn, session, access_token_str.clone(), ttl).await?; + let access_token = add_access_token( + &mut txn, + &mut rng, + &clock, + session, + access_token_str.clone(), + ttl, + ) + .await?; - let _refresh_token = - add_refresh_token(&mut txn, session, access_token, refresh_token_str.clone()).await?; + let _refresh_token = add_refresh_token( + &mut txn, + &mut rng, + &clock, + session, + access_token, + refresh_token_str.clone(), + ) + .await?; let id_token = if session.scope.contains(&scope::OPENID) { let mut claims = HashMap::new(); - let now = Utc::now(); + let now = clock.now(); claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?; claims::SUB.insert(&mut claims, &browser_session.user.sub)?; claims::AUD.insert(&mut claims, client.client_id.clone())?; @@ -346,7 +356,7 @@ async fn authorization_code_grant( let signer = key.params().signing_key_for_alg(&alg)?; let header = JsonWebSignatureHeader::new(alg) .with_kid(key.kid().context("key has no `kid` for some reason")?); - let id_token = Jwt::sign(header, claims, &signer)?; + let id_token = Jwt::sign_with_rng(&mut rng, header, claims, &signer)?; Some(id_token.as_str().to_owned()) } else { @@ -362,7 +372,7 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } - exchange_grant(&mut txn, authz_grant).await?; + exchange_grant(&mut txn, &clock, authz_grant).await?; txn.commit().await?; @@ -374,6 +384,8 @@ async fn refresh_token_grant( client: &Client, mut txn: Transaction<'_, Postgres>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; + let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?; @@ -383,24 +395,33 @@ async fn refresh_token_grant( } let ttl = Duration::minutes(5); - let (access_token_str, refresh_token_str) = { - let mut rng = thread_rng(); - ( - TokenType::AccessToken.generate(&mut rng), - TokenType::RefreshToken.generate(&mut rng), - ) - }; + let access_token_str = TokenType::AccessToken.generate(&mut rng); + let refresh_token_str = TokenType::RefreshToken.generate(&mut rng); - let new_access_token = - add_access_token(&mut txn, &session, access_token_str.clone(), ttl).await?; + let new_access_token = add_access_token( + &mut txn, + &mut rng, + &clock, + &session, + access_token_str.clone(), + ttl, + ) + .await?; - let new_refresh_token = - add_refresh_token(&mut txn, &session, new_access_token, refresh_token_str).await?; + let new_refresh_token = add_refresh_token( + &mut txn, + &mut rng, + &clock, + &session, + new_access_token, + refresh_token_str, + ) + .await?; - consume_refresh_token(&mut txn, &refresh_token).await?; + consume_refresh_token(&mut txn, &clock, &refresh_token).await?; if let Some(access_token) = refresh_token.access_token { - revoke_access_token(&mut txn, access_token).await?; + revoke_access_token(&mut txn, &clock, access_token).await?; } let params = AccessTokenResponse::new(access_token_str) diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index a2c1bf69..3c47bcb4 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -54,6 +54,7 @@ pub async fn get( user_authorization: UserAuthorization, ) -> Result { // TODO: error handling + let (_clock, mut rng) = crate::rng_and_clock()?; let mut conn = pool.acquire().await?; let session = user_authorization.protected(&mut conn).await?; @@ -88,7 +89,7 @@ pub async fn get( user_info, }; - let token = Jwt::sign(header, user_info, &signer)?; + let token = Jwt::sign_with_rng(&mut rng, header, user_info, &signer)?; Ok(JwtResponse(token).into_response()) } else { Ok(Json(user_info).into_response()) diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 1165d02c..805f2e5a 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -72,6 +72,7 @@ pub(crate) async fn post( Query(query): Query, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; @@ -86,14 +87,22 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = add_user_email(&mut txn, &session.user, form.email).await?; + let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?; let next = mas_router::AccountVerifyEmail::new(user_email.data); let next = if let Some(action) = query.post_auth_action { next.and_then(action) } else { next }; - start_email_verification(&mailer, &mut txn, &session.user, user_email).await?; + start_email_verification( + &mailer, + &mut txn, + &mut rng, + &clock, + &session.user, + user_email, + ) + .await?; txn.commit().await?; diff --git a/crates/handlers/src/views/account/emails/mod.rs b/crates/handlers/src/views/account/emails/mod.rs index 3ab0b950..8d7b6d2b 100644 --- a/crates/handlers/src/views/account/emails/mod.rs +++ b/crates/handlers/src/views/account/emails/mod.rs @@ -32,10 +32,10 @@ use mas_storage::{ add_user_email, add_user_email_verification_code, get_user_email, get_user_emails, remove_user_email, set_user_email_as_primary, }, - PostgresqlBackend, + Clock, PostgresqlBackend, }; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; -use rand::{distributions::Uniform, thread_rng, Rng}; +use rand::{distributions::Uniform, Rng}; use serde::Deserialize; use sqlx::{PgExecutor, PgPool}; use tracing::info; @@ -93,17 +93,26 @@ async fn render( async fn start_email_verification( mailer: &Mailer, executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, user: &User, user_email: UserEmail, ) -> anyhow::Result<()> { // First, generate a code let range = Uniform::::from(0..1_000_000); - let code = thread_rng().sample(range).to_string(); + let code = rng.sample(range).to_string(); let address: Address = user_email.email.parse()?; - let verification = - add_user_email_verification_code(executor, user_email, Duration::hours(8), code).await?; + let verification = add_user_email_verification_code( + executor, + &mut rng, + clock, + user_email, + Duration::hours(8), + code, + ) + .await?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -126,6 +135,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -143,9 +153,18 @@ pub(crate) async fn post( match form { ManagementForm::Add { email } => { - let user_email = add_user_email(&mut txn, &session.user, email).await?; + let user_email = + add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?; let next = mas_router::AccountVerifyEmail::new(user_email.data); - start_email_verification(&mailer, &mut txn, &session.user, user_email).await?; + start_email_verification( + &mailer, + &mut txn, + &mut rng, + &clock, + &session.user, + user_email, + ) + .await?; txn.commit().await?; return Ok((cookie_jar, next.go()).into_response()); } @@ -154,7 +173,15 @@ pub(crate) async fn post( let user_email = get_user_email(&mut txn, &session.user, id).await?; let next = mas_router::AccountVerifyEmail::new(user_email.data); - start_email_verification(&mailer, &mut txn, &session.user, user_email).await?; + start_email_verification( + &mailer, + &mut txn, + &mut rng, + &clock, + &session.user, + user_email, + ) + .await?; txn.commit().await?; return Ok((cookie_jar, next.go()).into_response()); } diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 22e427d3..ff1addea 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -23,9 +23,12 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::Route; -use mas_storage::user::{ - consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code, - mark_user_email_as_verified, set_user_email_as_primary, +use mas_storage::{ + user::{ + consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code, + mark_user_email_as_verified, set_user_email_as_primary, + }, + Clock, }; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use serde::Deserialize; @@ -84,6 +87,7 @@ pub(crate) async fn post( Path(id): Path, Form(form): Form>, ) -> Result { + let clock = Clock::default(); let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; @@ -105,12 +109,13 @@ pub(crate) async fn post( } // TODO: make those 8 hours configurable - let verification = lookup_user_email_verification_code(&mut txn, email, &form.code).await?; + let verification = + lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?; // TODO: display nice errors if the code was already consumed or expired - let verification = consume_email_verification(&mut txn, verification).await?; + let verification = consume_email_verification(&mut txn, &clock, verification).await?; - let _email = mark_user_email_as_verified(&mut txn, verification.email).await?; + let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?; txn.commit().await?; diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 55fcecd4..7f317868 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -81,6 +81,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; @@ -96,7 +97,14 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - authenticate_session(&mut txn, &mut session, &form.current_password).await?; + authenticate_session( + &mut txn, + &mut rng, + &clock, + &mut session, + &form.current_password, + ) + .await?; // TODO: display nice form errors if form.new_password != form.new_password_confirm { @@ -104,7 +112,15 @@ pub(crate) async fn post( } let phf = Argon2::default(); - set_password(&mut txn, phf, &session.user, &form.new_password).await?; + set_password( + &mut txn, + &mut rng, + &clock, + phf, + &session.user, + &form.new_password, + ) + .await?; let reply = render(templates.clone(), session, cookie_jar).await?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index b43932b4..c62f5098 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -80,6 +80,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut conn = pool.acquire().await?; let form = cookie_jar.verify_form(form)?; @@ -114,7 +115,7 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } - match login(&mut conn, &form.username, &form.password).await { + match login(&mut conn, &mut rng, &clock, &form.username, &form.password).await { Ok(session_info) => { let cookie_jar = cookie_jar.set_session(&session_info); let reply = query.go_next(); diff --git a/crates/handlers/src/views/logout.rs b/crates/handlers/src/views/logout.rs index 34d4a69c..a742a9f3 100644 --- a/crates/handlers/src/views/logout.rs +++ b/crates/handlers/src/views/logout.rs @@ -23,7 +23,7 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_router::{PostAuthAction, Route}; -use mas_storage::user::end_session; +use mas_storage::{user::end_session, Clock}; use sqlx::PgPool; pub(crate) async fn post( @@ -31,6 +31,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>>, ) -> Result { + let clock = Clock::default(); let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; @@ -40,7 +41,7 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut txn).await?; if let Some(session) = maybe_session { - end_session(&mut txn, &session).await?; + end_session(&mut txn, &clock, &session).await?; cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); } diff --git a/crates/handlers/src/views/reauth.rs b/crates/handlers/src/views/reauth.rs index 8cb2463f..dfaa8e0e 100644 --- a/crates/handlers/src/views/reauth.rs +++ b/crates/handlers/src/views/reauth.rs @@ -80,6 +80,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; @@ -98,7 +99,7 @@ pub(crate) async fn post( }; // TODO: recover from errors here - authenticate_session(&mut txn, &mut session, &form.password).await?; + authenticate_session(&mut txn, &mut rng, &clock, &mut session, &form.password).await?; let cookie_jar = cookie_jar.set_session(&session); txn.commit().await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 57da405b..d1957d33 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -39,7 +39,7 @@ use mas_templates::{ EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, TemplateContext, Templates, ToFormState, }; -use rand::{distributions::Uniform, thread_rng, Rng}; +use rand::{distributions::Uniform, Rng}; use serde::{Deserialize, Serialize}; use sqlx::{PgConnection, PgPool}; @@ -87,6 +87,7 @@ pub(crate) async fn get( } } +#[allow(clippy::too_many_lines)] pub(crate) async fn post( State(mailer): State, State(policy_factory): State>, @@ -96,6 +97,7 @@ pub(crate) async fn post( cookie_jar: PrivateCookieJar, Form(form): Form>, ) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; let mut txn = pool.begin().await?; let form = cookie_jar.verify_form(form)?; @@ -180,18 +182,34 @@ pub(crate) async fn post( } let pfh = Argon2::default(); - let user = register_user(&mut txn, pfh, &form.username, &form.password).await?; + let user = register_user( + &mut txn, + &mut rng, + &clock, + pfh, + &form.username, + &form.password, + ) + .await?; - let user_email = add_user_email(&mut txn, &user, form.email).await?; + let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?; // First, generate a code let range = Uniform::::from(0..1_000_000); - let code = thread_rng().sample(range).to_string(); + let code = rng.sample(range); + let code = format!("{code:06}"); let address: Address = user_email.email.parse()?; - let verification = - add_user_email_verification_code(&mut txn, user_email, Duration::hours(8), code).await?; + let verification = add_user_email_verification_code( + &mut txn, + &mut rng, + &clock, + user_email, + Duration::hours(8), + code, + ) + .await?; // And send the verification email let mailbox = Mailbox::new(Some(user.username.clone()), address); @@ -203,7 +221,7 @@ pub(crate) async fn post( let next = mas_router::AccountVerifyEmail::new(verification.email.data) .and_maybe(query.post_auth_action); - let session = start_session(&mut txn, user).await?; + let session = start_session(&mut txn, &mut rng, &clock, user).await?; txn.commit().await?; diff --git a/crates/jose/src/jwt/signed.rs b/crates/jose/src/jwt/signed.rs index c55a8e3c..89a794bb 100644 --- a/crates/jose/src/jwt/signed.rs +++ b/crates/jose/src/jwt/signed.rs @@ -309,6 +309,7 @@ impl Jwt<'static, T> { S: Signature, T: Serialize, { + #[allow(clippy::disallowed_methods)] Self::sign_with_rng(thread_rng(), header, payload, key) } @@ -357,6 +358,7 @@ impl Jwt<'static, T> { #[cfg(test)] mod tests { + #![allow(clippy::disallowed_methods)] use mas_iana::jose::JsonWebSignatureAlg; use rand::thread_rng; diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index f3a154b8..05f7aaa1 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -19,6 +19,7 @@ tracing = "0.1.37" argon2 = { version = "0.4.1", features = ["password-hash"] } password-hash = { version = "0.4.2", features = ["std"] } rand = "0.8.5" +rand_chacha = "0.3.1" url = { version = "2.3.1", features = ["serde"] } uuid = "1.2.1" ulid = { version = "1.0.0", features = ["uuid", "serde"] } diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index f8084b56..6b0b9679 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -19,6 +19,7 @@ use mas_data_model::{ CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, Device, User, UserEmail, }; +use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres}; use thiserror::Error; use tokio::task; @@ -27,7 +28,7 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use crate::{user::lookup_user_by_username, DatabaseInconsistencyError, PostgresqlBackend}; +use crate::{user::lookup_user_by_username, Clock, DatabaseInconsistencyError, PostgresqlBackend}; struct CompatAccessTokenLookup { compat_access_token_id: Uuid, @@ -67,6 +68,7 @@ impl CompatAccessTokenLookupError { #[tracing::instrument(skip_all, err)] pub async fn lookup_active_compat_access_token( executor: impl PgExecutor<'_>, + clock: &Clock, token: &str, ) -> Result< ( @@ -112,7 +114,7 @@ pub async fn lookup_active_compat_access_token( // Check for token expiration if let Some(expires_at) = res.compat_access_token_expires_at { - if expires_at < Utc::now() { + if expires_at < clock.now() { return Err(CompatAccessTokenLookupError::Expired { when: expires_at }); } } @@ -311,7 +313,9 @@ pub async fn lookup_active_compat_refresh_token( err(Display), )] pub async fn compat_login( - conn: impl Acquire<'_, Database = Postgres>, + conn: impl Acquire<'_, Database = Postgres> + Send, + mut rng: impl Rng + Send, + clock: &Clock, username: &str, password: &str, device: Device, @@ -348,8 +352,8 @@ pub async fn compat_login( .instrument(tracing::info_span!("Verify hashed password")) .await??; - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_session.id", tracing::field::display(id)); sqlx::query!( @@ -392,12 +396,14 @@ pub async fn compat_login( )] pub async fn add_compat_access_token( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, session: &CompatSession, token: String, expires_after: Option, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); let expires_at = expires_after.map(|expires_after| created_at + expires_after); @@ -436,9 +442,10 @@ pub async fn add_compat_access_token( )] pub async fn expire_compat_access_token( executor: impl PgExecutor<'_>, + clock: &Clock, access_token: CompatAccessToken, ) -> Result<(), anyhow::Error> { - let expires_at = Utc::now(); + let expires_at = clock.now(); let res = sqlx::query!( r#" UPDATE compat_access_tokens @@ -474,12 +481,14 @@ pub async fn expire_compat_access_token( )] pub async fn add_compat_refresh_token( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, session: &CompatSession, access_token: &CompatAccessToken, token: String, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); sqlx::query!( @@ -514,9 +523,10 @@ pub async fn add_compat_refresh_token( )] pub async fn compat_logout( executor: impl PgExecutor<'_>, + clock: &Clock, token: &str, ) -> Result<(), anyhow::Error> { - let finished_at = Utc::now(); + let finished_at = clock.now(); // TODO: this does not check for token expiration let compat_session_id = sqlx::query_scalar!( r#" @@ -552,9 +562,10 @@ pub async fn compat_logout( )] pub async fn consume_compat_refresh_token( executor: impl PgExecutor<'_>, + clock: &Clock, refresh_token: CompatRefreshToken, ) -> Result<(), anyhow::Error> { - let consumed_at = Utc::now(); + let consumed_at = clock.now(); let res = sqlx::query!( r#" UPDATE compat_refresh_tokens @@ -587,11 +598,13 @@ pub async fn consume_compat_refresh_token( )] pub async fn insert_compat_sso_login( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, login_token: String, redirect_uri: Url, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); sqlx::query!( @@ -845,7 +858,9 @@ pub async fn get_compat_sso_login_by_token( err(Display), )] pub async fn fullfill_compat_sso_login( - conn: impl Acquire<'_, Database = Postgres>, + conn: impl Acquire<'_, Database = Postgres> + Send, + mut rng: impl Rng + Send, + clock: &Clock, user: User, mut login: CompatSsoLogin, device: Device, @@ -856,8 +871,8 @@ pub async fn fullfill_compat_sso_login( let mut txn = conn.begin().await.context("could not start transaction")?; - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("user.id", tracing::field::display(user.data)); sqlx::query!( @@ -883,7 +898,7 @@ pub async fn fullfill_compat_sso_login( finished_at: None, }; - let fulfilled_at = Utc::now(); + let fulfilled_at = clock.now(); sqlx::query!( r#" UPDATE compat_sso_logins @@ -924,6 +939,7 @@ pub async fn fullfill_compat_sso_login( )] pub async fn mark_compat_sso_login_as_exchanged( executor: impl PgExecutor<'_>, + clock: &Clock, mut login: CompatSsoLogin, ) -> Result, anyhow::Error> { let (fulfilled_at, session) = match login.state { @@ -934,7 +950,7 @@ pub async fn mark_compat_sso_login_as_exchanged( _ => bail!("sso login in wrong state"), }; - let exchanged_at = Utc::now(); + let exchanged_at = clock.now(); sqlx::query!( r#" UPDATE compat_sso_logins diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 6c5d8471..06a71bb7 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -15,7 +15,12 @@ //! Interactions with the database #![forbid(unsafe_code)] -#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] +#![deny( + clippy::all, + clippy::str_to_string, + clippy::future_not_send, + rustdoc::broken_intra_doc_links +)] #![warn(clippy::pedantic)] #![allow( clippy::missing_errors_doc, @@ -23,12 +28,27 @@ clippy::module_name_repetitions )] +use chrono::{DateTime, Utc}; use mas_data_model::{StorageBackend, StorageBackendMarker}; use serde::Serialize; use sqlx::migrate::Migrator; use thiserror::Error; use ulid::Ulid; +#[derive(Default, Debug, Clone, Copy)] +pub struct Clock { + _private: (), +} + +impl Clock { + #[must_use] + pub fn now(&self) -> DateTime { + // This is the clock used elsewhere, it's fine to call Utc::now here + #[allow(clippy::disallowed_methods)] + Utc::now() + } +} + #[derive(Debug, Error)] #[error("database query returned an inconsistent state")] pub struct DatabaseInconsistencyError; diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index b7702b5c..1e57637c 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -15,13 +15,14 @@ use anyhow::Context; use chrono::{DateTime, Duration, Utc}; use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail}; +use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres}; use thiserror::Error; use ulid::Ulid; use uuid::Uuid; use super::client::{lookup_client, ClientFetchError}; -use crate::{DatabaseInconsistencyError, PostgresqlBackend}; +use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend}; #[tracing::instrument( skip_all, @@ -35,13 +36,15 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend}; )] pub async fn add_access_token( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, session: &Session, access_token: String, expires_after: Duration, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); + let created_at = clock.now(); let expires_at = created_at + expires_after; - let id = Ulid::from_datetime(created_at.into()); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("access_token.id", tracing::field::display(id)); @@ -243,9 +246,10 @@ where )] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, + clock: &Clock, access_token: AccessToken, ) -> anyhow::Result<()> { - let revoked_at = Utc::now(); + let revoked_at = clock.now(); let res = sqlx::query!( r#" UPDATE oauth2_access_tokens @@ -266,9 +270,9 @@ pub async fn revoke_access_token( } } -pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result { +pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result { // Cleanup token which expired more than 15 minutes ago - let threshold = Utc::now() - Duration::minutes(15); + let threshold = clock.now() - Duration::minutes(15); let res = sqlx::query!( r#" DELETE FROM oauth2_access_tokens diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 850e6bea..f8a7041b 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -24,13 +24,14 @@ use mas_data_model::{ }; use mas_iana::oauth::PkceCodeChallengeMethod; use oauth2_types::{requests::ResponseMode, scope::Scope}; +use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use url::Url; use uuid::Uuid; use super::client::lookup_client; -use crate::{DatabaseInconsistencyError, PostgresqlBackend}; +use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend}; #[tracing::instrument( skip_all, @@ -43,6 +44,8 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend}; #[allow(clippy::too_many_arguments)] pub async fn new_authorization_grant( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, client: Client, redirect_uri: Url, scope: Scope, @@ -67,8 +70,8 @@ pub async fn new_authorization_grant( let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); let code_str = code.as_ref().map(|c| &c.code); - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("grant.id", tracing::field::display(id)); sqlx::query!( @@ -504,11 +507,13 @@ pub async fn lookup_grant_by_code( )] pub async fn derive_session( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, grant: &AuthorizationGrant, browser_session: BrowserSession, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("session.id", tracing::field::display(id)); sqlx::query!( @@ -623,9 +628,10 @@ pub async fn give_consent_to_grant( )] pub async fn exchange_grant( executor: impl PgExecutor<'_>, + clock: &Clock, mut grant: AuthorizationGrant, ) -> Result, anyhow::Error> { - let exchanged_at = Utc::now(); + let exchanged_at = clock.now(); sqlx::query!( r#" UPDATE oauth2_authorization_grants diff --git a/crates/storage/src/oauth2/consent.rs b/crates/storage/src/oauth2/consent.rs index dd531994..b19100db 100644 --- a/crates/storage/src/oauth2/consent.rs +++ b/crates/storage/src/oauth2/consent.rs @@ -14,14 +14,14 @@ use std::str::FromStr; -use chrono::Utc; use mas_data_model::{Client, User}; use oauth2_types::scope::{Scope, ScopeToken}; +use rand::Rng; use sqlx::PgExecutor; use ulid::Ulid; use uuid::Uuid; -use crate::PostgresqlBackend; +use crate::{Clock, PostgresqlBackend}; #[tracing::instrument( skip_all, @@ -67,17 +67,19 @@ pub async fn fetch_client_consent( )] pub async fn insert_client_consent( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, user: &User, client: &Client, scope: &Scope, ) -> Result<(), anyhow::Error> { - let now = Utc::now(); + let now = clock.now(); let (tokens, ids): (Vec, Vec) = scope .iter() .map(|token| { ( token.to_string(), - Uuid::from(Ulid::from_datetime(now.into())), + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), ) }) .unzip(); diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index 38f59f22..6b5822b1 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use chrono::Utc; use mas_data_model::Session; use sqlx::PgExecutor; use uuid::Uuid; -use crate::PostgresqlBackend; +use crate::{Clock, PostgresqlBackend}; pub mod access_token; pub mod authorization_grant; @@ -37,9 +36,10 @@ pub mod refresh_token; )] pub async fn end_oauth_session( executor: impl PgExecutor<'_>, + clock: &Clock, session: Session, ) -> Result<(), anyhow::Error> { - let finished_at = Utc::now(); + let finished_at = clock.now(); let res = sqlx::query!( r#" UPDATE oauth2_sessions diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 9423e21b..94c2cef7 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -17,13 +17,14 @@ use chrono::{DateTime, Utc}; use mas_data_model::{ AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail, }; +use rand::Rng; use sqlx::{PgConnection, PgExecutor}; use thiserror::Error; use ulid::Ulid; use uuid::Uuid; use super::client::{lookup_client, ClientFetchError}; -use crate::{DatabaseInconsistencyError, PostgresqlBackend}; +use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend}; #[tracing::instrument( skip_all, @@ -38,12 +39,14 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend}; )] pub async fn add_refresh_token( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, session: &Session, access_token: AccessToken, refresh_token: String, ) -> anyhow::Result> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); sqlx::query!( @@ -263,9 +266,10 @@ pub async fn lookup_active_refresh_token( )] pub async fn consume_refresh_token( executor: impl PgExecutor<'_>, + clock: &Clock, refresh_token: &RefreshToken, ) -> Result<(), anyhow::Error> { - let consumed_at = Utc::now(); + let consumed_at = clock.now(); let res = sqlx::query!( r#" UPDATE oauth2_refresh_tokens diff --git a/crates/storage/src/user.rs b/crates/storage/src/user.rs index 6d5f3816..b9070ec6 100644 --- a/crates/storage/src/user.rs +++ b/crates/storage/src/user.rs @@ -22,7 +22,7 @@ use mas_data_model::{ UserEmailVerificationState, }; use password_hash::{PasswordHash, PasswordHasher, SaltString}; -use rand::thread_rng; +use rand::{CryptoRng, Rng}; use sqlx::{Acquire, PgExecutor, Postgres, Transaction}; use thiserror::Error; use tokio::task; @@ -31,6 +31,7 @@ use ulid::Ulid; use uuid::Uuid; use super::{DatabaseInconsistencyError, PostgresqlBackend}; +use crate::Clock; #[derive(Debug, Clone)] struct UserLookup { @@ -68,7 +69,9 @@ pub enum LoginError { err, )] pub async fn login( - conn: impl Acquire<'_, Database = Postgres>, + conn: impl Acquire<'_, Database = Postgres> + Send, + mut rng: impl Rng + Send, + clock: &Clock, username: &str, password: &str, ) -> Result, LoginError> { @@ -86,8 +89,8 @@ pub async fn login( } })?; - let mut session = start_session(&mut txn, user).await?; - authenticate_session(&mut txn, &mut session, password) + let mut session = start_session(&mut txn, &mut rng, clock, user).await?; + authenticate_session(&mut txn, &mut rng, clock, &mut session, password) .await .map_err(|source| { if matches!(source, AuthenticationError::Password { .. }) { @@ -230,10 +233,12 @@ pub async fn lookup_active_session( )] pub async fn start_session( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, user: User, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("user_session.id", tracing::field::display(id)); sqlx::query!( @@ -301,13 +306,16 @@ pub enum AuthenticationError { #[tracing::instrument( skip_all, fields( - session.id = %session.data, - user.id = %session.user.data + user.id = %session.user.data, + user_session.id = %session.data, + user_session_authentication.id, ), err, )] pub async fn authenticate_session( txn: &mut Transaction<'_, Postgres>, + mut rng: impl Rng + Send, + clock: &Clock, session: &mut BrowserSession, password: &str, ) -> Result<(), AuthenticationError> { @@ -341,8 +349,13 @@ pub async fn authenticate_session( .await??; // That went well, let's insert the auth info - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); + tracing::Span::current().record( + "user_session_authentication.id", + tracing::field::display(id), + ); + sqlx::query!( r#" INSERT INTO user_session_authentications @@ -376,12 +389,14 @@ pub async fn authenticate_session( )] pub async fn register_user( txn: &mut Transaction<'_, Postgres>, - phf: impl PasswordHasher, + mut rng: impl CryptoRng + Rng + Send, + clock: &Clock, + phf: impl PasswordHasher + Send, username: &str, password: &str, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("user.id", tracing::field::display(id)); sqlx::query!( @@ -405,7 +420,7 @@ pub async fn register_user( primary_email: None, }; - set_password(txn.borrow_mut(), phf, &user, password).await?; + set_password(txn.borrow_mut(), &mut rng, clock, phf, &user, password).await?; Ok(user) } @@ -420,15 +435,17 @@ pub async fn register_user( )] pub async fn set_password( executor: impl PgExecutor<'_>, - phf: impl PasswordHasher, + mut rng: impl CryptoRng + Rng + Send, + clock: &Clock, + phf: impl PasswordHasher + Send, user: &User, password: &str, ) -> Result<(), anyhow::Error> { - let created_at = Utc::now(); + let created_at = clock.now(); let id = Ulid::from_datetime(created_at.into()); tracing::Span::current().record("user_password.id", tracing::field::display(id)); - let salt = SaltString::generate(thread_rng()); + let salt = SaltString::generate(&mut rng); let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?; sqlx::query_scalar!( @@ -456,9 +473,10 @@ pub async fn set_password( )] pub async fn end_session( executor: impl PgExecutor<'_>, + clock: &Clock, session: &BrowserSession, ) -> Result<(), anyhow::Error> { - let now = Utc::now(); + let now = clock.now(); let res = sqlx::query!( r#" UPDATE user_sessions @@ -672,11 +690,13 @@ pub async fn get_user_email( )] pub async fn add_user_email( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, user: &User, email: String, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("user_email.id", tracing::field::display(id)); sqlx::query!( @@ -842,9 +862,10 @@ pub async fn lookup_user_email_by_id( )] pub async fn mark_user_email_as_verified( executor: impl PgExecutor<'_>, + clock: &Clock, mut email: UserEmail, ) -> Result, anyhow::Error> { - let confirmed_at = Utc::now(); + let confirmed_at = clock.now(); sqlx::query!( r#" UPDATE user_emails @@ -881,10 +902,11 @@ struct UserEmailConfirmationCodeLookup { )] pub async fn lookup_user_email_verification_code( executor: impl PgExecutor<'_>, + clock: &Clock, email: UserEmail, code: &str, ) -> Result, anyhow::Error> { - let now = Utc::now(); + let now = clock.now(); let res = sqlx::query_as!( UserEmailConfirmationCodeLookup, @@ -935,13 +957,14 @@ pub async fn lookup_user_email_verification_code( )] pub async fn consume_email_verification( executor: impl PgExecutor<'_>, + clock: &Clock, mut verification: UserEmailVerification, ) -> Result, anyhow::Error> { if !matches!(verification.state, UserEmailVerificationState::Valid) { bail!("user email verification in wrong state"); } - let consumed_at = Utc::now(); + let consumed_at = clock.now(); sqlx::query!( r#" @@ -974,12 +997,14 @@ pub async fn consume_email_verification( )] pub async fn add_user_email_verification_code( executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, email: UserEmail, max_age: chrono::Duration, code: String, ) -> Result, anyhow::Error> { - let created_at = Utc::now(); - let id = Ulid::from_datetime(created_at.into()); + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); let expires_at = created_at + max_age; @@ -1013,23 +1038,27 @@ pub async fn add_user_email_verification_code( #[cfg(test)] mod tests { + use rand::SeedableRng; + use super::*; #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_user_registration_and_login(pool: sqlx::PgPool) -> anyhow::Result<()> { + let clock = Clock::default(); + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); let mut txn = pool.begin().await?; let exists = username_exists(&mut txn, "john").await?; assert!(!exists); let hasher = Argon2::default(); - let user = register_user(&mut txn, hasher, "john", "hunter2").await?; + let user = register_user(&mut txn, &mut rng, &clock, hasher, "john", "hunter2").await?; assert_eq!(user.username, "john"); let exists = username_exists(&mut txn, "john").await?; assert!(exists); - let session = login(&mut txn, "john", "hunter2").await?; + let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?; assert_eq!(session.user.data, user.data); let user2 = lookup_user_by_username(&mut txn, "john").await?; diff --git a/crates/tasks/src/database.rs b/crates/tasks/src/database.rs index b80eec0d..5e72141e 100644 --- a/crates/tasks/src/database.rs +++ b/crates/tasks/src/database.rs @@ -14,13 +14,14 @@ //! Database-related tasks +use mas_storage::Clock; use sqlx::{Pool, Postgres}; use tracing::{debug, error, info}; use super::Task; #[derive(Clone)] -struct CleanupExpired(Pool); +struct CleanupExpired(Pool, Clock); impl std::fmt::Debug for CleanupExpired { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -31,7 +32,7 @@ impl std::fmt::Debug for CleanupExpired { #[async_trait::async_trait] impl Task for CleanupExpired { async fn run(&self) { - let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0).await; + let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await; match res { Ok(0) => { debug!("no token to clean up"); @@ -49,5 +50,6 @@ impl Task for CleanupExpired { /// Cleanup expired tokens #[must_use] pub fn cleanup_expired(pool: &Pool) -> impl Task + Clone { - CleanupExpired(pool.clone()) + // XXX: the clock should come from somewhere else + CleanupExpired(pool.clone(), Clock::default()) }