1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: repository pattern for the compat layer

This commit is contained in:
Quentin Gliech
2023-01-12 15:41:26 +01:00
parent 9f0c9f1466
commit 36396c0b45
18 changed files with 1738 additions and 1191 deletions

View File

@ -20,8 +20,9 @@ use url::Url;
use super::CompatSession;
use crate::InvalidTransitionError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub enum CompatSsoLoginState {
#[default]
Pending,
Fulfilled {
fulfilled_at: DateTime<Utc>,

View File

@ -15,7 +15,7 @@
use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID};
use chrono::{DateTime, Utc};
use mas_storage::{compat::lookup_compat_session, user::UserRepository, Repository};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository};
use sqlx::PgPool;
use url::Url;
@ -101,7 +101,9 @@ impl CompatSsoLogin {
let Some(session_id) = self.0.session_id() else { return Ok(None) };
let mut conn = ctx.data::<PgPool>()?.acquire().await?;
let session = lookup_compat_session(&mut conn, session_id)
let session = conn
.compat_session()
.lookup(session_id)
.await?
.context("Could not load compat session")?;

View File

@ -18,6 +18,7 @@ use async_graphql::{
};
use chrono::{DateTime, Utc};
use mas_storage::{
compat::CompatSsoLoginRepository,
oauth2::OAuth2SessionRepository,
user::{BrowserSessionRepository, UserEmailRepository},
Repository, UpstreamOAuthLinkRepository,
@ -96,14 +97,13 @@ impl User {
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
.transpose()?;
let (has_previous_page, has_next_page, edges) =
mas_storage::compat::get_paginated_user_compat_sso_logins(
&mut conn, &self.0, before_id, after_id, first, last,
)
let page = conn
.compat_sso_login()
.list_paginated(&self.0, before_id, after_id, first, last)
.await?;
let mut connection = Connection::new(has_previous_page, has_next_page);
connection.edges.extend(edges.into_iter().map(|u| {
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
connection.edges.extend(page.edges.into_iter().map(|u| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)),
CompatSsoLogin(u),

View File

@ -18,8 +18,8 @@ use hyper::StatusCode;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User};
use mas_storage::{
compat::{
add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token,
lookup_compat_session, mark_compat_sso_login_as_exchanged, start_compat_session,
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
user::{UserPasswordRepository, UserRepository},
Clock, Repository,
@ -224,27 +224,17 @@ pub(crate) async fn post(
};
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
access_token,
expires_in,
)
.await?;
let access_token = txn
.compat_access_token()
.add(&mut rng, &clock, &session, access_token, expires_in)
.await?;
let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&access_token,
refresh_token,
)
.await?;
let refresh_token = txn
.compat_refresh_token()
.add(&mut rng, &clock, &session, &access_token, refresh_token)
.await?;
Some(refresh_token.token)
} else {
None
@ -266,7 +256,9 @@ async fn token_login(
clock: &Clock,
token: &str,
) -> Result<(CompatSession, User), RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token)
let login = txn
.compat_sso_login()
.find_by_token(token)
.await?
.ok_or(RouteError::InvalidLoginToken)?;
@ -308,7 +300,9 @@ async fn token_login(
}
};
let session = lookup_compat_session(&mut *txn, session_id)
let session = txn
.compat_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
@ -318,7 +312,7 @@ async fn token_login(
.await?
.ok_or(RouteError::UserNotFound)?;
mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
txn.compat_sso_login().exchange(clock, login).await?;
Ok((session, user))
}
@ -374,7 +368,10 @@ async fn user_password_login(
// Now that the user credentials have been verified, start a new compat session
let device = Device::generate(&mut rng);
let session = start_compat_session(&mut *txn, &mut rng, &clock, &user, device).await?;
let session = txn
.compat_session()
.add(&mut rng, &clock, &user, device)
.await?;
Ok((session, user))
}

View File

@ -29,7 +29,10 @@ use mas_axum_utils::{
use mas_data_model::Device;
use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository},
Repository,
};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
@ -87,7 +90,9 @@ pub async fn get(
return Ok((cookie_jar, destination.go()).into_response());
}
let login = get_compat_sso_login_by_id(&mut conn, id)
let login = conn
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
@ -149,7 +154,9 @@ pub async fn post(
return Ok((cookie_jar, destination.go()).into_response());
}
let login = get_compat_sso_login_by_id(&mut txn, id)
let login = txn
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
@ -181,8 +188,14 @@ pub async fn post(
};
let device = Device::generate(&mut rng);
let _login =
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, &session.user, login, device).await?;
let compat_session = txn
.compat_session()
.add(&mut rng, &clock, &session.user, device)
.await?;
txn.compat_sso_login()
.fulfill(&clock, login, &compat_session)
.await?;
txn.commit().await?;

View File

@ -19,7 +19,7 @@ use axum::{
};
use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::compat::insert_compat_sso_login;
use mas_storage::{compat::CompatSsoLoginRepository, Repository};
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use serde_with::serde;
@ -49,6 +49,7 @@ pub enum RouteError {
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@ -80,7 +81,10 @@ pub async fn get(
let token = Alphanumeric.sample_string(&mut rng, 32);
let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
let login = conn
.compat_sso_login()
.add(&mut rng, &clock, token, redirect_url)
.await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action)))
}

View File

@ -17,8 +17,8 @@ use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::TokenType;
use mas_storage::{
compat::{end_compat_session, find_compat_access_token, lookup_compat_session},
Clock,
compat::{CompatAccessTokenRepository, CompatSessionRepository},
Clock, Repository,
};
use sqlx::PgPool;
use thiserror::Error;
@ -83,17 +83,21 @@ pub(crate) async fn post(
return Err(RouteError::InvalidAuthorization);
}
let token = find_compat_access_token(&mut txn, token)
let token = txn
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::InvalidAuthorization)?;
let session = lookup_compat_session(&mut txn, token.session_id)
let session = txn
.compat_session()
.lookup(token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::InvalidAuthorization)?;
end_compat_session(&mut txn, &clock, session).await?;
txn.compat_session().finish(&clock, session).await?;
txn.commit().await?;

View File

@ -16,10 +16,9 @@ use axum::{extract::State, response::IntoResponse, Json};
use chrono::Duration;
use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::compat::{
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
expire_compat_access_token, find_compat_refresh_token, lookup_compat_access_token,
lookup_compat_session,
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
Repository,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds};
@ -101,7 +100,9 @@ pub(crate) async fn post(
return Err(RouteError::InvalidToken);
}
let refresh_token = find_compat_refresh_token(&mut txn, &input.refresh_token)
let refresh_token = txn
.compat_refresh_token()
.find_by_token(&input.refresh_token)
.await?
.ok_or(RouteError::InvalidToken)?;
@ -109,7 +110,9 @@ pub(crate) async fn post(
return Err(RouteError::RefreshTokenConsumed);
}
let session = lookup_compat_session(&mut txn, refresh_token.session_id)
let session = txn
.compat_session()
.lookup(refresh_token.session_id)
.await?
.ok_or(RouteError::UnknownSession)?;
@ -117,7 +120,9 @@ pub(crate) async fn post(
return Err(RouteError::InvalidSession);
}
let access_token = lookup_compat_access_token(&mut txn, refresh_token.access_token_id)
let access_token = txn
.compat_access_token()
.lookup(refresh_token.access_token_id)
.await?
.filter(|t| t.is_valid(clock.now()));
@ -125,29 +130,35 @@ pub(crate) async fn post(
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
let expires_in = Duration::minutes(5);
let new_access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
let new_access_token = txn
.compat_access_token()
.add(
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = txn
.compat_refresh_token()
.add(
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
txn.compat_refresh_token()
.consume(&clock, refresh_token)
.await?;
if let Some(access_token) = access_token {
expire_compat_access_token(&mut txn, &clock, access_token).await?;
txn.compat_access_token()
.expire(&clock, access_token)
.await?;
}
txn.commit().await?;

View File

@ -22,7 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter;
use mas_storage::{
compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session},
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
oauth2::{
access_token::find_access_token, refresh_token::lookup_refresh_token,
OAuth2SessionRepository,
@ -243,12 +243,16 @@ pub(crate) async fn post(
}
TokenType::CompatAccessToken => {
let token = find_compat_access_token(&mut conn, token)
let access_token = conn
.compat_access_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid(clock.now()))
.ok_or(RouteError::UnknownToken)?;
let session = lookup_compat_session(&mut conn, token.session_id)
let session = conn
.compat_session()
.lookup(access_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;
@ -269,9 +273,9 @@ pub(crate) async fn post(
client_id: Some("legacy".into()),
username: Some(user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: token.expires_at,
iat: Some(token.created_at),
nbf: Some(token.created_at),
exp: access_token.expires_at,
iat: Some(access_token.created_at),
nbf: Some(access_token.created_at),
sub: Some(user.sub),
aud: None,
iss: None,
@ -280,12 +284,16 @@ pub(crate) async fn post(
}
TokenType::CompatRefreshToken => {
let refresh_token = find_compat_refresh_token(&mut conn, token)
let refresh_token = conn
.compat_refresh_token()
.find_by_token(token)
.await?
.filter(|t| t.is_valid())
.ok_or(RouteError::UnknownToken)?;
let session = lookup_compat_session(&mut conn, refresh_token.session_id)
let session = conn
.compat_session()
.lookup(refresh_token.session_id)
.await?
.filter(|s| s.is_valid())
.ok_or(RouteError::UnknownToken)?;

View File

@ -15,7 +15,7 @@
use anyhow::Context;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id,
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository,
};
use mas_templates::{PostAuthContext, PostAuthContextInner};
@ -54,7 +54,9 @@ impl OptionalPostAuthAction {
}
PostAuthAction::ContinueCompatSsoLogin { id } => {
let login = get_compat_sso_login_by_id(conn, id)
let login = conn
.compat_sso_login()
.lookup(id)
.await?
.context("Failed to load compat SSO login")?;
let login = Box::new(login);

View File

@ -98,6 +98,21 @@
},
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n "
},
"18c3e56c72ef26bd42653c379767ffdd97bb06cb1686dfbf4099f3ad3d7b22c8": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n "
},
"1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": {
"describe": {
"columns": [
@ -168,22 +183,6 @@
},
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n "
},
"2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Text",
"Timestamptz",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n "
},
"262bee715889dc3e608639549600a131e641951ff979634e7c97afc74bbc1605": {
"describe": {
"columns": [],
@ -197,79 +196,6 @@
},
"query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n "
},
"2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n "
},
"360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Uuid",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n "
},
"3cf8e061206620071b39d0262cd165bb367b12b8e904180730d8acfa5af3d4b9": {
"describe": {
"columns": [
{
"name": "compat_session_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "device_id",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "user_id",
"ordinal": 2,
"type_info": "Uuid"
},
{
"name": "created_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "finished_at",
"ordinal": 4,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n "
},
"3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": {
"describe": {
"columns": [],
@ -384,6 +310,56 @@
},
"query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n "
},
"432e199b0d47fe299d840c91159726c0a4f89f65b4dc3e33ddad58aabf6b148b": {
"describe": {
"columns": [
{
"name": "compat_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "consumed_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
},
{
"name": "compat_access_token_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n "
},
"43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": {
"describe": {
"columns": [],
@ -465,20 +441,7 @@
},
"query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime<Utc>\"\n "
},
"4c4dbb846bb98d84f6b7f886f8af9833c7efe27b8b4f297077887232bef322ee": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n "
},
"4f080990eb6dd9f6128f3a1aee195b99d5f286fa0f6c27d744f73848343879d4": {
"478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4": {
"describe": {
"columns": [
{
@ -487,27 +450,27 @@
"type_info": "Uuid"
},
{
"name": "compat_sso_login_token",
"name": "login_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "compat_sso_login_redirect_uri",
"name": "redirect_uri",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "compat_sso_login_created_at",
"name": "created_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_sso_login_fulfilled_at",
"name": "fulfilled_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "compat_sso_login_exchanged_at",
"name": "exchanged_at",
"ordinal": 5,
"type_info": "Timestamptz"
},
@ -528,11 +491,25 @@
],
"parameters": {
"Left": [
"Uuid"
"Text"
]
}
},
"query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n\n FROM compat_sso_logins cl\n WHERE cl.compat_sso_login_id = $1\n "
"query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n "
},
"4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n "
},
"51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": {
"describe": {
@ -555,6 +532,50 @@
},
"query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n "
},
"53ad718642644b47a2d49f768d81bd993088526923769a9147281686c2d47591": {
"describe": {
"columns": [
{
"name": "compat_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n "
},
"583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": {
"describe": {
"columns": [],
@ -598,20 +619,6 @@
},
"query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n "
},
"60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n "
},
"62d05e8e4317bdb180298737d422e64d161c5ad3813913a6f7d67a53569ea76a": {
"describe": {
"columns": [],
@ -745,6 +752,21 @@
},
"query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n "
},
"6e21e7d816f806da9bb5176931bdb550dee05c44c9d93f53df95fe3b4a840347": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n "
},
"6f97b5f9ad0d4d15387150bea3839fb7f81015f7ceef61ecaadba64521895cff": {
"describe": {
"columns": [],
@ -782,6 +804,50 @@
},
"query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n "
},
"77dfa9fae1a9c77b70476d7da19d3313a02886994cfff0690451229fb5ae2f77": {
"describe": {
"columns": [
{
"name": "compat_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n "
},
"79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": {
"describe": {
"columns": [
@ -871,19 +937,6 @@
},
"query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n "
},
"7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n "
},
"836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": {
"describe": {
"columns": [
@ -1154,6 +1207,19 @@
},
"query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n "
},
"9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n "
},
"94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": {
"describe": {
"columns": [
@ -1174,18 +1240,21 @@
},
"query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n "
},
"99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": {
"9f7bdc034c618e47e49c467d0d7f5b8c297d055abe248cc876dbc12c5a7dc920": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Uuid",
"Text",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n "
"query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n "
},
"a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": {
"describe": {
@ -1243,6 +1312,22 @@
},
"query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n "
},
"a7f780528882a2ae66c45435215763eed0582264861436eab3f862e3eb12cab1": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Text",
"Timestamptz",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n "
},
"aa2fd69c595f94d8598715766a79671dba8f87b9d7af6ac30e3fa1fbc8cce28a": {
"describe": {
"columns": [
@ -1371,6 +1456,19 @@
},
"query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n "
},
"ab34912b42a48a8b5c8d63e271b99b7d0b690a2471873c6654b1b6cf2079b95c": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n "
},
"aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": {
"describe": {
"columns": [
@ -1652,6 +1750,19 @@
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n "
},
"bbf62633c561706a762089bbab2f76a9ba3e2ed3539ef16accb601fb609c2ec9": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
}
},
"query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n "
},
"bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": {
"describe": {
"columns": [],
@ -1696,106 +1807,6 @@
},
"query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n "
},
"c3e60701299be7728108b8967ec5396fb186adaac360d6a0152d25e4a4f46f46": {
"describe": {
"columns": [
{
"name": "compat_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n "
},
"c78246fc8737491352f71ea9410e79df8de88596c8197405cda36eb8c8187810": {
"describe": {
"columns": [
{
"name": "compat_sso_login_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "compat_sso_login_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "compat_sso_login_redirect_uri",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "compat_sso_login_created_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_sso_login_fulfilled_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "compat_sso_login_exchanged_at",
"ordinal": 5,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 6,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
true
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n FROM compat_sso_logins cl\n WHERE cl.login_token = $1\n "
},
"c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": {
"describe": {
"columns": [],
@ -1822,114 +1833,18 @@
},
"query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n "
},
"ca63558e877bd115aa7ca24de0cc3f78a13cb55105758fe0bd930da513f75504": {
"describe": {
"columns": [
{
"name": "compat_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "consumed_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
},
{
"name": "compat_access_token_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n "
},
"caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": {
"d0b403e9c843ef19fa5ad60bec32ebf14a1ba0d01681c3836366d3f55e7851f4": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n "
},
"cf43b82bdf534400f900cff3c5356083db0f9e5407e288b64f43d7ac100de058": {
"describe": {
"columns": [
{
"name": "compat_access_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "expires_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n "
"query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n "
},
"d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": {
"describe": {
@ -1951,21 +1866,6 @@
},
"query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n "
},
"d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Timestamptz"
]
}
},
"query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n "
},
"d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": {
"describe": {
"columns": [
@ -2211,6 +2111,112 @@
},
"query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n "
},
"ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82": {
"describe": {
"columns": [
{
"name": "compat_sso_login_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "login_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "redirect_uri",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "fulfilled_at",
"ordinal": 4,
"type_info": "Timestamptz"
},
{
"name": "exchanged_at",
"ordinal": 5,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 6,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
false,
true,
true,
true
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE compat_sso_login_id = $1\n "
},
"e35d56de7136d43d0803ec825b0612e4185cef838f105d66f18cb24865e45140": {
"describe": {
"columns": [
{
"name": "compat_refresh_token_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "refresh_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "consumed_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Uuid"
},
{
"name": "compat_access_token_id",
"ordinal": 5,
"type_info": "Uuid"
}
],
"nullable": [
false,
false,
false,
true,
false,
false
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE compat_refresh_token_id = $1\n "
},
"e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": {
"describe": {
"columns": [
@ -2306,6 +2312,50 @@
},
"query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n "
},
"f3ee06958d827b152c57328caa0a6030c372cb99cdb60e4b75a28afeb5096f45": {
"describe": {
"columns": [
{
"name": "compat_session_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "device_id",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "user_id",
"ordinal": 2,
"type_info": "Uuid"
},
{
"name": "created_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "finished_at",
"ordinal": 4,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
false,
true
],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n "
},
"f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": {
"describe": {
"columns": [],

View File

@ -1,757 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User,
};
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::{process_page, QueryBuilderExt},
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
#[tracing::instrument(skip_all, err)]
pub async fn lookup_compat_session(
executor: impl PgExecutor<'_>,
session_id: Ulid,
) -> Result<Option<CompatSession>, DatabaseError> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(session_id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = res.compat_session_id.into();
let device = Device::try_from(res.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match res.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: res.user_id.into(),
device,
created_at: res.created_at,
};
Ok(Some(session))
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[tracing::instrument(skip_all, err)]
pub async fn find_compat_access_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatAccessToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
token,
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(
compat_access_token.id = %id,
),
err,
)]
pub async fn lookup_compat_access_token(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<CompatAccessToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
pub struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
#[tracing::instrument(skip_all, err)]
#[allow(clippy::type_complexity)]
pub async fn find_compat_refresh_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatRefreshToken>, DatabaseError> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
token,
)
.fetch_one(executor)
.await
.to_option()?;
let Some(res) = res else { return Ok(None); };
let state = match res.consumed_at {
None => CompatRefreshTokenState::Valid,
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
};
let refresh_token = CompatRefreshToken {
id: res.compat_refresh_token_id.into(),
state,
session_id: res.compat_session_id.into(),
access_token_id: res.compat_access_token_id.into(),
token: res.refresh_token,
created_at: res.created_at,
};
Ok(Some(refresh_token))
}
#[tracing::instrument(
skip_all,
fields(
compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(),
compat_access_token.id,
user.id = %session.user_id,
),
err,
)]
pub async fn add_compat_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
token,
created_at,
expires_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat access token"))
.await?;
Ok(CompatAccessToken {
id,
session_id: session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
skip_all,
fields(
compat_access_token.id = %access_token.id,
),
err,
)]
pub async fn expire_compat_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: CompatAccessToken,
) -> Result<(), DatabaseError> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(access_token.id),
expires_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
skip_all,
fields(
compat_session.id = %session.id,
compat_session.device.id = session.device.as_str(),
compat_access_token.id = %access_token.id,
compat_refresh_token.id,
user.id = %session.user_id,
),
err,
)]
pub async fn add_compat_refresh_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession,
access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.id),
Uuid::from(access_token.id),
token,
created_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat refresh token"))
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: session.id,
access_token_id: access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
skip_all,
fields(%compat_session.id),
err,
)]
pub async fn end_compat_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, DatabaseError> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
#[tracing::instrument(
skip_all,
fields(
compat_refresh_token.id = %refresh_token.id,
),
err,
)]
pub async fn consume_compat_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: CompatRefreshToken,
) -> Result<(), DatabaseError> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(refresh_token.id),
consumed_at,
)
.execute(executor)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
skip_all,
fields(
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
pub async fn insert_compat_sso_login(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat SSO login"))
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::Pending,
})
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
compat_sso_login_token: String,
compat_sso_login_redirect_uri: String,
compat_sso_login_created_at: DateTime<Utc>,
compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (
res.compat_sso_login_fulfilled_at,
res.compat_sso_login_exchanged_at,
res.compat_session_id,
) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.compat_sso_login_token,
redirect_uri,
created_at: res.compat_sso_login_created_at,
state,
})
}
}
#[tracing::instrument(
skip_all,
fields(
compat_sso_login.id = %id,
),
err,
)]
pub async fn get_compat_sso_login_by_id(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT cl.compat_sso_login_id
, cl.login_token AS "compat_sso_login_token"
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
WHERE cl.compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.fetch_one(executor)
.instrument(tracing::info_span!("Lookup compat SSO login"))
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
%user.username,
),
err,
)]
pub async fn get_paginated_user_compat_sso_logins(
executor: impl PgExecutor<'_>,
user: &User,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token AS "compat_sso_login_token"
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
"#,
);
query
.push(" WHERE cs.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
let span = info_span!(
"Fetch paginated user compat SSO logins",
db.statement = query.sql()
);
let page: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.fetch_all(executor)
.instrument(span)
.await?;
let (has_previous_page, has_next_page, page) = process_page(page, first, last)?;
let page: Result<Vec<_>, _> = page.into_iter().map(TryInto::try_into).collect();
Ok((has_previous_page, has_next_page, page?))
}
#[tracing::instrument(skip_all, err)]
pub async fn get_compat_sso_login_by_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT cl.compat_sso_login_id
, cl.login_token AS "compat_sso_login_token"
, cl.redirect_uri AS "compat_sso_login_redirect_uri"
, cl.created_at AS "compat_sso_login_created_at"
, cl.fulfilled_at AS "compat_sso_login_fulfilled_at"
, cl.exchanged_at AS "compat_sso_login_exchanged_at"
, cl.compat_session_id AS "compat_session_id"
FROM compat_sso_logins cl
WHERE cl.login_token = $1
"#,
token,
)
.fetch_one(executor)
.instrument(tracing::info_span!("Lookup compat SSO login"))
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
compat_session.id,
compat_session.device.id = device.as_str(),
),
err,
)]
pub async fn start_compat_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, DatabaseError> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.execute(executor)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
skip_all,
fields(
%user.id,
%compat_sso_login.id,
%compat_sso_login.redirect_uri,
compat_session.id,
compat_session.device.id = device.as_str(),
),
err,
)]
pub async fn fullfill_compat_sso_login(
conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User,
compat_sso_login: CompatSsoLogin,
device: Device,
) -> Result<CompatSsoLogin, DatabaseError> {
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
return Err(DatabaseError::invalid_operation());
};
let mut txn = conn.begin().await?;
let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?;
let session_id = session.id;
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, &session)
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(session_id),
fulfilled_at,
)
.execute(&mut txn)
.instrument(tracing::info_span!("Update compat SSO login"))
.await?;
txn.commit().await?;
Ok(compat_sso_login)
}
#[tracing::instrument(
skip_all,
fields(
%compat_sso_login.id,
%compat_sso_login.redirect_uri,
),
err,
)]
pub async fn mark_compat_sso_login_as_exchanged(
executor: impl PgExecutor<'_>,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, DatabaseError> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.execute(executor)
.instrument(tracing::info_span!("Update compat SSO login"))
.await?;
Ok(compat_sso_login)
}

View File

@ -0,0 +1,246 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{CompatAccessToken, CompatSession};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[async_trait]
pub trait CompatAccessTokenRepository: Send + Sync {
type Error;
/// Lookup a compat access token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error>;
/// Find a compat access token by its token
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error>;
/// Add a new compat access token to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error>;
/// Set the expiration time of the compat access token to now
async fn expire(
&mut self,
clock: &Clock,
compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error>;
}
pub struct PgCompatAccessTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatAccessTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
access_token: String,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
compat_session_id: Uuid,
}
impl From<CompatAccessTokenLookup> for CompatAccessToken {
fn from(value: CompatAccessTokenLookup) -> Self {
Self {
id: value.compat_access_token_id.into(),
session_id: value.compat_session_id.into(),
token: value.access_token,
created_at: value.created_at,
expires_at: value.expires_at,
}
}
}
#[async_trait]
impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_access_token.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE compat_access_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
access_token: &str,
) -> Result<Option<CompatAccessToken>, Self::Error> {
let res = sqlx::query_as!(
CompatAccessTokenLookup,
r#"
SELECT compat_access_token_id
, access_token
, created_at
, expires_at
, compat_session_id
FROM compat_access_tokens
WHERE access_token = $1
"#,
access_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_access_token.add",
skip_all,
fields(
db.statement,
compat_access_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
token,
created_at,
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatAccessToken {
id,
session_id: compat_session.id,
token,
created_at,
expires_at,
})
}
#[tracing::instrument(
name = "db.compat_access_token.expire",
skip_all,
fields(
db.statement,
%compat_access_token.id,
compat_session.id = %compat_access_token.session_id,
),
err,
)]
async fn expire(
&mut self,
clock: &Clock,
mut compat_access_token: CompatAccessToken,
) -> Result<CompatAccessToken, Self::Error> {
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
Uuid::from(compat_access_token.id),
expires_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
compat_access_token.expires_at = Some(expires_at);
Ok(compat_access_token)
}
}

View File

@ -0,0 +1,25 @@
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod access_token;
mod refresh_token;
mod session;
mod sso_login;
pub use self::{
access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository},
refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository},
session::{CompatSessionRepository, PgCompatSessionRepository},
sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository},
};

View File

@ -0,0 +1,260 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt};
#[async_trait]
pub trait CompatRefreshTokenRepository: Send + Sync {
type Error;
/// Lookup a compat refresh token by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error>;
/// Find a compat refresh token by its token
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error>;
/// Add a new compat refresh token to the database
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error>;
/// Consume a compat refresh token
async fn consume(
&mut self,
clock: &Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error>;
}
pub struct PgCompatRefreshTokenRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatRefreshTokenRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatRefreshTokenLookup {
compat_refresh_token_id: Uuid,
refresh_token: String,
created_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
compat_access_token_id: Uuid,
compat_session_id: Uuid,
}
impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
fn from(value: CompatRefreshTokenLookup) -> Self {
let state = match value.consumed_at {
Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
None => CompatRefreshTokenState::Valid,
};
Self {
id: value.compat_refresh_token_id.into(),
state,
session_id: value.compat_session_id.into(),
token: value.refresh_token,
created_at: value.created_at,
access_token_id: value.compat_access_token_id.into(),
}
}
}
#[async_trait]
impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_refresh_token.lookup",
skip_all,
fields(
db.statement,
compat_refresh_token.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
refresh_token: &str,
) -> Result<Option<CompatRefreshToken>, Self::Error> {
let res = sqlx::query_as!(
CompatRefreshTokenLookup,
r#"
SELECT compat_refresh_token_id
, refresh_token
, created_at
, consumed_at
, compat_session_id
, compat_access_token_id
FROM compat_refresh_tokens
WHERE refresh_token = $1
"#,
refresh_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
name = "db.compat_refresh_token.add",
skip_all,
fields(
db.statement,
compat_refresh_token.id,
%compat_session.id,
user.id = %compat_session.user_id,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
compat_session: &CompatSession,
compat_access_token: &CompatAccessToken,
token: String,
) -> Result<CompatRefreshToken, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(compat_session.id),
Uuid::from(compat_access_token.id),
token,
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatRefreshToken {
id,
state: CompatRefreshTokenState::default(),
session_id: compat_session.id,
access_token_id: compat_access_token.id,
token,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_refresh_token.consume",
skip_all,
fields(
db.statement,
%compat_refresh_token.id,
compat_session.id = %compat_refresh_token.session_id,
),
err,
)]
async fn consume(
&mut self,
clock: &Clock,
compat_refresh_token: CompatRefreshToken,
) -> Result<CompatRefreshToken, Self::Error> {
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
Uuid::from(compat_refresh_token.id),
consumed_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_refresh_token = compat_refresh_token
.consume(consumed_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_refresh_token)
}
}

View File

@ -0,0 +1,220 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSessionState, Device, User};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[async_trait]
pub trait CompatSessionRepository: Send + Sync {
type Error;
/// Lookup a compat session by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error>;
/// Start a new compat session
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error>;
/// End a compat session
async fn finish(
&mut self,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error>;
}
pub struct PgCompatSessionRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSessionRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
struct CompatSessionLookup {
compat_session_id: Uuid,
device_id: String,
user_id: Uuid,
created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>,
}
impl TryFrom<CompatSessionLookup> for CompatSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
let id = value.compat_session_id.into();
let device = Device::try_from(value.device_id).map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
id,
state,
user_id: value.user_id.into(),
device,
created_at: value.created_at,
};
Ok(session)
}
}
#[async_trait]
impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_session.lookup",
skip_all,
fields(
db.statement,
compat_session.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
let res = sqlx::query_as!(
CompatSessionLookup,
r#"
SELECT compat_session_id
, device_id
, user_id
, created_at
, finished_at
FROM compat_sessions
WHERE compat_session_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_session.add",
skip_all,
fields(
db.statement,
compat_session.id,
%user.id,
%user.username,
compat_session.device.id = device.as_str(),
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
user: &User,
device: Device,
) -> Result<CompatSession, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
device.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSession {
id,
state: CompatSessionState::default(),
user_id: user.id,
device,
created_at,
})
}
#[tracing::instrument(
name = "db.compat_session.finish",
skip_all,
fields(
db.statement,
%compat_session.id,
user.id = %compat_session.user_id,
compat_session.device.id = compat_session.device.as_str(),
),
err,
)]
async fn finish(
&mut self,
clock: &Clock,
compat_session: CompatSession,
) -> Result<CompatSession, Self::Error> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_sessions cs
SET finished_at = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
finished_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
let compat_session = compat_session
.finish(finished_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(compat_session)
}
}

View File

@ -0,0 +1,397 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
pagination::{process_page, Page, QueryBuilderExt},
tracing::ExecuteExt,
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
};
#[async_trait]
pub trait CompatSsoLoginRepository: Send + Sync {
type Error;
/// Lookup a compat SSO login by its ID
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error>;
/// Find a compat SSO login by its login token
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error>;
/// Start a new compat SSO login token
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error>;
/// Fulfill a compat SSO login by providing a compat session
async fn fulfill(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error>;
/// Mark a compat SSO login as exchanged
async fn exchange(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error>;
/// Get a paginated list of compat SSO logins for a user
async fn list_paginated(
&mut self,
user: &User,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<CompatSsoLogin>, Self::Error>;
}
pub struct PgCompatSsoLoginRepository<'c> {
conn: &'c mut PgConnection,
}
impl<'c> PgCompatSsoLoginRepository<'c> {
pub fn new(conn: &'c mut PgConnection) -> Self {
Self { conn }
}
}
#[derive(sqlx::FromRow)]
struct CompatSsoLoginLookup {
compat_sso_login_id: Uuid,
login_token: String,
redirect_uri: String,
created_at: DateTime<Utc>,
fulfilled_at: Option<DateTime<Utc>>,
exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<Uuid>,
}
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
type Error = DatabaseInconsistencyError;
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
let id = res.compat_sso_login_id.into();
let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
DatabaseInconsistencyError::on("compat_sso_logins")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session_id: session_id.into(),
},
(Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
CompatSsoLoginState::Exchanged {
fulfilled_at,
exchanged_at,
session_id: session_id.into(),
}
}
_ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
};
Ok(CompatSsoLogin {
id,
login_token: res.login_token,
redirect_uri,
created_at: res.created_at,
state,
})
}
}
#[async_trait]
impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
type Error = DatabaseError;
#[tracing::instrument(
name = "db.compat_sso_login.lookup",
skip_all,
fields(
db.statement,
compat_sso_login.id = %id,
),
err,
)]
async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE compat_sso_login_id = $1
"#,
Uuid::from(id),
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.find_by_token",
skip_all,
fields(
db.statement,
),
err,
)]
async fn find_by_token(
&mut self,
login_token: &str,
) -> Result<Option<CompatSsoLogin>, Self::Error> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT compat_sso_login_id
, login_token
, redirect_uri
, created_at
, fulfilled_at
, exchanged_at
, compat_session_id
FROM compat_sso_logins
WHERE login_token = $1
"#,
login_token,
)
.traced()
.fetch_one(&mut *self.conn)
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
name = "db.compat_sso_login.add",
skip_all,
fields(
db.statement,
compat_sso_login.id,
compat_sso_login.redirect_uri = %redirect_uri,
),
err,
)]
async fn add(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(CompatSsoLogin {
id,
login_token,
redirect_uri,
created_at,
state: CompatSsoLoginState::default(),
})
}
#[tracing::instrument(
name = "db.compat_sso_login.fulfill",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
%compat_session.id,
compat_session.device.id = compat_session.device.as_str(),
user.id = %compat_session.user_id,
),
err,
)]
async fn fulfill(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
compat_session: &CompatSession,
) -> Result<CompatSsoLogin, Self::Error> {
let fulfilled_at = clock.now();
let compat_sso_login = compat_sso_login
.fulfill(fulfilled_at, compat_session)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
compat_session_id = $2,
fulfilled_at = $3
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
Uuid::from(compat_session.id),
fulfilled_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.exchange",
skip_all,
fields(
db.statement,
%compat_sso_login.id,
),
err,
)]
async fn exchange(
&mut self,
clock: &Clock,
compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, Self::Error> {
let exchanged_at = clock.now();
let compat_sso_login = compat_sso_login
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
let res = sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = $2
WHERE
compat_sso_login_id = $1
"#,
Uuid::from(compat_sso_login.id),
exchanged_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_sso_login)
}
#[tracing::instrument(
name = "db.compat_sso_login.list_paginated",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
)]
async fn list_paginated(
&mut self,
user: &User,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Page<CompatSsoLogin>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT cl.compat_sso_login_id
, cl.login_token
, cl.redirect_uri
, cl.created_at
, cl.fulfilled_at
, cl.exchanged_at
, cl.compat_session_id
FROM compat_sso_logins cl
INNER JOIN compat_sessions ON compat_session_id
"#,
);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
let page: Vec<CompatSsoLoginLookup> = query
.build_query_as()
.traced()
.fetch_all(&mut *self.conn)
.await?;
let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?;
let edges: Result<Vec<_>, _> = edges.into_iter().map(TryInto::try_into).collect();
Ok(Page {
has_next_page,
has_previous_page,
edges: edges?,
})
}
}

View File

@ -15,6 +15,10 @@
use sqlx::{PgConnection, Postgres, Transaction};
use crate::{
compat::{
PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
PgCompatSsoLoginRepository,
},
oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository},
upstream_oauth2::{
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
@ -63,6 +67,22 @@ pub trait Repository {
where
Self: 'c;
type CompatSessionRepository<'c>
where
Self: 'c;
type CompatSsoLoginRepository<'c>
where
Self: 'c;
type CompatAccessTokenRepository<'c>
where
Self: 'c;
type CompatRefreshTokenRepository<'c>
where
Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
@ -72,6 +92,10 @@ pub trait Repository {
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>;
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>;
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>;
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>;
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>;
}
impl Repository for PgConnection {
@ -84,6 +108,10 @@ impl Repository for PgConnection {
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
@ -120,6 +148,22 @@ impl Repository for PgConnection {
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
}
}
impl<'t> Repository for Transaction<'t, Postgres> {
@ -132,6 +176,10 @@ impl<'t> Repository for Transaction<'t, Postgres> {
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c;
type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c;
type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c;
type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c;
type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c;
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
PgUpstreamOAuthLinkRepository::new(self)
@ -168,4 +216,20 @@ impl<'t> Repository for Transaction<'t, Postgres> {
fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> {
PgOAuth2SessionRepository::new(self)
}
fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> {
PgCompatSessionRepository::new(self)
}
fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> {
PgCompatSsoLoginRepository::new(self)
}
fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> {
PgCompatAccessTokenRepository::new(self)
}
fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> {
PgCompatRefreshTokenRepository::new(self)
}
}