diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 2e69fb9e..6dfcd04a 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -51,7 +51,7 @@ pub use self::{ UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference, }, users::{ - Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, - UserEmailVerificationState, + Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail, + UserEmailVerification, UserEmailVerificationState, }, }; diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 9858356f..f71c0b26 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -64,6 +64,14 @@ pub struct Password { pub struct Authentication { pub id: Ulid, pub created_at: DateTime, + pub authentication_method: AuthenticationMethod, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub enum AuthenticationMethod { + Password { user_password_id: Ulid }, + UpstreamOAuth2 { upstream_oauth2_session_id: Ulid }, + Unknown, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index b2db2712..9b205ded 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -211,12 +211,13 @@ pub(crate) async fn get( (Some(session), Some(user_id)) if session.user.id == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. - repo.upstream_oauth_session() + let upstream_session = repo + .upstream_oauth_session() .consume(&clock, upstream_session) .await?; repo.browser_session() - .authenticate_with_upstream(&mut rng, &clock, &session, &link) + .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; cookie_jar = cookie_jar.set_session(&session); @@ -265,12 +266,13 @@ pub(crate) async fn get( let session = repo.browser_session().add(&mut rng, &clock, &user).await?; - repo.upstream_oauth_session() + let upstream_session = repo + .upstream_oauth_session() .consume(&clock, upstream_session) .await?; repo.browser_session() - .authenticate_with_upstream(&mut rng, &clock, &session, &link) + .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; cookie_jar = sessions_cookie @@ -507,12 +509,13 @@ pub(crate) async fn post( _ => return Err(RouteError::InvalidFormAction), }; - repo.upstream_oauth_session() + let upstream_session = repo + .upstream_oauth_session() .consume(&clock, upstream_session) .await?; repo.browser_session() - .authenticate_with_upstream(&mut rng, &clock, &session, &link) + .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) .await?; let cookie_jar = sessions_cookie diff --git a/crates/storage-pg/.sqlx/query-4c2064fed8fa464ea3d2a1258fb0544dbf1493cad31a21c0cd7ddb57ed12de16.json b/crates/storage-pg/.sqlx/query-4c2064fed8fa464ea3d2a1258fb0544dbf1493cad31a21c0cd7ddb57ed12de16.json new file mode 100644 index 00000000..4d1f69f1 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-4c2064fed8fa464ea3d2a1258fb0544dbf1493cad31a21c0cd7ddb57ed12de16.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT user_session_authentication_id\n , created_at\n , user_password_id\n , upstream_oauth_authorization_session_id\n FROM user_session_authentications\n WHERE user_session_id = $1\n ORDER BY created_at DESC\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user_session_authentication_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "user_password_id", + "type_info": "Uuid" + }, + { + "ordinal": 3, + "name": "upstream_oauth_authorization_session_id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + true, + true + ] + }, + "hash": "4c2064fed8fa464ea3d2a1258fb0544dbf1493cad31a21c0cd7ddb57ed12de16" +} diff --git a/crates/storage-pg/.sqlx/query-41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4.json b/crates/storage-pg/.sqlx/query-608366f45ecaf392ab69cddb12252b5efcc103c3383fa68b552295e2289d1f55.json similarity index 58% rename from crates/storage-pg/.sqlx/query-41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4.json rename to crates/storage-pg/.sqlx/query-608366f45ecaf392ab69cddb12252b5efcc103c3383fa68b552295e2289d1f55.json index 572f5748..4076e098 100644 --- a/crates/storage-pg/.sqlx/query-41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4.json +++ b/crates/storage-pg/.sqlx/query-608366f45ecaf392ab69cddb12252b5efcc103c3383fa68b552295e2289d1f55.json @@ -1,16 +1,17 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n ", + "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at, user_password_id)\n VALUES ($1, $2, $3, $4)\n ", "describe": { "columns": [], "parameters": { "Left": [ "Uuid", "Uuid", - "Timestamptz" + "Timestamptz", + "Uuid" ] }, "nullable": [] }, - "hash": "41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4" + "hash": "608366f45ecaf392ab69cddb12252b5efcc103c3383fa68b552295e2289d1f55" } diff --git a/crates/storage-pg/.sqlx/query-9c9c65d4ca6847761d8f999253590082672b3782875cf3f5ba0b2f9d26e3a507.json b/crates/storage-pg/.sqlx/query-9c9c65d4ca6847761d8f999253590082672b3782875cf3f5ba0b2f9d26e3a507.json new file mode 100644 index 00000000..7ffc942e --- /dev/null +++ b/crates/storage-pg/.sqlx/query-9c9c65d4ca6847761d8f999253590082672b3782875cf3f5ba0b2f9d26e3a507.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)\n VALUES ($1, $2, $3, $4)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "9c9c65d4ca6847761d8f999253590082672b3782875cf3f5ba0b2f9d26e3a507" +} diff --git a/crates/storage-pg/.sqlx/query-ce0dbf84b23f4d5cfbd068811149d88898d4c5df8ab557846e2f9184636f2dcf.json b/crates/storage-pg/.sqlx/query-ce0dbf84b23f4d5cfbd068811149d88898d4c5df8ab557846e2f9184636f2dcf.json deleted file mode 100644 index 4a715ab5..00000000 --- a/crates/storage-pg/.sqlx/query-ce0dbf84b23f4d5cfbd068811149d88898d4c5df8ab557846e2f9184636f2dcf.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT user_session_authentication_id AS id\n , created_at\n FROM user_session_authentications\n WHERE user_session_id = $1\n ORDER BY created_at DESC\n LIMIT 1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "ce0dbf84b23f4d5cfbd068811149d88898d4c5df8ab557846e2f9184636f2dcf" -} diff --git a/crates/storage-pg/migrations/20230828143553_user_session_authentication_source.sql b/crates/storage-pg/migrations/20230828143553_user_session_authentication_source.sql new file mode 100644 index 00000000..77a45348 --- /dev/null +++ b/crates/storage-pg/migrations/20230828143553_user_session_authentication_source.sql @@ -0,0 +1,23 @@ +-- Copyright 2023 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- This adds the source of each authentication of a user_session +ALTER TABLE "user_session_authentications" + ADD COLUMN "user_password_id" UUID + REFERENCES "user_passwords" ("user_password_id") + ON DELETE SET NULL, + + ADD COLUMN "upstream_oauth_authorization_session_id" UUID + REFERENCES "upstream_oauth_authorization_sessions" ("upstream_oauth_authorization_session_id") + ON DELETE SET NULL; \ No newline at end of file diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index 3b094c48..bc838d7e 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -14,7 +14,10 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_data_model::{ + Authentication, AuthenticationMethod, BrowserSession, Password, + UpstreamOAuthAuthorizationSession, User, +}; use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; use rand::RngCore; use sea_query::{Expr, PostgresQueryBuilder}; @@ -80,6 +83,42 @@ impl TryFrom for BrowserSession { } } +struct AuthenticationLookup { + user_session_authentication_id: Uuid, + created_at: DateTime, + user_password_id: Option, + upstream_oauth_authorization_session_id: Option, +} + +impl TryFrom for Authentication { + type Error = DatabaseInconsistencyError; + + fn try_from(value: AuthenticationLookup) -> Result { + let id = Ulid::from(value.user_session_authentication_id); + let authentication_method = match ( + value.user_password_id.map(Into::into), + value + .upstream_oauth_authorization_session_id + .map(Into::into), + ) { + (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id }, + (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 { + upstream_oauth2_session_id, + }, + (None, None) => AuthenticationMethod::Unknown, + _ => { + return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id)); + } + }; + + Ok(Authentication { + id, + created_at: value.created_at, + authentication_method, + }) + } +} + #[async_trait] impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { type Error = DatabaseError; @@ -337,7 +376,6 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { user_session: &BrowserSession, user_password: &Password, ) -> Result { - let _user_password = user_password; let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), rng); tracing::Span::current().record( @@ -348,18 +386,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { sqlx::query!( r#" INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) + (user_session_authentication_id, user_session_id, created_at, user_password_id) + VALUES ($1, $2, $3, $4) "#, Uuid::from(id), Uuid::from(user_session.id), created_at, + Uuid::from(user_password.id), ) .traced() .execute(&mut *self.conn) .await?; - Ok(Authentication { id, created_at }) + Ok(Authentication { + id, + created_at, + authentication_method: AuthenticationMethod::Password { + user_password_id: user_password.id, + }, + }) } #[tracing::instrument( @@ -368,7 +413,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { fields( db.statement, %user_session.id, - %upstream_oauth_link.id, + %upstream_oauth_session.id, user_session_authentication.id, ), err, @@ -378,9 +423,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { rng: &mut (dyn RngCore + Send), clock: &dyn Clock, user_session: &BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, + upstream_oauth_session: &UpstreamOAuthAuthorizationSession, ) -> Result { - let _upstream_oauth_link = upstream_oauth_link; let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), rng); tracing::Span::current().record( @@ -391,18 +435,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { sqlx::query!( r#" INSERT INTO user_session_authentications - (user_session_authentication_id, user_session_id, created_at) - VALUES ($1, $2, $3) + (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id) + VALUES ($1, $2, $3, $4) "#, Uuid::from(id), Uuid::from(user_session.id), created_at, + Uuid::from(upstream_oauth_session.id), ) .traced() .execute(&mut *self.conn) .await?; - Ok(Authentication { id, created_at }) + Ok(Authentication { + id, + created_at, + authentication_method: AuthenticationMethod::UpstreamOAuth2 { + upstream_oauth2_session_id: upstream_oauth_session.id, + }, + }) } #[tracing::instrument( @@ -419,10 +470,12 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { user_session: &BrowserSession, ) -> Result, Self::Error> { let authentication = sqlx::query_as!( - Authentication, + AuthenticationLookup, r#" - SELECT user_session_authentication_id AS id + SELECT user_session_authentication_id , created_at + , user_password_id + , upstream_oauth_authorization_session_id FROM user_session_authentications WHERE user_session_id = $1 ORDER BY created_at DESC @@ -434,6 +487,11 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { .fetch_optional(&mut *self.conn) .await?; - Ok(authentication) + let Some(authentication) = authentication else { + return Ok(None); + }; + + let authentication = Authentication::try_from(authentication)?; + Ok(Some(authentication)) } } diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index f517d9ed..d86a3b5e 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -13,7 +13,9 @@ // limitations under the License. use async_trait::async_trait; -use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; +use mas_data_model::{ + Authentication, BrowserSession, Password, UpstreamOAuthAuthorizationSession, User, +}; use rand_core::RngCore; use ulid::Ulid; @@ -188,14 +190,15 @@ pub trait BrowserSessionRepository: Send + Sync { user_password: &Password, ) -> Result; - /// Authenticate a [`BrowserSession`] with the given [`UpstreamOAuthLink`] + /// Authenticate a [`BrowserSession`] with the given + /// [`UpstreamOAuthAuthorizationSession`] /// /// # Parameters /// /// * `rng`: The random number generator to use /// * `clock`: The clock used to generate timestamps /// * `user_session`: The session to authenticate - /// * `upstream_oauth_link`: The upstream OAuth link which was used to + /// * `upstream_oauth_session`: The upstream OAuth session which was used to /// authenticate /// /// # Errors @@ -206,7 +209,7 @@ pub trait BrowserSessionRepository: Send + Sync { rng: &mut (dyn RngCore + Send), clock: &dyn Clock, user_session: &BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, + upstream_oauth_session: &UpstreamOAuthAuthorizationSession, ) -> Result; /// Get the last successful authentication for a [`BrowserSession`] @@ -259,7 +262,7 @@ repository_impl!(BrowserSessionRepository: rng: &mut (dyn RngCore + Send), clock: &dyn Clock, user_session: &BrowserSession, - upstream_oauth_link: &UpstreamOAuthLink, + upstream_oauth_session: &UpstreamOAuthAuthorizationSession, ) -> Result; async fn get_last_authentication(