From 309c89fc4f15cdbf526d0c6e22af5fb37c4d371d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 18 May 2022 16:10:31 +0200 Subject: [PATCH] Handle legacy token expiration & refresh tokens --- crates/data-model/src/compat.rs | 116 ++++++ crates/data-model/src/lib.rs | 2 +- crates/handlers/src/compat/login.rs | 67 +++- crates/handlers/src/oauth2/introspection.rs | 37 +- .../20220512150806_compat_login.down.sql | 2 + .../20220512150806_compat_login.up.sql | 25 +- ...20517085913_compat_refresh_tokens.down.sql | 1 - ...0220517085913_compat_refresh_tokens.up.sql | 50 --- crates/storage/sqlx-data.json | 344 +++++++++++++----- crates/storage/src/compat.rs | 237 ++++++++++-- 10 files changed, 682 insertions(+), 199 deletions(-) create mode 100644 crates/data-model/src/compat.rs delete mode 100644 crates/storage/migrations/20220517085913_compat_refresh_tokens.down.sql delete mode 100644 crates/storage/migrations/20220517085913_compat_refresh_tokens.up.sql diff --git a/crates/data-model/src/compat.rs b/crates/data-model/src/compat.rs new file mode 100644 index 00000000..1fd6d09f --- /dev/null +++ b/crates/data-model/src/compat.rs @@ -0,0 +1,116 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use oauth2_types::scope::ScopeToken; +use rand::{ + distributions::{Alphanumeric, DistString}, + Rng, +}; +use serde::Serialize; +use thiserror::Error; + +use crate::{StorageBackend, StorageBackendMarker, User}; + +static DEVICE_ID_LENGTH: usize = 10; + +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(transparent)] +pub struct Device { + id: String, +} + +#[derive(Debug, Error)] +pub enum InvalidDeviceID { + #[error("Device ID does not have the right size")] + InvalidLength, + + #[error("Device ID contains invalid characters")] + InvalidCharacters, +} + +impl Device { + /// Get the corresponding [`ScopeToken`] for that device + #[must_use] + pub fn to_scope_token(&self) -> ScopeToken { + // SAFETY: the inner id should only have valid scope characters + format!("urn:matrix:device:{}", self.id).parse().unwrap() + } + + /// Generate a random device ID + pub fn generate(rng: &mut R) -> Self { + let id: String = Alphanumeric.sample_string(rng, DEVICE_ID_LENGTH); + Self { id } + } + + /// Get the inner device ID as [`&str`] + #[must_use] + pub fn as_str(&self) -> &str { + &self.id + } +} + +impl TryFrom for Device { + type Error = InvalidDeviceID; + + /// Create a [`Device`] out of an ID, validating the ID has the right shape + fn try_from(id: String) -> Result { + if id.len() != DEVICE_ID_LENGTH { + return Err(InvalidDeviceID::InvalidLength); + } + + if !id.chars().all(|c| c.is_ascii_alphanumeric()) { + return Err(InvalidDeviceID::InvalidCharacters); + } + + Ok(Self { id }) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(bound = "T: StorageBackend")] +pub struct CompatSession { + #[serde(skip_serializing)] + pub data: T::CompatSessionData, + pub user: User, + pub device: Device, + pub created_at: DateTime, + pub deleted_at: Option>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct CompatAccessToken { + pub data: T::CompatAccessTokenData, + pub token: String, + pub created_at: DateTime, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct CompatRefreshToken { + pub data: T::RefreshTokenData, + pub token: String, + pub created_at: DateTime, +} + +impl From> for CompatAccessToken<()> { + fn from(t: CompatAccessToken) -> Self { + CompatAccessToken { + data: (), + token: t.token, + created_at: t.created_at, + expires_at: t.expires_at, + } + } +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 3793fbb6..6fd173c0 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -29,7 +29,7 @@ pub(crate) mod traits; pub(crate) mod users; pub use self::{ - compat::{CompatAccessToken, CompatSession, Device}, + compat::{CompatAccessToken, CompatRefreshToken, CompatSession, Device}, oauth2::{ AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, InvalidRedirectUriError, JwksOrJwksUri, Pkce, Session, diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 212d9464..3fe9fefc 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -13,12 +13,14 @@ // limitations under the License. use axum::{response::IntoResponse, Extension, Json}; +use chrono::Duration; use hyper::StatusCode; use mas_config::MatrixConfig; use mas_data_model::{Device, TokenType}; -use mas_storage::compat::compat_login; +use mas_storage::compat::{add_compat_access_token, add_compat_refresh_token, compat_login}; use rand::thread_rng; use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use sqlx::PgPool; use thiserror::Error; @@ -44,9 +46,18 @@ pub(crate) async fn get() -> impl IntoResponse { Json(res) } +#[derive(Debug, Serialize, Deserialize)] +pub struct RequestBody { + #[serde(flatten)] + credentials: Credentials, + + #[serde(default)] + refresh_token: bool, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] -pub enum RequestBody { +pub enum Credentials { #[serde(rename = "m.login.password")] Password { identifier: Identifier, @@ -67,11 +78,16 @@ pub enum Identifier { Unsupported, } +#[skip_serializing_none] +#[serde_as] #[derive(Debug, Serialize)] pub struct ResponseBody { access_token: String, device_id: Device, user_id: String, + refresh_token: Option, + #[serde_as(as = "Option>")] + expires_in_ms: Option, } #[derive(Debug, Error)] @@ -79,6 +95,9 @@ pub enum RouteError { #[error(transparent)] Internal(Box), + #[error(transparent)] + Anyhow(#[from] anyhow::Error), + #[error("unsupported login method")] Unsupported, @@ -95,7 +114,7 @@ impl From for RouteError { impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { - Self::Internal(_e) => MatrixError { + Self::Internal(_) | Self::Anyhow(_) => MatrixError { errcode: "M_UNKNOWN", error: "Internal server error", status: StatusCode::INTERNAL_SERVER_ERROR, @@ -121,9 +140,8 @@ pub(crate) async fn post( Extension(config): Extension, Json(input): Json, ) -> Result { - let mut conn = pool.acquire().await?; - let (username, password) = match input { - RequestBody::Password { + let (username, password) = match input.credentials { + Credentials::Password { identifier: Identifier::User { user }, password, } => (user, password), @@ -132,22 +150,43 @@ pub(crate) async fn post( } }; - let (token, device) = { - let mut rng = thread_rng(); - let token = TokenType::CompatAccessToken.generate(&mut rng); - let device = Device::generate(&mut rng); - (token, device) - }; + let mut txn = pool.begin().await?; - let (token, session) = compat_login(&mut conn, &username, &password, device, token) + let device = Device::generate(&mut thread_rng()); + let session = compat_login(&mut txn, &username, &password, device) .await .map_err(|_| RouteError::LoginFailed)?; let user_id = format!("@{}:{}", session.user.username, config.homeserver); + // If the client asked for a refreshable token, make it expire + let expires_in = if input.refresh_token { + // TODO: this should be configurable + Some(Duration::minutes(5)) + } else { + None + }; + + let access_token = TokenType::CompatAccessToken.generate(&mut thread_rng()); + let access_token = + add_compat_access_token(&mut txn, &session, access_token, expires_in).await?; + + let refresh_token = if input.refresh_token { + let refresh_token = TokenType::CompatRefreshToken.generate(&mut thread_rng()); + let refresh_token = + add_compat_refresh_token(&mut txn, &session, &access_token, refresh_token).await?; + Some(refresh_token.token) + } else { + None + }; + + txn.commit().await?; + Ok(Json(ResponseBody { - access_token: token.token, + access_token: access_token.token, device_id: session.device, user_id, + refresh_token, + expires_in_ms: expires_in, })) } diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index c34512c1..4e7c537f 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -19,7 +19,10 @@ use mas_config::Encrypter; use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_storage::{ - compat::{lookup_active_compat_access_token, CompatAccessTokenLookupError}, + compat::{ + lookup_active_compat_access_token, lookup_active_compat_refresh_token, + CompatAccessTokenLookupError, CompatRefreshTokenLookupError, + }, oauth2::{ access_token::{lookup_active_access_token, AccessTokenLookupError}, client::ClientFetchError, @@ -124,6 +127,16 @@ impl From for RouteError { } } +impl From for RouteError { + fn from(e: CompatRefreshTokenLookupError) -> Self { + if e.not_found() { + Self::UnknownToken + } else { + Self::Internal(Box::new(e)) + } + } +} + const INACTIVE: IntrospectionResponse = IntrospectionResponse { active: false, scope: None, @@ -225,7 +238,7 @@ pub(crate) async fn post( client_id: Some("legacy".into()), username: Some(session.user.username), token_type: Some(OAuthTokenTypeHint::AccessToken), - exp: token.exp(), + exp: token.expires_at, iat: Some(token.created_at), nbf: Some(token.created_at), sub: Some(session.user.sub), @@ -235,7 +248,25 @@ pub(crate) async fn post( } } TokenType::CompatRefreshToken => { - todo!() + let (token, session) = lookup_active_compat_refresh_token(&mut conn, token).await?; + + let device_scope = session.device.to_scope_token(); + let scope = [device_scope].into_iter().collect(); + + IntrospectionResponse { + active: true, + scope: Some(scope), + client_id: Some("legacy".into()), + username: Some(session.user.username), + token_type: Some(OAuthTokenTypeHint::RefreshToken), + exp: None, + iat: Some(token.created_at), + nbf: Some(token.created_at), + sub: Some(session.user.sub), + aud: None, + iss: None, + jti: None, + } } }; diff --git a/crates/storage/migrations/20220512150806_compat_login.down.sql b/crates/storage/migrations/20220512150806_compat_login.down.sql index 44f09300..d269544f 100644 --- a/crates/storage/migrations/20220512150806_compat_login.down.sql +++ b/crates/storage/migrations/20220512150806_compat_login.down.sql @@ -12,4 +12,6 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +DROP TABLE compat_refresh_tokens; DROP TABLE compat_access_tokens; +DROP TABLE compat_session; diff --git a/crates/storage/migrations/20220512150806_compat_login.up.sql b/crates/storage/migrations/20220512150806_compat_login.up.sql index cea31338..9b06e207 100644 --- a/crates/storage/migrations/20220512150806_compat_login.up.sql +++ b/crates/storage/migrations/20220512150806_compat_login.up.sql @@ -12,12 +12,31 @@ -- See the License for the specific language governing permissions and -- limitations under the License. -CREATE TABLE compat_access_tokens ( +CREATE TABLE compat_sessions ( "id" BIGSERIAL PRIMARY KEY, "user_id" BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE, - "token" TEXT UNIQUE NOT NULL, "device_id" TEXT UNIQUE NOT NULL, "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), "deleted_at" TIMESTAMP WITH TIME ZONE -) +); + +CREATE TABLE compat_access_tokens ( + "id" BIGSERIAL PRIMARY KEY, + "compat_session_id" BIGINT NOT NULL REFERENCES compat_sessions (id) ON DELETE CASCADE, + "token" TEXT UNIQUE NOT NULL, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), + "expires_at" TIMESTAMP WITH TIME ZONE +); + +CREATE TABLE compat_refresh_tokens ( + "id" BIGSERIAL PRIMARY KEY, + "compat_session_id" BIGINT NOT NULL REFERENCES compat_sessions (id) ON DELETE CASCADE, + "compat_access_token_id" BIGINT REFERENCES compat_access_tokens (id) ON DELETE SET NULL, + + "token" TEXT UNIQUE NOT NULL, + "next_token_id" BIGINT REFERENCES compat_refresh_tokens (id), + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now() +); diff --git a/crates/storage/migrations/20220517085913_compat_refresh_tokens.down.sql b/crates/storage/migrations/20220517085913_compat_refresh_tokens.down.sql deleted file mode 100644 index d2f607c5..00000000 --- a/crates/storage/migrations/20220517085913_compat_refresh_tokens.down.sql +++ /dev/null @@ -1 +0,0 @@ --- Add down migration script here diff --git a/crates/storage/migrations/20220517085913_compat_refresh_tokens.up.sql b/crates/storage/migrations/20220517085913_compat_refresh_tokens.up.sql deleted file mode 100644 index a22df45f..00000000 --- a/crates/storage/migrations/20220517085913_compat_refresh_tokens.up.sql +++ /dev/null @@ -1,50 +0,0 @@ --- Copyright 2022 The Matrix.org Foundation C.I.C. --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. - -CREATE TABLE compat_sessions ( - "id" BIGSERIAL PRIMARY KEY, - "user_id" BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE, - "device_id" TEXT UNIQUE NOT NULL, - - "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), - "deleted_at" TIMESTAMP WITH TIME ZONE -); - -INSERT INTO compat_sessions (user_id, device_id, created_at, deleted_at) - SELECT user_id, device_id, created_at, deleted_at - FROM compat_access_tokens; - -ALTER TABLE compat_access_tokens - ADD COLUMN "compat_session_id" BIGINT REFERENCES compat_sessions (id) ON DELETE CASCADE; - -UPDATE compat_access_tokens - SET compat_session_id = compat_sessions.id - FROM compat_sessions - WHERE compat_sessions.device_id = compat_access_tokens.device_id; - -ALTER TABLE compat_access_tokens - ALTER COLUMN "compat_session_id" SET NOT NULL, - DROP COLUMN "device_id", - DROP COLUMN "user_id", - DROP COLUMN "deleted_at", - ADD COLUMN "expires_after" INT; - -CREATE TABLE compat_refresh_tokens ( - "id" BIGSERIAL PRIMARY KEY, - "compat_session_id" BIGINT NOT NULL REFERENCES compat_sessions (id) ON DELETE CASCADE, - "compat_access_token_id" BIGINT REFERENCES compat_access_tokens (id) ON DELETE SET NULL, - "token" TEXT UNIQUE NOT NULL, - "next_token_id" BIGINT REFERENCES compat_refresh_tokens (id), - "created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now() -); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index f42fd1a3..b2bba71c 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -699,6 +699,104 @@ }, "query": "\n INSERT INTO oauth2_clients\n (client_id,\n encrypted_client_secret,\n response_types,\n grant_type_authorization_code,\n grant_type_refresh_token,\n contacts,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)\n RETURNING id\n " }, + "5ee505120c3bfddccd7c933de356dd035d18d56316ddf4d0be0d13530b8a643c": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "compat_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Int8" + }, + { + "name": "compat_session_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_deleted_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 8, + "type_info": "Int8" + }, + { + "name": "user_username!", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_email_id?", + "ordinal": 10, + "type_info": "Int8" + }, + { + "name": "user_email?", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "user_email_created_at?", + "ordinal": 12, + "type_info": "Timestamptz" + }, + { + "name": "user_email_confirmed_at?", + "ordinal": 13, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n ON cs.id = ct.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE ct.token = $1\n AND (ct.expires_at IS NULL OR ct.expires_at > NOW())\n AND cs.deleted_at IS NULL\n " + }, "647a2a5bbde39d0ed3931d0287b468bc7dedf6171e1dc6171a5d9f079b9ed0fa": { "describe": { "columns": [ @@ -1074,104 +1172,6 @@ }, "query": "\n INSERT INTO oauth2_sessions\n (user_session_id, oauth2_client_id, scope)\n SELECT\n $1,\n og.oauth2_client_id,\n og.scope\n FROM\n oauth2_authorization_grants og\n WHERE\n og.id = $2\n RETURNING id, created_at\n " }, - "7d94b7b6ed2f68479adb6247880b32bc378790174a81a05dff50b92e9be15bf8": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Int8" - }, - { - "name": "compat_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_after", - "ordinal": 3, - "type_info": "Int4" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Int8" - }, - { - "name": "compat_session_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_deleted_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 8, - "type_info": "Int8" - }, - { - "name": "user_username!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 10, - "type_info": "Int8" - }, - { - "name": "user_email?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 12, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 13, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_after AS \"compat_access_token_expires_after\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n ON cs.id = ct.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE ct.token = $1\n AND cs.deleted_at IS NULL\n " - }, "7de9cfa6e90ba20f5b298ea387cf13a7e40d0f5b3eb903a80d06fbe33074d596": { "describe": { "columns": [ @@ -1231,6 +1231,34 @@ }, "query": "\n INSERT INTO compat_sessions (user_id, device_id)\n VALUES ($1, $2)\n RETURNING id, created_at\n " }, + "8aed8f0b7aec4854f8dfc88f43e3e6029ef563189eff6ed1e33c3421b395040c": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "created_at", + "ordinal": 1, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false + ], + "parameters": { + "Left": [ + "Int8", + "Text", + "Interval" + ] + } + }, + "query": "\n INSERT INTO compat_access_tokens (compat_session_id, token, created_at, expires_at)\n VALUES ($1, $2, NOW(), NOW() + $3)\n RETURNING id, created_at\n " + }, "9882e49f34dff80c1442565f035a1b47ed4dbae1a405f58cf2db198885bb9f47": { "describe": { "columns": [ @@ -1803,6 +1831,34 @@ }, "query": "\n SELECT \n ue.id AS \"user_email_id\",\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n " }, + "dbf9d2ee583d4dec07d7948c7540ff39b3e1de0c6abd168f47c02401f8417eec": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "created_at", + "ordinal": 1, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false + ], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Text" + ] + } + }, + "query": "\n INSERT INTO compat_refresh_tokens (compat_session_id, compat_access_token_id, token)\n VALUES ($1, $2, $3)\n RETURNING id, created_at\n " + }, "dda03ba41249bff965cb8f129acc15f4e40807adb9b75dee0ac43edd7809de84": { "describe": { "columns": [ @@ -1982,5 +2038,97 @@ } }, "query": "TRUNCATE oauth2_client_redirect_uris, oauth2_clients RESTART IDENTITY CASCADE" + }, + "fc5d32bab9999ad383f906dbf20a45dafba1149e809155eccb4d94506ff6cf6f": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "compat_refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_refresh_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "compat_session_created_at", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_deleted_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 7, + "type_info": "Int8" + }, + { + "name": "user_username!", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "user_email_id?", + "ordinal": 9, + "type_info": "Int8" + }, + { + "name": "user_email?", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "user_email_created_at?", + "ordinal": 11, + "type_info": "Timestamptz" + }, + { + "name": "user_email_confirmed_at?", + "ordinal": 12, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + true, + false, + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n cr.id AS \"compat_refresh_token_id\",\n cr.token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n ON cs.id = cr.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE cr.token = $1\n AND cs.deleted_at IS NULL\n " } } \ No newline at end of file diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index 072836ea..43b9abf8 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -15,8 +15,10 @@ use anyhow::Context; use argon2::{Argon2, PasswordHash}; use chrono::{DateTime, Duration, Utc}; -use mas_data_model::{CompatAccessToken, CompatSession, Device, User, UserEmail}; -use sqlx::{Acquire, PgExecutor, Postgres}; +use mas_data_model::{ + CompatAccessToken, CompatRefreshToken, CompatSession, Device, User, UserEmail, +}; +use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres}; use thiserror::Error; use tokio::task; use tracing::{info_span, Instrument}; @@ -28,8 +30,8 @@ use crate::{ pub struct CompatAccessTokenLookup { compat_access_token_id: i64, compat_access_token: String, - compat_access_token_expires_after: Option, compat_access_token_created_at: DateTime, + compat_access_token_expires_at: Option>, compat_session_id: i64, compat_session_created_at: DateTime, compat_session_deleted_at: Option>, @@ -56,7 +58,7 @@ impl CompatAccessTokenLookupError { } } -#[tracing::instrument(skip(executor), err)] +#[tracing::instrument(skip_all, err)] pub async fn lookup_active_compat_access_token( executor: impl PgExecutor<'_>, token: &str, @@ -74,7 +76,7 @@ pub async fn lookup_active_compat_access_token( ct.id AS "compat_access_token_id", ct.token AS "compat_access_token", ct.created_at AS "compat_access_token_created_at", - ct.expires_after AS "compat_access_token_expires_after", + ct.expires_at AS "compat_access_token_expires_at", cs.id AS "compat_session_id", cs.created_at AS "compat_session_created_at", cs.deleted_at AS "compat_session_deleted_at", @@ -95,6 +97,7 @@ pub async fn lookup_active_compat_access_token( ON ue.id = u.primary_email_id WHERE ct.token = $1 + AND (ct.expires_at IS NULL OR ct.expires_at > NOW()) AND cs.deleted_at IS NULL "#, token, @@ -107,9 +110,7 @@ pub async fn lookup_active_compat_access_token( data: res.compat_access_token_id, token: res.compat_access_token, created_at: res.compat_access_token_created_at, - expires_after: res - .compat_access_token_expires_after - .map(|d| Duration::seconds(d.into())), + expires_at: res.compat_access_token_expires_at, }; let primary_email = match ( @@ -148,20 +149,131 @@ pub async fn lookup_active_compat_access_token( Ok((token, session)) } -#[tracing::instrument(skip(conn, password, token), err)] +pub struct CompatRefreshTokenLookup { + compat_refresh_token_id: i64, + compat_refresh_token: String, + compat_refresh_token_created_at: DateTime, + compat_session_id: i64, + compat_session_created_at: DateTime, + compat_session_deleted_at: Option>, + compat_session_device_id: String, + user_id: i64, + user_username: String, + user_email_id: Option, + user_email: Option, + user_email_created_at: Option>, + user_email_confirmed_at: Option>, +} + +#[derive(Debug, Error)] +#[error("failed to lookup compat refresh token")] +pub enum CompatRefreshTokenLookupError { + Database(#[from] sqlx::Error), + Inconsistency(#[from] DatabaseInconsistencyError), +} + +impl CompatRefreshTokenLookupError { + #[must_use] + pub fn not_found(&self) -> bool { + matches!(self, Self::Database(sqlx::Error::RowNotFound)) + } +} + +#[tracing::instrument(skip_all, err)] +pub async fn lookup_active_compat_refresh_token( + executor: impl PgExecutor<'_>, + token: &str, +) -> Result< + ( + CompatRefreshToken, + CompatSession, + ), + CompatRefreshTokenLookupError, +> { + let res = sqlx::query_as!( + CompatRefreshTokenLookup, + r#" + SELECT + cr.id AS "compat_refresh_token_id", + cr.token AS "compat_refresh_token", + cr.created_at AS "compat_refresh_token_created_at", + cs.id AS "compat_session_id", + cs.created_at AS "compat_session_created_at", + cs.deleted_at AS "compat_session_deleted_at", + cs.device_id AS "compat_session_device_id", + u.id AS "user_id!", + u.username AS "user_username!", + ue.id AS "user_email_id?", + ue.email AS "user_email?", + ue.created_at AS "user_email_created_at?", + ue.confirmed_at AS "user_email_confirmed_at?" + + FROM compat_refresh_tokens cr + INNER JOIN compat_sessions cs + ON cs.id = cr.compat_session_id + INNER JOIN users u + ON u.id = cs.user_id + LEFT JOIN user_emails ue + ON ue.id = u.primary_email_id + + WHERE cr.token = $1 + AND cs.deleted_at IS NULL + "#, + token, + ) + .fetch_one(executor) + .instrument(info_span!("Fetch compat refresh token")) + .await?; + + let token = CompatRefreshToken { + data: res.compat_refresh_token_id, + token: res.compat_refresh_token, + created_at: res.compat_refresh_token_created_at, + }; + + let primary_email = match ( + res.user_email_id, + res.user_email, + res.user_email_created_at, + res.user_email_confirmed_at, + ) { + (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { + data: id, + email, + created_at, + confirmed_at, + }), + (None, None, None, None) => None, + _ => return Err(DatabaseInconsistencyError.into()), + }; + + let user = User { + data: res.user_id, + username: res.user_username, + sub: format!("fake-sub-{}", res.user_id), + primary_email, + }; + + let device = Device::try_from(res.compat_session_device_id).unwrap(); + + let session = CompatSession { + data: res.compat_session_id, + user, + device, + created_at: res.compat_session_created_at, + deleted_at: res.compat_session_deleted_at, + }; + + Ok((token, session)) +} + +#[tracing::instrument(skip(conn, password), err)] pub async fn compat_login( conn: impl Acquire<'_, Database = Postgres>, username: &str, password: &str, device: Device, - token: String, -) -> Result< - ( - CompatAccessToken, - CompatSession, - ), - anyhow::Error, -> { +) -> Result, anyhow::Error> { let mut txn = conn.begin().await.context("could not start transaction")?; // First, lookup the user @@ -216,30 +328,97 @@ pub async fn compat_login( deleted_at: None, }; - let res = sqlx::query_as!( - IdAndCreationTime, - r#" + txn.commit().await.context("could not commit transaction")?; + Ok(session) +} + +#[tracing::instrument(skip(executor, token), err)] +pub async fn add_compat_access_token( + executor: impl PgExecutor<'_>, + session: &CompatSession, + token: String, + expires_after: Option, +) -> Result, anyhow::Error> { + if let Some(expires_after) = expires_after { + // For some reason, we need to convert the type first + let pg_expires_after = PgInterval::try_from(expires_after) + // For some reason, this error type does not let me to just bubble up the error here + .map_err(|e| anyhow::anyhow!("failed to encode duration: {}", e))?; + + let res = sqlx::query_as!( + IdAndCreationTime, + r#" + INSERT INTO compat_access_tokens (compat_session_id, token, created_at, expires_at) + VALUES ($1, $2, NOW(), NOW() + $3) + RETURNING id, created_at + "#, + session.data, + token, + pg_expires_after, + ) + .fetch_one(executor) + .instrument(tracing::info_span!("Insert compat access token")) + .await + .context("could not insert compat access token")?; + + Ok(CompatAccessToken { + data: res.id, + token, + created_at: res.created_at, + expires_at: Some(res.created_at + expires_after), + }) + } else { + let res = sqlx::query_as!( + IdAndCreationTime, + r#" INSERT INTO compat_access_tokens (compat_session_id, token) VALUES ($1, $2) RETURNING id, created_at + "#, + session.data, + token, + ) + .fetch_one(executor) + .instrument(tracing::info_span!("Insert compat access token")) + .await + .context("could not insert compat access token")?; + + Ok(CompatAccessToken { + data: res.id, + token, + created_at: res.created_at, + expires_at: None, + }) + } +} + +pub async fn add_compat_refresh_token( + executor: impl PgExecutor<'_>, + session: &CompatSession, + access_token: &CompatAccessToken, + token: String, +) -> Result, anyhow::Error> { + let res = sqlx::query_as!( + IdAndCreationTime, + r#" + INSERT INTO compat_refresh_tokens (compat_session_id, compat_access_token_id, token) + VALUES ($1, $2, $3) + RETURNING id, created_at "#, session.data, + access_token.data, token, ) - .fetch_one(&mut txn) - .instrument(tracing::info_span!("Insert compat access token")) + .fetch_one(executor) + .instrument(tracing::info_span!("Insert compat refresh token")) .await - .context("could not insert compat access token")?; + .context("could not insert compat refresh token")?; - let token = CompatAccessToken { + Ok(CompatRefreshToken { data: res.id, token, created_at: res.created_at, - expires_after: None, - }; - - txn.commit().await.context("could not commit transaction")?; - Ok((token, session)) + }) } #[tracing::instrument(skip_all, err)]