1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

storage: do less joins in compat sessions

This commit is contained in:
Quentin Gliech
2023-01-10 18:49:35 +01:00
parent 35787aa072
commit 920869b583
11 changed files with 616 additions and 542 deletions

View File

@ -14,8 +14,8 @@
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSessionState, CompatSsoLogin,
CompatSsoLoginState, Device, User,
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User,
};
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
@ -29,71 +29,47 @@ use crate::{
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
compat_access_token: String,
compat_access_token_created_at: DateTime<Utc>,
compat_access_token_expires_at: Option<DateTime<Utc>>,
struct CompatSessionLookup {
compat_session_id: Uuid,
compat_session_created_at: DateTime<Utc>,
compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: String,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
#[tracing::instrument(skip_all, err)]
pub async fn lookup_active_compat_access_token(
pub async fn lookup_compat_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
token: &str,
) -> Result<Option<(CompatAccessToken, CompatSession)>, DatabaseError> {
session_id: Ulid,
) -> Result<Option<CompatSession>, DatabaseError> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
CompatSessionLookup,
r#"
SELECT ct.compat_access_token_id
, ct.access_token AS "compat_access_token"
, ct.created_at AS "compat_access_token_created_at"
, ct.expires_at AS "compat_access_token_expires_at"
, cs.compat_session_id
, cs.created_at AS "compat_session_created_at"
, cs.finished_at AS "compat_session_finished_at"
, cs.device_id AS "compat_session_device_id"
, cs.user_id AS "user_id!"
FROM compat_access_tokens ct
INNER JOIN compat_sessions cs
USING (compat_session_id)
WHERE ct.access_token = $1
AND (ct.expires_at < $2 OR ct.expires_at IS NULL)
AND cs.finished_at IS NULL
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
token,
clock.now(),
Uuid::from(session_id),
)
.fetch_one(executor)
.instrument(info_span!("Fetch compat access token"))
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let token = CompatAccessToken {
id: res.compat_access_token_id.into(),
token: res.compat_access_token,
created_at: res.compat_access_token_created_at,
expires_at: res.compat_access_token_expires_at,
};
let id = res.compat_session_id.into();
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
let device = Device::try_from(res.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match res.compat_session_finished_at {
let state = match res.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
@ -103,103 +79,148 @@ pub async fn lookup_active_compat_access_token(
state,
user_id: res.user_id.into(),
device,
created_at: res.compat_session_created_at,
created_at: res.created_at,
};
Ok(Some((token, session)))
Ok(Some(session))
}
pub struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
compat_refresh_token: String,
compat_refresh_token_created_at: DateTime<Utc>,
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
compat_access_token: String,
compat_access_token_created_at: DateTime<Utc>,
compat_access_token_expires_at: Option<DateTime<Utc>>,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
compat_session_created_at: DateTime<Utc>,
compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: String,
user_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::type_complexity)]
pub async fn lookup_active_compat_refresh_token(
pub async fn find_compat_access_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<(CompatRefreshToken, CompatAccessToken, CompatSession)>, DatabaseError> {
) -> Result<Option<CompatAccessToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
CompatAccessTokenLookup,
r#"
SELECT cr.compat_refresh_token_id
, cr.refresh_token AS "compat_refresh_token"
, cr.created_at AS "compat_refresh_token_created_at"
, ct.compat_access_token_id
, ct.access_token AS "compat_access_token"
, ct.created_at AS "compat_access_token_created_at"
, ct.expires_at AS "compat_access_token_expires_at"
, cs.compat_session_id
, cs.created_at AS "compat_session_created_at"
, cs.finished_at AS "compat_session_finished_at"
, cs.device_id AS "compat_session_device_id"
, cs.user_id
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_refresh_tokens cr
INNER JOIN compat_sessions cs
USING (compat_session_id)
INNER JOIN compat_access_tokens ct
USING (compat_access_token_id)
FROM compat_access_tokens
WHERE cr.refresh_token = $1
AND cr.consumed_at IS NULL
AND cs.finished_at IS NULL
WHERE access_token = $1
"#,
token,
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
compat_access_token.id = %id,
),
err,
)]
pub async fn lookup_compat_access_token(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<CompatAccessToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
pub struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::type_complexity)]
pub async fn find_compat_refresh_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatRefreshToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
token,
)
.fetch_one(executor)
.instrument(info_span!("Fetch compat refresh token"))
.await
.to_option()?;
let Some(res) = res else { return Ok(None); };
let state = match res.consumed_at {
None => CompatRefreshTokenState::Valid,
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
};
let refresh_token = CompatRefreshToken {
id: res.compat_refresh_token_id.into(),
token: res.compat_refresh_token,
created_at: res.compat_refresh_token_created_at,
};
let access_token = CompatAccessToken {
id: res.compat_access_token_id.into(),
token: res.compat_access_token,
created_at: res.compat_access_token_created_at,
expires_at: res.compat_access_token_expires_at,
};
let id = res.compat_session_id.into();
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match res.compat_session_finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: res.user_id.into(),
device,
created_at: res.compat_session_created_at,
session_id: res.compat_session_id.into(),
access_token_id: res.compat_access_token_id.into(),
token: res.refresh_token,
created_at: res.created_at,
};
Ok(Some((refresh_token, access_token, session)))
Ok(Some(refresh_token))
}
#[tracing::instrument(
@ -244,6 +265,7 @@ pub async fn add_compat_access_token(
Ok(CompatAccessToken {
id,
session_id: session.id,
token,
created_at,
expires_at,
@ -320,6 +342,9 @@ pub async fn add_compat_refresh_token(
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: session.id,
access_token_id: access_token.id,
token,
created_at,
})
@ -327,42 +352,35 @@ pub async fn add_compat_refresh_token(
#[tracing::instrument(
skip_all,
fields(compat_session.id),
fields(%compat_session.id),
err,
)]
pub async fn compat_logout(
pub async fn end_compat_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
token: &str,
) -> Result<bool, sqlx::Error> {
compat_session: CompatSession,
) -> Result<CompatSession, DatabaseError> {
let finished_at = clock.now();
// TODO: this does not check for token expiration
let res = sqlx::query_scalar!(
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
FROM compat_access_tokens ca
WHERE ca.access_token = $1
AND ca.compat_session_id = cs.compat_session_id
AND cs.finished_at IS NULL
RETURNING cs.compat_session_id
WHERE compat_session_id = $1
"#,
token,
Uuid::from(compat_session.id),
finished_at,
)
.fetch_one(executor)
.await
.to_option()?;
.execute(executor)
.await?;
if let Some(compat_session_id) = res {
tracing::Span::current().record(
"compat_session.id",
tracing::field::display(compat_session_id),
);
Ok(true)
} else {
Ok(false)
}
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
#[tracing::instrument(
@ -445,10 +463,6 @@ struct CompatSsoLoginLookup {
compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
compat_session_created_at: Option<DateTime<Utc>>,
compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: Option<String>,
user_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
@ -463,58 +477,21 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
.source(e)
})?;
let session = match (
res.compat_session_id,
res.compat_session_device_id,
res.compat_session_created_at,
res.compat_session_finished_at,
res.user_id,
) {
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user_id)) => {
let id = id.into();
let device = Device::try_from(device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device")
.row(id)
.source(e)
})?;
let state = match finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
Some(CompatSession {
id,
state,
user_id: user_id.into(),
device,
created_at,
})
}
(None, None, None, None, None) => None,
_ => {
return Err(DatabaseInconsistencyError::on("compat_sso_logins")
.column("compat_session_id")
.row(id))
}
};
let state = match (
res.compat_sso_login_fulfilled_at,
res.compat_sso_login_exchanged_at,
session,
res.compat_session_id,
) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session)) => CompatSsoLoginState::Fulfilled {
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session)) => {
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
@ -550,15 +527,9 @@ pub async fn get_compat_sso_login_by_id(
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cs.compat_session_id AS "compat_session_id?"
, cs.created_at AS "compat_session_created_at?"
, cs.finished_at AS "compat_session_finished_at?"
, cs.device_id AS "compat_session_device_id?"
, cs.user_id AS "user_id?"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs
USING (compat_session_id)
WHERE cl.compat_sso_login_id = $1
"#,
Uuid::from(id),
@ -589,8 +560,6 @@ pub async fn get_paginated_user_compat_sso_logins(
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
// TODO: this queries too much (like user info) which we probably don't need
// because we already have them
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
@ -599,14 +568,8 @@ pub async fn get_paginated_user_compat_sso_logins(
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cs.compat_session_id AS "compat_session_id"
, cs.created_at AS "compat_session_created_at"
, cs.finished_at AS "compat_session_finished_at"
, cs.device_id AS "compat_session_device_id"
, cs.user_id
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs
USING (compat_session_id)
"#,
);
@ -645,14 +608,8 @@ pub async fn get_compat_sso_login_by_token(
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cs.compat_session_id AS "compat_session_id?"
, cs.created_at AS "compat_session_created_at?"
, cs.finished_at AS "compat_session_finished_at?"
, cs.device_id AS "compat_session_device_id?"
, cs.user_id AS "user_id?"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs
USING (compat_session_id)
WHERE cl.login_token = $1
"#,
token,
@ -739,7 +696,7 @@ pub async fn fullfill_compat_sso_login(
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, session)
.fulfill(fulfilled_at, &session)
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"