1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Save the authentication method on each authorization

This will help us logging out of the upstream.
This commit is contained in:
Quentin Gliech
2023-08-28 17:00:00 +02:00
parent 096386e9b9
commit d9a12de8a3
10 changed files with 183 additions and 58 deletions

View File

@@ -51,7 +51,7 @@ pub use self::{
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
}, },
users::{ users::{
Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail,
UserEmailVerificationState, UserEmailVerification, UserEmailVerificationState,
}, },
}; };

View File

@@ -64,6 +64,14 @@ pub struct Password {
pub struct Authentication { pub struct Authentication {
pub id: Ulid, pub id: Ulid,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub authentication_method: AuthenticationMethod,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum AuthenticationMethod {
Password { user_password_id: Ulid },
UpstreamOAuth2 { upstream_oauth2_session_id: Ulid },
Unknown,
} }
#[derive(Debug, Clone, PartialEq, Eq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]

View File

@@ -211,12 +211,13 @@ pub(crate) async fn get(
(Some(session), Some(user_id)) if session.user.id == user_id => { (Some(session), Some(user_id)) if session.user.id == user_id => {
// Session already linked, and link matches the currently logged // Session already linked, and link matches the currently logged
// user. Mark the session as consumed and renew the authentication. // user. Mark the session as consumed and renew the authentication.
repo.upstream_oauth_session() let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session) .consume(&clock, upstream_session)
.await?; .await?;
repo.browser_session() repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &link) .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?; .await?;
cookie_jar = cookie_jar.set_session(&session); cookie_jar = cookie_jar.set_session(&session);
@@ -265,12 +266,13 @@ pub(crate) async fn get(
let session = repo.browser_session().add(&mut rng, &clock, &user).await?; let session = repo.browser_session().add(&mut rng, &clock, &user).await?;
repo.upstream_oauth_session() let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session) .consume(&clock, upstream_session)
.await?; .await?;
repo.browser_session() repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &link) .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?; .await?;
cookie_jar = sessions_cookie cookie_jar = sessions_cookie
@@ -507,12 +509,13 @@ pub(crate) async fn post(
_ => return Err(RouteError::InvalidFormAction), _ => return Err(RouteError::InvalidFormAction),
}; };
repo.upstream_oauth_session() let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session) .consume(&clock, upstream_session)
.await?; .await?;
repo.browser_session() repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &link) .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?; .await?;
let cookie_jar = sessions_cookie let cookie_jar = sessions_cookie

View File

@@ -0,0 +1,40 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_session_authentication_id\n , created_at\n , user_password_id\n , upstream_oauth_authorization_session_id\n FROM user_session_authentications\n WHERE user_session_id = $1\n ORDER BY created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "user_session_authentication_id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 2,
"name": "user_password_id",
"type_info": "Uuid"
},
{
"ordinal": 3,
"name": "upstream_oauth_authorization_session_id",
"type_info": "Uuid"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false,
false,
true,
true
]
},
"hash": "4c2064fed8fa464ea3d2a1258fb0544dbf1493cad31a21c0cd7ddb57ed12de16"
}

View File

@@ -1,16 +1,17 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n ", "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at, user_password_id)\n VALUES ($1, $2, $3, $4)\n ",
"describe": { "describe": {
"columns": [], "columns": [],
"parameters": { "parameters": {
"Left": [ "Left": [
"Uuid", "Uuid",
"Uuid", "Uuid",
"Timestamptz" "Timestamptz",
"Uuid"
] ]
}, },
"nullable": [] "nullable": []
}, },
"hash": "41c1aafbd338c24476f27d342cf80eef7de2836e85b078232d143d6712fc2be4" "hash": "608366f45ecaf392ab69cddb12252b5efcc103c3383fa68b552295e2289d1f55"
} }

View File

@@ -0,0 +1,17 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)\n VALUES ($1, $2, $3, $4)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Timestamptz",
"Uuid"
]
},
"nullable": []
},
"hash": "9c9c65d4ca6847761d8f999253590082672b3782875cf3f5ba0b2f9d26e3a507"
}

View File

@@ -1,28 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_session_authentication_id AS id\n , created_at\n FROM user_session_authentications\n WHERE user_session_id = $1\n ORDER BY created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false,
false
]
},
"hash": "ce0dbf84b23f4d5cfbd068811149d88898d4c5df8ab557846e2f9184636f2dcf"
}

View File

@@ -0,0 +1,23 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- This adds the source of each authentication of a user_session
ALTER TABLE "user_session_authentications"
ADD COLUMN "user_password_id" UUID
REFERENCES "user_passwords" ("user_password_id")
ON DELETE SET NULL,
ADD COLUMN "upstream_oauth_authorization_session_id" UUID
REFERENCES "upstream_oauth_authorization_sessions" ("upstream_oauth_authorization_session_id")
ON DELETE SET NULL;

View File

@@ -14,7 +14,10 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; use mas_data_model::{
Authentication, AuthenticationMethod, BrowserSession, Password,
UpstreamOAuthAuthorizationSession, User,
};
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
use rand::RngCore; use rand::RngCore;
use sea_query::{Expr, PostgresQueryBuilder}; use sea_query::{Expr, PostgresQueryBuilder};
@@ -80,6 +83,42 @@ impl TryFrom<SessionLookup> for BrowserSession {
} }
} }
struct AuthenticationLookup {
user_session_authentication_id: Uuid,
created_at: DateTime<Utc>,
user_password_id: Option<Uuid>,
upstream_oauth_authorization_session_id: Option<Uuid>,
}
impl TryFrom<AuthenticationLookup> for Authentication {
type Error = DatabaseInconsistencyError;
fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.user_session_authentication_id);
let authentication_method = match (
value.user_password_id.map(Into::into),
value
.upstream_oauth_authorization_session_id
.map(Into::into),
) {
(Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
(None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
upstream_oauth2_session_id,
},
(None, None) => AuthenticationMethod::Unknown,
_ => {
return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
}
};
Ok(Authentication {
id,
created_at: value.created_at,
authentication_method,
})
}
}
#[async_trait] #[async_trait]
impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
type Error = DatabaseError; type Error = DatabaseError;
@@ -337,7 +376,6 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
user_session: &BrowserSession, user_session: &BrowserSession,
user_password: &Password, user_password: &Password,
) -> Result<Authentication, Self::Error> { ) -> Result<Authentication, Self::Error> {
let _user_password = user_password;
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng); let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record( tracing::Span::current().record(
@@ -348,18 +386,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
sqlx::query!( sqlx::query!(
r#" r#"
INSERT INTO user_session_authentications INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at) (user_session_authentication_id, user_session_id, created_at, user_password_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3, $4)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(user_session.id), Uuid::from(user_session.id),
created_at, created_at,
Uuid::from(user_password.id),
) )
.traced() .traced()
.execute(&mut *self.conn) .execute(&mut *self.conn)
.await?; .await?;
Ok(Authentication { id, created_at }) Ok(Authentication {
id,
created_at,
authentication_method: AuthenticationMethod::Password {
user_password_id: user_password.id,
},
})
} }
#[tracing::instrument( #[tracing::instrument(
@@ -368,7 +413,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
fields( fields(
db.statement, db.statement,
%user_session.id, %user_session.id,
%upstream_oauth_link.id, %upstream_oauth_session.id,
user_session_authentication.id, user_session_authentication.id,
), ),
err, err,
@@ -378,9 +423,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
user_session: &BrowserSession, user_session: &BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink, upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
) -> Result<Authentication, Self::Error> { ) -> Result<Authentication, Self::Error> {
let _upstream_oauth_link = upstream_oauth_link;
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng); let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record( tracing::Span::current().record(
@@ -391,18 +435,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
sqlx::query!( sqlx::query!(
r#" r#"
INSERT INTO user_session_authentications INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at) (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3, $4)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(user_session.id), Uuid::from(user_session.id),
created_at, created_at,
Uuid::from(upstream_oauth_session.id),
) )
.traced() .traced()
.execute(&mut *self.conn) .execute(&mut *self.conn)
.await?; .await?;
Ok(Authentication { id, created_at }) Ok(Authentication {
id,
created_at,
authentication_method: AuthenticationMethod::UpstreamOAuth2 {
upstream_oauth2_session_id: upstream_oauth_session.id,
},
})
} }
#[tracing::instrument( #[tracing::instrument(
@@ -419,10 +470,12 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
user_session: &BrowserSession, user_session: &BrowserSession,
) -> Result<Option<Authentication>, Self::Error> { ) -> Result<Option<Authentication>, Self::Error> {
let authentication = sqlx::query_as!( let authentication = sqlx::query_as!(
Authentication, AuthenticationLookup,
r#" r#"
SELECT user_session_authentication_id AS id SELECT user_session_authentication_id
, created_at , created_at
, user_password_id
, upstream_oauth_authorization_session_id
FROM user_session_authentications FROM user_session_authentications
WHERE user_session_id = $1 WHERE user_session_id = $1
ORDER BY created_at DESC ORDER BY created_at DESC
@@ -434,6 +487,11 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
.fetch_optional(&mut *self.conn) .fetch_optional(&mut *self.conn)
.await?; .await?;
Ok(authentication) let Some(authentication) = authentication else {
return Ok(None);
};
let authentication = Authentication::try_from(authentication)?;
Ok(Some(authentication))
} }
} }

View File

@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
use async_trait::async_trait; use async_trait::async_trait;
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; use mas_data_model::{
Authentication, BrowserSession, Password, UpstreamOAuthAuthorizationSession, User,
};
use rand_core::RngCore; use rand_core::RngCore;
use ulid::Ulid; use ulid::Ulid;
@@ -188,14 +190,15 @@ pub trait BrowserSessionRepository: Send + Sync {
user_password: &Password, user_password: &Password,
) -> Result<Authentication, Self::Error>; ) -> Result<Authentication, Self::Error>;
/// Authenticate a [`BrowserSession`] with the given [`UpstreamOAuthLink`] /// Authenticate a [`BrowserSession`] with the given
/// [`UpstreamOAuthAuthorizationSession`]
/// ///
/// # Parameters /// # Parameters
/// ///
/// * `rng`: The random number generator to use /// * `rng`: The random number generator to use
/// * `clock`: The clock used to generate timestamps /// * `clock`: The clock used to generate timestamps
/// * `user_session`: The session to authenticate /// * `user_session`: The session to authenticate
/// * `upstream_oauth_link`: The upstream OAuth link which was used to /// * `upstream_oauth_session`: The upstream OAuth session which was used to
/// authenticate /// authenticate
/// ///
/// # Errors /// # Errors
@@ -206,7 +209,7 @@ pub trait BrowserSessionRepository: Send + Sync {
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
user_session: &BrowserSession, user_session: &BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink, upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
) -> Result<Authentication, Self::Error>; ) -> Result<Authentication, Self::Error>;
/// Get the last successful authentication for a [`BrowserSession`] /// Get the last successful authentication for a [`BrowserSession`]
@@ -259,7 +262,7 @@ repository_impl!(BrowserSessionRepository:
rng: &mut (dyn RngCore + Send), rng: &mut (dyn RngCore + Send),
clock: &dyn Clock, clock: &dyn Clock,
user_session: &BrowserSession, user_session: &BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink, upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
) -> Result<Authentication, Self::Error>; ) -> Result<Authentication, Self::Error>;
async fn get_last_authentication( async fn get_last_authentication(