diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 5dd18250..8148f796 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -209,6 +209,18 @@ }, "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " }, + "1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM user_email_confirmation_codes\n WHERE user_email_id = $1\n " + }, "2564bf6366eb59268c41fb25bb40d0e4e9e1fd1f9ea53b7a359c9025d7304223": { "describe": { "columns": [], diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 97aeee24..49c9c3a9 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -161,18 +161,88 @@ impl DatabaseInconsistencyError { } } -#[derive(Default, Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Clock { _private: (), + + #[cfg(test)] + mock: Option>, } impl Clock { #[must_use] pub fn now(&self) -> DateTime { + #[cfg(test)] + if let Some(timestamp) = &self.mock { + let timestamp = timestamp.load(std::sync::atomic::Ordering::Relaxed); + return chrono::TimeZone::timestamp_opt(&Utc, timestamp, 0).unwrap(); + } + // This is the clock used elsewhere, it's fine to call Utc::now here #[allow(clippy::disallowed_methods)] Utc::now() } + + #[cfg(test)] + pub fn mock() -> Self { + use std::sync::{atomic::AtomicI64, Arc}; + + use chrono::TimeZone; + + let datetime = Utc.with_ymd_and_hms(2022, 01, 16, 14, 40, 0).unwrap(); + let timestamp = datetime.timestamp(); + + Self { + mock: Some(Arc::new(AtomicI64::new(timestamp))), + _private: (), + } + } + + #[cfg(test)] + pub fn advance(&self, duration: chrono::Duration) { + let timestamp = self + .mock + .as_ref() + .expect("Clock::advance should only be called on mocked clocks in tests"); + timestamp.fetch_add(duration.num_seconds(), std::sync::atomic::Ordering::Relaxed); + } +} + +#[cfg(test)] +mod tests { + use chrono::Duration; + + use super::*; + + #[test] + fn test_mocked_clock() { + let clock = Clock::mock(); + + // Time should be frozen, and give out the same timestamp on each call + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_eq!(first, second); + + // Clock can be advanced by a fixed duration + clock.advance(Duration::seconds(10)); + let third = clock.now(); + assert_eq!(first + Duration::seconds(10), third); + } + + #[test] + fn test_real_clock() { + let clock = Clock::default(); + + // Time should not be frozen + let first = clock.now(); + std::thread::sleep(std::time::Duration::from_millis(10)); + let second = clock.now(); + + assert_ne!(first, second); + assert!(first < second); + } } pub mod compat; diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index d725dea5..2d5ad987 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -17,6 +17,7 @@ use chrono::{DateTime, Utc}; use mas_data_model::{User, UserEmail, UserEmailVerification, UserEmailVerificationState}; use rand::RngCore; use sqlx::{PgConnection, QueryBuilder}; +use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; @@ -405,7 +406,23 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { err, )] async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> { + let span = info_span!( + "db.user_email.remove.codes", + db.statement = tracing::field::Empty + ); sqlx::query!( + r#" + DELETE FROM user_email_confirmation_codes + WHERE user_email_id = $1 + "#, + Uuid::from(user_email.id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + let res = sqlx::query!( r#" DELETE FROM user_emails WHERE user_email_id = $1 @@ -416,6 +433,8 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { .execute(&mut *self.conn) .await?; + DatabaseError::ensure_affected_rows(&res, 1)?; + Ok(()) } diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 592cb59d..9dd3d2ca 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. +// Copyright 2021-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,9 @@ mod email; mod password; mod session; +#[cfg(test)] +mod tests; + pub use self::{ email::{PgUserEmailRepository, UserEmailRepository}, password::{PgUserPasswordRepository, UserPasswordRepository}, diff --git a/crates/storage/src/user/tests.rs b/crates/storage/src/user/tests.rs new file mode 100644 index 00000000..fca35ce0 --- /dev/null +++ b/crates/storage/src/user/tests.rs @@ -0,0 +1,394 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::Duration; +use rand::SeedableRng; +use rand_chacha::ChaChaRng; +use sqlx::PgPool; + +use crate::{ + user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository}, + Clock, PgRepository, Repository, +}; + +/// Test the user repository, by adding and looking up a user +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_repo(pool: PgPool) { + const USERNAME: &str = "john"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + // Initially, the user shouldn't exist + assert!(!repo.user().exists(USERNAME).await.unwrap()); + assert!(repo + .user() + .find_by_username(USERNAME) + .await + .unwrap() + .is_none()); + + // Adding the user should work + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // And now it should exist + assert!(repo.user().exists(USERNAME).await.unwrap()); + assert!(repo + .user() + .find_by_username(USERNAME) + .await + .unwrap() + .is_some()); + assert!(repo.user().lookup(user.id).await.unwrap().is_some()); + + // Adding a second time should give a conflict + assert!(repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .is_err()); + + repo.save().await.unwrap(); +} + +/// Test the user email repository, by trying out most of its methods +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_email_repo(pool: PgPool) { + const USERNAME: &str = "john"; + const CODE: &str = "012345"; + const CODE2: &str = "543210"; + const EMAIL: &str = "john@example.com"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // The user email should not exist yet + assert!(repo + .user_email() + .find(&user, &EMAIL) + .await + .unwrap() + .is_none()); + + assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); + + let user_email = repo + .user_email() + .add(&mut rng, &clock, &user, EMAIL.to_owned()) + .await + .unwrap(); + + assert_eq!(user_email.user_id, user.id); + assert_eq!(user_email.email, EMAIL); + assert!(user_email.confirmed_at.is_none()); + + assert_eq!(repo.user_email().count(&user).await.unwrap(), 1); + + assert!(repo + .user_email() + .find(&user, &EMAIL) + .await + .unwrap() + .is_some()); + + let user_email = repo + .user_email() + .lookup(user_email.id) + .await + .unwrap() + .expect("user email was not found"); + + assert_eq!(user_email.user_id, user.id); + assert_eq!(user_email.email, EMAIL); + + let verification = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::hours(8), + CODE.to_owned(), + ) + .await + .unwrap(); + + let verification_id = verification.id; + assert_eq!(verification.user_email_id, user_email.id); + assert_eq!(verification.code, CODE); + + // A single user email can have multiple verification at the same time + let _verification2 = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::hours(8), + CODE2.to_owned(), + ) + .await + .unwrap(); + + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, CODE) + .await + .unwrap() + .expect("user email verification was not found"); + + assert_eq!(verification.id, verification_id); + assert_eq!(verification.user_email_id, user_email.id); + assert_eq!(verification.code, CODE); + + // Consuming the verification code + repo.user_email() + .consume_verification_code(&clock, verification) + .await + .unwrap(); + + // Mark the email as verified + repo.user_email() + .mark_as_verified(&clock, user_email) + .await + .unwrap(); + + // Reload the user_email + let user_email = repo + .user_email() + .find(&user, &EMAIL) + .await + .unwrap() + .expect("user email was not found"); + + // The email should be marked as verified now + assert!(user_email.confirmed_at.is_some()); + + // Reload the verification + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, CODE) + .await + .unwrap() + .expect("user email verification was not found"); + + // Consuming a second time should not work + assert!(repo + .user_email() + .consume_verification_code(&clock, verification) + .await + .is_err()); + + // The user shouldn't have a primary email yet + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_none()); + + repo.user_email().set_as_primary(&user_email).await.unwrap(); + + // Reload the user + let user = repo + .user() + .lookup(user.id) + .await + .unwrap() + .expect("user was not found"); + + // Now it should have one + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_some()); + + // Deleting the user email should work + repo.user_email().remove(user_email).await.unwrap(); + assert_eq!(repo.user_email().count(&user).await.unwrap(), 0); + + // Reload the user + let user = repo + .user() + .lookup(user.id) + .await + .unwrap() + .expect("user was not found"); + + // The primary user email should be gone + assert!(repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .is_none()); + + repo.save().await.unwrap(); +} + +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_password_repo(pool: PgPool) { + const USERNAME: &str = "john"; + const FIRST_PASSWORD_HASH: &str = "doesntmatter"; + const SECOND_PASSWORD_HASH: &str = "alsodoesntmatter"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + // User should have no active password + assert!(repo.user_password().active(&user).await.unwrap().is_none()); + + // Insert a first password + let first_password = repo + .user_password() + .add( + &mut rng, + &clock, + &user, + 1, + FIRST_PASSWORD_HASH.to_owned(), + None, + ) + .await + .unwrap(); + + // User should now have an active password + let first_password_lookup = repo + .user_password() + .active(&user) + .await + .unwrap() + .expect("user should have an active password"); + + assert_eq!(first_password.id, first_password_lookup.id); + assert_eq!(first_password_lookup.hashed_password, FIRST_PASSWORD_HASH); + assert_eq!(first_password_lookup.version, 1); + assert_eq!(first_password_lookup.upgraded_from_id, None); + + // Getting the last inserted password is based on the clock, so we need to + // advance it + clock.advance(Duration::seconds(10)); + + let second_password = repo + .user_password() + .add( + &mut rng, + &clock, + &user, + 2, + SECOND_PASSWORD_HASH.to_owned(), + Some(&first_password), + ) + .await + .unwrap(); + + // User should now have an active password + let second_password_lookup = repo + .user_password() + .active(&user) + .await + .unwrap() + .expect("user should have an active password"); + + assert_eq!(second_password.id, second_password_lookup.id); + assert_eq!(second_password_lookup.hashed_password, SECOND_PASSWORD_HASH); + assert_eq!(second_password_lookup.version, 2); + assert_eq!( + second_password_lookup.upgraded_from_id, + Some(first_password.id) + ); + + repo.save().await.unwrap(); +} + +#[sqlx::test(migrator = "crate::MIGRATOR")] +async fn test_user_session(pool: PgPool) { + const USERNAME: &str = "john"; + + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = Clock::mock(); + + let user = repo + .user() + .add(&mut rng, &clock, USERNAME.to_owned()) + .await + .unwrap(); + + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0); + + let session = repo + .browser_session() + .add(&mut rng, &clock, &user) + .await + .unwrap(); + assert_eq!(session.user.id, user.id); + assert!(session.finished_at.is_none()); + + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 1); + + let session_lookup = repo + .browser_session() + .lookup(session.id) + .await + .unwrap() + .expect("user session not found"); + + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user.id, user.id); + assert!(session_lookup.finished_at.is_none()); + + // Finish the session + repo.browser_session() + .finish(&clock, session_lookup) + .await + .unwrap(); + + // The active session counter is back to 0 + assert_eq!(repo.browser_session().count_active(&user).await.unwrap(), 0); + + // Reload the session + let session_lookup = repo + .browser_session() + .lookup(session.id) + .await + .unwrap() + .expect("user session not found"); + + assert_eq!(session_lookup.id, session.id); + assert_eq!(session_lookup.user.id, user.id); + // This time the session is finished + assert!(session_lookup.finished_at.is_some()); +}