diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index e1d18c35..3793fbb6 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -22,17 +22,19 @@ clippy::trait_duplication_in_bounds )] +pub(crate) mod compat; pub(crate) mod oauth2; pub(crate) mod tokens; pub(crate) mod traits; pub(crate) mod users; pub use self::{ + compat::{CompatAccessToken, CompatSession, Device}, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, }, - tokens::{AccessToken, CompatAccessToken, RefreshToken, TokenFormatError, TokenType}, + tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, traits::{StorageBackend, StorageBackendMarker}, users::{ Authentication, BrowserSession, User, UserEmail, UserEmailVerification, diff --git a/crates/data-model/src/tokens.rs b/crates/data-model/src/tokens.rs index e98fa4fd..a5f93c35 100644 --- a/crates/data-model/src/tokens.rs +++ b/crates/data-model/src/tokens.rs @@ -66,15 +66,6 @@ impl From> for RefreshToken<()> { } } -#[derive(Debug, Clone, PartialEq)] -pub struct CompatAccessToken { - pub data: T::CompatAccessTokenData, - pub token: String, - pub device_id: String, - pub created_at: DateTime, - pub deleted_at: Option>, -} - /// Type of token to generate or validate #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TokenType { @@ -86,6 +77,9 @@ pub enum TokenType { /// A legacy access token CompatAccessToken, + + /// A legacy refresh token + CompatRefreshToken, } impl TokenType { @@ -94,6 +88,7 @@ impl TokenType { TokenType::AccessToken => "mat", TokenType::RefreshToken => "mar", TokenType::CompatAccessToken => "mct", + TokenType::CompatRefreshToken => "mcr", } } @@ -102,6 +97,7 @@ impl TokenType { "mat" => Some(TokenType::AccessToken), "mar" => Some(TokenType::RefreshToken), "mct" => Some(TokenType::CompatAccessToken), + "mcr" => Some(TokenType::CompatRefreshToken), _ => None, } } @@ -181,7 +177,10 @@ impl PartialEq for TokenType { ( TokenType::AccessToken | TokenType::CompatAccessToken, OAuthTokenTypeHint::AccessToken - ) | (TokenType::RefreshToken, OAuthTokenTypeHint::RefreshToken) + ) | ( + TokenType::RefreshToken | TokenType::CompatRefreshToken, + OAuthTokenTypeHint::RefreshToken + ) ) } } @@ -234,13 +233,22 @@ mod tests { #[test] fn test_prefix_match() { - use TokenType::{AccessToken, CompatAccessToken, RefreshToken}; + use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken}; assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken)); + assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken)); assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken)); assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken)); assert_eq!(TokenType::match_prefix("matt"), None); assert_eq!(TokenType::match_prefix("marr"), None); assert_eq!(TokenType::match_prefix("ma"), None); + assert_eq!( + TokenType::match_prefix(TokenType::CompatAccessToken.prefix()), + Some(TokenType::CompatAccessToken) + ); + assert_eq!( + TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()), + Some(TokenType::CompatRefreshToken) + ); assert_eq!( TokenType::match_prefix(TokenType::AccessToken.prefix()), Some(TokenType::AccessToken) @@ -255,28 +263,23 @@ mod tests { fn test_generate_and_check() { const COUNT: usize = 500; // Generate 500 of each token type let mut rng = thread_rng(); - // Generate many access tokens - let tokens: HashSet = (0..COUNT) - .map(|_| TokenType::AccessToken.generate(&mut rng)) - .collect(); - // Check that they are all different - assert_eq!(tokens.len(), COUNT, "All tokens are unique"); + for t in [ + TokenType::CompatAccessToken, + TokenType::CompatRefreshToken, + TokenType::AccessToken, + TokenType::RefreshToken, + ] { + // Generate many tokens + let tokens: HashSet = (0..COUNT).map(|_| t.generate(&mut rng)).collect(); - // Check that they are all valid and detected as access tokens - for token in tokens { - assert_eq!(TokenType::check(&token).unwrap(), TokenType::AccessToken); - } + // Check that they are all different + assert_eq!(tokens.len(), COUNT, "All tokens are unique"); - // Same, but for refresh tokens - let tokens: HashSet = (0..COUNT) - .map(|_| TokenType::RefreshToken.generate(&mut rng)) - .collect(); - - assert_eq!(tokens.len(), COUNT, "All tokens are unique"); - - for token in tokens { - assert_eq!(TokenType::check(&token).unwrap(), TokenType::RefreshToken); + // Check that they are all valid and detected as the right token type + for token in tokens { + assert_eq!(TokenType::check(&token).unwrap(), t); + } } } } diff --git a/crates/data-model/src/traits.rs b/crates/data-model/src/traits.rs index 0cf3dfe2..569ffe2b 100644 --- a/crates/data-model/src/traits.rs +++ b/crates/data-model/src/traits.rs @@ -35,6 +35,8 @@ pub trait StorageBackend { type AccessTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; type RefreshTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; type CompatAccessTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; + type CompatRefreshTokenData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; + type CompatSessionData: Clone + Debug + PartialEq + Serialize + DeserializeOwned + Default; } impl StorageBackend for () { @@ -44,6 +46,8 @@ impl StorageBackend for () { type BrowserSessionData = (); type ClientData = (); type CompatAccessTokenData = (); + type CompatRefreshTokenData = (); + type CompatSessionData = (); type RefreshTokenData = (); type SessionData = (); type UserData = (); diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 5605f25b..212d9464 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -15,11 +15,12 @@ use axum::{response::IntoResponse, Extension, Json}; use hyper::StatusCode; use mas_config::MatrixConfig; -use mas_data_model::TokenType; +use mas_data_model::{Device, TokenType}; use mas_storage::compat::compat_login; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use rand::thread_rng; use serde::{Deserialize, Serialize}; use sqlx::PgPool; +use thiserror::Error; use super::MatrixError; @@ -69,13 +70,19 @@ pub enum Identifier { #[derive(Debug, Serialize)] pub struct ResponseBody { access_token: String, - device_id: String, + device_id: Device, user_id: String, } +#[derive(Debug, Error)] pub enum RouteError { + #[error(transparent)] Internal(Box), + + #[error("unsupported login method")] Unsupported, + + #[error("login failed")] LoginFailed, } @@ -108,6 +115,7 @@ impl IntoResponse for RouteError { } } +#[tracing::instrument(skip_all, err)] pub(crate) async fn post( Extension(pool): Extension, Extension(config): Extension, @@ -124,26 +132,22 @@ pub(crate) async fn post( } }; - let (token, device_id) = { + let (token, device) = { let mut rng = thread_rng(); let token = TokenType::CompatAccessToken.generate(&mut rng); - let device_id: String = rng - .sample_iter(&Alphanumeric) - .take(10) - .map(char::from) - .collect(); - (token, device_id) + let device = Device::generate(&mut rng); + (token, device) }; - let (token, user) = compat_login(&mut conn, &username, &password, device_id, token) + let (token, session) = compat_login(&mut conn, &username, &password, device, token) .await .map_err(|_| RouteError::LoginFailed)?; - let user_id = format!("@{}:{}", user.username, config.homeserver); + let user_id = format!("@{}:{}", session.user.username, config.homeserver); Ok(Json(ResponseBody { access_token: token.token, - device_id: token.device_id, + device_id: session.device, user_id, })) } diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 07233279..e6b72b3f 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -21,7 +21,7 @@ use axum_extra::extract::PrivateCookieJar; use hyper::StatusCode; use mas_axum_utils::SessionInfoExt; use mas_config::Encrypter; -use mas_data_model::{AuthorizationCode, Pkce}; +use mas_data_model::{AuthorizationCode, Device, Pkce}; use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; use mas_router::{PostAuthAction, Route}; use mas_storage::oauth2::{ @@ -38,7 +38,6 @@ use oauth2_types::{ pkce, prelude::*, requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode}, - scope::ScopeToken, }; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use serde::Deserialize; @@ -252,15 +251,8 @@ pub(crate) async fn get( }; // Generate the device ID - // TODO: this should probably be done somewhere else? - let device_id: String = thread_rng() - .sample_iter(&Alphanumeric) - .take(10) - .map(char::from) - .collect(); - let device_scope: ScopeToken = format!("urn:matrix:device:{}", device_id) - .parse() - .context("could not parse generated device scope")?; + let device = Device::generate(&mut thread_rng()); + let device_scope = device.to_scope_token(); let scope = { let mut s = params.auth.scope.clone(); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 3e9b9665..c34512c1 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -26,10 +26,7 @@ use mas_storage::{ refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError}, }, }; -use oauth2_types::{ - requests::{IntrospectionRequest, IntrospectionResponse}, - scope::ScopeToken, -}; +use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; use sqlx::PgPool; use thiserror::Error; @@ -217,28 +214,29 @@ pub(crate) async fn post( } } TokenType::CompatAccessToken => { - let (token, user) = lookup_active_compat_access_token(&mut conn, token).await?; + let (token, session) = lookup_active_compat_access_token(&mut conn, token).await?; - let device_scope: ScopeToken = format!("urn:matrix:device:{}", token.device_id) - .parse() - .unwrap(); + let device_scope = session.device.to_scope_token(); let scope = [device_scope].into_iter().collect(); IntrospectionResponse { active: true, scope: Some(scope), client_id: Some("legacy".into()), - username: Some(user.username), + username: Some(session.user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), - exp: None, + exp: token.exp(), iat: Some(token.created_at), nbf: Some(token.created_at), - sub: Some(user.sub), + sub: Some(session.user.sub), aud: None, iss: None, jti: None, } } + TokenType::CompatRefreshToken => { + todo!() + } }; Ok(Json(reply)) diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 94b8bb28..f42fd1a3 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1074,6 +1074,104 @@ }, "query": "\n INSERT INTO oauth2_sessions\n (user_session_id, oauth2_client_id, scope)\n SELECT\n $1,\n og.oauth2_client_id,\n og.scope\n FROM\n oauth2_authorization_grants og\n WHERE\n og.id = $2\n RETURNING id, created_at\n " }, + "7d94b7b6ed2f68479adb6247880b32bc378790174a81a05dff50b92e9be15bf8": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "compat_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_after", + "ordinal": 3, + "type_info": "Int4" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Int8" + }, + { + "name": "compat_session_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_deleted_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 8, + "type_info": "Int8" + }, + { + "name": "user_username!", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_email_id?", + "ordinal": 10, + "type_info": "Int8" + }, + { + "name": "user_email?", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "user_email_created_at?", + "ordinal": 12, + "type_info": "Timestamptz" + }, + { + "name": "user_email_confirmed_at?", + "ordinal": 13, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_after AS \"compat_access_token_expires_after\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n ON cs.id = ct.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE ct.token = $1\n AND cs.deleted_at IS NULL\n " + }, "7de9cfa6e90ba20f5b298ea387cf13a7e40d0f5b3eb903a80d06fbe33074d596": { "describe": { "columns": [ @@ -1875,86 +1973,6 @@ }, "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.active\n " }, - "eb12b728e0d58f6bba1a20fbc9bd01f3a6cbae7e40961b39eac3b294609edf2f": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Int8" - }, - { - "name": "compat_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_deleted_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 5, - "type_info": "Int8" - }, - { - "name": "user_username!", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 7, - "type_info": "Int8" - }, - { - "name": "user_email?", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 9, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 10, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n ON cs.id = ct.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE ct.token = $1\n AND cs.deleted_at IS NULL\n " - }, "ebf73a609e81830b16700d2c315fffa93fd85b2886e29f234d9953b18a9f72b5": { "describe": { "columns": [], diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index c4395d50..072836ea 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -14,8 +14,8 @@ use anyhow::Context; use argon2::{Argon2, PasswordHash}; -use chrono::{DateTime, Utc}; -use mas_data_model::{CompatAccessToken, User, UserEmail}; +use chrono::{DateTime, Duration, Utc}; +use mas_data_model::{CompatAccessToken, CompatSession, Device, User, UserEmail}; use sqlx::{Acquire, PgExecutor, Postgres}; use thiserror::Error; use tokio::task; @@ -28,7 +28,10 @@ use crate::{ pub struct CompatAccessTokenLookup { compat_access_token_id: i64, compat_access_token: String, + compat_access_token_expires_after: Option, compat_access_token_created_at: DateTime, + compat_session_id: i64, + compat_session_created_at: DateTime, compat_session_deleted_at: Option>, compat_session_device_id: String, user_id: i64, @@ -53,14 +56,14 @@ impl CompatAccessTokenLookupError { } } -#[tracing::instrument(skip(executor))] +#[tracing::instrument(skip(executor), err)] pub async fn lookup_active_compat_access_token( executor: impl PgExecutor<'_>, token: &str, ) -> Result< ( CompatAccessToken, - User, + CompatSession, ), CompatAccessTokenLookupError, > { @@ -71,6 +74,9 @@ pub async fn lookup_active_compat_access_token( ct.id AS "compat_access_token_id", ct.token AS "compat_access_token", ct.created_at AS "compat_access_token_created_at", + ct.expires_after AS "compat_access_token_expires_after", + cs.id AS "compat_session_id", + cs.created_at AS "compat_session_created_at", cs.deleted_at AS "compat_session_deleted_at", cs.device_id AS "compat_session_device_id", u.id AS "user_id!", @@ -101,8 +107,9 @@ pub async fn lookup_active_compat_access_token( data: res.compat_access_token_id, token: res.compat_access_token, created_at: res.compat_access_token_created_at, - deleted_at: res.compat_session_deleted_at, - device_id: res.compat_session_device_id, + expires_after: res + .compat_access_token_expires_after + .map(|d| Duration::seconds(d.into())), }; let primary_email = match ( @@ -128,20 +135,30 @@ pub async fn lookup_active_compat_access_token( primary_email, }; - Ok((token, user)) + let device = Device::try_from(res.compat_session_device_id).unwrap(); + + let session = CompatSession { + data: res.compat_session_id, + user, + device, + created_at: res.compat_session_created_at, + deleted_at: res.compat_session_deleted_at, + }; + + Ok((token, session)) } -#[tracing::instrument(skip(conn, password, token))] +#[tracing::instrument(skip(conn, password, token), err)] pub async fn compat_login( conn: impl Acquire<'_, Database = Postgres>, username: &str, password: &str, - device_id: String, + device: Device, token: String, ) -> Result< ( CompatAccessToken, - User, + CompatSession, ), anyhow::Error, > { @@ -176,7 +193,7 @@ pub async fn compat_login( .instrument(tracing::info_span!("Verify hashed password")) .await??; - let session = sqlx::query_as!( + let res = sqlx::query_as!( IdAndCreationTime, r#" INSERT INTO compat_sessions (user_id, device_id) @@ -184,12 +201,21 @@ pub async fn compat_login( RETURNING id, created_at "#, user.data, - device_id, + device.as_str(), ) .fetch_one(&mut txn) + .instrument(tracing::info_span!("Insert compat session")) .await .context("could not insert compat session")?; + let session = CompatSession { + data: res.id, + user, + device, + created_at: res.created_at, + deleted_at: None, + }; + let res = sqlx::query_as!( IdAndCreationTime, r#" @@ -197,10 +223,11 @@ pub async fn compat_login( VALUES ($1, $2) RETURNING id, created_at "#, - session.id, + session.data, token, ) .fetch_one(&mut txn) + .instrument(tracing::info_span!("Insert compat access token")) .await .context("could not insert compat access token")?; @@ -208,15 +235,14 @@ pub async fn compat_login( data: res.id, token, created_at: res.created_at, - deleted_at: None, - device_id, + expires_after: None, }; txn.commit().await.context("could not commit transaction")?; - Ok((token, user)) + Ok((token, session)) } -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all, err)] pub async fn compat_logout( executor: impl PgExecutor<'_>, token: &str, diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index cd351f2f..b0b75f3b 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -43,6 +43,8 @@ impl StorageBackend for PostgresqlBackend { type BrowserSessionData = i64; type ClientData = i64; type CompatAccessTokenData = i64; + type CompatRefreshTokenData = i64; + type CompatSessionData = i64; type RefreshTokenData = i64; type SessionData = i64; type UserData = i64;