1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

data-model: simplify the compat sessions

This commit is contained in:
Quentin Gliech
2022-12-06 18:05:32 +01:00
parent feebbd0e97
commit 479e009931
9 changed files with 85 additions and 190 deletions

View File

@ -20,9 +20,10 @@ use rand::{
}; };
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid;
use url::Url; use url::Url;
use crate::{StorageBackend, StorageBackendMarker, User}; use crate::User;
static DEVICE_ID_LENGTH: usize = 10; static DEVICE_ID_LENGTH: usize = 10;
@ -81,123 +82,49 @@ impl TryFrom<String> for Device {
} }
} }
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(bound = "T: StorageBackend")] pub struct CompatSession {
pub struct CompatSession<T: StorageBackend> { pub id: Ulid,
#[serde(skip_serializing)]
pub data: T::CompatSessionData,
pub user: User, pub user: User,
pub device: Device, pub device: Device,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>, pub finished_at: Option<DateTime<Utc>>,
} }
impl<S: StorageBackendMarker> From<CompatSession<S>> for CompatSession<()> { #[derive(Debug, Clone, PartialEq, Eq)]
fn from(t: CompatSession<S>) -> Self { pub struct CompatAccessToken {
Self { pub id: Ulid,
data: (),
user: t.user,
device: t.device,
created_at: t.created_at,
finished_at: t.finished_at,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CompatAccessToken<T: StorageBackend> {
pub data: T::CompatAccessTokenData,
pub token: String, pub token: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>, pub expires_at: Option<DateTime<Utc>>,
} }
impl<S: StorageBackendMarker> From<CompatAccessToken<S>> for CompatAccessToken<()> { #[derive(Debug, Clone, PartialEq, Eq)]
fn from(t: CompatAccessToken<S>) -> Self { pub struct CompatRefreshToken {
Self { pub id: Ulid,
data: (),
token: t.token,
created_at: t.created_at,
expires_at: t.expires_at,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CompatRefreshToken<T: StorageBackend> {
pub data: T::CompatRefreshTokenData,
pub token: String, pub token: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
} }
impl<S: StorageBackendMarker> From<CompatRefreshToken<S>> for CompatRefreshToken<()> { #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
fn from(t: CompatRefreshToken<S>) -> Self { pub enum CompatSsoLoginState {
Self {
data: (),
token: t.token,
created_at: t.created_at,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(bound = "T: StorageBackend")]
pub enum CompatSsoLoginState<T: StorageBackend> {
Pending, Pending,
Fulfilled { Fulfilled {
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
session: CompatSession<T>, session: CompatSession,
}, },
Exchanged { Exchanged {
fulfilled_at: DateTime<Utc>, fulfilled_at: DateTime<Utc>,
exchanged_at: DateTime<Utc>, exchanged_at: DateTime<Utc>,
session: CompatSession<T>, session: CompatSession,
}, },
} }
impl<S: StorageBackendMarker> From<CompatSsoLoginState<S>> for CompatSsoLoginState<()> { #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
fn from(t: CompatSsoLoginState<S>) -> Self { pub struct CompatSsoLogin {
match t { pub id: Ulid,
CompatSsoLoginState::Pending => Self::Pending,
CompatSsoLoginState::Fulfilled {
fulfilled_at,
session,
} => Self::Fulfilled {
fulfilled_at,
session: session.into(),
},
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session,
} => Self::Exchanged {
fulfilled_at,
exchanged_at,
session: session.into(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(bound = "T: StorageBackend")]
pub struct CompatSsoLogin<T: StorageBackend> {
#[serde(skip_serializing)]
pub data: T::CompatSsoLoginData,
pub redirect_uri: Url, pub redirect_uri: Url,
pub login_token: String, pub login_token: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub state: CompatSsoLoginState<T>, pub state: CompatSsoLoginState,
}
impl<S: StorageBackendMarker> From<CompatSsoLogin<S>> for CompatSsoLogin<()> {
fn from(t: CompatSsoLogin<S>) -> Self {
Self {
data: (),
redirect_uri: t.redirect_uri,
login_token: t.login_token,
created_at: t.created_at,
state: t.state.into(),
}
}
} }

View File

@ -33,18 +33,10 @@ pub trait StorageBackend {
type ClientData: Data; type ClientData: Data;
type SessionData: Data; type SessionData: Data;
type AuthorizationGrantData: Data; type AuthorizationGrantData: Data;
type CompatAccessTokenData: Data;
type CompatRefreshTokenData: Data;
type CompatSessionData: Data;
type CompatSsoLoginData: Data;
} }
impl StorageBackend for () { impl StorageBackend for () {
type AuthorizationGrantData = (); type AuthorizationGrantData = ();
type ClientData = (); type ClientData = ();
type CompatAccessTokenData = ();
type CompatRefreshTokenData = ();
type CompatSessionData = ();
type CompatSsoLoginData = ();
type SessionData = (); type SessionData = ();
} }

View File

@ -15,7 +15,6 @@
use async_graphql::{Description, Object, ID}; use async_graphql::{Description, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::CompatSsoLoginState; use mas_data_model::CompatSsoLoginState;
use mas_storage::PostgresqlBackend;
use url::Url; use url::Url;
use super::{NodeType, User}; use super::{NodeType, User};
@ -23,13 +22,13 @@ use super::{NodeType, User};
/// A compat session represents a client session which used the legacy Matrix /// A compat session represents a client session which used the legacy Matrix
/// login API. /// login API.
#[derive(Description)] #[derive(Description)]
pub struct CompatSession(pub mas_data_model::CompatSession<PostgresqlBackend>); pub struct CompatSession(pub mas_data_model::CompatSession);
#[Object(use_type_description)] #[Object(use_type_description)]
impl CompatSession { impl CompatSession {
/// ID of the object. /// ID of the object.
pub async fn id(&self) -> ID { pub async fn id(&self) -> ID {
NodeType::CompatSession.id(self.0.data) NodeType::CompatSession.id(self.0.id)
} }
/// The user authorized for this session. /// The user authorized for this session.
@ -56,13 +55,13 @@ impl CompatSession {
/// A compat SSO login represents a login done through the legacy Matrix login /// A compat SSO login represents a login done through the legacy Matrix login
/// API, via the `m.login.sso` login method. /// API, via the `m.login.sso` login method.
#[derive(Description)] #[derive(Description)]
pub struct CompatSsoLogin(pub mas_data_model::CompatSsoLogin<PostgresqlBackend>); pub struct CompatSsoLogin(pub mas_data_model::CompatSsoLogin);
#[Object(use_type_description)] #[Object(use_type_description)]
impl CompatSsoLogin { impl CompatSsoLogin {
/// ID of the object. /// ID of the object.
pub async fn id(&self) -> ID { pub async fn id(&self) -> ID {
NodeType::CompatSsoLogin.id(self.0.data) NodeType::CompatSsoLogin.id(self.0.id)
} }
/// When the object was created. /// When the object was created.

View File

@ -94,7 +94,7 @@ impl User {
let mut connection = Connection::new(has_previous_page, has_next_page); let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|u| { connection.edges.extend(edges.into_iter().map(|u| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.data)), OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)),
CompatSsoLogin(u), CompatSsoLogin(u),
) )
})); }));

View File

@ -22,7 +22,7 @@ use mas_storage::{
get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged, get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged,
CompatSsoLoginLookupError, CompatSsoLoginLookupError,
}, },
Clock, LookupError, PostgresqlBackend, Clock, LookupError,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
@ -267,14 +267,14 @@ async fn token_login(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
clock: &Clock, clock: &Clock,
token: &str, token: &str,
) -> Result<CompatSession<PostgresqlBackend>, RouteError> { ) -> Result<CompatSession, RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token).await?; let login = get_compat_sso_login_by_token(&mut *txn, token).await?;
let now = clock.now(); let now = clock.now();
match login.state { match login.state {
CompatSsoLoginState::Pending => { CompatSsoLoginState::Pending => {
tracing::error!( tracing::error!(
compat_sso_login.id = %login.data, compat_sso_login.id = %login.id,
"Exchanged a token for a login that was not fullfilled yet" "Exchanged a token for a login that was not fullfilled yet"
); );
return Err(RouteError::InvalidLoginToken); return Err(RouteError::InvalidLoginToken);
@ -291,7 +291,7 @@ async fn token_login(
if now > exchanged_at + Duration::seconds(30) { if now > exchanged_at + Duration::seconds(30) {
// TODO: log that session out // TODO: log that session out
tracing::error!( tracing::error!(
compat_sso_login.id = %login.data, compat_sso_login.id = %login.id,
"Login token exchanged a second time more than 30s after" "Login token exchanged a second time more than 30s after"
); );
} }
@ -312,7 +312,7 @@ async fn user_password_login(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
username: String, username: String,
password: String, password: String,
) -> Result<CompatSession<PostgresqlBackend>, RouteError> { ) -> Result<CompatSession, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?; let (clock, mut rng) = crate::rng_and_clock()?;
let device = Device::generate(&mut rng); let device = Device::generate(&mut rng);

View File

@ -87,5 +87,5 @@ pub async fn get(
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?; let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action))) Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action)))
} }

View File

@ -31,7 +31,7 @@ use uuid::Uuid;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, QueryBuilderExt},
user::lookup_user_by_username, user::lookup_user_by_username,
Clock, DatabaseInconsistencyError, LookupError, PostgresqlBackend, Clock, DatabaseInconsistencyError, LookupError,
}; };
struct CompatAccessTokenLookup { struct CompatAccessTokenLookup {
@ -73,13 +73,7 @@ pub async fn lookup_active_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
token: &str, token: &str,
) -> Result< ) -> Result<(CompatAccessToken, CompatSession), CompatAccessTokenLookupError> {
(
CompatAccessToken<PostgresqlBackend>,
CompatSession<PostgresqlBackend>,
),
CompatAccessTokenLookupError,
> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatAccessTokenLookup, CompatAccessTokenLookup,
r#" r#"
@ -123,7 +117,7 @@ pub async fn lookup_active_compat_access_token(
} }
let token = CompatAccessToken { let token = CompatAccessToken {
data: res.compat_access_token_id.into(), id: res.compat_access_token_id.into(),
token: res.compat_access_token, token: res.compat_access_token,
created_at: res.compat_access_token_created_at, created_at: res.compat_access_token_created_at,
expires_at: res.compat_access_token_expires_at, expires_at: res.compat_access_token_expires_at,
@ -156,7 +150,7 @@ pub async fn lookup_active_compat_access_token(
let device = Device::try_from(res.compat_session_device_id).unwrap(); let device = Device::try_from(res.compat_session_device_id).unwrap();
let session = CompatSession { let session = CompatSession {
data: res.compat_session_id.into(), id: res.compat_session_id.into(),
user, user,
device, device,
created_at: res.compat_session_created_at, created_at: res.compat_session_created_at,
@ -204,14 +198,7 @@ impl LookupError for CompatRefreshTokenLookupError {
pub async fn lookup_active_compat_refresh_token( pub async fn lookup_active_compat_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
token: &str, token: &str,
) -> Result< ) -> Result<(CompatRefreshToken, CompatAccessToken, CompatSession), CompatRefreshTokenLookupError> {
(
CompatRefreshToken<PostgresqlBackend>,
CompatAccessToken<PostgresqlBackend>,
CompatSession<PostgresqlBackend>,
),
CompatRefreshTokenLookupError,
> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatRefreshTokenLookup, CompatRefreshTokenLookup,
r#" r#"
@ -255,13 +242,13 @@ pub async fn lookup_active_compat_refresh_token(
.await?; .await?;
let refresh_token = CompatRefreshToken { let refresh_token = CompatRefreshToken {
data: res.compat_refresh_token_id.into(), id: res.compat_refresh_token_id.into(),
token: res.compat_refresh_token, token: res.compat_refresh_token,
created_at: res.compat_refresh_token_created_at, created_at: res.compat_refresh_token_created_at,
}; };
let access_token = CompatAccessToken { let access_token = CompatAccessToken {
data: res.compat_access_token_id.into(), id: res.compat_access_token_id.into(),
token: res.compat_access_token, token: res.compat_access_token,
created_at: res.compat_access_token_created_at, created_at: res.compat_access_token_created_at,
expires_at: res.compat_access_token_expires_at, expires_at: res.compat_access_token_expires_at,
@ -294,7 +281,7 @@ pub async fn lookup_active_compat_refresh_token(
let device = Device::try_from(res.compat_session_device_id).unwrap(); let device = Device::try_from(res.compat_session_device_id).unwrap();
let session = CompatSession { let session = CompatSession {
data: res.compat_session_id.into(), id: res.compat_session_id.into(),
user, user,
device, device,
created_at: res.compat_session_created_at, created_at: res.compat_session_created_at,
@ -321,7 +308,7 @@ pub async fn compat_login(
username: &str, username: &str,
password: &str, password: &str,
device: Device, device: Device,
) -> Result<CompatSession<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatSession, anyhow::Error> {
let mut txn = conn.begin().await.context("could not start transaction")?; let mut txn = conn.begin().await.context("could not start transaction")?;
// First, lookup the user // First, lookup the user
@ -375,7 +362,7 @@ pub async fn compat_login(
.context("could not insert compat session")?; .context("could not insert compat session")?;
let session = CompatSession { let session = CompatSession {
data: id, id,
user, user,
device, device,
created_at, created_at,
@ -389,7 +376,7 @@ pub async fn compat_login(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
compat_session.id = %session.data, compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(), compat_session.device.id = session.device.as_str(),
compat_access_token.id, compat_access_token.id,
user.id = %session.user.id, user.id = %session.user.id,
@ -400,10 +387,10 @@ pub async fn add_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
session: &CompatSession<PostgresqlBackend>, session: &CompatSession,
token: String, token: String,
expires_after: Option<Duration>, expires_after: Option<Duration>,
) -> Result<CompatAccessToken<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatAccessToken, anyhow::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
@ -417,7 +404,7 @@ pub async fn add_compat_access_token(
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(session.data), Uuid::from(session.id),
token, token,
created_at, created_at,
expires_at, expires_at,
@ -428,7 +415,7 @@ pub async fn add_compat_access_token(
.context("could not insert compat access token")?; .context("could not insert compat access token")?;
Ok(CompatAccessToken { Ok(CompatAccessToken {
data: id, id,
token, token,
created_at, created_at,
expires_at, expires_at,
@ -438,14 +425,14 @@ pub async fn add_compat_access_token(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
compat_access_token.id = %access_token.data, compat_access_token.id = %access_token.id,
), ),
err(Display), err(Display),
)] )]
pub async fn expire_compat_access_token( pub async fn expire_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
access_token: CompatAccessToken<PostgresqlBackend>, access_token: CompatAccessToken,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let expires_at = clock.now(); let expires_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
@ -454,7 +441,7 @@ pub async fn expire_compat_access_token(
SET expires_at = $2 SET expires_at = $2
WHERE compat_access_token_id = $1 WHERE compat_access_token_id = $1
"#, "#,
Uuid::from(access_token.data), Uuid::from(access_token.id),
expires_at, expires_at,
) )
.execute(executor) .execute(executor)
@ -473,9 +460,9 @@ pub async fn expire_compat_access_token(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
compat_session.id = %session.data, compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(), compat_session.device.id = session.device.as_str(),
compat_access_token.id = %access_token.data, compat_access_token.id = %access_token.id,
compat_refresh_token.id, compat_refresh_token.id,
user.id = %session.user.id, user.id = %session.user.id,
), ),
@ -485,10 +472,10 @@ pub async fn add_compat_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
session: &CompatSession<PostgresqlBackend>, session: &CompatSession,
access_token: &CompatAccessToken<PostgresqlBackend>, access_token: &CompatAccessToken,
token: String, token: String,
) -> Result<CompatRefreshToken<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatRefreshToken, anyhow::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
@ -501,8 +488,8 @@ pub async fn add_compat_refresh_token(
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
"#, "#,
Uuid::from(id), Uuid::from(id),
Uuid::from(session.data), Uuid::from(session.id),
Uuid::from(access_token.data), Uuid::from(access_token.id),
token, token,
created_at, created_at,
) )
@ -512,7 +499,7 @@ pub async fn add_compat_refresh_token(
.context("could not insert compat refresh token")?; .context("could not insert compat refresh token")?;
Ok(CompatRefreshToken { Ok(CompatRefreshToken {
data: id, id,
token, token,
created_at, created_at,
}) })
@ -558,14 +545,14 @@ pub async fn compat_logout(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
compat_refresh_token.id = %refresh_token.data, compat_refresh_token.id = %refresh_token.id,
), ),
err(Display), err(Display),
)] )]
pub async fn consume_compat_refresh_token( pub async fn consume_compat_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
refresh_token: CompatRefreshToken<PostgresqlBackend>, refresh_token: CompatRefreshToken,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let consumed_at = clock.now(); let consumed_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
@ -574,7 +561,7 @@ pub async fn consume_compat_refresh_token(
SET consumed_at = $2 SET consumed_at = $2
WHERE compat_refresh_token_id = $1 WHERE compat_refresh_token_id = $1
"#, "#,
Uuid::from(refresh_token.data), Uuid::from(refresh_token.id),
consumed_at, consumed_at,
) )
.execute(executor) .execute(executor)
@ -604,7 +591,7 @@ pub async fn insert_compat_sso_login(
clock: &Clock, clock: &Clock,
login_token: String, login_token: String,
redirect_uri: Url, redirect_uri: Url,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatSsoLogin, anyhow::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
@ -626,7 +613,7 @@ pub async fn insert_compat_sso_login(
.context("could not insert compat SSO login")?; .context("could not insert compat SSO login")?;
Ok(CompatSsoLogin { Ok(CompatSsoLogin {
data: id, id,
login_token, login_token,
redirect_uri, redirect_uri,
created_at, created_at,
@ -654,7 +641,7 @@ struct CompatSsoLoginLookup {
user_email_confirmed_at: Option<DateTime<Utc>>, user_email_confirmed_at: Option<DateTime<Utc>>,
} }
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> { impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError; type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> { fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
@ -702,7 +689,7 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => { (Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => {
let device = Device::try_from(device_id).map_err(|_| DatabaseInconsistencyError)?; let device = Device::try_from(device_id).map_err(|_| DatabaseInconsistencyError)?;
Some(CompatSession { Some(CompatSession {
data: id.into(), id: id.into(),
user, user,
device, device,
created_at, created_at,
@ -734,7 +721,7 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
}; };
Ok(CompatSsoLogin { Ok(CompatSsoLogin {
data: res.compat_sso_login_id.into(), id: res.compat_sso_login_id.into(),
login_token: res.compat_sso_login_token, login_token: res.compat_sso_login_token,
redirect_uri, redirect_uri,
created_at: res.compat_sso_login_created_at, created_at: res.compat_sso_login_created_at,
@ -766,7 +753,7 @@ impl LookupError for CompatSsoLoginLookupError {
pub async fn get_compat_sso_login_by_id( pub async fn get_compat_sso_login_by_id(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
id: Ulid, id: Ulid,
) -> Result<CompatSsoLogin<PostgresqlBackend>, CompatSsoLoginLookupError> { ) -> Result<CompatSsoLogin, CompatSsoLoginLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatSsoLoginLookup, CompatSsoLoginLookup,
r#" r#"
@ -820,7 +807,7 @@ pub async fn get_paginated_user_compat_sso_logins(
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<(bool, bool, Vec<CompatSsoLogin<PostgresqlBackend>>), anyhow::Error> { ) -> Result<(bool, bool, Vec<CompatSsoLogin>), anyhow::Error> {
// TODO: this queries too much (like user info) which we probably don't need // TODO: this queries too much (like user info) which we probably don't need
// because we already have them // because we already have them
let mut query = QueryBuilder::new( let mut query = QueryBuilder::new(
@ -877,7 +864,7 @@ pub async fn get_paginated_user_compat_sso_logins(
pub async fn get_compat_sso_login_by_token( pub async fn get_compat_sso_login_by_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
token: &str, token: &str,
) -> Result<CompatSsoLogin<PostgresqlBackend>, CompatSsoLoginLookupError> { ) -> Result<CompatSsoLogin, CompatSsoLoginLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
CompatSsoLoginLookup, CompatSsoLoginLookup,
r#" r#"
@ -920,7 +907,7 @@ pub async fn get_compat_sso_login_by_token(
skip_all, skip_all,
fields( fields(
%user.id, %user.id,
compat_sso_login.id = %login.data, compat_sso_login.id = %login.id,
compat_sso_login.redirect_uri = %login.redirect_uri, compat_sso_login.redirect_uri = %login.redirect_uri,
compat_session.id, compat_session.id,
compat_session.device.id = device.as_str(), compat_session.device.id = device.as_str(),
@ -932,9 +919,9 @@ pub async fn fullfill_compat_sso_login(
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
user: User, user: User,
mut login: CompatSsoLogin<PostgresqlBackend>, mut login: CompatSsoLogin,
device: Device, device: Device,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatSsoLogin, anyhow::Error> {
if !matches!(login.state, CompatSsoLoginState::Pending) { if !matches!(login.state, CompatSsoLoginState::Pending) {
bail!("sso login in wrong state"); bail!("sso login in wrong state");
}; };
@ -961,7 +948,7 @@ pub async fn fullfill_compat_sso_login(
.context("could not insert compat session")?; .context("could not insert compat session")?;
let session = CompatSession { let session = CompatSession {
data: id, id,
user, user,
device, device,
created_at, created_at,
@ -978,8 +965,8 @@ pub async fn fullfill_compat_sso_login(
WHERE WHERE
compat_sso_login_id = $1 compat_sso_login_id = $1
"#, "#,
Uuid::from(login.data), Uuid::from(login.id),
Uuid::from(session.data), Uuid::from(session.id),
fulfilled_at, fulfilled_at,
) )
.execute(&mut txn) .execute(&mut txn)
@ -1002,7 +989,7 @@ pub async fn fullfill_compat_sso_login(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
compat_sso_login.id = %login.data, compat_sso_login.id = %login.id,
compat_sso_login.redirect_uri = %login.redirect_uri, compat_sso_login.redirect_uri = %login.redirect_uri,
), ),
err(Display), err(Display),
@ -1010,8 +997,8 @@ pub async fn fullfill_compat_sso_login(
pub async fn mark_compat_sso_login_as_exchanged( pub async fn mark_compat_sso_login_as_exchanged(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
mut login: CompatSsoLogin<PostgresqlBackend>, mut login: CompatSsoLogin,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatSsoLogin, anyhow::Error> {
let (fulfilled_at, session) = match login.state { let (fulfilled_at, session) = match login.state {
CompatSsoLoginState::Fulfilled { CompatSsoLoginState::Fulfilled {
fulfilled_at, fulfilled_at,
@ -1029,7 +1016,7 @@ pub async fn mark_compat_sso_login_as_exchanged(
WHERE WHERE
compat_sso_login_id = $1 compat_sso_login_id = $1
"#, "#,
Uuid::from(login.data), Uuid::from(login.id),
exchanged_at, exchanged_at,
) )
.execute(executor) .execute(executor)

View File

@ -107,10 +107,6 @@ pub struct PostgresqlBackend;
impl StorageBackend for PostgresqlBackend { impl StorageBackend for PostgresqlBackend {
type AuthorizationGrantData = Ulid; type AuthorizationGrantData = Ulid;
type ClientData = Ulid; type ClientData = Ulid;
type CompatAccessTokenData = Ulid;
type CompatRefreshTokenData = Ulid;
type CompatSessionData = Ulid;
type CompatSsoLoginData = Ulid;
type SessionData = Ulid; type SessionData = Ulid;
} }

View File

@ -256,7 +256,7 @@ pub enum PostAuthContextInner {
/// TODO: add the login context in there /// TODO: add the login context in there
ContinueCompatSsoLogin { ContinueCompatSsoLogin {
/// The compat SSO login request /// The compat SSO login request
login: Box<CompatSsoLogin<()>>, login: Box<CompatSsoLogin>,
}, },
/// Change the account password /// Change the account password
@ -512,7 +512,7 @@ impl ReauthContext {
/// Context used by the `sso.html` template /// Context used by the `sso.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct CompatSsoContext { pub struct CompatSsoContext {
login: CompatSsoLogin<()>, login: CompatSsoLogin,
action: PostAuthAction, action: PostAuthAction,
} }
@ -521,17 +521,16 @@ impl TemplateContext for CompatSsoContext {
where where
Self: Sized, Self: Sized,
{ {
let id = Ulid::from_datetime_with_source(now.into(), rng);
vec![CompatSsoContext { vec![CompatSsoContext {
login: CompatSsoLogin { login: CompatSsoLogin {
data: (), id,
redirect_uri: Url::parse("https://app.element.io/").unwrap(), redirect_uri: Url::parse("https://app.element.io/").unwrap(),
login_token: "abcdefghijklmnopqrstuvwxyz012345".into(), login_token: "abcdefghijklmnopqrstuvwxyz012345".into(),
created_at: now, created_at: now,
state: CompatSsoLoginState::Pending, state: CompatSsoLoginState::Pending,
}, },
action: PostAuthAction::ContinueCompatSsoLogin { action: PostAuthAction::ContinueCompatSsoLogin { data: id },
data: Ulid::from_datetime_with_source(now.into(), rng),
},
}] }]
} }
} }
@ -539,14 +538,9 @@ impl TemplateContext for CompatSsoContext {
impl CompatSsoContext { impl CompatSsoContext {
/// Constructs a context for the legacy SSO login page /// Constructs a context for the legacy SSO login page
#[must_use] #[must_use]
pub fn new<T>(login: T, action: PostAuthAction) -> Self pub fn new(login: CompatSsoLogin, action: PostAuthAction) -> Self
where where {
T: Into<CompatSsoLogin<()>>, Self { login, action }
{
Self {
login: login.into(),
action,
}
} }
} }