1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Better data-model for compat sessions & devices

This commit is contained in:
Quentin Gliech
2022-05-18 14:03:14 +02:00
parent 33204b7cf8
commit c4fa87e457
9 changed files with 212 additions and 163 deletions

View File

@ -22,17 +22,19 @@
clippy::trait_duplication_in_bounds clippy::trait_duplication_in_bounds
)] )]
pub(crate) mod compat;
pub(crate) mod oauth2; pub(crate) mod oauth2;
pub(crate) mod tokens; pub(crate) mod tokens;
pub(crate) mod traits; pub(crate) mod traits;
pub(crate) mod users; pub(crate) mod users;
pub use self::{ pub use self::{
compat::{CompatAccessToken, CompatSession, Device},
oauth2::{ oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client,
InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session,
}, },
tokens::{AccessToken, CompatAccessToken, RefreshToken, TokenFormatError, TokenType}, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker}, traits::{StorageBackend, StorageBackendMarker},
users::{ users::{
Authentication, BrowserSession, User, UserEmail, UserEmailVerification, Authentication, BrowserSession, User, UserEmail, UserEmailVerification,

View File

@ -66,15 +66,6 @@ impl<S: StorageBackendMarker> From<RefreshToken<S>> for RefreshToken<()> {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub struct CompatAccessToken<T: StorageBackend> {
pub data: T::CompatAccessTokenData,
pub token: String,
pub device_id: String,
pub created_at: DateTime<Utc>,
pub deleted_at: Option<DateTime<Utc>>,
}
/// Type of token to generate or validate /// Type of token to generate or validate
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType { pub enum TokenType {
@ -86,6 +77,9 @@ pub enum TokenType {
/// A legacy access token /// A legacy access token
CompatAccessToken, CompatAccessToken,
/// A legacy refresh token
CompatRefreshToken,
} }
impl TokenType { impl TokenType {
@ -94,6 +88,7 @@ impl TokenType {
TokenType::AccessToken => "mat", TokenType::AccessToken => "mat",
TokenType::RefreshToken => "mar", TokenType::RefreshToken => "mar",
TokenType::CompatAccessToken => "mct", TokenType::CompatAccessToken => "mct",
TokenType::CompatRefreshToken => "mcr",
} }
} }
@ -102,6 +97,7 @@ impl TokenType {
"mat" => Some(TokenType::AccessToken), "mat" => Some(TokenType::AccessToken),
"mar" => Some(TokenType::RefreshToken), "mar" => Some(TokenType::RefreshToken),
"mct" => Some(TokenType::CompatAccessToken), "mct" => Some(TokenType::CompatAccessToken),
"mcr" => Some(TokenType::CompatRefreshToken),
_ => None, _ => None,
} }
} }
@ -181,7 +177,10 @@ impl PartialEq<OAuthTokenTypeHint> for TokenType {
( (
TokenType::AccessToken | TokenType::CompatAccessToken, TokenType::AccessToken | TokenType::CompatAccessToken,
OAuthTokenTypeHint::AccessToken OAuthTokenTypeHint::AccessToken
) | (TokenType::RefreshToken, OAuthTokenTypeHint::RefreshToken) ) | (
TokenType::RefreshToken | TokenType::CompatRefreshToken,
OAuthTokenTypeHint::RefreshToken
)
) )
} }
} }
@ -234,13 +233,22 @@ mod tests {
#[test] #[test]
fn test_prefix_match() { fn test_prefix_match() {
use TokenType::{AccessToken, CompatAccessToken, RefreshToken}; use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken};
assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken)); assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken));
assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken));
assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken)); assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken)); assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
assert_eq!(TokenType::match_prefix("matt"), None); assert_eq!(TokenType::match_prefix("matt"), None);
assert_eq!(TokenType::match_prefix("marr"), None); assert_eq!(TokenType::match_prefix("marr"), None);
assert_eq!(TokenType::match_prefix("ma"), None); assert_eq!(TokenType::match_prefix("ma"), None);
assert_eq!(
TokenType::match_prefix(TokenType::CompatAccessToken.prefix()),
Some(TokenType::CompatAccessToken)
);
assert_eq!(
TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()),
Some(TokenType::CompatRefreshToken)
);
assert_eq!( assert_eq!(
TokenType::match_prefix(TokenType::AccessToken.prefix()), TokenType::match_prefix(TokenType::AccessToken.prefix()),
Some(TokenType::AccessToken) Some(TokenType::AccessToken)
@ -255,28 +263,23 @@ mod tests {
fn test_generate_and_check() { fn test_generate_and_check() {
const COUNT: usize = 500; // Generate 500 of each token type const COUNT: usize = 500; // Generate 500 of each token type
let mut rng = thread_rng(); let mut rng = thread_rng();
// Generate many access tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| TokenType::AccessToken.generate(&mut rng))
.collect();
// Check that they are all different for t in [
assert_eq!(tokens.len(), COUNT, "All tokens are unique"); TokenType::CompatAccessToken,
TokenType::CompatRefreshToken,
TokenType::AccessToken,
TokenType::RefreshToken,
] {
// Generate many tokens
let tokens: HashSet<String> = (0..COUNT).map(|_| t.generate(&mut rng)).collect();
// Check that they are all valid and detected as access tokens // Check that they are all different
for token in tokens { assert_eq!(tokens.len(), COUNT, "All tokens are unique");
assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken);
}
// Same, but for refresh tokens // Check that they are all valid and detected as the right token type
let tokens: HashSet<String> = (0..COUNT) for token in tokens {
.map(|_| TokenType::RefreshToken.generate(&mut rng)) assert_eq!(TokenType::check(&token).unwrap(), t);
.collect(); }
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
for token in tokens {
assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken);
} }
} }
} }

View File

@ -35,6 +35,8 @@ pub trait StorageBackend {
type AccessTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; type AccessTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default;
type RefreshTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; type RefreshTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default;
type CompatAccessTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; type CompatAccessTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default;
type CompatRefreshTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default;
type CompatSessionData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default;
} }
impl StorageBackend for () { impl StorageBackend for () {
@ -44,6 +46,8 @@ impl StorageBackend for () {
type BrowserSessionData = (); type BrowserSessionData = ();
type ClientData = (); type ClientData = ();
type CompatAccessTokenData = (); type CompatAccessTokenData = ();
type CompatRefreshTokenData = ();
type CompatSessionData = ();
type RefreshTokenData = (); type RefreshTokenData = ();
type SessionData = (); type SessionData = ();
type UserData = (); type UserData = ();

View File

@ -15,11 +15,12 @@
use axum::{response::IntoResponse, Extension, Json}; use axum::{response::IntoResponse, Extension, Json};
use hyper::StatusCode; use hyper::StatusCode;
use mas_config::MatrixConfig; use mas_config::MatrixConfig;
use mas_data_model::TokenType; use mas_data_model::{Device, TokenType};
use mas_storage::compat::compat_login; use mas_storage::compat::compat_login;
use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rand::thread_rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error;
use super::MatrixError; use super::MatrixError;
@ -69,13 +70,19 @@ pub enum Identifier {
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct ResponseBody { pub struct ResponseBody {
access_token: String, access_token: String,
device_id: String, device_id: Device,
user_id: String, user_id: String,
} }
#[derive(Debug, Error)]
pub enum RouteError { pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>), Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("unsupported login method")]
Unsupported, Unsupported,
#[error("login failed")]
LoginFailed, LoginFailed,
} }
@ -108,6 +115,7 @@ impl IntoResponse for RouteError {
} }
} }
#[tracing::instrument(skip_all, err)]
pub(crate) async fn post( pub(crate) async fn post(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Extension(config): Extension<MatrixConfig>, Extension(config): Extension<MatrixConfig>,
@ -124,26 +132,22 @@ pub(crate) async fn post(
} }
}; };
let (token, device_id) = { let (token, device) = {
let mut rng = thread_rng(); let mut rng = thread_rng();
let token = TokenType::CompatAccessToken.generate(&mut rng); let token = TokenType::CompatAccessToken.generate(&mut rng);
let device_id: String = rng let device = Device::generate(&mut rng);
.sample_iter(&Alphanumeric) (token, device)
.take(10)
.map(char::from)
.collect();
(token, device_id)
}; };
let (token, user) = compat_login(&mut conn, &username, &password, device_id, token) let (token, session) = compat_login(&mut conn, &username, &password, device, token)
.await .await
.map_err(|_| RouteError::LoginFailed)?; .map_err(|_| RouteError::LoginFailed)?;
let user_id = format!("@{}:{}", user.username, config.homeserver); let user_id = format!("@{}:{}", session.user.username, config.homeserver);
Ok(Json(ResponseBody { Ok(Json(ResponseBody {
access_token: token.token, access_token: token.token,
device_id: token.device_id, device_id: session.device,
user_id, user_id,
})) }))
} }

View File

@ -21,7 +21,7 @@ use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt; use mas_axum_utils::SessionInfoExt;
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{AuthorizationCode, Pkce}; use mas_data_model::{AuthorizationCode, Device, Pkce};
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{ use mas_storage::oauth2::{
@ -38,7 +38,6 @@ use oauth2_types::{
pkce, pkce,
prelude::*, prelude::*,
requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode}, requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
scope::ScopeToken,
}; };
use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::Deserialize; use serde::Deserialize;
@ -252,15 +251,8 @@ pub(crate) async fn get(
}; };
// Generate the device ID // Generate the device ID
// TODO: this should probably be done somewhere else? let device = Device::generate(&mut thread_rng());
let device_id: String = thread_rng() let device_scope = device.to_scope_token();
.sample_iter(&Alphanumeric)
.take(10)
.map(char::from)
.collect();
let device_scope: ScopeToken = format!("urn:matrix:device:{}", device_id)
.parse()
.context("could not parse generated device scope")?;
let scope = { let scope = {
let mut s = params.auth.scope.clone(); let mut s = params.auth.scope.clone();

View File

@ -26,10 +26,7 @@ use mas_storage::{
refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError}, refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError},
}, },
}; };
use oauth2_types::{ use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
requests::{IntrospectionRequest, IntrospectionResponse},
scope::ScopeToken,
};
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
@ -217,28 +214,29 @@ pub(crate) async fn post(
} }
} }
TokenType::CompatAccessToken => { TokenType::CompatAccessToken => {
let (token, user) = lookup_active_compat_access_token(&mut conn, token).await?; let (token, session) = lookup_active_compat_access_token(&mut conn, token).await?;
let device_scope: ScopeToken = format!("urn:matrix:device:{}", token.device_id) let device_scope = session.device.to_scope_token();
.parse()
.unwrap();
let scope = [device_scope].into_iter().collect(); let scope = [device_scope].into_iter().collect();
IntrospectionResponse { IntrospectionResponse {
active: true, active: true,
scope: Some(scope), scope: Some(scope),
client_id: Some("legacy".into()), client_id: Some("legacy".into()),
username: Some(user.username), username: Some(session.user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken), token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: None, exp: token.exp(),
iat: Some(token.created_at), iat: Some(token.created_at),
nbf: Some(token.created_at), nbf: Some(token.created_at),
sub: Some(user.sub), sub: Some(session.user.sub),
aud: None, aud: None,
iss: None, iss: None,
jti: None, jti: None,
} }
} }
TokenType::CompatRefreshToken => {
todo!()
}
}; };
Ok(Json(reply)) Ok(Json(reply))

View File

@ -1074,6 +1074,104 @@
}, },
"query": "\n INSERT INTO oauth2_sessions\n (user_session_id, oauth2_client_id, scope)\n SELECT\n $1,\n og.oauth2_client_id,\n og.scope\n FROM\n oauth2_authorization_grants og\n WHERE\n og.id = $2\n RETURNING id, created_at\n " "query": "\n INSERT INTO oauth2_sessions\n (user_session_id, oauth2_client_id, scope)\n SELECT\n $1,\n og.oauth2_client_id,\n og.scope\n FROM\n oauth2_authorization_grants og\n WHERE\n og.id = $2\n RETURNING id, created_at\n "
}, },
"7d94b7b6ed2f68479adb6247880b32bc378790174a81a05dff50b92e9be15bf8": {
"describe": {
"columns": [
{
"name": "compat_access_token_id",
"ordinal": 0,
"type_info": "Int8"
},
{
"name": "compat_access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "compat_access_token_created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "compat_access_token_expires_after",
"ordinal": 3,
"type_info": "Int4"
},
{
"name": "compat_session_id",
"ordinal": 4,
"type_info": "Int8"
},
{
"name": "compat_session_created_at",
"ordinal": 5,
"type_info": "Timestamptz"
},
{
"name": "compat_session_deleted_at",
"ordinal": 6,
"type_info": "Timestamptz"
},
{
"name": "compat_session_device_id",
"ordinal": 7,
"type_info": "Text"
},
{
"name": "user_id!",
"ordinal": 8,
"type_info": "Int8"
},
{
"name": "user_username!",
"ordinal": 9,
"type_info": "Text"
},
{
"name": "user_email_id?",
"ordinal": 10,
"type_info": "Int8"
},
{
"name": "user_email?",
"ordinal": 11,
"type_info": "Text"
},
{
"name": "user_email_created_at?",
"ordinal": 12,
"type_info": "Timestamptz"
},
{
"name": "user_email_confirmed_at?",
"ordinal": 13,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
true,
false,
false,
true,
false,
false,
false,
false,
false,
false,
true
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_after AS \"compat_access_token_expires_after\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n ON cs.id = ct.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE ct.token = $1\n AND cs.deleted_at IS NULL\n "
},
"7de9cfa6e90ba20f5b298ea387cf13a7e40d0f5b3eb903a80d06fbe33074d596": { "7de9cfa6e90ba20f5b298ea387cf13a7e40d0f5b3eb903a80d06fbe33074d596": {
"describe": { "describe": {
"columns": [ "columns": [
@ -1875,86 +1973,6 @@
}, },
"query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.active\n " "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.active\n "
}, },
"eb12b728e0d58f6bba1a20fbc9bd01f3a6cbae7e40961b39eac3b294609edf2f": {
"describe": {
"columns": [
{
"name": "compat_access_token_id",
"ordinal": 0,
"type_info": "Int8"
},
{
"name": "compat_access_token",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "compat_access_token_created_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "compat_session_deleted_at",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "compat_session_device_id",
"ordinal": 4,
"type_info": "Text"
},
{
"name": "user_id!",
"ordinal": 5,
"type_info": "Int8"
},
{
"name": "user_username!",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "user_email_id?",
"ordinal": 7,
"type_info": "Int8"
},
{
"name": "user_email?",
"ordinal": 8,
"type_info": "Text"
},
{
"name": "user_email_created_at?",
"ordinal": 9,
"type_info": "Timestamptz"
},
{
"name": "user_email_confirmed_at?",
"ordinal": 10,
"type_info": "Timestamptz"
}
],
"nullable": [
false,
false,
false,
true,
false,
false,
false,
false,
false,
false,
true
],
"parameters": {
"Left": [
"Text"
]
}
},
"query": "\n SELECT\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n ON cs.id = ct.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE ct.token = $1\n AND cs.deleted_at IS NULL\n "
},
"ebf73a609e81830b16700d2c315fffa93fd85b2886e29f234d9953b18a9f72b5": { "ebf73a609e81830b16700d2c315fffa93fd85b2886e29f234d9953b18a9f72b5": {
"describe": { "describe": {
"columns": [], "columns": [],

View File

@ -14,8 +14,8 @@
use anyhow::Context; use anyhow::Context;
use argon2::{Argon2, PasswordHash}; use argon2::{Argon2, PasswordHash};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{CompatAccessToken, User, UserEmail}; use mas_data_model::{CompatAccessToken, CompatSession, Device, User, UserEmail};
use sqlx::{Acquire, PgExecutor, Postgres}; use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
@ -28,7 +28,10 @@ use crate::{
pub struct CompatAccessTokenLookup { pub struct CompatAccessTokenLookup {
compat_access_token_id: i64, compat_access_token_id: i64,
compat_access_token: String, compat_access_token: String,
compat_access_token_expires_after: Option<i32>,
compat_access_token_created_at: DateTime<Utc>, compat_access_token_created_at: DateTime<Utc>,
compat_session_id: i64,
compat_session_created_at: DateTime<Utc>,
compat_session_deleted_at: Option<DateTime<Utc>>, compat_session_deleted_at: Option<DateTime<Utc>>,
compat_session_device_id: String, compat_session_device_id: String,
user_id: i64, user_id: i64,
@ -53,14 +56,14 @@ impl CompatAccessTokenLookupError {
} }
} }
#[tracing::instrument(skip(executor))] #[tracing::instrument(skip(executor), err)]
pub async fn lookup_active_compat_access_token( pub async fn lookup_active_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
token: &str, token: &str,
) -> Result< ) -> Result<
( (
CompatAccessToken<PostgresqlBackend>, CompatAccessToken<PostgresqlBackend>,
User<PostgresqlBackend>, CompatSession<PostgresqlBackend>,
), ),
CompatAccessTokenLookupError, CompatAccessTokenLookupError,
> { > {
@ -71,6 +74,9 @@ pub async fn lookup_active_compat_access_token(
ct.id AS "compat_access_token_id", ct.id AS "compat_access_token_id",
ct.token AS "compat_access_token", ct.token AS "compat_access_token",
ct.created_at AS "compat_access_token_created_at", ct.created_at AS "compat_access_token_created_at",
ct.expires_after AS "compat_access_token_expires_after",
cs.id AS "compat_session_id",
cs.created_at AS "compat_session_created_at",
cs.deleted_at AS "compat_session_deleted_at", cs.deleted_at AS "compat_session_deleted_at",
cs.device_id AS "compat_session_device_id", cs.device_id AS "compat_session_device_id",
u.id AS "user_id!", u.id AS "user_id!",
@ -101,8 +107,9 @@ pub async fn lookup_active_compat_access_token(
data: res.compat_access_token_id, data: res.compat_access_token_id,
token: res.compat_access_token, token: res.compat_access_token,
created_at: res.compat_access_token_created_at, created_at: res.compat_access_token_created_at,
deleted_at: res.compat_session_deleted_at, expires_after: res
device_id: res.compat_session_device_id, .compat_access_token_expires_after
.map(|d| Duration::seconds(d.into())),
}; };
let primary_email = match ( let primary_email = match (
@ -128,20 +135,30 @@ pub async fn lookup_active_compat_access_token(
primary_email, primary_email,
}; };
Ok((token, user)) let device = Device::try_from(res.compat_session_device_id).unwrap();
let session = CompatSession {
data: res.compat_session_id,
user,
device,
created_at: res.compat_session_created_at,
deleted_at: res.compat_session_deleted_at,
};
Ok((token, session))
} }
#[tracing::instrument(skip(conn, password, token))] #[tracing::instrument(skip(conn, password, token), err)]
pub async fn compat_login( pub async fn compat_login(
conn: impl Acquire<'_, Database = Postgres>, conn: impl Acquire<'_, Database = Postgres>,
username: &str, username: &str,
password: &str, password: &str,
device_id: String, device: Device,
token: String, token: String,
) -> Result< ) -> Result<
( (
CompatAccessToken<PostgresqlBackend>, CompatAccessToken<PostgresqlBackend>,
User<PostgresqlBackend>, CompatSession<PostgresqlBackend>,
), ),
anyhow::Error, anyhow::Error,
> { > {
@ -176,7 +193,7 @@ pub async fn compat_login(
.instrument(tracing::info_span!("Verify hashed password")) .instrument(tracing::info_span!("Verify hashed password"))
.await??; .await??;
let session = sqlx::query_as!( let res = sqlx::query_as!(
IdAndCreationTime, IdAndCreationTime,
r#" r#"
INSERT INTO compat_sessions (user_id, device_id) INSERT INTO compat_sessions (user_id, device_id)
@ -184,12 +201,21 @@ pub async fn compat_login(
RETURNING id, created_at RETURNING id, created_at
"#, "#,
user.data, user.data,
device_id, device.as_str(),
) )
.fetch_one(&mut txn) .fetch_one(&mut txn)
.instrument(tracing::info_span!("Insert compat session"))
.await .await
.context("could not insert compat session")?; .context("could not insert compat session")?;
let session = CompatSession {
data: res.id,
user,
device,
created_at: res.created_at,
deleted_at: None,
};
let res = sqlx::query_as!( let res = sqlx::query_as!(
IdAndCreationTime, IdAndCreationTime,
r#" r#"
@ -197,10 +223,11 @@ pub async fn compat_login(
VALUES ($1, $2) VALUES ($1, $2)
RETURNING id, created_at RETURNING id, created_at
"#, "#,
session.id, session.data,
token, token,
) )
.fetch_one(&mut txn) .fetch_one(&mut txn)
.instrument(tracing::info_span!("Insert compat access token"))
.await .await
.context("could not insert compat access token")?; .context("could not insert compat access token")?;
@ -208,15 +235,14 @@ pub async fn compat_login(
data: res.id, data: res.id,
token, token,
created_at: res.created_at, created_at: res.created_at,
deleted_at: None, expires_after: None,
device_id,
}; };
txn.commit().await.context("could not commit transaction")?; txn.commit().await.context("could not commit transaction")?;
Ok((token, user)) Ok((token, session))
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all, err)]
pub async fn compat_logout( pub async fn compat_logout(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
token: &str, token: &str,

View File

@ -43,6 +43,8 @@ impl StorageBackend for PostgresqlBackend {
type BrowserSessionData = i64; type BrowserSessionData = i64;
type ClientData = i64; type ClientData = i64;
type CompatAccessTokenData = i64; type CompatAccessTokenData = i64;
type CompatRefreshTokenData = i64;
type CompatSessionData = i64;
type RefreshTokenData = i64; type RefreshTokenData = i64;
type SessionData = i64; type SessionData = i64;
type UserData = i64; type UserData = i64;