diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index a7b9d027..67c3f609 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -224,7 +224,7 @@ pub enum ClientAuthorizationError { MissingCredentials, InvalidRequest, InvalidAssertion, - InternalError(Box), + Internal(Box), } impl IntoResponse for ClientAuthorizationError { @@ -289,7 +289,7 @@ where return Err(ClientAuthorizationError::BadForm(err)) } // Other errors (body read twice, byte stream broke) return an internal error - Err(e) => return Err(ClientAuthorizationError::InternalError(Box::new(e))), + Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))), }; // And now, figure out the actual auth method diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 2c14fc63..576f0ae7 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -104,7 +104,7 @@ pub enum UserAuthorizationError { InvalidHeader, TokenInFormAndHeader, BadForm(FailedToDeserializeForm), - InternalError(Box), + Internal(Box), } #[derive(Debug, Error)] @@ -119,7 +119,7 @@ pub enum AuthorizationVerificationError { MissingForm, #[error(transparent)] - InternalError(Box), + Internal(Box), } impl From for AuthorizationVerificationError { @@ -127,7 +127,7 @@ impl From for AuthorizationVerificationError { if e.not_found() { Self::InvalidToken } else { - Self::InternalError(Box::new(e)) + Self::Internal(Box::new(e)) } } } @@ -232,9 +232,7 @@ impl IntoResponse for UserAuthorizationError { }); (StatusCode::BAD_REQUEST, headers).into_response() } - Self::InternalError(e) => { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - } + Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), } } } @@ -262,9 +260,7 @@ impl IntoResponse for AuthorizationVerificationError { }); (StatusCode::BAD_REQUEST, headers).into_response() } - Self::InternalError(e) => { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - } + Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), } } } @@ -309,7 +305,7 @@ where return Err(UserAuthorizationError::BadForm(err)) } // Other errors (body read twice, byte stream broke) return an internal error - Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))), + Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))), }; let access_token = match (token_from_header, token_from_form) { diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index d9386cdd..279f401d 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -20,9 +20,8 @@ use mas_storage::{ compat::{ add_compat_access_token, add_compat_refresh_token, compat_login, get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged, - CompatSsoLoginLookupError, }, - Clock, LookupError, + Clock, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; @@ -30,6 +29,7 @@ use sqlx::{PgPool, Postgres, Transaction}; use thiserror::Error; use super::{MatrixError, MatrixHomeserver}; +use crate::impl_from_error_for_route; #[derive(Debug, Serialize)] #[serde(tag = "type")] @@ -145,21 +145,8 @@ pub enum RouteError { InvalidLoginToken, } -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: CompatSsoLoginLookupError) -> Self { - if e.not_found() { - Self::InvalidLoginToken - } else { - Self::Internal(Box::new(e)) - } - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -268,7 +255,9 @@ async fn token_login( clock: &Clock, token: &str, ) -> Result { - let login = get_compat_sso_login_by_token(&mut *txn, token).await?; + let login = get_compat_sso_login_by_token(&mut *txn, token) + .await? + .ok_or(RouteError::InvalidLoginToken)?; let now = clock.now(); match login.state { diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 43602aca..0cbd6fd8 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -15,6 +15,7 @@ use std::collections::HashMap; +use anyhow::Context; use axum::{ extract::{Form, Path, Query, State}, response::{Html, IntoResponse, Redirect, Response}, @@ -92,7 +93,9 @@ pub async fn get( return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut conn, id).await?; + let login = get_compat_sso_login_by_id(&mut conn, id) + .await? + .context("Could not find compat SSO login")?; // Bail out if that login session is more than 30min old if clock.now() > login.created_at + Duration::minutes(30) { @@ -158,7 +161,9 @@ pub async fn post( return Ok((cookie_jar, destination.go()).into_response()); } - let login = get_compat_sso_login_by_id(&mut txn, id).await?; + let login = get_compat_sso_login_by_id(&mut txn, id) + .await? + .context("Could not find compat SSO login")?; // Bail out if that login session is more than 30min old if clock.now() > login.created_at + Duration::minutes(30) { diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs index 43e4be01..e358d595 100644 --- a/crates/handlers/src/compat/refresh.rs +++ b/crates/handlers/src/compat/refresh.rs @@ -16,13 +16,9 @@ use axum::{extract::State, response::IntoResponse, Json}; use chrono::Duration; use hyper::StatusCode; use mas_data_model::{TokenFormatError, TokenType}; -use mas_storage::{ - compat::{ - add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, - expire_compat_access_token, lookup_active_compat_refresh_token, - CompatRefreshTokenLookupError, - }, - LookupError, +use mas_storage::compat::{ + add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, + expire_compat_access_token, lookup_active_compat_refresh_token, }; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DurationMilliSeconds}; @@ -30,6 +26,7 @@ use sqlx::PgPool; use thiserror::Error; use super::MatrixError; +use crate::impl_from_error_for_route; #[derive(Debug, Deserialize)] pub struct RequestBody { @@ -66,11 +63,8 @@ impl IntoResponse for RouteError { } } -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::Internal(Box::new(e)) - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -78,16 +72,6 @@ impl From for RouteError { } } -impl From for RouteError { - fn from(e: CompatRefreshTokenLookupError) -> Self { - if e.not_found() { - Self::InvalidToken - } else { - Self::Internal(Box::new(e)) - } - } -} - #[serde_as] #[derive(Debug, Serialize)] pub struct ResponseBody { @@ -111,7 +95,9 @@ pub(crate) async fn post( } let (refresh_token, access_token, session) = - lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?; + lookup_active_compat_refresh_token(&mut txn, &input.refresh_token) + .await? + .ok_or(RouteError::InvalidToken)?; let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng); let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng); diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 9c00b8f8..05927f48 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -62,7 +62,7 @@ macro_rules! impl_from_error_for_route { ($error:ty) => { impl From<$error> for self::RouteError { fn from(e: $error) -> Self { - Self::InternalError(Box::new(e)) + Self::Internal(Box::new(e)) } } }; diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 2a122ec6..993da628 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -22,10 +22,7 @@ use mas_data_model::{TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_storage::{ - compat::{ - lookup_active_compat_access_token, lookup_active_compat_refresh_token, - CompatAccessTokenLookupError, CompatRefreshTokenLookupError, - }, + compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token}, oauth2::{ access_token::{lookup_active_access_token, AccessTokenLookupError}, client::ClientFetchError, @@ -37,6 +34,8 @@ use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; use sqlx::PgPool; use thiserror::Error; +use crate::impl_from_error_for_route; + #[derive(Debug, Error)] pub enum RouteError { #[error(transparent)] @@ -79,11 +78,8 @@ impl IntoResponse for RouteError { } } -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::Internal(Box::new(e)) - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl From for RouteError { fn from(_e: TokenFormatError) -> Self { @@ -111,16 +107,6 @@ impl From for RouteError { } } -impl From for RouteError { - fn from(e: CompatAccessTokenLookupError) -> Self { - if e.not_found() { - Self::UnknownToken - } else { - Self::Internal(Box::new(e)) - } - } -} - impl From for RouteError { fn from(e: RefreshTokenLookupError) -> Self { if e.not_found() { @@ -131,16 +117,6 @@ impl From for RouteError { } } -impl From for RouteError { - fn from(e: CompatRefreshTokenLookupError) -> Self { - if e.not_found() { - Self::UnknownToken - } else { - Self::Internal(Box::new(e)) - } - } -} - const INACTIVE: IntrospectionResponse = IntrospectionResponse { active: false, scope: None, @@ -232,8 +208,9 @@ pub(crate) async fn post( } } TokenType::CompatAccessToken => { - let (token, session) = - lookup_active_compat_access_token(&mut conn, &clock, token).await?; + let (token, session) = lookup_active_compat_access_token(&mut conn, &clock, token) + .await? + .ok_or(RouteError::UnknownToken)?; let device_scope = session.device.to_scope_token(); let scope = [device_scope].into_iter().collect(); @@ -255,7 +232,9 @@ pub(crate) async fn post( } TokenType::CompatRefreshToken => { let (refresh_token, _access_token, session) = - lookup_active_compat_refresh_token(&mut conn, token).await?; + lookup_active_compat_refresh_token(&mut conn, token) + .await? + .ok_or(RouteError::UnknownToken)?; let device_scope = session.device.to_scope_token(); let scope = [device_scope].into_iter().collect(); diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index ae7dddd1..3f605957 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -36,7 +36,7 @@ pub(crate) enum RouteError { ProviderNotFound, #[error(transparent)] - InternalError(Box), + Internal(Box), #[error(transparent)] Anyhow(#[from] anyhow::Error), @@ -52,9 +52,7 @@ impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(), - Self::InternalError(e) => { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - } + Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), Self::Anyhow(e) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response() } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 7b01945c..2fdfb45d 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -90,7 +90,7 @@ pub(crate) enum RouteError { MissingCookie, #[error(transparent)] - InternalError(Box), + Internal(Box), #[error(transparent)] Anyhow(#[from] anyhow::Error), @@ -111,9 +111,7 @@ impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(), - Self::InternalError(e) => { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - } + Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), Self::Anyhow(e) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response() } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 36c1d078..d82c51fc 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -66,7 +66,7 @@ pub(crate) enum RouteError { InvalidFormAction, #[error(transparent)] - InternalError(Box), + Internal(Box), #[error(transparent)] Anyhow(#[from] anyhow::Error), @@ -85,9 +85,7 @@ impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(), - Self::InternalError(e) => { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - } + Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), Self::Anyhow(e) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response() } diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index ab759885..693fb4b4 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -51,7 +51,9 @@ impl OptionalPostAuthAction { } PostAuthAction::ContinueCompatSsoLogin { data } => { - let login = get_compat_sso_login_by_id(conn, data).await?; + let login = get_compat_sso_login_by_id(conn, data) + .await? + .context("Failed to load compat SSO login")?; let login = Box::new(login); PostAuthContextInner::ContinueCompatSsoLogin { login } } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 83132760..cc8925f2 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1,103 +1,5 @@ { "db": "PostgreSQL", - "0157f14a089d100bdfe245e51082526326b2f84b11da7901ca6c0aaae9e43efd": { - "describe": { - "columns": [ - { - "name": "compat_access_token_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "compat_access_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_access_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_access_token_expires_at", - "ordinal": 3, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 4, - "type_info": "Uuid" - }, - { - "name": "compat_session_created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_finished_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 8, - "type_info": "Uuid" - }, - { - "name": "user_username!", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 10, - "type_info": "Uuid" - }, - { - "name": "user_email?", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 12, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 13, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE ct.access_token = $1 AND cs.finished_at IS NULL\n " - }, "05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": { "describe": { "columns": [ @@ -2169,6 +2071,105 @@ }, "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " }, + "a0ef64e3de97dc2d24efe235c289557018448957a4776197445eafec8b5fb7a9": { + "describe": { + "columns": [ + { + "name": "compat_access_token_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "compat_access_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 4, + "type_info": "Uuid" + }, + { + "name": "compat_session_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_finished_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 8, + "type_info": "Uuid" + }, + { + "name": "user_username!", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "user_email_id?", + "ordinal": 10, + "type_info": "Uuid" + }, + { + "name": "user_email?", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "user_email_created_at?", + "ordinal": 12, + "type_info": "Timestamptz" + }, + { + "name": "user_email_confirmed_at?", + "ordinal": 13, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text", + "Timestamptz" + ] + } + }, + "query": "\n SELECT\n ct.compat_access_token_id,\n ct.access_token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.compat_session_id,\n cs.created_at AS \"compat_session_created_at\",\n cs.finished_at AS \"compat_session_finished_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.user_id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.user_email_id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_access_tokens ct\n INNER JOIN compat_sessions cs\n USING (compat_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_emails ue\n ON ue.user_email_id = u.primary_user_email_id\n\n WHERE ct.access_token = $1\n AND ct.expires_at < $2\n AND cs.finished_at IS NULL \n " + }, "a5a7dad633396e087239d5629092e4a305908ffce9c2610db07372f719070546": { "describe": { "columns": [], diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index d6c31cc8..0e5176a0 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::{bail, Context}; +use anyhow::Context; use argon2::{Argon2, PasswordHash}; use chrono::{DateTime, Duration, Utc}; use mas_data_model::{ @@ -21,7 +21,6 @@ use mas_data_model::{ }; use rand::Rng; use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder}; -use thiserror::Error; use tokio::task; use tracing::{info_span, Instrument}; use ulid::Ulid; @@ -31,7 +30,7 @@ use uuid::Uuid; use crate::{ pagination::{process_page, QueryBuilderExt}, user::lookup_user_by_username, - Clock, DatabaseInconsistencyError, LookupError, + Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt, }; struct CompatAccessTokenLookup { @@ -51,29 +50,12 @@ struct CompatAccessTokenLookup { user_email_confirmed_at: Option>, } -#[derive(Debug, Error)] -#[error("failed to lookup compat access token")] -pub enum CompatAccessTokenLookupError { - Expired { when: DateTime }, - Database(#[from] sqlx::Error), - Inconsistency(#[from] DatabaseInconsistencyError), -} - -impl LookupError for CompatAccessTokenLookupError { - fn not_found(&self) -> bool { - matches!( - self, - Self::Database(sqlx::Error::RowNotFound) | Self::Expired { .. } - ) - } -} - #[tracing::instrument(skip_all, err)] pub async fn lookup_active_compat_access_token( executor: impl PgExecutor<'_>, clock: &Clock, token: &str, -) -> Result<(CompatAccessToken, CompatSession), CompatAccessTokenLookupError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( CompatAccessTokenLookup, r#" @@ -101,20 +83,19 @@ pub async fn lookup_active_compat_access_token( LEFT JOIN user_emails ue ON ue.user_email_id = u.primary_user_email_id - WHERE ct.access_token = $1 AND cs.finished_at IS NULL + WHERE ct.access_token = $1 + AND ct.expires_at < $2 + AND cs.finished_at IS NULL "#, token, + clock.now(), ) .fetch_one(executor) .instrument(info_span!("Fetch compat access token")) - .await?; + .await + .to_option()?; - // Check for token expiration - if let Some(expires_at) = res.compat_access_token_expires_at { - if expires_at < clock.now() { - return Err(CompatAccessTokenLookupError::Expired { when: expires_at }); - } - } + let Some(res) = res else { return Ok(None) }; let token = CompatAccessToken { id: res.compat_access_token_id.into(), @@ -123,6 +104,7 @@ pub async fn lookup_active_compat_access_token( expires_at: res.compat_access_token_expires_at, }; + let user_id = Ulid::from(res.user_id); let primary_email = match ( res.user_email_id, res.user_email, @@ -136,28 +118,38 @@ pub async fn lookup_active_compat_access_token( confirmed_at, }), (None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError.into()), + _ => { + return Err(DatabaseInconsistencyError2::on("compat_sessions") + .column("user_id") + .row(user_id) + .into()) + } }; - let id = Ulid::from(res.user_id); let user = User { - id, + id: user_id, username: res.user_username, - sub: id.to_string(), + sub: user_id.to_string(), primary_email, }; - let device = Device::try_from(res.compat_session_device_id).unwrap(); + let id = res.compat_session_id.into(); + let device = Device::try_from(res.compat_session_device_id).map_err(|e| { + DatabaseInconsistencyError2::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; let session = CompatSession { - id: res.compat_session_id.into(), + id, user, device, created_at: res.compat_session_created_at, finished_at: res.compat_session_finished_at, }; - Ok((token, session)) + Ok(Some((token, session))) } pub struct CompatRefreshTokenLookup { @@ -180,25 +172,12 @@ pub struct CompatRefreshTokenLookup { user_email_confirmed_at: Option>, } -#[derive(Debug, Error)] -#[error("failed to lookup compat refresh token")] -pub enum CompatRefreshTokenLookupError { - Database(#[from] sqlx::Error), - Inconsistency(#[from] DatabaseInconsistencyError), -} - -impl LookupError for CompatRefreshTokenLookupError { - fn not_found(&self) -> bool { - matches!(self, Self::Database(sqlx::Error::RowNotFound)) - } -} - #[tracing::instrument(skip_all, err)] #[allow(clippy::type_complexity)] pub async fn lookup_active_compat_refresh_token( executor: impl PgExecutor<'_>, token: &str, -) -> Result<(CompatRefreshToken, CompatAccessToken, CompatSession), CompatRefreshTokenLookupError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( CompatRefreshTokenLookup, r#" @@ -239,7 +218,10 @@ pub async fn lookup_active_compat_refresh_token( ) .fetch_one(executor) .instrument(info_span!("Fetch compat refresh token")) - .await?; + .await + .to_option()?; + + let Some(res) = res else { return Ok(None); }; let refresh_token = CompatRefreshToken { id: res.compat_refresh_token_id.into(), @@ -254,6 +236,7 @@ pub async fn lookup_active_compat_refresh_token( expires_at: res.compat_access_token_expires_at, }; + let user_id = Ulid::from(res.user_id); let primary_email = match ( res.user_email_id, res.user_email, @@ -267,28 +250,38 @@ pub async fn lookup_active_compat_refresh_token( confirmed_at, }), (None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError.into()), + _ => { + return Err(DatabaseInconsistencyError2::on("users") + .column("primary_user_email_id") + .row(user_id) + .into()) + } }; - let id = Ulid::from(res.user_id); let user = User { - id, + id: user_id, username: res.user_username, - sub: id.to_string(), + sub: user_id.to_string(), primary_email, }; - let device = Device::try_from(res.compat_session_device_id).unwrap(); + let session_id = res.compat_session_id.into(); + let device = Device::try_from(res.compat_session_device_id).map_err(|e| { + DatabaseInconsistencyError2::on("compat_sessions") + .column("device_id") + .row(session_id) + .source(e) + })?; let session = CompatSession { - id: res.compat_session_id.into(), + id: session_id, user, device, created_at: res.compat_session_created_at, finished_at: res.compat_session_finished_at, }; - Ok((refresh_token, access_token, session)) + Ok(Some((refresh_token, access_token, session))) } #[tracing::instrument( @@ -299,7 +292,7 @@ pub async fn lookup_active_compat_refresh_token( compat_session.id, compat_session.device.id = device.as_str(), ), - err(Display), + err(Debug), )] pub async fn compat_login( conn: impl Acquire<'_, Database = Postgres> + Send, @@ -309,6 +302,7 @@ pub async fn compat_login( password: &str, device: Device, ) -> Result { + // TODO: that should be split and not verify the password hash here let mut txn = conn.begin().await.context("could not start transaction")?; // First, lookup the user @@ -381,7 +375,7 @@ pub async fn compat_login( compat_access_token.id, user.id = %session.user.id, ), - err(Display), + err, )] pub async fn add_compat_access_token( executor: impl PgExecutor<'_>, @@ -390,7 +384,7 @@ pub async fn add_compat_access_token( session: &CompatSession, token: String, expires_after: Option, -) -> Result { +) -> Result { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); @@ -411,8 +405,7 @@ pub async fn add_compat_access_token( ) .execute(executor) .instrument(tracing::info_span!("Insert compat access token")) - .await - .context("could not insert compat access token")?; + .await?; Ok(CompatAccessToken { id, @@ -427,13 +420,13 @@ pub async fn add_compat_access_token( fields( compat_access_token.id = %access_token.id, ), - err(Display), + err, )] pub async fn expire_compat_access_token( executor: impl PgExecutor<'_>, clock: &Clock, access_token: CompatAccessToken, -) -> Result<(), anyhow::Error> { +) -> Result<(), DatabaseError> { let expires_at = clock.now(); let res = sqlx::query!( r#" @@ -445,16 +438,9 @@ pub async fn expire_compat_access_token( expires_at, ) .execute(executor) - .await - .context("failed to update compat access token")?; + .await?; - if res.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow::anyhow!( - "no row were affected when updating access token" - )) - } + DatabaseError::ensure_affected_rows(&res, 1) } #[tracing::instrument( @@ -466,7 +452,7 @@ pub async fn expire_compat_access_token( compat_refresh_token.id, user.id = %session.user.id, ), - err(Display), + err, )] pub async fn add_compat_refresh_token( executor: impl PgExecutor<'_>, @@ -475,7 +461,7 @@ pub async fn add_compat_refresh_token( session: &CompatSession, access_token: &CompatAccessToken, token: String, -) -> Result { +) -> Result { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); @@ -495,8 +481,7 @@ pub async fn add_compat_refresh_token( ) .execute(executor) .instrument(tracing::info_span!("Insert compat refresh token")) - .await - .context("could not insert compat refresh token")?; + .await?; Ok(CompatRefreshToken { id, @@ -508,13 +493,13 @@ pub async fn add_compat_refresh_token( #[tracing::instrument( skip_all, fields(compat_session.id), - err(Display), + err, )] pub async fn compat_logout( executor: impl PgExecutor<'_>, clock: &Clock, token: &str, -) -> Result<(), anyhow::Error> { +) -> Result<(), sqlx::Error> { let finished_at = clock.now(); // TODO: this does not check for token expiration let compat_session_id = sqlx::query_scalar!( @@ -531,8 +516,7 @@ pub async fn compat_logout( finished_at, ) .fetch_one(executor) - .await - .context("could not update compat access token")?; + .await?; tracing::Span::current().record( "compat_session.id", @@ -547,13 +531,13 @@ pub async fn compat_logout( fields( compat_refresh_token.id = %refresh_token.id, ), - err(Display), + err, )] pub async fn consume_compat_refresh_token( executor: impl PgExecutor<'_>, clock: &Clock, refresh_token: CompatRefreshToken, -) -> Result<(), anyhow::Error> { +) -> Result<(), DatabaseError> { let consumed_at = clock.now(); let res = sqlx::query!( r#" @@ -565,16 +549,9 @@ pub async fn consume_compat_refresh_token( consumed_at, ) .execute(executor) - .await - .context("failed to update compat refresh token")?; + .await?; - if res.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow::anyhow!( - "no row were affected when updating refresh token" - )) - } + DatabaseError::ensure_affected_rows(&res, 1) } #[tracing::instrument( @@ -583,7 +560,7 @@ pub async fn consume_compat_refresh_token( compat_sso_login.id, compat_sso_login.redirect_uri = %redirect_uri, ), - err(Display), + err, )] pub async fn insert_compat_sso_login( executor: impl PgExecutor<'_>, @@ -591,7 +568,7 @@ pub async fn insert_compat_sso_login( clock: &Clock, login_token: String, redirect_uri: Url, -) -> Result { +) -> Result { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); @@ -609,8 +586,7 @@ pub async fn insert_compat_sso_login( ) .execute(executor) .instrument(tracing::info_span!("Insert compat SSO login")) - .await - .context("could not insert compat SSO login")?; + .await?; Ok(CompatSsoLogin { id, @@ -642,11 +618,16 @@ struct CompatSsoLoginLookup { } impl TryFrom for CompatSsoLogin { - type Error = DatabaseInconsistencyError; + type Error = DatabaseInconsistencyError2; fn try_from(res: CompatSsoLoginLookup) -> Result { - let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri) - .map_err(|_| DatabaseInconsistencyError)?; + let id = res.compat_sso_login_id.into(); + let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| { + DatabaseInconsistencyError2::on("compat_sso_logins") + .column("redirect_uri") + .row(id) + .source(e) + })?; let primary_email = match ( res.user_email_id, @@ -661,7 +642,9 @@ impl TryFrom for CompatSsoLogin { confirmed_at, }), (None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError), + _ => { + return Err(DatabaseInconsistencyError2::on("users").column("primary_user_email_id")) + } }; let user = match (res.user_id, res.user_username, primary_email) { @@ -676,7 +659,7 @@ impl TryFrom for CompatSsoLogin { } (None, None, None) => None, - _ => return Err(DatabaseInconsistencyError), + _ => return Err(DatabaseInconsistencyError2::on("compat_sessions").column("user_id")), }; let session = match ( @@ -687,9 +670,15 @@ impl TryFrom for CompatSsoLogin { user, ) { (Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => { - let device = Device::try_from(device_id).map_err(|_| DatabaseInconsistencyError)?; + let id = id.into(); + let device = Device::try_from(device_id).map_err(|e| { + DatabaseInconsistencyError2::on("compat_sessions") + .column("device") + .row(id) + .source(e) + })?; Some(CompatSession { - id: id.into(), + id, user, device, created_at, @@ -697,7 +686,11 @@ impl TryFrom for CompatSsoLogin { }) } (None, None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError), + _ => { + return Err(DatabaseInconsistencyError2::on("compat_sso_logins") + .column("compat_session_id") + .row(id)) + } }; let state = match ( @@ -717,11 +710,11 @@ impl TryFrom for CompatSsoLogin { session, } } - _ => return Err(DatabaseInconsistencyError), + _ => return Err(DatabaseInconsistencyError2::on("compat_sso_logins").row(id)), }; Ok(CompatSsoLogin { - id: res.compat_sso_login_id.into(), + id, login_token: res.compat_sso_login_token, redirect_uri, created_at: res.compat_sso_login_created_at, @@ -730,19 +723,6 @@ impl TryFrom for CompatSsoLogin { } } -#[derive(Debug, Error)] -#[error("failed to lookup compat SSO login")] -pub enum CompatSsoLoginLookupError { - Database(#[from] sqlx::Error), - Inconsistency(#[from] DatabaseInconsistencyError), -} - -impl LookupError for CompatSsoLoginLookupError { - fn not_found(&self) -> bool { - matches!(self, Self::Database(sqlx::Error::RowNotFound)) - } -} - #[tracing::instrument( skip_all, fields( @@ -753,7 +733,7 @@ impl LookupError for CompatSsoLoginLookupError { pub async fn get_compat_sso_login_by_id( executor: impl PgExecutor<'_>, id: Ulid, -) -> Result { +) -> Result, DatabaseError> { let res = sqlx::query_as!( CompatSsoLoginLookup, r#" @@ -787,9 +767,12 @@ pub async fn get_compat_sso_login_by_id( ) .fetch_one(executor) .instrument(tracing::info_span!("Lookup compat SSO login")) - .await?; + .await + .to_option()?; - Ok(res.try_into()?) + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) } #[tracing::instrument( @@ -798,7 +781,7 @@ pub async fn get_compat_sso_login_by_id( %user.id, %user.username, ), - err(Display), + err, )] pub async fn get_paginated_user_compat_sso_logins( executor: impl PgExecutor<'_>, @@ -807,7 +790,7 @@ pub async fn get_paginated_user_compat_sso_logins( after: Option, first: Option, last: Option, -) -> Result<(bool, bool, Vec), anyhow::Error> { +) -> Result<(bool, bool, Vec), DatabaseError> { // TODO: this queries too much (like user info) which we probably don't need // because we already have them let mut query = QueryBuilder::new( @@ -864,7 +847,7 @@ pub async fn get_paginated_user_compat_sso_logins( pub async fn get_compat_sso_login_by_token( executor: impl PgExecutor<'_>, token: &str, -) -> Result { +) -> Result, DatabaseError> { let res = sqlx::query_as!( CompatSsoLoginLookup, r#" @@ -898,35 +881,38 @@ pub async fn get_compat_sso_login_by_token( ) .fetch_one(executor) .instrument(tracing::info_span!("Lookup compat SSO login")) - .await?; + .await + .to_option()?; - Ok(res.try_into()?) + let Some(res) = res else { return Ok(None) }; + + Ok(Some(res.try_into()?)) } #[tracing::instrument( skip_all, fields( %user.id, - compat_sso_login.id = %login.id, - compat_sso_login.redirect_uri = %login.redirect_uri, + %compat_sso_login.id, + %compat_sso_login.redirect_uri, compat_session.id, compat_session.device.id = device.as_str(), ), - err(Display), + err, )] pub async fn fullfill_compat_sso_login( conn: impl Acquire<'_, Database = Postgres> + Send, mut rng: impl Rng + Send, clock: &Clock, user: User, - mut login: CompatSsoLogin, + mut compat_sso_login: CompatSsoLogin, device: Device, -) -> Result { - if !matches!(login.state, CompatSsoLoginState::Pending) { - bail!("sso login in wrong state"); +) -> Result { + if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) { + return Err(DatabaseError::InvalidOperation); }; - let mut txn = conn.begin().await.context("could not start transaction")?; + let mut txn = conn.begin().await?; let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); @@ -944,8 +930,7 @@ pub async fn fullfill_compat_sso_login( ) .execute(&mut txn) .instrument(tracing::info_span!("Insert compat session")) - .await - .context("could not insert compat session")?; + .await?; let session = CompatSession { id, @@ -965,46 +950,41 @@ pub async fn fullfill_compat_sso_login( WHERE compat_sso_login_id = $1 "#, - Uuid::from(login.id), + Uuid::from(compat_sso_login.id), Uuid::from(session.id), fulfilled_at, ) .execute(&mut txn) .instrument(tracing::info_span!("Update compat SSO login")) - .await - .context("could not update compat SSO login")?; + .await?; let state = CompatSsoLoginState::Fulfilled { fulfilled_at, session, }; - login.state = state; + compat_sso_login.state = state; txn.commit().await?; - Ok(login) + Ok(compat_sso_login) } #[tracing::instrument( skip_all, fields( - compat_sso_login.id = %login.id, - compat_sso_login.redirect_uri = %login.redirect_uri, + %compat_sso_login.id, + %compat_sso_login.redirect_uri, ), - err(Display), + err, )] pub async fn mark_compat_sso_login_as_exchanged( executor: impl PgExecutor<'_>, clock: &Clock, - mut login: CompatSsoLogin, -) -> Result { - let (fulfilled_at, session) = match login.state { - CompatSsoLoginState::Fulfilled { - fulfilled_at, - session, - } => (fulfilled_at, session), - _ => bail!("sso login in wrong state"), + mut compat_sso_login: CompatSsoLogin, +) -> Result { + let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else { + return Err(DatabaseError::InvalidOperation); }; let exchanged_at = clock.now(); @@ -1016,19 +996,18 @@ pub async fn mark_compat_sso_login_as_exchanged( WHERE compat_sso_login_id = $1 "#, - Uuid::from(login.id), + Uuid::from(compat_sso_login.id), exchanged_at, ) .execute(executor) .instrument(tracing::info_span!("Update compat SSO login")) - .await - .context("could not update compat SSO login")?; + .await?; let state = CompatSsoLoginState::Exchanged { fulfilled_at, exchanged_at, session, }; - login.state = state; - Ok(login) + compat_sso_login.state = state; + Ok(compat_sso_login) } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 7905eece..22fe1cb9 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -30,7 +30,7 @@ use chrono::{DateTime, Utc}; use pagination::InvalidPagination; -use sqlx::migrate::Migrator; +use sqlx::{migrate::Migrator, postgres::PgQueryResult}; use thiserror::Error; use ulid::Ulid; @@ -100,6 +100,30 @@ pub enum DatabaseError { /// An error which occured while generating the paginated query Pagination(#[from] InvalidPagination), + + /// An error which happened because the requested database operation is + /// invalid + #[error("Invalid database operation")] + InvalidOperation, + + /// An error which happens when an operation affects not enough or too many + /// rows + #[error("Expected {expected} rows to be affected, but {actual} rows were affected")] + RowsAffected { expected: u64, actual: u64 }, +} + +impl DatabaseError { + pub(crate) fn ensure_affected_rows( + result: &PgQueryResult, + expected: u64, + ) -> Result<(), DatabaseError> { + let actual = result.rows_affected(); + if actual == expected { + Ok(()) + } else { + Err(DatabaseError::RowsAffected { expected, actual }) + } + } } #[derive(Debug, Error)]