You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
storage: do less joins in compat sessions
This commit is contained in:
@ -14,8 +14,8 @@
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin,
|
||||
CompatSsoLoginState, Device, User,
|
||||
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
|
||||
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User,
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
|
||||
@ -29,71 +29,47 @@ use crate::{
|
||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
compat_access_token: String,
|
||||
compat_access_token_created_at: DateTime<Utc>,
|
||||
compat_access_token_expires_at: Option<DateTime<Utc>>,
|
||||
struct CompatSessionLookup {
|
||||
compat_session_id: Uuid,
|
||||
compat_session_created_at: DateTime<Utc>,
|
||||
compat_session_finished_at: Option<DateTime<Utc>>,
|
||||
compat_session_device_id: String,
|
||||
device_id: String,
|
||||
user_id: Uuid,
|
||||
created_at: DateTime<Utc>,
|
||||
finished_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn lookup_active_compat_access_token(
|
||||
pub async fn lookup_compat_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<Option<(CompatAccessToken, CompatSession)>, DatabaseError> {
|
||||
session_id: Ulid,
|
||||
) -> Result<Option<CompatSession>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
CompatSessionLookup,
|
||||
r#"
|
||||
SELECT ct.compat_access_token_id
|
||||
, ct.access_token AS "compat_access_token"
|
||||
, ct.created_at AS "compat_access_token_created_at"
|
||||
, ct.expires_at AS "compat_access_token_expires_at"
|
||||
, cs.compat_session_id
|
||||
, cs.created_at AS "compat_session_created_at"
|
||||
, cs.finished_at AS "compat_session_finished_at"
|
||||
, cs.device_id AS "compat_session_device_id"
|
||||
, cs.user_id AS "user_id!"
|
||||
|
||||
FROM compat_access_tokens ct
|
||||
INNER JOIN compat_sessions cs
|
||||
USING (compat_session_id)
|
||||
|
||||
WHERE ct.access_token = $1
|
||||
AND (ct.expires_at < $2 OR ct.expires_at IS NULL)
|
||||
AND cs.finished_at IS NULL
|
||||
SELECT compat_session_id
|
||||
, device_id
|
||||
, user_id
|
||||
, created_at
|
||||
, finished_at
|
||||
FROM compat_sessions
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
token,
|
||||
clock.now(),
|
||||
Uuid::from(session_id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch compat access token"))
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let token = CompatAccessToken {
|
||||
id: res.compat_access_token_id.into(),
|
||||
token: res.compat_access_token,
|
||||
created_at: res.compat_access_token_created_at,
|
||||
expires_at: res.compat_access_token_expires_at,
|
||||
};
|
||||
|
||||
let id = res.compat_session_id.into();
|
||||
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
|
||||
let device = Device::try_from(res.device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match res.compat_session_finished_at {
|
||||
let state = match res.finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
@ -103,103 +79,148 @@ pub async fn lookup_active_compat_access_token(
|
||||
state,
|
||||
user_id: res.user_id.into(),
|
||||
device,
|
||||
created_at: res.compat_session_created_at,
|
||||
created_at: res.created_at,
|
||||
};
|
||||
|
||||
Ok(Some((token, session)))
|
||||
Ok(Some(session))
|
||||
}
|
||||
|
||||
pub struct CompatRefreshTokenLookup {
|
||||
compat_refresh_token_id: Uuid,
|
||||
compat_refresh_token: String,
|
||||
compat_refresh_token_created_at: DateTime<Utc>,
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
compat_access_token: String,
|
||||
compat_access_token_created_at: DateTime<Utc>,
|
||||
compat_access_token_expires_at: Option<DateTime<Utc>>,
|
||||
access_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Uuid,
|
||||
compat_session_created_at: DateTime<Utc>,
|
||||
compat_session_finished_at: Option<DateTime<Utc>>,
|
||||
compat_session_device_id: String,
|
||||
user_id: Uuid,
|
||||
}
|
||||
|
||||
impl From<CompatAccessTokenLookup> for CompatAccessToken {
|
||||
fn from(value: CompatAccessTokenLookup) -> Self {
|
||||
Self {
|
||||
id: value.compat_access_token_id.into(),
|
||||
session_id: value.compat_session_id.into(),
|
||||
token: value.access_token,
|
||||
created_at: value.created_at,
|
||||
expires_at: value.expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn lookup_active_compat_refresh_token(
|
||||
pub async fn find_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<Option<(CompatRefreshToken, CompatAccessToken, CompatSession)>, DatabaseError> {
|
||||
) -> Result<Option<CompatAccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT cr.compat_refresh_token_id
|
||||
, cr.refresh_token AS "compat_refresh_token"
|
||||
, cr.created_at AS "compat_refresh_token_created_at"
|
||||
, ct.compat_access_token_id
|
||||
, ct.access_token AS "compat_access_token"
|
||||
, ct.created_at AS "compat_access_token_created_at"
|
||||
, ct.expires_at AS "compat_access_token_expires_at"
|
||||
, cs.compat_session_id
|
||||
, cs.created_at AS "compat_session_created_at"
|
||||
, cs.finished_at AS "compat_session_finished_at"
|
||||
, cs.device_id AS "compat_session_device_id"
|
||||
, cs.user_id
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_refresh_tokens cr
|
||||
INNER JOIN compat_sessions cs
|
||||
USING (compat_session_id)
|
||||
INNER JOIN compat_access_tokens ct
|
||||
USING (compat_access_token_id)
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE cr.refresh_token = $1
|
||||
AND cr.consumed_at IS NULL
|
||||
AND cs.finished_at IS NULL
|
||||
WHERE access_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_access_token.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<Option<CompatAccessToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
SELECT compat_access_token_id
|
||||
, access_token
|
||||
, created_at
|
||||
, expires_at
|
||||
, compat_session_id
|
||||
|
||||
FROM compat_access_tokens
|
||||
|
||||
WHERE compat_access_token_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
pub struct CompatRefreshTokenLookup {
|
||||
compat_refresh_token_id: Uuid,
|
||||
refresh_token: String,
|
||||
created_at: DateTime<Utc>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
compat_access_token_id: Uuid,
|
||||
compat_session_id: Uuid,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn find_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<Option<CompatRefreshToken>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
SELECT compat_refresh_token_id
|
||||
, refresh_token
|
||||
, created_at
|
||||
, consumed_at
|
||||
, compat_session_id
|
||||
, compat_access_token_id
|
||||
|
||||
FROM compat_refresh_tokens
|
||||
|
||||
WHERE refresh_token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch compat refresh token"))
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None); };
|
||||
|
||||
let state = match res.consumed_at {
|
||||
None => CompatRefreshTokenState::Valid,
|
||||
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
|
||||
};
|
||||
|
||||
let refresh_token = CompatRefreshToken {
|
||||
id: res.compat_refresh_token_id.into(),
|
||||
token: res.compat_refresh_token,
|
||||
created_at: res.compat_refresh_token_created_at,
|
||||
};
|
||||
|
||||
let access_token = CompatAccessToken {
|
||||
id: res.compat_access_token_id.into(),
|
||||
token: res.compat_access_token,
|
||||
created_at: res.compat_access_token_created_at,
|
||||
expires_at: res.compat_access_token_expires_at,
|
||||
};
|
||||
|
||||
let id = res.compat_session_id.into();
|
||||
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match res.compat_session_finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: res.user_id.into(),
|
||||
device,
|
||||
created_at: res.compat_session_created_at,
|
||||
session_id: res.compat_session_id.into(),
|
||||
access_token_id: res.compat_access_token_id.into(),
|
||||
token: res.refresh_token,
|
||||
created_at: res.created_at,
|
||||
};
|
||||
|
||||
Ok(Some((refresh_token, access_token, session)))
|
||||
Ok(Some(refresh_token))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -244,6 +265,7 @@ pub async fn add_compat_access_token(
|
||||
|
||||
Ok(CompatAccessToken {
|
||||
id,
|
||||
session_id: session.id,
|
||||
token,
|
||||
created_at,
|
||||
expires_at,
|
||||
@ -320,6 +342,9 @@ pub async fn add_compat_refresh_token(
|
||||
|
||||
Ok(CompatRefreshToken {
|
||||
id,
|
||||
state: CompatRefreshTokenState::default(),
|
||||
session_id: session.id,
|
||||
access_token_id: access_token.id,
|
||||
token,
|
||||
created_at,
|
||||
})
|
||||
@ -327,42 +352,35 @@ pub async fn add_compat_refresh_token(
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(compat_session.id),
|
||||
fields(%compat_session.id),
|
||||
err,
|
||||
)]
|
||||
pub async fn compat_logout(
|
||||
pub async fn end_compat_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
compat_session: CompatSession,
|
||||
) -> Result<CompatSession, DatabaseError> {
|
||||
let finished_at = clock.now();
|
||||
// TODO: this does not check for token expiration
|
||||
let res = sqlx::query_scalar!(
|
||||
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sessions cs
|
||||
SET finished_at = $2
|
||||
FROM compat_access_tokens ca
|
||||
WHERE ca.access_token = $1
|
||||
AND ca.compat_session_id = cs.compat_session_id
|
||||
AND cs.finished_at IS NULL
|
||||
RETURNING cs.compat_session_id
|
||||
WHERE compat_session_id = $1
|
||||
"#,
|
||||
token,
|
||||
Uuid::from(compat_session.id),
|
||||
finished_at,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
if let Some(compat_session_id) = res {
|
||||
tracing::Span::current().record(
|
||||
"compat_session.id",
|
||||
tracing::field::display(compat_session_id),
|
||||
);
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
let compat_session = compat_session
|
||||
.finish(finished_at)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(compat_session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -445,10 +463,6 @@ struct CompatSsoLoginLookup {
|
||||
compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
|
||||
compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
|
||||
compat_session_id: Option<Uuid>,
|
||||
compat_session_created_at: Option<DateTime<Utc>>,
|
||||
compat_session_finished_at: Option<DateTime<Utc>>,
|
||||
compat_session_device_id: Option<String>,
|
||||
user_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
@ -463,58 +477,21 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = match (
|
||||
res.compat_session_id,
|
||||
res.compat_session_device_id,
|
||||
res.compat_session_created_at,
|
||||
res.compat_session_finished_at,
|
||||
res.user_id,
|
||||
) {
|
||||
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user_id)) => {
|
||||
let id = id.into();
|
||||
let device = Device::try_from(device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("compat_sessions")
|
||||
.column("device")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let state = match finished_at {
|
||||
None => CompatSessionState::Valid,
|
||||
Some(finished_at) => CompatSessionState::Finished { finished_at },
|
||||
};
|
||||
|
||||
Some(CompatSession {
|
||||
id,
|
||||
state,
|
||||
user_id: user_id.into(),
|
||||
device,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
(None, None, None, None, None) => None,
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError::on("compat_sso_logins")
|
||||
.column("compat_session_id")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
let state = match (
|
||||
res.compat_sso_login_fulfilled_at,
|
||||
res.compat_sso_login_exchanged_at,
|
||||
session,
|
||||
res.compat_session_id,
|
||||
) {
|
||||
(None, None, None) => CompatSsoLoginState::Pending,
|
||||
(Some(fulfilled_at), None, Some(session)) => CompatSsoLoginState::Fulfilled {
|
||||
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session,
|
||||
session_id: session_id.into(),
|
||||
},
|
||||
(Some(fulfilled_at), Some(exchanged_at), Some(session)) => {
|
||||
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
|
||||
CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session,
|
||||
session_id: session_id.into(),
|
||||
}
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
|
||||
@ -550,15 +527,9 @@ pub async fn get_compat_sso_login_by_id(
|
||||
, cl.created_at AS "compat_sso_login_created_at"
|
||||
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
|
||||
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
|
||||
, cs.compat_session_id AS "compat_session_id?"
|
||||
, cs.created_at AS "compat_session_created_at?"
|
||||
, cs.finished_at AS "compat_session_finished_at?"
|
||||
, cs.device_id AS "compat_session_device_id?"
|
||||
, cs.user_id AS "user_id?"
|
||||
, cl.compat_session_id AS "compat_session_id"
|
||||
|
||||
FROM compat_sso_logins cl
|
||||
LEFT JOIN compat_sessions cs
|
||||
USING (compat_session_id)
|
||||
WHERE cl.compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
@ -589,8 +560,6 @@ pub async fn get_paginated_user_compat_sso_logins(
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
|
||||
// TODO: this queries too much (like user info) which we probably don't need
|
||||
// because we already have them
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT cl.compat_sso_login_id
|
||||
@ -599,14 +568,8 @@ pub async fn get_paginated_user_compat_sso_logins(
|
||||
, cl.created_at AS "compat_sso_login_created_at"
|
||||
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
|
||||
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
|
||||
, cs.compat_session_id AS "compat_session_id"
|
||||
, cs.created_at AS "compat_session_created_at"
|
||||
, cs.finished_at AS "compat_session_finished_at"
|
||||
, cs.device_id AS "compat_session_device_id"
|
||||
, cs.user_id
|
||||
, cl.compat_session_id AS "compat_session_id"
|
||||
FROM compat_sso_logins cl
|
||||
LEFT JOIN compat_sessions cs
|
||||
USING (compat_session_id)
|
||||
"#,
|
||||
);
|
||||
|
||||
@ -645,14 +608,8 @@ pub async fn get_compat_sso_login_by_token(
|
||||
, cl.created_at AS "compat_sso_login_created_at"
|
||||
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
|
||||
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
|
||||
, cs.compat_session_id AS "compat_session_id?"
|
||||
, cs.created_at AS "compat_session_created_at?"
|
||||
, cs.finished_at AS "compat_session_finished_at?"
|
||||
, cs.device_id AS "compat_session_device_id?"
|
||||
, cs.user_id AS "user_id?"
|
||||
, cl.compat_session_id AS "compat_session_id"
|
||||
FROM compat_sso_logins cl
|
||||
LEFT JOIN compat_sessions cs
|
||||
USING (compat_session_id)
|
||||
WHERE cl.login_token = $1
|
||||
"#,
|
||||
token,
|
||||
@ -739,7 +696,7 @@ pub async fn fullfill_compat_sso_login(
|
||||
|
||||
let fulfilled_at = clock.now();
|
||||
let compat_sso_login = compat_sso_login
|
||||
.fulfill(fulfilled_at, session)
|
||||
.fulfill(fulfilled_at, &session)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
sqlx::query!(
|
||||
r#"
|
||||
|
Reference in New Issue
Block a user