1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Save the authentication method on each authorization

This will help us logging out of the upstream.
This commit is contained in:
Quentin Gliech
2023-08-28 17:00:00 +02:00
parent 096386e9b9
commit d9a12de8a3
10 changed files with 183 additions and 58 deletions

View File

@@ -14,7 +14,10 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
use mas_data_model::{
Authentication, AuthenticationMethod, BrowserSession, Password,
UpstreamOAuthAuthorizationSession, User,
};
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
use rand::RngCore;
use sea_query::{Expr, PostgresQueryBuilder};
@@ -80,6 +83,42 @@ impl TryFrom<SessionLookup> for BrowserSession {
}
}
struct AuthenticationLookup {
user_session_authentication_id: Uuid,
created_at: DateTime<Utc>,
user_password_id: Option<Uuid>,
upstream_oauth_authorization_session_id: Option<Uuid>,
}
impl TryFrom<AuthenticationLookup> for Authentication {
type Error = DatabaseInconsistencyError;
fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
let id = Ulid::from(value.user_session_authentication_id);
let authentication_method = match (
value.user_password_id.map(Into::into),
value
.upstream_oauth_authorization_session_id
.map(Into::into),
) {
(Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
(None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
upstream_oauth2_session_id,
},
(None, None) => AuthenticationMethod::Unknown,
_ => {
return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
}
};
Ok(Authentication {
id,
created_at: value.created_at,
authentication_method,
})
}
}
#[async_trait]
impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
type Error = DatabaseError;
@@ -337,7 +376,6 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
user_session: &BrowserSession,
user_password: &Password,
) -> Result<Authentication, Self::Error> {
let _user_password = user_password;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
@@ -348,18 +386,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
sqlx::query!(
r#"
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
(user_session_authentication_id, user_session_id, created_at, user_password_id)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
created_at,
Uuid::from(user_password.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Authentication { id, created_at })
Ok(Authentication {
id,
created_at,
authentication_method: AuthenticationMethod::Password {
user_password_id: user_password.id,
},
})
}
#[tracing::instrument(
@@ -368,7 +413,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
fields(
db.statement,
%user_session.id,
%upstream_oauth_link.id,
%upstream_oauth_session.id,
user_session_authentication.id,
),
err,
@@ -378,9 +423,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
user_session: &BrowserSession,
upstream_oauth_link: &UpstreamOAuthLink,
upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
) -> Result<Authentication, Self::Error> {
let _upstream_oauth_link = upstream_oauth_link;
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record(
@@ -391,18 +435,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
sqlx::query!(
r#"
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
(user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user_session.id),
created_at,
Uuid::from(upstream_oauth_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(Authentication { id, created_at })
Ok(Authentication {
id,
created_at,
authentication_method: AuthenticationMethod::UpstreamOAuth2 {
upstream_oauth2_session_id: upstream_oauth_session.id,
},
})
}
#[tracing::instrument(
@@ -419,10 +470,12 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
user_session: &BrowserSession,
) -> Result<Option<Authentication>, Self::Error> {
let authentication = sqlx::query_as!(
Authentication,
AuthenticationLookup,
r#"
SELECT user_session_authentication_id AS id
SELECT user_session_authentication_id
, created_at
, user_password_id
, upstream_oauth_authorization_session_id
FROM user_session_authentications
WHERE user_session_id = $1
ORDER BY created_at DESC
@@ -434,6 +487,11 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
.fetch_optional(&mut *self.conn)
.await?;
Ok(authentication)
let Some(authentication) = authentication else {
return Ok(None);
};
let authentication = Authentication::try_from(authentication)?;
Ok(Some(authentication))
}
}