diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 20f8af07..878996da 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -54,6 +54,6 @@ pub use self::{ user_agent::{DeviceType, UserAgent}, users::{ Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail, - UserEmailVerification, UserEmailVerificationState, + UserEmailVerification, UserEmailVerificationState, UserRecoverySession, UserRecoveryTicket, }, }; diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 2be860ae..9e8bc94c 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021-2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -79,6 +79,44 @@ pub enum AuthenticationMethod { Unknown, } +/// A session to recover a user if they have lost their credentials +/// +/// For each session intiated, there may be multiple [`UserRecoveryTicket`]s +/// sent to the user, either because multiple [`User`] have the same email +/// address, or because the user asked to send the recovery email again. +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UserRecoverySession { + pub id: Ulid, + pub email: String, + pub user_agent: UserAgent, + pub ip_address: Option, + pub locale: String, + pub created_at: DateTime, + pub consumed_at: Option>, +} + +/// A single recovery ticket for a user recovery session +/// +/// Whenever a new recovery session is initiated, a new ticket is created for +/// each email address matching in the database. That ticket is sent by email, +/// as a link that the user can click to recover their account. +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UserRecoveryTicket { + pub id: Ulid, + pub user_recovery_session_id: Ulid, + pub user_email_id: Ulid, + pub ticket: String, + pub created_at: DateTime, + pub expires_at: DateTime, +} + +impl UserRecoveryTicket { + #[must_use] + pub fn active(&self, now: DateTime) -> bool { + now < self.expires_at + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct BrowserSession { pub id: Ulid, diff --git a/crates/storage-pg/.sqlx/query-1764715e59f879f6b917ca30f8e3c1de5910c7a46e7fe52d1fb3bfd5561ac320.json b/crates/storage-pg/.sqlx/query-1764715e59f879f6b917ca30f8e3c1de5910c7a46e7fe52d1fb3bfd5561ac320.json new file mode 100644 index 00000000..3a1232f2 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-1764715e59f879f6b917ca30f8e3c1de5910c7a46e7fe52d1fb3bfd5561ac320.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE user_recovery_sessions\n SET consumed_at = $1\n WHERE user_recovery_session_id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "1764715e59f879f6b917ca30f8e3c1de5910c7a46e7fe52d1fb3bfd5561ac320" +} diff --git a/crates/storage-pg/.sqlx/query-607262ccf28b672df51e4e5d371e5cc5119a7d6e7fe784112703c0406f28300f.json b/crates/storage-pg/.sqlx/query-607262ccf28b672df51e4e5d371e5cc5119a7d6e7fe784112703c0406f28300f.json new file mode 100644 index 00000000..3efb2e7f --- /dev/null +++ b/crates/storage-pg/.sqlx/query-607262ccf28b672df51e4e5d371e5cc5119a7d6e7fe784112703c0406f28300f.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n user_recovery_ticket_id\n , user_recovery_session_id\n , user_email_id\n , ticket\n , created_at\n , expires_at\n FROM user_recovery_tickets\n WHERE ticket = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user_recovery_ticket_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "user_recovery_session_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "user_email_id", + "type_info": "Uuid" + }, + { + "ordinal": 3, + "name": "ticket", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "expires_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false + ] + }, + "hash": "607262ccf28b672df51e4e5d371e5cc5119a7d6e7fe784112703c0406f28300f" +} diff --git a/crates/storage-pg/.sqlx/query-8275a440640ea28fd8f82e7df672e45a6eba981a0d621665ed8f8b60354b3389.json b/crates/storage-pg/.sqlx/query-8275a440640ea28fd8f82e7df672e45a6eba981a0d621665ed8f8b60354b3389.json new file mode 100644 index 00000000..ee2e27cd --- /dev/null +++ b/crates/storage-pg/.sqlx/query-8275a440640ea28fd8f82e7df672e45a6eba981a0d621665ed8f8b60354b3389.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO user_recovery_sessions (\n user_recovery_session_id\n , email\n , user_agent\n , ip_address\n , locale\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5, $6)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Inet", + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "8275a440640ea28fd8f82e7df672e45a6eba981a0d621665ed8f8b60354b3389" +} diff --git a/crates/storage-pg/.sqlx/query-d7a0e4fa2f168976505405c7e7800847f3379f7b57c0972659a35bfb68b0f6cd.json b/crates/storage-pg/.sqlx/query-d7a0e4fa2f168976505405c7e7800847f3379f7b57c0972659a35bfb68b0f6cd.json new file mode 100644 index 00000000..c9cebfad --- /dev/null +++ b/crates/storage-pg/.sqlx/query-d7a0e4fa2f168976505405c7e7800847f3379f7b57c0972659a35bfb68b0f6cd.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO user_recovery_tickets (\n user_recovery_ticket_id\n , user_recovery_session_id\n , user_email_id\n , ticket\n , created_at\n , expires_at\n )\n VALUES ($1, $2, $3, $4, $5, $6)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "d7a0e4fa2f168976505405c7e7800847f3379f7b57c0972659a35bfb68b0f6cd" +} diff --git a/crates/storage-pg/.sqlx/query-f46e87bbb149b35e1d13b2b3cd2bdeab3c28a56a395f52f001a7bb013a5dfece.json b/crates/storage-pg/.sqlx/query-f46e87bbb149b35e1d13b2b3cd2bdeab3c28a56a395f52f001a7bb013a5dfece.json new file mode 100644 index 00000000..8d036abc --- /dev/null +++ b/crates/storage-pg/.sqlx/query-f46e87bbb149b35e1d13b2b3cd2bdeab3c28a56a395f52f001a7bb013a5dfece.json @@ -0,0 +1,58 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n user_recovery_session_id\n , email\n , user_agent\n , ip_address as \"ip_address: IpAddr\"\n , locale\n , created_at\n , consumed_at\n FROM user_recovery_sessions\n WHERE user_recovery_session_id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user_recovery_session_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "user_agent", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "ip_address: IpAddr", + "type_info": "Inet" + }, + { + "ordinal": 4, + "name": "locale", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "consumed_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false, + true + ] + }, + "hash": "f46e87bbb149b35e1d13b2b3cd2bdeab3c28a56a395f52f001a7bb013a5dfece" +} diff --git a/crates/storage-pg/migrations/20240621080509_user_recovery.sql b/crates/storage-pg/migrations/20240621080509_user_recovery.sql new file mode 100644 index 00000000..b75c49b0 --- /dev/null +++ b/crates/storage-pg/migrations/20240621080509_user_recovery.sql @@ -0,0 +1,64 @@ +-- Copyright 2024 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Stores user recovery sessions for when the user lost their credentials. +CREATE TABLE "user_recovery_sessions" ( + "user_recovery_session_id" UUID NOT NULL + CONSTRAINT "user_recovery_sessions_pkey" + PRIMARY KEY, + + -- The email address for which the recovery session was requested + "email" TEXT NOT NULL, + + -- The user agent of the client that requested the recovery session + "user_agent" TEXT NOT NULL, + + -- The IP address of the client that requested the recovery session + "ip_address" INET, + + -- The language of the client that requested the recovery session + "locale" TEXT NOT NULL, + + -- When the recovery session was created + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the recovery session was consumed + "consumed_at" TIMESTAMP WITH TIME ZONE +); + +-- Stores the recovery tickets for a user recovery session. +CREATE TABLE "user_recovery_tickets" ( + "user_recovery_ticket_id" UUID NOT NULL + CONSTRAINT "user_recovery_tickets_pkey" + PRIMARY KEY, + + -- The recovery session this ticket belongs to + "user_recovery_session_id" UUID NOT NULL + REFERENCES "user_recovery_sessions" ("user_recovery_session_id") + ON DELETE CASCADE, + + -- The user_email for which the recovery ticket was generated + "user_email_id" UUID NOT NULL + REFERENCES "user_emails" ("user_email_id") + ON DELETE CASCADE, + + -- The recovery ticket + "ticket" TEXT NOT NULL, + + -- When the recovery ticket was created + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL, + + -- When the recovery ticket expires + "expires_at" TIMESTAMP WITH TIME ZONE NOT NULL +); diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index b51004d0..6e951051 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -1,4 +1,4 @@ -// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// Copyright 2022-2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ use crate::{ }, user::{ PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository, - PgUserRepository, PgUserTermsRepository, + PgUserRecoveryRepository, PgUserRepository, PgUserTermsRepository, }, DatabaseError, }; @@ -179,6 +179,12 @@ where Box::new(PgUserPasswordRepository::new(self.conn.as_mut())) } + fn user_recovery<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(PgUserRecoveryRepository::new(self.conn.as_mut())) + } + fn user_terms<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index 01ba7529..f3291ed6 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -28,6 +28,7 @@ use crate::{tracing::ExecuteExt, DatabaseError}; mod email; mod password; +mod recovery; mod session; mod terms; @@ -36,7 +37,8 @@ mod tests; pub use self::{ email::PgUserEmailRepository, password::PgUserPasswordRepository, - session::PgBrowserSessionRepository, terms::PgUserTermsRepository, + recovery::PgUserRecoveryRepository, session::PgBrowserSessionRepository, + terms::PgUserTermsRepository, }; /// An implementation of [`UserRepository`] for a PostgreSQL connection diff --git a/crates/storage-pg/src/user/recovery.rs b/crates/storage-pg/src/user/recovery.rs new file mode 100644 index 00000000..2f7829df --- /dev/null +++ b/crates/storage-pg/src/user/recovery.rs @@ -0,0 +1,337 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::IpAddr; + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{UserAgent, UserEmail, UserRecoverySession, UserRecoveryTicket}; +use mas_storage::{user::UserRecoveryRepository, Clock}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{DatabaseError, ExecuteExt}; + +/// An implementation of [`UserRecoveryRepository`] for a PostgreSQL connection +pub struct PgUserRecoveryRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUserRecoveryRepository<'c> { + /// Create a new [`PgUserRecoveryRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct UserRecoverySessionRow { + user_recovery_session_id: Uuid, + email: String, + user_agent: String, + ip_address: Option, + locale: String, + created_at: DateTime, + consumed_at: Option>, +} + +impl From for UserRecoverySession { + fn from(row: UserRecoverySessionRow) -> Self { + UserRecoverySession { + id: row.user_recovery_session_id.into(), + email: row.email, + user_agent: UserAgent::parse(row.user_agent), + ip_address: row.ip_address, + locale: row.locale, + created_at: row.created_at, + consumed_at: row.consumed_at, + } + } +} + +struct UserRecoveryTicketRow { + user_recovery_ticket_id: Uuid, + user_recovery_session_id: Uuid, + user_email_id: Uuid, + ticket: String, + created_at: DateTime, + expires_at: DateTime, +} + +impl From for UserRecoveryTicket { + fn from(row: UserRecoveryTicketRow) -> Self { + Self { + id: row.user_recovery_ticket_id.into(), + user_recovery_session_id: row.user_recovery_session_id.into(), + user_email_id: row.user_email_id.into(), + ticket: row.ticket, + created_at: row.created_at, + expires_at: row.expires_at, + } + } +} + +#[async_trait] +impl<'c> UserRecoveryRepository for PgUserRecoveryRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.user_recovery.lookup_session", + skip_all, + fields( + db.statement, + user_recovery_session.id = %id, + ), + err, + )] + async fn lookup_session( + &mut self, + id: Ulid, + ) -> Result, Self::Error> { + let row = sqlx::query_as!( + UserRecoverySessionRow, + r#" + SELECT + user_recovery_session_id + , email + , user_agent + , ip_address as "ip_address: IpAddr" + , locale + , created_at + , consumed_at + FROM user_recovery_sessions + WHERE user_recovery_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(row) = row else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[tracing::instrument( + name = "db.user_recovery.add_session", + skip_all, + fields( + db.statement, + user_recovery_session.id, + user_recovery_session.email = email, + user_recovery_session.user_agent = &*user_agent, + user_recovery_session.ip_address = ip_address.map(|ip| ip.to_string()), + ) + )] + async fn add_session( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + email: String, + user_agent: UserAgent, + ip_address: Option, + locale: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_recovery_session.id", tracing::field::display(id)); + sqlx::query!( + r#" + INSERT INTO user_recovery_sessions ( + user_recovery_session_id + , email + , user_agent + , ip_address + , locale + , created_at + ) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + Uuid::from(id), + &email, + &*user_agent, + ip_address as Option, + &locale, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let user_recovery_session = UserRecoverySession { + id, + email, + user_agent, + ip_address, + locale, + created_at, + consumed_at: None, + }; + + Ok(user_recovery_session) + } + + #[tracing::instrument( + name = "db.user_recovery.find_ticket", + skip_all, + fields( + db.statement, + user_recovery_ticket.id = ticket, + ), + err, + )] + async fn find_ticket( + &mut self, + ticket: &str, + ) -> Result, Self::Error> { + let row = sqlx::query_as!( + UserRecoveryTicketRow, + r#" + SELECT + user_recovery_ticket_id + , user_recovery_session_id + , user_email_id + , ticket + , created_at + , expires_at + FROM user_recovery_tickets + WHERE ticket = $1 + "#, + ticket, + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(row) = row else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[tracing::instrument( + name = "db.user_recovery.add_ticket", + skip_all, + fields( + db.statement, + user_recovery_ticket.id, + user_recovery_ticket.id = ticket, + %user_recovery_session.id, + %user_email.id, + ) + )] + async fn add_ticket( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_recovery_session: &UserRecoverySession, + user_email: &UserEmail, + ticket: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("user_recovery_ticket.id", tracing::field::display(id)); + + // TODO: move that to a parameter + let expires_at = created_at + Duration::minutes(10); + + sqlx::query!( + r#" + INSERT INTO user_recovery_tickets ( + user_recovery_ticket_id + , user_recovery_session_id + , user_email_id + , ticket + , created_at + , expires_at + ) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + Uuid::from(id), + Uuid::from(user_recovery_session.id), + Uuid::from(user_email.id), + &ticket, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + let ticket = UserRecoveryTicket { + id, + user_recovery_session_id: user_recovery_session.id, + user_email_id: user_email.id, + ticket, + created_at, + expires_at, + }; + + Ok(ticket) + } + + #[tracing::instrument( + name = "db.user_recovery.consume_ticket", + skip_all, + fields( + db.statement, + %user_recovery_ticket.id, + user_email.id = %user_recovery_ticket.user_email_id, + %user_recovery_session.id, + %user_recovery_session.email, + ), + err, + )] + async fn consume_ticket( + &mut self, + clock: &dyn Clock, + user_recovery_ticket: UserRecoveryTicket, + mut user_recovery_session: UserRecoverySession, + ) -> Result { + // We don't really use the ticket, we just want to make sure we drop it + let _ = user_recovery_ticket; + + // This should have been checked by the caller + if user_recovery_session.consumed_at.is_some() { + return Err(DatabaseError::invalid_operation()); + } + + let consumed_at = clock.now(); + + let res = sqlx::query!( + r#" + UPDATE user_recovery_sessions + SET consumed_at = $1 + WHERE user_recovery_session_id = $2 + "#, + consumed_at, + Uuid::from(user_recovery_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + user_recovery_session.consumed_at = Some(consumed_at); + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(user_recovery_session) + } +} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 65f4526d..dae46e9a 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -31,8 +31,8 @@ use crate::{ UpstreamOAuthSessionRepository, }, user::{ - BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository, - UserTermsRepository, + BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, + UserRecoveryRepository, UserRepository, UserTermsRepository, }, MapErr, }; @@ -149,6 +149,10 @@ pub trait RepositoryAccess: Send { fn user_password<'c>(&'c mut self) -> Box + 'c>; + /// Get an [`UserRecoveryRepository`] + fn user_recovery<'c>(&'c mut self) + -> Box + 'c>; + /// Get an [`UserTermsRepository`] fn user_terms<'c>(&'c mut self) -> Box + 'c>; @@ -322,6 +326,12 @@ mod impls { Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper)) } + fn user_recovery<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.user_recovery(), &mut self.mapper)) + } + fn user_terms<'c>(&'c mut self) -> Box + 'c> { Box::new(MapErr::new(self.inner.user_terms(), &mut self.mapper)) } @@ -456,6 +466,12 @@ mod impls { (**self).user_password() } + fn user_recovery<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).user_recovery() + } + fn user_terms<'c>(&'c mut self) -> Box + 'c> { (**self).user_terms() } diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 9f845f8e..80bc505a 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -23,12 +23,14 @@ use crate::{repository_impl, Clock}; mod email; mod password; +mod recovery; mod session; mod terms; pub use self::{ email::{UserEmailFilter, UserEmailRepository}, password::UserPasswordRepository, + recovery::UserRecoveryRepository, session::{BrowserSessionFilter, BrowserSessionRepository}, terms::UserTermsRepository, }; diff --git a/crates/storage/src/user/recovery.rs b/crates/storage/src/user/recovery.rs new file mode 100644 index 00000000..485bd913 --- /dev/null +++ b/crates/storage/src/user/recovery.rs @@ -0,0 +1,167 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::IpAddr; + +use async_trait::async_trait; +use mas_data_model::{UserAgent, UserEmail, UserRecoverySession, UserRecoveryTicket}; +use rand_core::RngCore; +use ulid::Ulid; + +use crate::{repository_impl, Clock}; + +/// A [`UserRecoveryRepository`] helps interacting with [`UserRecoverySession`] +/// and [`UserRecoveryTicket`] saved in the storage backend +#[async_trait] +pub trait UserRecoveryRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// Lookup an [`UserRecoverySession`] by its ID + /// + /// Returns `None` if no [`UserRecoverySession`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`UserRecoverySession`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn lookup_session( + &mut self, + id: Ulid, + ) -> Result, Self::Error>; + + /// Create a new [`UserRecoverySession`] for the given email + /// + /// Returns the newly created [`UserRecoverySession`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock to use + /// * `email`: The email to create the session for + /// * `user_agent`: The user agent of the browser which initiated the + /// session + /// * `ip_address`: The IP address of the browser which initiated the + /// session, if known + /// * `locale`: The locale of the browser which initiated the session + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add_session( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + email: String, + user_agent: UserAgent, + ip_address: Option, + locale: String, + ) -> Result; + + /// Find a [`UserRecoveryTicket`] by its ticket + /// + /// Returns `None` if no [`UserRecoveryTicket`] was found + /// + /// # Parameters + /// + /// * `ticket`: The ticket of the [`UserRecoveryTicket`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_ticket( + &mut self, + ticket: &str, + ) -> Result, Self::Error>; + + /// Add a [`UserRecoveryTicket`] to the given [`UserRecoverySession`] for + /// the given [`UserEmail`] + /// + /// # Parameters + /// + /// * `rng`: The random number generator to use + /// * `clock`: The clock to use + /// * `session`: The [`UserRecoverySession`] to add the ticket to + /// * `user_email`: The [`UserEmail`] to add the ticket for + /// * `ticket`: The ticket to add + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn add_ticket( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_recovery_session: &UserRecoverySession, + user_email: &UserEmail, + ticket: String, + ) -> Result; + + /// Consume a [`UserRecoveryTicket`] and mark the session as used + /// + /// # Parameters + /// + /// * `clock`: The clock to use to record the time of consumption + /// * `ticket`: The [`UserRecoveryTicket`] to consume + /// * `session`: The [`UserRecoverySession`] to mark as used + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// recovery session was already used + async fn consume_ticket( + &mut self, + clock: &dyn Clock, + user_recovery_ticket: UserRecoveryTicket, + user_recovery_session: UserRecoverySession, + ) -> Result; +} + +repository_impl!(UserRecoveryRepository: + async fn lookup_session(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn add_session( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + email: String, + user_agent: UserAgent, + ip_address: Option, + locale: String, + ) -> Result; + + async fn find_ticket( + &mut self, + ticket: &str, + ) -> Result, Self::Error>; + + async fn add_ticket( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &dyn Clock, + user_recovery_session: &UserRecoverySession, + user_email: &UserEmail, + ticket: String, + ) -> Result; + + async fn consume_ticket( + &mut self, + clock: &dyn Clock, + user_recovery_ticket: UserRecoveryTicket, + user_recovery_session: UserRecoverySession, + ) -> Result; +);