You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Save the authentication method on each authorization
This will help us logging out of the upstream.
This commit is contained in:
@@ -14,7 +14,10 @@
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
|
||||
use mas_data_model::{
|
||||
Authentication, AuthenticationMethod, BrowserSession, Password,
|
||||
UpstreamOAuthAuthorizationSession, User,
|
||||
};
|
||||
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
|
||||
use rand::RngCore;
|
||||
use sea_query::{Expr, PostgresQueryBuilder};
|
||||
@@ -80,6 +83,42 @@ impl TryFrom<SessionLookup> for BrowserSession {
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationLookup {
|
||||
user_session_authentication_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
user_password_id: Option<Uuid>,
|
||||
upstream_oauth_authorization_session_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<AuthenticationLookup> for Authentication {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
|
||||
let id = Ulid::from(value.user_session_authentication_id);
|
||||
let authentication_method = match (
|
||||
value.user_password_id.map(Into::into),
|
||||
value
|
||||
.upstream_oauth_authorization_session_id
|
||||
.map(Into::into),
|
||||
) {
|
||||
(Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
|
||||
(None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
|
||||
upstream_oauth2_session_id,
|
||||
},
|
||||
(None, None) => AuthenticationMethod::Unknown,
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Authentication {
|
||||
id,
|
||||
created_at: value.created_at,
|
||||
authentication_method,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
@@ -337,7 +376,6 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
user_session: &BrowserSession,
|
||||
user_password: &Password,
|
||||
) -> Result<Authentication, Self::Error> {
|
||||
let _user_password = user_password;
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record(
|
||||
@@ -348,18 +386,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications
|
||||
(user_session_authentication_id, user_session_id, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
(user_session_authentication_id, user_session_id, created_at, user_password_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
created_at,
|
||||
Uuid::from(user_password.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Authentication { id, created_at })
|
||||
Ok(Authentication {
|
||||
id,
|
||||
created_at,
|
||||
authentication_method: AuthenticationMethod::Password {
|
||||
user_password_id: user_password.id,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -368,7 +413,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
fields(
|
||||
db.statement,
|
||||
%user_session.id,
|
||||
%upstream_oauth_link.id,
|
||||
%upstream_oauth_session.id,
|
||||
user_session_authentication.id,
|
||||
),
|
||||
err,
|
||||
@@ -378,9 +423,8 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &dyn Clock,
|
||||
user_session: &BrowserSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
|
||||
) -> Result<Authentication, Self::Error> {
|
||||
let _upstream_oauth_link = upstream_oauth_link;
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
|
||||
tracing::Span::current().record(
|
||||
@@ -391,18 +435,25 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications
|
||||
(user_session_authentication_id, user_session_id, created_at)
|
||||
VALUES ($1, $2, $3)
|
||||
(user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(user_session.id),
|
||||
created_at,
|
||||
Uuid::from(upstream_oauth_session.id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(Authentication { id, created_at })
|
||||
Ok(Authentication {
|
||||
id,
|
||||
created_at,
|
||||
authentication_method: AuthenticationMethod::UpstreamOAuth2 {
|
||||
upstream_oauth2_session_id: upstream_oauth_session.id,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@@ -419,10 +470,12 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
user_session: &BrowserSession,
|
||||
) -> Result<Option<Authentication>, Self::Error> {
|
||||
let authentication = sqlx::query_as!(
|
||||
Authentication,
|
||||
AuthenticationLookup,
|
||||
r#"
|
||||
SELECT user_session_authentication_id AS id
|
||||
SELECT user_session_authentication_id
|
||||
, created_at
|
||||
, user_password_id
|
||||
, upstream_oauth_authorization_session_id
|
||||
FROM user_session_authentications
|
||||
WHERE user_session_id = $1
|
||||
ORDER BY created_at DESC
|
||||
@@ -434,6 +487,11 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
.fetch_optional(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
Ok(authentication)
|
||||
let Some(authentication) = authentication else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let authentication = Authentication::try_from(authentication)?;
|
||||
Ok(Some(authentication))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user