diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs index 54fd96b3..ccc7bb37 100644 --- a/crates/data-model/src/compat/sso_login.rs +++ b/crates/data-model/src/compat/sso_login.rs @@ -20,8 +20,9 @@ use url::Url; use super::CompatSession; use crate::InvalidTransitionError; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] pub enum CompatSsoLoginState { + #[default] Pending, Fulfilled { fulfilled_at: DateTime, diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 394639cf..3c94c672 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,7 +15,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::lookup_compat_session, user::UserRepository, Repository}; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository, Repository}; use sqlx::PgPool; use url::Url; @@ -101,7 +101,9 @@ impl CompatSsoLogin { let Some(session_id) = self.0.session_id() else { return Ok(None) }; let mut conn = ctx.data::()?.acquire().await?; - let session = lookup_compat_session(&mut conn, session_id) + let session = conn + .compat_session() + .lookup(session_id) .await? .context("Could not load compat session")?; diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 2f241ced..b19a1ae1 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -18,6 +18,7 @@ use async_graphql::{ }; use chrono::{DateTime, Utc}; use mas_storage::{ + compat::CompatSsoLoginRepository, oauth2::OAuth2SessionRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, @@ -96,14 +97,13 @@ impl User { .map(|x: OpaqueCursor| x.extract_for_type(NodeType::CompatSsoLogin)) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::compat::get_paginated_user_compat_sso_logins( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .compat_sso_login() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|u| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)), CompatSsoLogin(u), diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f36d520b..e7376f72 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -18,8 +18,8 @@ use hyper::StatusCode; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType, User}; use mas_storage::{ compat::{ - add_compat_access_token, add_compat_refresh_token, get_compat_sso_login_by_token, - lookup_compat_session, mark_compat_sso_login_as_exchanged, start_compat_session, + CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, + CompatSsoLoginRepository, }, user::{UserPasswordRepository, UserRepository}, Clock, Repository, @@ -224,27 +224,17 @@ pub(crate) async fn post( }; let access_token = TokenType::CompatAccessToken.generate(&mut rng); - let access_token = add_compat_access_token( - &mut txn, - &mut rng, - &clock, - &session, - access_token, - expires_in, - ) - .await?; + let access_token = txn + .compat_access_token() + .add(&mut rng, &clock, &session, access_token, expires_in) + .await?; let refresh_token = if input.refresh_token { let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng); - let refresh_token = add_compat_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &access_token, - refresh_token, - ) - .await?; + let refresh_token = txn + .compat_refresh_token() + .add(&mut rng, &clock, &session, &access_token, refresh_token) + .await?; Some(refresh_token.token) } else { None @@ -266,7 +256,9 @@ async fn token_login( clock: &Clock, token: &str, ) -> Result<(CompatSession, User), RouteError> { - let login = get_compat_sso_login_by_token(&mut *txn, token) + let login = txn + .compat_sso_login() + .find_by_token(token) .await? .ok_or(RouteError::InvalidLoginToken)?; @@ -308,7 +300,9 @@ async fn token_login( } }; - let session = lookup_compat_session(&mut *txn, session_id) + let session = txn + .compat_session() + .lookup(session_id) .await? .ok_or(RouteError::SessionNotFound)?; @@ -318,7 +312,7 @@ async fn token_login( .await? .ok_or(RouteError::UserNotFound)?; - mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?; + txn.compat_sso_login().exchange(clock, login).await?; Ok((session, user)) } @@ -374,7 +368,10 @@ async fn user_password_login( // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - let session = start_compat_session(&mut *txn, &mut rng, &clock, &user, device).await?; + let session = txn + .compat_session() + .add(&mut rng, &clock, &user, device) + .await?; Ok((session, user)) } diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index f31856d6..33352424 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -29,7 +29,10 @@ use mas_axum_utils::{ use mas_data_model::Device; use mas_keystore::Encrypter; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; -use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; +use mas_storage::{ + compat::{CompatSessionRepository, CompatSsoLoginRepository}, + Repository, +}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; @@ -87,7 +90,9 @@ pub async fn get( return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut conn, id) + let login = conn + .compat_sso_login() + .lookup(id) .await? .context("Could not find compat SSO login")?; @@ -149,7 +154,9 @@ pub async fn post( return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut txn, id) + let login = txn + .compat_sso_login() + .lookup(id) .await? .context("Could not find compat SSO login")?; @@ -181,8 +188,14 @@ pub async fn post( }; let device = Device::generate(&mut rng); - let _login = - fullfill_compat_sso_login(&mut txn, &mut rng, &clock, &session.user, login, device).await?; + let compat_session = txn + .compat_session() + .add(&mut rng, &clock, &session.user, device) + .await?; + + txn.compat_sso_login() + .fulfill(&clock, login, &compat_session) + .await?; txn.commit().await?; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index f90862c7..9c23b733 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -19,7 +19,7 @@ use axum::{ }; use hyper::StatusCode; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; -use mas_storage::compat::insert_compat_sso_login; +use mas_storage::{compat::CompatSsoLoginRepository, Repository}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; use serde_with::serde; @@ -49,6 +49,7 @@ pub enum RouteError { } impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -80,7 +81,10 @@ pub async fn get( let token = Alphanumeric.sample_string(&mut rng, 32); let mut conn = pool.acquire().await?; - let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?; + let login = conn + .compat_sso_login() + .add(&mut rng, &clock, token, redirect_url) + .await?; Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, params.action))) } diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index e16c8c98..25125c72 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -17,8 +17,8 @@ use headers::{authorization::Bearer, Authorization}; use hyper::StatusCode; use mas_data_model::TokenType; use mas_storage::{ - compat::{end_compat_session, find_compat_access_token, lookup_compat_session}, - Clock, + compat::{CompatAccessTokenRepository, CompatSessionRepository}, + Clock, Repository, }; use sqlx::PgPool; use thiserror::Error; @@ -83,17 +83,21 @@ pub(crate) async fn post( return Err(RouteError::InvalidAuthorization); } - let token = find_compat_access_token(&mut txn, token) + let token = txn + .compat_access_token() + .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::InvalidAuthorization)?; - let session = lookup_compat_session(&mut txn, token.session_id) + let session = txn + .compat_session() + .lookup(token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::InvalidAuthorization)?; - end_compat_session(&mut txn, &clock, session).await?; + txn.compat_session().finish(&clock, session).await?; txn.commit().await?; diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 58e9eb8e..7bfc940a 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -16,10 +16,9 @@ use axum::{extract::State, response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; -use mas_storage::compat::{ - add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, - expire_compat_access_token, find_compat_refresh_token, lookup_compat_access_token, - lookup_compat_session, +use mas_storage::{ + compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, + Repository, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; @@ -101,7 +100,9 @@ pub(crate) async fn post( return Err(RouteError::InvalidToken); } - let refresh_token = find_compat_refresh_token(&mut txn, &input.refresh_token) + let refresh_token = txn + .compat_refresh_token() + .find_by_token(&input.refresh_token) .await? .ok_or(RouteError::InvalidToken)?; @@ -109,7 +110,9 @@ pub(crate) async fn post( return Err(RouteError::RefreshTokenConsumed); } - let session = lookup_compat_session(&mut txn, refresh_token.session_id) + let session = txn + .compat_session() + .lookup(refresh_token.session_id) .await? .ok_or(RouteError::UnknownSession)?; @@ -117,7 +120,9 @@ pub(crate) async fn post( return Err(RouteError::InvalidSession); } - let access_token = lookup_compat_access_token(&mut txn, refresh_token.access_token_id) + let access_token = txn + .compat_access_token() + .lookup(refresh_token.access_token_id) .await? .filter(|t| t.is_valid(clock.now())); @@ -125,29 +130,35 @@ pub(crate) async fn post( let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); let expires_in = Duration::minutes(5); - let new_access_token = add_compat_access_token( - &mut txn, - &mut rng, - &clock, - &session, - new_access_token_str, - Some(expires_in), - ) - .await?; - let new_refresh_token = add_compat_refresh_token( - &mut txn, - &mut rng, - &clock, - &session, - &new_access_token, - new_refresh_token_str, - ) - .await?; + let new_access_token = txn + .compat_access_token() + .add( + &mut rng, + &clock, + &session, + new_access_token_str, + Some(expires_in), + ) + .await?; + let new_refresh_token = txn + .compat_refresh_token() + .add( + &mut rng, + &clock, + &session, + &new_access_token, + new_refresh_token_str, + ) + .await?; - consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?; + txn.compat_refresh_token() + .consume(&clock, refresh_token) + .await?; if let Some(access_token) = access_token { - expire_compat_access_token(&mut txn, &clock, access_token).await?; + txn.compat_access_token() + .expire(&clock, access_token) + .await?; } txn.commit().await?; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 3dec02db..ef6ba5b2 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -22,7 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_storage::{ - compat::{find_compat_access_token, find_compat_refresh_token, lookup_compat_session}, + compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, oauth2::{ access_token::find_access_token, refresh_token::lookup_refresh_token, OAuth2SessionRepository, @@ -243,12 +243,16 @@ pub(crate) async fn post( } TokenType::CompatAccessToken => { - let token = find_compat_access_token(&mut conn, token) + let access_token = conn + .compat_access_token() + .find_by_token(token) .await? .filter(|t| t.is_valid(clock.now())) .ok_or(RouteError::UnknownToken)?; - let session = lookup_compat_session(&mut conn, token.session_id) + let session = conn + .compat_session() + .lookup(access_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; @@ -269,9 +273,9 @@ pub(crate) async fn post( client_id: Some("legacy".into()), username: Some(user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), - exp: token.expires_at, - iat: Some(token.created_at), - nbf: Some(token.created_at), + exp: access_token.expires_at, + iat: Some(access_token.created_at), + nbf: Some(access_token.created_at), sub: Some(user.sub), aud: None, iss: None, @@ -280,12 +284,16 @@ pub(crate) async fn post( } TokenType::CompatRefreshToken => { - let refresh_token = find_compat_refresh_token(&mut conn, token) + let refresh_token = conn + .compat_refresh_token() + .find_by_token(token) .await? .filter(|t| t.is_valid()) .ok_or(RouteError::UnknownToken)?; - let session = lookup_compat_session(&mut conn, refresh_token.session_id) + let session = conn + .compat_session() + .lookup(refresh_token.session_id) .await? .filter(|s| s.is_valid()) .ok_or(RouteError::UnknownToken)?; diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 6035c74d..3872588f 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,7 +15,7 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + compat::CompatSsoLoginRepository, oauth2::authorization_grant::get_grant_by_id, upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; @@ -54,7 +54,9 @@ impl OptionalPostAuthAction { } PostAuthAction::ContinueCompatSsoLogin { id } => { - let login = get_compat_sso_login_by_id(conn, id) + let login = conn + .compat_sso_login() + .lookup(id) .await? .context("Failed to load compat SSO login")?; let login = Box::new(login); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 5324fa2a..6e7082dd 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -98,6 +98,21 @@ }, "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " }, + "18c3e56c72ef26bd42653c379767ffdd97bb06cb1686dfbf4099f3ad3d7b22c8": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, "1d372f36c382ab16264cea54537af3544ea6d6d75d10b432b07dbd0dadd2fa4e": { "describe": { "columns": [ @@ -168,22 +183,6 @@ }, "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " }, - "2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, "262bee715889dc3e608639549600a131e641951ff979634e7c97afc74bbc1605": { "describe": { "columns": [], @@ -197,79 +196,6 @@ }, "query": "\n UPDATE oauth2_authorization_grants\n SET exchanged_at = $2\n WHERE oauth2_authorization_grant_id = $1\n " }, - "2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " - }, - "360466ff599c67c9af2ac75399c0b536a22c1178972a0172b707bcc81d47357b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " - }, - "3cf8e061206620071b39d0262cd165bb367b12b8e904180730d8acfa5af3d4b9": { - "describe": { - "columns": [ - { - "name": "compat_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "device_id", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "finished_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n " - }, "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { "describe": { "columns": [], @@ -384,6 +310,56 @@ }, "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " }, + "432e199b0d47fe299d840c91159726c0a4f89f65b4dc3e33ddad58aabf6b148b": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " + }, "43a5cafbdc8037e9fb779812a0793cf0859902aa0dc8d25d4c33d231d3d1118b": { "describe": { "columns": [], @@ -465,20 +441,7 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, - "4c4dbb846bb98d84f6b7f886f8af9833c7efe27b8b4f297077887232bef322ee": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " - }, - "4f080990eb6dd9f6128f3a1aee195b99d5f286fa0f6c27d744f73848343879d4": { + "478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4": { "describe": { "columns": [ { @@ -487,27 +450,27 @@ "type_info": "Uuid" }, { - "name": "compat_sso_login_token", + "name": "login_token", "ordinal": 1, "type_info": "Text" }, { - "name": "compat_sso_login_redirect_uri", + "name": "redirect_uri", "ordinal": 2, "type_info": "Text" }, { - "name": "compat_sso_login_created_at", + "name": "created_at", "ordinal": 3, "type_info": "Timestamptz" }, { - "name": "compat_sso_login_fulfilled_at", + "name": "fulfilled_at", "ordinal": 4, "type_info": "Timestamptz" }, { - "name": "compat_sso_login_exchanged_at", + "name": "exchanged_at", "ordinal": 5, "type_info": "Timestamptz" }, @@ -528,11 +491,25 @@ ], "parameters": { "Left": [ - "Uuid" + "Text" ] } }, - "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n\n FROM compat_sso_logins cl\n WHERE cl.compat_sso_login_id = $1\n " + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n " + }, + "4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " }, "51158bfcaa1a8d8e051bffe7c5ba0369bf53fb162f7622626054e89e68fc07bd": { "describe": { @@ -555,6 +532,50 @@ }, "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, + "53ad718642644b47a2d49f768d81bd993088526923769a9147281686c2d47591": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n " + }, "583ae9a0db9cd55fa57a179339550f3dab1bfc76f35ad488e1560ea37f7ed029": { "describe": { "columns": [], @@ -598,20 +619,6 @@ }, "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " }, - "60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n " - }, "62d05e8e4317bdb180298737d422e64d161c5ad3813913a6f7d67a53569ea76a": { "describe": { "columns": [], @@ -745,6 +752,21 @@ }, "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, + "6e21e7d816f806da9bb5176931bdb550dee05c44c9d93f53df95fe3b4a840347": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " + }, "6f97b5f9ad0d4d15387150bea3839fb7f81015f7ceef61ecaadba64521895cff": { "describe": { "columns": [], @@ -782,6 +804,50 @@ }, "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " }, + "77dfa9fae1a9c77b70476d7da19d3313a02886994cfff0690451229fb5ae2f77": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n " + }, "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { "describe": { "columns": [ @@ -871,19 +937,6 @@ }, "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " }, - "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " - }, "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { "describe": { "columns": [ @@ -1154,6 +1207,19 @@ }, "query": "\n UPDATE user_email_confirmation_codes\n SET consumed_at = $2\n WHERE user_email_confirmation_code_id = $1\n " }, + "9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n " + }, "94fd96446b237c87bd6bf741f3c42b37ee751b87b7fcc459602bdf8c46962443": { "describe": { "columns": [ @@ -1174,18 +1240,21 @@ }, "query": "\n SELECT EXISTS(\n SELECT 1 FROM users WHERE username = $1\n ) AS \"exists!\"\n " }, - "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { + "9f7bdc034c618e47e49c467d0d7f5b8c297d055abe248cc876dbc12c5a7dc920": { "describe": { "columns": [], "nullable": [], "parameters": { "Left": [ "Uuid", + "Uuid", + "Uuid", + "Text", "Timestamptz" ] } }, - "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " + "query": "\n INSERT INTO compat_refresh_tokens\n (compat_refresh_token_id, compat_session_id,\n compat_access_token_id, refresh_token, created_at)\n VALUES ($1, $2, $3, $4, $5)\n " }, "a300fe99c95679c5664646a6a525c0491829e97db45f3234483872ed38436322": { "describe": { @@ -1243,6 +1312,22 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n requires_consent = 'f'\n WHERE\n og.oauth2_authorization_grant_id = $1\n " }, + "a7f780528882a2ae66c45435215763eed0582264861436eab3f862e3eb12cab1": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO compat_access_tokens\n (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)\n VALUES ($1, $2, $3, $4, $5)\n " + }, "aa2fd69c595f94d8598715766a79671dba8f87b9d7af6ac30e3fa1fbc8cce28a": { "describe": { "columns": [ @@ -1371,6 +1456,19 @@ }, "query": "\n SELECT oauth2_authorization_grant_id\n , created_at AS oauth2_authorization_grant_created_at\n , cancelled_at AS oauth2_authorization_grant_cancelled_at\n , fulfilled_at AS oauth2_authorization_grant_fulfilled_at\n , exchanged_at AS oauth2_authorization_grant_exchanged_at\n , scope AS oauth2_authorization_grant_scope\n , state AS oauth2_authorization_grant_state\n , redirect_uri AS oauth2_authorization_grant_redirect_uri\n , response_mode AS oauth2_authorization_grant_response_mode\n , nonce AS oauth2_authorization_grant_nonce\n , max_age AS oauth2_authorization_grant_max_age\n , oauth2_client_id AS oauth2_client_id\n , authorization_code AS oauth2_authorization_grant_code\n , response_type_code AS oauth2_authorization_grant_response_type_code\n , response_type_id_token AS oauth2_authorization_grant_response_type_id_token\n , code_challenge AS oauth2_authorization_grant_code_challenge\n , code_challenge_method AS oauth2_authorization_grant_code_challenge_method\n , requires_consent AS oauth2_authorization_grant_requires_consent\n , oauth2_session_id AS \"oauth2_session_id?\"\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n " }, + "ab34912b42a48a8b5c8d63e271b99b7d0b690a2471873c6654b1b6cf2079b95c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_sessions cs\n SET finished_at = $2\n WHERE compat_session_id = $1\n " + }, "aff08a8caabeb62f4929e6e901e7ca7c55e284c18c5c1d1e78821dd9bc961412": { "describe": { "columns": [ @@ -1652,6 +1750,19 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " }, + "bbf62633c561706a762089bbab2f76a9ba3e2ed3539ef16accb601fb609c2ec9": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " + }, "bd1f6daa5fa1b10250c01f8b3fbe451646a9ceeefa6f72b9c4e29b6d05f17641": { "describe": { "columns": [], @@ -1696,106 +1807,6 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, - "c3e60701299be7728108b8967ec5396fb186adaac360d6a0152d25e4a4f46f46": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE compat_access_token_id = $1\n " - }, - "c78246fc8737491352f71ea9410e79df8de88596c8197405cda36eb8c8187810": { - "describe": { - "columns": [ - { - "name": "compat_sso_login_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_sso_login_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_sso_login_redirect_uri", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "compat_sso_login_created_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_fulfilled_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_sso_login_exchanged_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 6, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT cl.compat_sso_login_id\n , cl.login_token AS \"compat_sso_login_token\"\n , cl.redirect_uri AS \"compat_sso_login_redirect_uri\"\n , cl.created_at AS \"compat_sso_login_created_at\"\n , cl.fulfilled_at AS \"compat_sso_login_fulfilled_at\"\n , cl.exchanged_at AS \"compat_sso_login_exchanged_at\"\n , cl.compat_session_id AS \"compat_session_id\"\n FROM compat_sso_logins cl\n WHERE cl.login_token = $1\n " - }, "c88376abdba124ff0487a9a69d2345c7d69d7394f355111ec369cfa6d45fb40f": { "describe": { "columns": [], @@ -1822,114 +1833,18 @@ }, "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n " }, - "ca63558e877bd115aa7ca24de0cc3f78a13cb55105758fe0bd930da513f75504": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "compat_access_token_id", - "ordinal": 5, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE refresh_token = $1\n " - }, - "caf54e4659306a746747aa61906bdb2cb8da51176e90435aa8b9754ebf3e4d60": { + "d0b403e9c843ef19fa5ad60bec32ebf14a1ba0d01681c3836366d3f55e7851f4": { "describe": { "columns": [], "nullable": [], "parameters": { "Left": [ "Uuid", - "Uuid", - "Text", "Timestamptz" ] } }, - "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, - "cf43b82bdf534400f900cff3c5356083db0f9e5407e288b64f43d7ac100de058": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - } - ], - "nullable": [ - false, - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT compat_access_token_id\n , access_token\n , created_at\n , expires_at\n , compat_session_id\n\n FROM compat_access_tokens\n\n WHERE access_token = $1\n " + "query": "\n UPDATE compat_refresh_tokens\n SET consumed_at = $2\n WHERE compat_refresh_token_id = $1\n " }, "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c": { "describe": { @@ -1951,21 +1866,6 @@ }, "query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n " }, - "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO compat_sso_logins\n (compat_sso_login_id, login_token, redirect_uri, created_at)\n VALUES ($1, $2, $3, $4)\n " - }, "d1f1aac41bb2f0d194f9b3d846663c267865d0d22dd5fa8a668daf29dca88d36": { "describe": { "columns": [ @@ -2211,6 +2111,112 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, + "ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82": { + "describe": { + "columns": [ + { + "name": "compat_sso_login_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "login_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uri", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "fulfilled_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "exchanged_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 6, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE compat_sso_login_id = $1\n " + }, + "e35d56de7136d43d0803ec825b0612e4185cef838f105d66f18cb24865e45140": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_access_token_id", + "ordinal": 5, + "type_info": "Uuid" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , compat_session_id\n , compat_access_token_id\n\n FROM compat_refresh_tokens\n\n WHERE compat_refresh_token_id = $1\n " + }, "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ @@ -2306,6 +2312,50 @@ }, "query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n " }, + "f3ee06958d827b152c57328caa0a6030c372cb99cdb60e4b75a28afeb5096f45": { + "describe": { + "columns": [ + { + "name": "compat_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "device_id", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "finished_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n FROM compat_sessions\n WHERE compat_session_id = $1\n " + }, "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { "describe": { "columns": [], diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs deleted file mode 100644 index 3befa8dd..00000000 --- a/crates/storage/src/compat.rs +++ /dev/null @@ -1,757 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{ - CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, - CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User, -}; -use rand::Rng; -use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; -use tracing::{info_span, Instrument}; -use ulid::Ulid; -use url::Url; -use uuid::Uuid; - -use crate::{ - pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, -}; - -struct CompatSessionLookup { - compat_session_id: Uuid, - device_id: String, - user_id: Uuid, - created_at: DateTime, - finished_at: Option>, -} - -#[tracing::instrument(skip_all, err)] -pub async fn lookup_compat_session( - executor: impl PgExecutor<'_>, - session_id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSessionLookup, - r#" - SELECT compat_session_id - , device_id - , user_id - , created_at - , finished_at - FROM compat_sessions - WHERE compat_session_id = $1 - "#, - Uuid::from(session_id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - let id = res.compat_session_id.into(); - let device = Device::try_from(res.device_id).map_err(|e| { - DatabaseInconsistencyError::on("compat_sessions") - .column("device_id") - .row(id) - .source(e) - })?; - - let state = match res.finished_at { - None => CompatSessionState::Valid, - Some(finished_at) => CompatSessionState::Finished { finished_at }, - }; - - let session = CompatSession { - id, - state, - user_id: res.user_id.into(), - device, - created_at: res.created_at, - }; - - Ok(Some(session)) -} - -struct CompatAccessTokenLookup { - compat_access_token_id: Uuid, - access_token: String, - created_at: DateTime, - expires_at: Option>, - compat_session_id: Uuid, -} - -impl From for CompatAccessToken { - fn from(value: CompatAccessTokenLookup) -> Self { - Self { - id: value.compat_access_token_id.into(), - session_id: value.compat_session_id.into(), - token: value.access_token, - created_at: value.created_at, - expires_at: value.expires_at, - } - } -} - -#[tracing::instrument(skip_all, err)] -pub async fn find_compat_access_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE access_token = $1 - "#, - token, - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -#[tracing::instrument( - skip_all, - fields( - compat_access_token.id = %id, - ), - err, -)] -pub async fn lookup_compat_access_token( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatAccessTokenLookup, - r#" - SELECT compat_access_token_id - , access_token - , created_at - , expires_at - , compat_session_id - - FROM compat_access_tokens - - WHERE compat_access_token_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.into())) -} - -pub struct CompatRefreshTokenLookup { - compat_refresh_token_id: Uuid, - refresh_token: String, - created_at: DateTime, - consumed_at: Option>, - compat_access_token_id: Uuid, - compat_session_id: Uuid, -} - -#[tracing::instrument(skip_all, err)] -#[allow(clippy::type_complexity)] -pub async fn find_compat_refresh_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatRefreshTokenLookup, - r#" - SELECT compat_refresh_token_id - , refresh_token - , created_at - , consumed_at - , compat_session_id - , compat_access_token_id - - FROM compat_refresh_tokens - - WHERE refresh_token = $1 - "#, - token, - ) - .fetch_one(executor) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None); }; - - let state = match res.consumed_at { - None => CompatRefreshTokenState::Valid, - Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, - }; - - let refresh_token = CompatRefreshToken { - id: res.compat_refresh_token_id.into(), - state, - session_id: res.compat_session_id.into(), - access_token_id: res.compat_access_token_id.into(), - token: res.refresh_token, - created_at: res.created_at, - }; - - Ok(Some(refresh_token)) -} - -#[tracing::instrument( - skip_all, - fields( - compat_session.id = %session.id, - compat_session.device.id = session.device.as_str(), - compat_access_token.id, - user.id = %session.user_id, - ), - err, -)] -pub async fn add_compat_access_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &CompatSession, - token: String, - expires_after: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); - - let expires_at = expires_after.map(|expires_after| created_at + expires_after); - - sqlx::query!( - r#" - INSERT INTO compat_access_tokens - (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - token, - created_at, - expires_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat access token")) - .await?; - - Ok(CompatAccessToken { - id, - session_id: session.id, - token, - created_at, - expires_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - compat_access_token.id = %access_token.id, - ), - err, -)] -pub async fn expire_compat_access_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - access_token: CompatAccessToken, -) -> Result<(), DatabaseError> { - let expires_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_access_tokens - SET expires_at = $2 - WHERE compat_access_token_id = $1 - "#, - Uuid::from(access_token.id), - expires_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields( - compat_session.id = %session.id, - compat_session.device.id = session.device.as_str(), - compat_access_token.id = %access_token.id, - compat_refresh_token.id, - user.id = %session.user_id, - ), - err, -)] -pub async fn add_compat_refresh_token( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - session: &CompatSession, - access_token: &CompatAccessToken, - token: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_refresh_tokens - (compat_refresh_token_id, compat_session_id, - compat_access_token_id, refresh_token, created_at) - VALUES ($1, $2, $3, $4, $5) - "#, - Uuid::from(id), - Uuid::from(session.id), - Uuid::from(access_token.id), - token, - created_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat refresh token")) - .await?; - - Ok(CompatRefreshToken { - id, - state: CompatRefreshTokenState::default(), - session_id: session.id, - access_token_id: access_token.id, - token, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields(%compat_session.id), - err, -)] -pub async fn end_compat_session( - executor: impl PgExecutor<'_>, - clock: &Clock, - compat_session: CompatSession, -) -> Result { - let finished_at = clock.now(); - - let res = sqlx::query!( - r#" - UPDATE compat_sessions cs - SET finished_at = $2 - WHERE compat_session_id = $1 - "#, - Uuid::from(compat_session.id), - finished_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1)?; - - let compat_session = compat_session - .finish(finished_at) - .map_err(DatabaseError::to_invalid_operation)?; - - Ok(compat_session) -} - -#[tracing::instrument( - skip_all, - fields( - compat_refresh_token.id = %refresh_token.id, - ), - err, -)] -pub async fn consume_compat_refresh_token( - executor: impl PgExecutor<'_>, - clock: &Clock, - refresh_token: CompatRefreshToken, -) -> Result<(), DatabaseError> { - let consumed_at = clock.now(); - let res = sqlx::query!( - r#" - UPDATE compat_refresh_tokens - SET consumed_at = $2 - WHERE compat_refresh_token_id = $1 - "#, - Uuid::from(refresh_token.id), - consumed_at, - ) - .execute(executor) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) -} - -#[tracing::instrument( - skip_all, - fields( - compat_sso_login.id, - compat_sso_login.redirect_uri = %redirect_uri, - ), - err, -)] -pub async fn insert_compat_sso_login( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - login_token: String, - redirect_uri: Url, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sso_logins - (compat_sso_login_id, login_token, redirect_uri, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - &login_token, - redirect_uri.as_str(), - created_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Insert compat SSO login")) - .await?; - - Ok(CompatSsoLogin { - id, - login_token, - redirect_uri, - created_at, - state: CompatSsoLoginState::Pending, - }) -} - -#[derive(sqlx::FromRow)] -struct CompatSsoLoginLookup { - compat_sso_login_id: Uuid, - compat_sso_login_token: String, - compat_sso_login_redirect_uri: String, - compat_sso_login_created_at: DateTime, - compat_sso_login_fulfilled_at: Option>, - compat_sso_login_exchanged_at: Option>, - compat_session_id: Option, -} - -impl TryFrom for CompatSsoLogin { - type Error = DatabaseInconsistencyError; - - fn try_from(res: CompatSsoLoginLookup) -> Result { - let id = res.compat_sso_login_id.into(); - let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| { - DatabaseInconsistencyError::on("compat_sso_logins") - .column("redirect_uri") - .row(id) - .source(e) - })?; - - let state = match ( - res.compat_sso_login_fulfilled_at, - res.compat_sso_login_exchanged_at, - res.compat_session_id, - ) { - (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { - fulfilled_at, - session_id: session_id.into(), - }, - (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { - CompatSsoLoginState::Exchanged { - fulfilled_at, - exchanged_at, - session_id: session_id.into(), - } - } - _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), - }; - - Ok(CompatSsoLogin { - id, - login_token: res.compat_sso_login_token, - redirect_uri, - created_at: res.compat_sso_login_created_at, - state, - }) - } -} - -#[tracing::instrument( - skip_all, - fields( - compat_sso_login.id = %id, - ), - err, -)] -pub async fn get_compat_sso_login_by_id( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT cl.compat_sso_login_id - , cl.login_token AS "compat_sso_login_token" - , cl.redirect_uri AS "compat_sso_login_redirect_uri" - , cl.created_at AS "compat_sso_login_created_at" - , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" - , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cl.compat_session_id AS "compat_session_id" - - FROM compat_sso_logins cl - WHERE cl.compat_sso_login_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .instrument(tracing::info_span!("Lookup compat SSO login")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %user.username, - ), - err, -)] -pub async fn get_paginated_user_compat_sso_logins( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT cl.compat_sso_login_id - , cl.login_token AS "compat_sso_login_token" - , cl.redirect_uri AS "compat_sso_login_redirect_uri" - , cl.created_at AS "compat_sso_login_created_at" - , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" - , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cl.compat_session_id AS "compat_session_id" - FROM compat_sso_logins cl - "#, - ); - - query - .push(" WHERE cs.user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated user compat SSO logins", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) - .await?; - - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; - - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_compat_sso_login_by_token( - executor: impl PgExecutor<'_>, - token: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - CompatSsoLoginLookup, - r#" - SELECT cl.compat_sso_login_id - , cl.login_token AS "compat_sso_login_token" - , cl.redirect_uri AS "compat_sso_login_redirect_uri" - , cl.created_at AS "compat_sso_login_created_at" - , cl.fulfilled_at AS "compat_sso_login_fulfilled_at" - , cl.exchanged_at AS "compat_sso_login_exchanged_at" - , cl.compat_session_id AS "compat_session_id" - FROM compat_sso_logins cl - WHERE cl.login_token = $1 - "#, - token, - ) - .fetch_one(executor) - .instrument(tracing::info_span!("Lookup compat SSO login")) - .await - .to_option()?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - compat_session.id, - compat_session.device.id = device.as_str(), - ), - err, -)] -pub async fn start_compat_session( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - device: Device, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("compat_session.id", tracing::field::display(id)); - - sqlx::query!( - r#" - INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) - VALUES ($1, $2, $3, $4) - "#, - Uuid::from(id), - Uuid::from(user.id), - device.as_str(), - created_at, - ) - .execute(executor) - .await?; - - Ok(CompatSession { - id, - state: CompatSessionState::default(), - user_id: user.id, - device, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %user.id, - %compat_sso_login.id, - %compat_sso_login.redirect_uri, - compat_session.id, - compat_session.device.id = device.as_str(), - ), - err, -)] -pub async fn fullfill_compat_sso_login( - conn: impl Acquire<'_, Database = Postgres> + Send, - mut rng: impl Rng + Send, - clock: &Clock, - user: &User, - compat_sso_login: CompatSsoLogin, - device: Device, -) -> Result { - if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) { - return Err(DatabaseError::invalid_operation()); - }; - - let mut txn = conn.begin().await?; - - let session = start_compat_session(&mut txn, &mut rng, clock, user, device).await?; - let session_id = session.id; - - let fulfilled_at = clock.now(); - let compat_sso_login = compat_sso_login - .fulfill(fulfilled_at, &session) - .map_err(DatabaseError::to_invalid_operation)?; - sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - compat_session_id = $2, - fulfilled_at = $3 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - Uuid::from(session_id), - fulfilled_at, - ) - .execute(&mut txn) - .instrument(tracing::info_span!("Update compat SSO login")) - .await?; - - txn.commit().await?; - - Ok(compat_sso_login) -} - -#[tracing::instrument( - skip_all, - fields( - %compat_sso_login.id, - %compat_sso_login.redirect_uri, - ), - err, -)] -pub async fn mark_compat_sso_login_as_exchanged( - executor: impl PgExecutor<'_>, - clock: &Clock, - compat_sso_login: CompatSsoLogin, -) -> Result { - let exchanged_at = clock.now(); - let compat_sso_login = compat_sso_login - .exchange(exchanged_at) - .map_err(DatabaseError::to_invalid_operation)?; - - sqlx::query!( - r#" - UPDATE compat_sso_logins - SET - exchanged_at = $2 - WHERE - compat_sso_login_id = $1 - "#, - Uuid::from(compat_sso_login.id), - exchanged_at, - ) - .execute(executor) - .instrument(tracing::info_span!("Update compat SSO login")) - .await?; - - Ok(compat_sso_login) -} diff --git a/crates/storage/src/compat/access_token.rs b/crates/storage/src/compat/access_token.rs new file mode 100644 index 00000000..86d2dd19 --- /dev/null +++ b/crates/storage/src/compat/access_token.rs @@ -0,0 +1,246 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{CompatAccessToken, CompatSession}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; + +#[async_trait] +pub trait CompatAccessTokenRepository: Send + Sync { + type Error; + + /// Lookup a compat access token by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat access token by its token + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error>; + + /// Add a new compat access token to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result; + + /// Set the expiration time of the compat access token to now + async fn expire( + &mut self, + clock: &Clock, + compat_access_token: CompatAccessToken, + ) -> Result; +} + +pub struct PgCompatAccessTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatAccessTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatAccessTokenLookup { + compat_access_token_id: Uuid, + access_token: String, + created_at: DateTime, + expires_at: Option>, + compat_session_id: Uuid, +} + +impl From for CompatAccessToken { + fn from(value: CompatAccessTokenLookup) -> Self { + Self { + id: value.compat_access_token_id.into(), + session_id: value.compat_session_id.into(), + token: value.access_token, + created_at: value.created_at, + expires_at: value.expires_at, + } + } +} + +#[async_trait] +impl<'c> CompatAccessTokenRepository for PgCompatAccessTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_access_token.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE compat_access_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + access_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatAccessTokenLookup, + r#" + SELECT compat_access_token_id + , access_token + , created_at + , expires_at + , compat_session_id + + FROM compat_access_tokens + + WHERE access_token = $1 + "#, + access_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_access_token.add", + skip_all, + fields( + db.statement, + compat_access_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + token: String, + expires_after: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); + + let expires_at = expires_after.map(|expires_after| created_at + expires_after); + + sqlx::query!( + r#" + INSERT INTO compat_access_tokens + (compat_access_token_id, compat_session_id, access_token, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + token, + created_at, + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatAccessToken { + id, + session_id: compat_session.id, + token, + created_at, + expires_at, + }) + } + + #[tracing::instrument( + name = "db.compat_access_token.expire", + skip_all, + fields( + db.statement, + %compat_access_token.id, + compat_session.id = %compat_access_token.session_id, + ), + err, + )] + async fn expire( + &mut self, + clock: &Clock, + mut compat_access_token: CompatAccessToken, + ) -> Result { + let expires_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET expires_at = $2 + WHERE compat_access_token_id = $1 + "#, + Uuid::from(compat_access_token.id), + expires_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + compat_access_token.expires_at = Some(expires_at); + Ok(compat_access_token) + } +} diff --git a/crates/storage/src/compat/mod.rs b/crates/storage/src/compat/mod.rs new file mode 100644 index 00000000..3a91f8c7 --- /dev/null +++ b/crates/storage/src/compat/mod.rs @@ -0,0 +1,25 @@ +// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod access_token; +mod refresh_token; +mod session; +mod sso_login; + +pub use self::{ + access_token::{CompatAccessTokenRepository, PgCompatAccessTokenRepository}, + refresh_token::{CompatRefreshTokenRepository, PgCompatRefreshTokenRepository}, + session::{CompatSessionRepository, PgCompatSessionRepository}, + sso_login::{CompatSsoLoginRepository, PgCompatSsoLoginRepository}, +}; diff --git a/crates/storage/src/compat/refresh_token.rs b/crates/storage/src/compat/refresh_token.rs new file mode 100644 index 00000000..30054622 --- /dev/null +++ b/crates/storage/src/compat/refresh_token.rs @@ -0,0 +1,260 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{ + CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession, +}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{tracing::ExecuteExt, Clock, DatabaseError, LookupResultExt}; + +#[async_trait] +pub trait CompatRefreshTokenRepository: Send + Sync { + type Error; + + /// Lookup a compat refresh token by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat refresh token by its token + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error>; + + /// Add a new compat refresh token to the database + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result; + + /// Consume a compat refresh token + async fn consume( + &mut self, + clock: &Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result; +} + +pub struct PgCompatRefreshTokenRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatRefreshTokenRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatRefreshTokenLookup { + compat_refresh_token_id: Uuid, + refresh_token: String, + created_at: DateTime, + consumed_at: Option>, + compat_access_token_id: Uuid, + compat_session_id: Uuid, +} + +impl From for CompatRefreshToken { + fn from(value: CompatRefreshTokenLookup) -> Self { + let state = match value.consumed_at { + Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at }, + None => CompatRefreshTokenState::Valid, + }; + + Self { + id: value.compat_refresh_token_id.into(), + state, + session_id: value.compat_session_id.into(), + token: value.refresh_token, + created_at: value.created_at, + access_token_id: value.compat_access_token_id.into(), + } + } +} + +#[async_trait] +impl<'c> CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_refresh_token.lookup", + skip_all, + fields( + db.statement, + compat_refresh_token.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + refresh_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT compat_refresh_token_id + , refresh_token + , created_at + , consumed_at + , compat_session_id + , compat_access_token_id + + FROM compat_refresh_tokens + + WHERE refresh_token = $1 + "#, + refresh_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.into())) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.add", + skip_all, + fields( + db.statement, + compat_refresh_token.id, + %compat_session.id, + user.id = %compat_session.user_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + compat_session: &CompatSession, + compat_access_token: &CompatAccessToken, + token: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_refresh_tokens + (compat_refresh_token_id, compat_session_id, + compat_access_token_id, refresh_token, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + Uuid::from(id), + Uuid::from(compat_session.id), + Uuid::from(compat_access_token.id), + token, + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatRefreshToken { + id, + state: CompatRefreshTokenState::default(), + session_id: compat_session.id, + access_token_id: compat_access_token.id, + token, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_refresh_token.consume", + skip_all, + fields( + db.statement, + %compat_refresh_token.id, + compat_session.id = %compat_refresh_token.session_id, + ), + err, + )] + async fn consume( + &mut self, + clock: &Clock, + compat_refresh_token: CompatRefreshToken, + ) -> Result { + let consumed_at = clock.now(); + let res = sqlx::query!( + r#" + UPDATE compat_refresh_tokens + SET consumed_at = $2 + WHERE compat_refresh_token_id = $1 + "#, + Uuid::from(compat_refresh_token.id), + consumed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_refresh_token = compat_refresh_token + .consume(consumed_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_refresh_token) + } +} diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs new file mode 100644 index 00000000..3068be73 --- /dev/null +++ b/crates/storage/src/compat/session.rs @@ -0,0 +1,220 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSessionState, Device, User}; +use rand::RngCore; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait CompatSessionRepository: Send + Sync { + type Error; + + /// Lookup a compat session by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Start a new compat session + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + device: Device, + ) -> Result; + + /// End a compat session + async fn finish( + &mut self, + clock: &Clock, + compat_session: CompatSession, + ) -> Result; +} + +pub struct PgCompatSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSessionRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +struct CompatSessionLookup { + compat_session_id: Uuid, + device_id: String, + user_id: Uuid, + created_at: DateTime, + finished_at: Option>, +} + +impl TryFrom for CompatSession { + type Error = DatabaseInconsistencyError; + + fn try_from(value: CompatSessionLookup) -> Result { + let id = value.compat_session_id.into(); + let device = Device::try_from(value.device_id).map_err(|e| { + DatabaseInconsistencyError::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + + let session = CompatSession { + id, + state, + user_id: value.user_id.into(), + device, + created_at: value.created_at, + }; + + Ok(session) + } +} + +#[async_trait] +impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_session.lookup", + skip_all, + fields( + db.statement, + compat_session.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSessionLookup, + r#" + SELECT compat_session_id + , device_id + , user_id + , created_at + , finished_at + FROM compat_sessions + WHERE compat_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_session.add", + skip_all, + fields( + db.statement, + compat_session.id, + %user.id, + %user.username, + compat_session.device.id = device.as_str(), + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + user: &User, + device: Device, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_session.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + Uuid::from(user.id), + device.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSession { + id, + state: CompatSessionState::default(), + user_id: user.id, + device, + created_at, + }) + } + + #[tracing::instrument( + name = "db.compat_session.finish", + skip_all, + fields( + db.statement, + %compat_session.id, + user.id = %compat_session.user_id, + compat_session.device.id = compat_session.device.as_str(), + ), + err, + )] + async fn finish( + &mut self, + clock: &Clock, + compat_session: CompatSession, + ) -> Result { + let finished_at = clock.now(); + + let res = sqlx::query!( + r#" + UPDATE compat_sessions cs + SET finished_at = $2 + WHERE compat_session_id = $1 + "#, + Uuid::from(compat_session.id), + finished_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + let compat_session = compat_session + .finish(finished_at) + .map_err(DatabaseError::to_invalid_operation)?; + + Ok(compat_session) + } +} diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs new file mode 100644 index 00000000..cba777d3 --- /dev/null +++ b/crates/storage/src/compat/sso_login.rs @@ -0,0 +1,397 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState, User}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; +use ulid::Ulid; +use url::Url; +use uuid::Uuid; + +use crate::{ + pagination::{process_page, Page, QueryBuilderExt}, + tracing::ExecuteExt, + Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait CompatSsoLoginRepository: Send + Sync { + type Error; + + /// Lookup a compat SSO login by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find a compat SSO login by its login token + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error>; + + /// Start a new compat SSO login token + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + login_token: String, + redirect_uri: Url, + ) -> Result; + + /// Fulfill a compat SSO login by providing a compat session + async fn fulfill( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result; + + /// Mark a compat SSO login as exchanged + async fn exchange( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result; + + /// Get a paginated list of compat SSO logins for a user + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; +} + +pub struct PgCompatSsoLoginRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgCompatSsoLoginRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +#[derive(sqlx::FromRow)] +struct CompatSsoLoginLookup { + compat_sso_login_id: Uuid, + login_token: String, + redirect_uri: String, + created_at: DateTime, + fulfilled_at: Option>, + exchanged_at: Option>, + compat_session_id: Option, +} + +impl TryFrom for CompatSsoLogin { + type Error = DatabaseInconsistencyError; + + fn try_from(res: CompatSsoLoginLookup) -> Result { + let id = res.compat_sso_login_id.into(); + let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| { + DatabaseInconsistencyError::on("compat_sso_logins") + .column("redirect_uri") + .row(id) + .source(e) + })?; + + let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { + (None, None, None) => CompatSsoLoginState::Pending, + (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { + fulfilled_at, + session_id: session_id.into(), + }, + (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { + CompatSsoLoginState::Exchanged { + fulfilled_at, + exchanged_at, + session_id: session_id.into(), + } + } + _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), + }; + + Ok(CompatSsoLogin { + id, + login_token: res.login_token, + redirect_uri, + created_at: res.created_at, + state, + }) + } +} + +#[async_trait] +impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { + type Error = DatabaseError; + + #[tracing::instrument( + name = "db.compat_sso_login.lookup", + skip_all, + fields( + db.statement, + compat_sso_login.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE compat_sso_login_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.find_by_token", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn find_by_token( + &mut self, + login_token: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + CompatSsoLoginLookup, + r#" + SELECT compat_sso_login_id + , login_token + , redirect_uri + , created_at + , fulfilled_at + , exchanged_at + , compat_session_id + + FROM compat_sso_logins + WHERE login_token = $1 + "#, + login_token, + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; + + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) + } + + #[tracing::instrument( + name = "db.compat_sso_login.add", + skip_all, + fields( + db.statement, + compat_sso_login.id, + compat_sso_login.redirect_uri = %redirect_uri, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + login_token: String, + redirect_uri: Url, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO compat_sso_logins + (compat_sso_login_id, login_token, redirect_uri, created_at) + VALUES ($1, $2, $3, $4) + "#, + Uuid::from(id), + &login_token, + redirect_uri.as_str(), + created_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(CompatSsoLogin { + id, + login_token, + redirect_uri, + created_at, + state: CompatSsoLoginState::default(), + }) + } + + #[tracing::instrument( + name = "db.compat_sso_login.fulfill", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + %compat_session.id, + compat_session.device.id = compat_session.device.as_str(), + user.id = %compat_session.user_id, + ), + err, + )] + async fn fulfill( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, + ) -> Result { + let fulfilled_at = clock.now(); + let compat_sso_login = compat_sso_login + .fulfill(fulfilled_at, compat_session) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + compat_session_id = $2, + fulfilled_at = $3 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + Uuid::from(compat_session.id), + fulfilled_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.exchange", + skip_all, + fields( + db.statement, + %compat_sso_login.id, + ), + err, + )] + async fn exchange( + &mut self, + clock: &Clock, + compat_sso_login: CompatSsoLogin, + ) -> Result { + let exchanged_at = clock.now(); + let compat_sso_login = compat_sso_login + .exchange(exchanged_at) + .map_err(DatabaseError::to_invalid_operation)?; + + let res = sqlx::query!( + r#" + UPDATE compat_sso_logins + SET + exchanged_at = $2 + WHERE + compat_sso_login_id = $1 + "#, + Uuid::from(compat_sso_login.id), + exchanged_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(compat_sso_login) + } + + #[tracing::instrument( + name = "db.compat_sso_login.list_paginated", + skip_all, + fields( + db.statement, + %user.id, + %user.username, + ), + err + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT cl.compat_sso_login_id + , cl.login_token + , cl.redirect_uri + , cl.created_at + , cl.fulfilled_at + , cl.exchanged_at + , cl.compat_session_id + + FROM compat_sso_logins cl + INNER JOIN compat_sessions ON compat_session_id + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("cl.compat_sso_login_id", before, after, first, last)?; + + let page: Vec = query + .build_query_as() + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); + Ok(Page { + has_next_page, + has_previous_page, + edges: edges?, + }) + } +} diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 8eda5701..ddd6e1ea 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -15,6 +15,10 @@ use sqlx::{PgConnection, Postgres, Transaction}; use crate::{ + compat::{ + PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, + PgCompatSsoLoginRepository, + }, oauth2::{PgOAuth2ClientRepository, PgOAuth2SessionRepository}, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, @@ -63,6 +67,22 @@ pub trait Repository { where Self: 'c; + type CompatSessionRepository<'c> + where + Self: 'c; + + type CompatSsoLoginRepository<'c> + where + Self: 'c; + + type CompatAccessTokenRepository<'c> + where + Self: 'c; + + type CompatRefreshTokenRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; @@ -72,6 +92,10 @@ pub trait Repository { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_>; + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_>; + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_>; + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_>; + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_>; } impl Repository for PgConnection { @@ -84,6 +108,10 @@ impl Repository for PgConnection { type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; + type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; + type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; + type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -120,6 +148,22 @@ impl Repository for PgConnection { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { PgOAuth2SessionRepository::new(self) } + + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { + PgCompatSessionRepository::new(self) + } + + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { + PgCompatSsoLoginRepository::new(self) + } + + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { + PgCompatAccessTokenRepository::new(self) + } + + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { + PgCompatRefreshTokenRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -132,6 +176,10 @@ impl<'t> Repository for Transaction<'t, Postgres> { type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; type OAuth2SessionRepository<'c> = PgOAuth2SessionRepository<'c> where Self: 'c; + type CompatSessionRepository<'c> = PgCompatSessionRepository<'c> where Self: 'c; + type CompatSsoLoginRepository<'c> = PgCompatSsoLoginRepository<'c> where Self: 'c; + type CompatAccessTokenRepository<'c> = PgCompatAccessTokenRepository<'c> where Self: 'c; + type CompatRefreshTokenRepository<'c> = PgCompatRefreshTokenRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -168,4 +216,20 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn oauth2_session(&mut self) -> Self::OAuth2SessionRepository<'_> { PgOAuth2SessionRepository::new(self) } + + fn compat_session(&mut self) -> Self::CompatSessionRepository<'_> { + PgCompatSessionRepository::new(self) + } + + fn compat_sso_login(&mut self) -> Self::CompatSsoLoginRepository<'_> { + PgCompatSsoLoginRepository::new(self) + } + + fn compat_access_token(&mut self) -> Self::CompatAccessTokenRepository<'_> { + PgCompatAccessTokenRepository::new(self) + } + + fn compat_refresh_token(&mut self) -> Self::CompatRefreshTokenRepository<'_> { + PgCompatRefreshTokenRepository::new(self) + } }