diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 67c3f609..00baf4a5 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,7 +31,7 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::oauth2::client::{lookup_client_by_client_id, ClientFetchError}; +use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use sqlx::PgExecutor; @@ -73,7 +73,10 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch(&self, executor: impl PgExecutor<'_>) -> Result { + pub async fn fetch( + &self, + executor: impl PgExecutor<'_>, + ) -> Result, DatabaseError> { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs index 576f0ae7..3161bd89 100644 --- a/crates/axum-utils/src/user_authorization.rs +++ b/crates/axum-utils/src/user_authorization.rs @@ -27,10 +27,7 @@ use axum::{ use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode}; use mas_data_model::Session; -use mas_storage::{ - oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError}, - LookupError, -}; +use mas_storage::{oauth2::access_token::lookup_active_access_token, DatabaseError}; use serde::{de::DeserializeOwned, Deserialize}; use sqlx::PgConnection; use thiserror::Error; @@ -61,7 +58,9 @@ impl AccessToken { AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), }; - let (token, session) = lookup_active_access_token(conn, token.as_str()).await?; + let (token, session) = lookup_active_access_token(conn, token.as_str()) + .await? + .ok_or(AuthorizationVerificationError::InvalidToken)?; Ok((token, session)) } @@ -119,17 +118,7 @@ pub enum AuthorizationVerificationError { MissingForm, #[error(transparent)] - Internal(Box), -} - -impl From for AuthorizationVerificationError { - fn from(e: AccessTokenLookupError) -> Self { - if e.not_found() { - Self::InvalidToken - } else { - Self::Internal(Box::new(e)) - } - } + Internal(#[from] DatabaseError), } enum BearerError { diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 0f033ecd..72356acd 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -24,7 +24,7 @@ use mas_storage::{ lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user, set_password, }, - Clock, LookupError, + Clock, }; use oauth2_types::scope::Scope; use rand::SeedableRng; @@ -259,14 +259,10 @@ impl Options { for client in config.clients.iter() { let client_id = client.client_id; - let res = lookup_client(&mut txn, client_id).await; - match res { - Ok(_) => { - warn!(%client_id, "Skipping already imported client"); - continue; - } - Err(e) if e.not_found() => {} - Err(e) => anyhow::bail!(e), + let res = lookup_client(&mut txn, client_id).await?; + if res.is_some() { + warn!(%client_id, "Skipping already imported client"); + continue; } info!(%client_id, "Importing client"); diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 0a2992d6..e04ae763 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -27,7 +27,6 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_axum_utils::SessionInfo; -use mas_storage::LookupResultExt; use sqlx::PgPool; use self::model::{ @@ -96,9 +95,7 @@ impl RootQuery { let database = ctx.data::()?; let mut conn = database.acquire().await?; - let client = mas_storage::oauth2::client::lookup_client(&mut conn, id) - .await - .to_option()?; + let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?; Ok(client.map(OAuth2Client)) } diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index fd0c5eab..89598ffa 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use mas_storage::oauth2::client::lookup_client; use oauth2_types::scope::Scope; @@ -114,7 +115,9 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { let mut conn = ctx.data::()?.acquire().await?; - let client = lookup_client(&mut conn, self.client_id).await?; + let client = lookup_client(&mut conn, self.client_id) + .await? + .context("Could not load client")?; Ok(OAuth2Client(client)) } } diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index bd2906a9..0cf5ed63 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -26,12 +26,8 @@ use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::{ - oauth2::{ - authorization_grant::new_authorization_grant, - client::{lookup_client_by_client_id, ClientFetchError}, - }, - LookupError, +use mas_storage::oauth2::{ + authorization_grant::new_authorization_grant, client::lookup_client_by_client_id, }; use mas_templates::Templates; use oauth2_types::{ @@ -46,6 +42,7 @@ use sqlx::PgPool; use thiserror::Error; use self::{callback::CallbackDestination, complete::GrantCompletionError}; +use crate::impl_from_error_for_route; mod callback; pub mod complete; @@ -56,7 +53,7 @@ pub enum RouteError { Internal(Box), #[error(transparent)] - Anyhow(anyhow::Error), + Anyhow(#[from] anyhow::Error), #[error("could not find client")] ClientNotFound, @@ -93,33 +90,9 @@ impl IntoResponse for RouteError { } } -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: self::callback::CallbackDestinationError) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: ClientFetchError) -> Self { - if e.not_found() { - Self::ClientNotFound - } else { - Self::Internal(Box::new(e)) - } - } -} - -impl From for RouteError { - fn from(e: anyhow::Error) -> Self { - Self::Anyhow(e) - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(self::callback::CallbackDestinationError); #[derive(Deserialize)] pub(crate) struct Params { @@ -166,7 +139,9 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; // First, figure out what client it is - let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id).await?; + let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id) + .await? + .ok_or(RouteError::ClientNotFound)?; // And resolve the redirect_uri and response_mode let redirect_uri = client diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 993da628..09b0e61b 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -24,11 +24,9 @@ use mas_keystore::Encrypter; use mas_storage::{ compat::{lookup_active_compat_access_token, lookup_active_compat_refresh_token}, oauth2::{ - access_token::{lookup_active_access_token, AccessTokenLookupError}, - client::ClientFetchError, - refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError}, + access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, }, - Clock, LookupError, + Clock, }; use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; use sqlx::PgPool; @@ -87,36 +85,6 @@ impl From for RouteError { } } -impl From for RouteError { - fn from(e: ClientFetchError) -> Self { - if e.not_found() { - Self::ClientNotFound - } else { - Self::Internal(Box::new(e)) - } - } -} - -impl From for RouteError { - fn from(e: AccessTokenLookupError) -> Self { - if e.not_found() { - Self::UnknownToken - } else { - Self::Internal(Box::new(e)) - } - } -} - -impl From for RouteError { - fn from(e: RefreshTokenLookupError) -> Self { - if e.not_found() { - Self::UnknownToken - } else { - Self::Internal(Box::new(e)) - } - } -} - const INACTIVE: IntrospectionResponse = IntrospectionResponse { active: false, scope: None, @@ -142,7 +110,11 @@ pub(crate) async fn post( let clock = Clock::default(); let mut conn = pool.acquire().await?; - let client = client_authorization.credentials.fetch(&mut conn).await?; + let client = client_authorization + .credentials + .fetch(&mut conn) + .await? + .ok_or(RouteError::ClientNotFound)?; let method = match &client.token_endpoint_auth_method { None | Some(OAuthClientAuthenticationMethod::None) => { @@ -172,7 +144,9 @@ pub(crate) async fn post( let reply = match token_type { TokenType::AccessToken => { - let (token, session) = lookup_active_access_token(&mut conn, token).await?; + let (token, session) = lookup_active_access_token(&mut conn, token) + .await? + .ok_or(RouteError::UnknownToken)?; IntrospectionResponse { active: true, @@ -190,7 +164,9 @@ pub(crate) async fn post( } } TokenType::RefreshToken => { - let (token, session) = lookup_active_refresh_token(&mut conn, token).await?; + let (token, session) = lookup_active_refresh_token(&mut conn, token) + .await? + .ok_or(RouteError::UnknownToken)?; IntrospectionResponse { active: true, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index c58d53b2..b9f4c2c9 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -26,9 +26,9 @@ use mas_axum_utils::{ use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ - claims::{self, hash_token, ClaimError, TokenHashError}, + claims::{self, hash_token}, constraints::Constrainable, - jwt::{JsonWebSignatureHeader, Jwt, JwtSignatureError}, + jwt::{JsonWebSignatureHeader, Jwt}, }; use mas_keystore::{Encrypter, Keystore}; use mas_router::UrlBuilder; @@ -36,14 +36,10 @@ use mas_storage::{ oauth2::{ access_token::{add_access_token, revoke_access_token}, authorization_grant::{exchange_grant, lookup_grant_by_code}, - client::ClientFetchError, end_oauth_session, - refresh_token::{ - add_refresh_token, consume_refresh_token, lookup_active_refresh_token, - RefreshTokenLookupError, - }, + refresh_token::{add_refresh_token, consume_refresh_token, lookup_active_refresh_token}, }, - DatabaseInconsistencyError, LookupError, + DatabaseInconsistencyError, }; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -60,6 +56,8 @@ use thiserror::Error; use tracing::debug; use url::Url; +use crate::impl_from_error_for_route; + #[serde_as] #[skip_serializing_none] #[derive(Serialize, Debug)] @@ -107,26 +105,6 @@ pub(crate) enum RouteError { UnauthorizedClient, } -impl From for RouteError { - fn from(e: ClientFetchError) -> Self { - if e.not_found() { - Self::ClientNotFound - } else { - Self::Internal(Box::new(e)) - } - } -} - -impl From for RouteError { - fn from(e: RefreshTokenLookupError) -> Self { - if e.not_found() { - Self::InvalidGrant - } else { - Self::Internal(Box::new(e)) - } - } -} - impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { @@ -162,35 +140,12 @@ impl IntoResponse for RouteError { } } -impl From for RouteError { - fn from(e: mas_keystore::WrongAlgorithmError) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: ClaimError) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: TokenHashError) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl From for RouteError { - fn from(e: JwtSignatureError) -> Self { - Self::Internal(Box::new(e)) - } -} +impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); +impl_from_error_for_route!(mas_keystore::WrongAlgorithmError); +impl_from_error_for_route!(mas_jose::claims::ClaimError); +impl_from_error_for_route!(mas_jose::claims::TokenHashError); +impl_from_error_for_route!(mas_jose::jwt::JwtSignatureError); #[tracing::instrument(skip_all, err)] pub(crate) async fn post( @@ -203,7 +158,11 @@ pub(crate) async fn post( ) -> Result { let mut txn = pool.begin().await?; - let client = client_authorization.credentials.fetch(&mut txn).await?; + let client = client_authorization + .credentials + .fetch(&mut txn) + .await? + .ok_or(RouteError::ClientNotFound)?; let method = client .token_endpoint_auth_method @@ -396,8 +355,9 @@ async fn refresh_token_grant( ) -> Result { let (clock, mut rng) = crate::rng_and_clock()?; - let (refresh_token, session) = - lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?; + let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token) + .await? + .ok_or(RouteError::InvalidGrant)?; if client.client_id != session.client.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 2fdfb45d..848055f5 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -25,9 +25,8 @@ use mas_oidc_client::requests::{ authorization_code::AuthorizationValidationData, jose::JwtVerificationData, }; use mas_router::{Route, UrlBuilder}; -use mas_storage::{ - upstream_oauth2::{add_link, complete_session, lookup_link_by_subject, lookup_session}, - LookupResultExt, +use mas_storage::upstream_oauth2::{ + add_link, complete_session, lookup_link_by_subject, lookup_session, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; @@ -97,8 +96,6 @@ pub(crate) enum RouteError { } impl_from_error_for_route!(mas_storage::DatabaseError); -impl_from_error_for_route!(mas_storage::GenericLookupError); -impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); @@ -141,8 +138,7 @@ pub(crate) async fn get( .map_err(|_| RouteError::MissingCookie)?; let (provider, session) = lookup_session(&mut txn, session_id) - .await - .to_option()? + .await? .ok_or(RouteError::SessionNotFound)?; if provider.id != provider_id { diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index cf147596..dccc7635 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -17,12 +17,11 @@ use chrono::{DateTime, Duration, Utc}; use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail}; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; -use thiserror::Error; use ulid::Ulid; use uuid::Uuid; -use super::client::{lookup_client, ClientFetchError}; -use crate::{Clock, DatabaseInconsistencyError, LookupError}; +use super::client::lookup_client; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError2}; #[tracing::instrument( skip_all, @@ -95,25 +94,11 @@ pub struct OAuth2AccessTokenLookup { user_email_confirmed_at: Option>, } -#[derive(Debug, Error)] -#[error("failed to lookup access token")] -pub enum AccessTokenLookupError { - Database(#[from] sqlx::Error), - ClientFetch(#[from] ClientFetchError), - Inconsistency(#[from] DatabaseInconsistencyError), -} - -impl LookupError for AccessTokenLookupError { - fn not_found(&self) -> bool { - matches!(self, Self::Database(sqlx::Error::RowNotFound)) - } -} - #[allow(clippy::too_many_lines)] pub async fn lookup_active_access_token( conn: &mut PgConnection, token: &str, -) -> Result<(AccessToken, Session), AccessTokenLookupError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2AccessTokenLookup, r#" @@ -160,17 +145,25 @@ pub async fn lookup_active_access_token( .fetch_one(&mut *conn) .await?; - let id = Ulid::from(res.oauth2_access_token_id); + let access_token_id = Ulid::from(res.oauth2_access_token_id); let access_token = AccessToken { - id, - jti: id.to_string(), + id: access_token_id, + jti: access_token_id.to_string(), access_token: res.oauth2_access_token, created_at: res.oauth2_access_token_created_at, expires_at: res.oauth2_access_token_expires_at, }; - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?; + let session_id = res.oauth2_session_id.into(); + let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) + .await? + .ok_or_else(|| { + DatabaseInconsistencyError2::on("oauth2_sessions") + .column("client_id") + .row(session_id) + })?; + let user_id = Ulid::from(res.user_id); let primary_email = match ( res.user_email_id, res.user_email, @@ -184,14 +177,18 @@ pub async fn lookup_active_access_token( confirmed_at, }), (None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError.into()), + _ => { + return Err(DatabaseInconsistencyError2::on("users") + .column("primary_user_email_id") + .row(user_id) + .into()) + } }; - let id = Ulid::from(res.user_id); let user = User { - id, + id: user_id, username: res.user_username, - sub: id.to_string(), + sub: user_id.to_string(), primary_email, }; @@ -204,7 +201,7 @@ pub async fn lookup_active_access_token( id: id.into(), created_at, }), - _ => return Err(DatabaseInconsistencyError.into()), + _ => return Err(DatabaseInconsistencyError2::on("user_session_authentications").into()), }; let browser_session = BrowserSession { @@ -214,28 +211,33 @@ pub async fn lookup_active_access_token( last_authentication, }; - let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; + let scope = res.scope.parse().map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_sessions") + .column("scope") + .row(session_id) + .source(e) + })?; let session = Session { - id: res.oauth2_session_id.into(), + id: session_id, client, browser_session, scope, }; - Ok((access_token, session)) + Ok(Some((access_token, session))) } #[tracing::instrument( skip_all, fields(%access_token.id), - err(Debug), + err, )] pub async fn revoke_access_token( executor: impl PgExecutor<'_>, clock: &Clock, access_token: AccessToken, -) -> anyhow::Result<()> { +) -> Result<(), DatabaseError> { let revoked_at = clock.now(); let res = sqlx::query!( r#" @@ -247,17 +249,15 @@ pub async fn revoke_access_token( revoked_at, ) .execute(executor) - .await - .context("could not revoke access tokens")?; + .await?; - if res.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow::anyhow!("no row were affected when revoking token")) - } + DatabaseError::ensure_affected_rows(&res, 1) } -pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result { +pub async fn cleanup_expired( + executor: impl PgExecutor<'_>, + clock: &Clock, +) -> Result { // Cleanup token which expired more than 15 minutes ago let threshold = clock.now() - Duration::minutes(15); let res = sqlx::query!( @@ -268,8 +268,7 @@ pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> an threshold, ) .execute(executor) - .await - .context("could not cleanup expired access tokens")?; + .await?; Ok(res.rows_affected()) } diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 39476250..5917f0d8 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -180,6 +180,7 @@ impl GrantLookup { // TODO: don't unwrap let client = lookup_client(executor, self.oauth2_client_id.into()) .await + .unwrap() .unwrap(); let last_authentication = match ( diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 65f3b704..a1619bf2 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -23,12 +23,11 @@ use mas_jose::jwk::PublicJsonWebKeySet; use oauth2_types::requests::GrantType; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; -use thiserror::Error; use ulid::Ulid; use url::Url; use uuid::Uuid; -use crate::{Clock, LookupError}; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt}; // XXX: response_types & contacts #[derive(Debug)] @@ -54,52 +53,20 @@ pub struct OAuth2ClientLookup { initiate_login_uri: Option, } -#[derive(Debug, Error)] -pub enum ClientFetchError { - #[error("invalid client ID")] - InvalidClientId(#[from] ulid::DecodeError), - - #[error("malformed jwks column")] - MalformedJwks(#[source] serde_json::Error), - - #[error("entry has both a jwks and a jwks_uri")] - BothJwksAndJwksUri, - - #[error("could not parse URL in field {field:?}")] - ParseUrl { - field: &'static str, - source: url::ParseError, - }, - - #[error("could not parse field {field:?}")] - ParseField { - field: &'static str, - source: mas_iana::ParseError, - }, - - #[error(transparent)] - Database(#[from] sqlx::Error), -} - -impl LookupError for ClientFetchError { - fn not_found(&self) -> bool { - matches!( - self, - Self::Database(sqlx::Error::RowNotFound) | Self::InvalidClientId(_) - ) - } -} - impl TryInto for OAuth2ClientLookup { - type Error = ClientFetchError; + type Error = DatabaseInconsistencyError2; #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing fn try_into(self) -> Result { + let id = Ulid::from(self.oauth2_client_id); + let redirect_uris: Result, _> = self.redirect_uris.iter().map(|s| s.parse()).collect(); - let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl { - field: "redirect_uris", - source, + let redirect_uris = redirect_uris.map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("redirect_uris") + .row(id) + .source(e) })?; let response_types = vec![ @@ -124,107 +91,125 @@ impl TryInto for OAuth2ClientLookup { grant_types.push(GrantType::RefreshToken); } - let logo_uri = self - .logo_uri - .map(|s| s.parse()) - .transpose() - .map_err(|source| ClientFetchError::ParseUrl { - field: "logo_uri", - source, - })?; + let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("logo_uri") + .row(id) + .source(e) + })?; let client_uri = self .client_uri .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseUrl { - field: "client_uri", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("client_uri") + .row(id) + .source(e) })?; let policy_uri = self .policy_uri .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseUrl { - field: "policy_uri", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("policy_uri") + .row(id) + .source(e) })?; - let tos_uri = self - .tos_uri - .map(|s| s.parse()) - .transpose() - .map_err(|source| ClientFetchError::ParseUrl { - field: "tos_uri", - source, - })?; + let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("tos_uri") + .row(id) + .source(e) + })?; let id_token_signed_response_alg = self .id_token_signed_response_alg .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseField { - field: "id_token_signed_response_alg", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("id_token_signed_response_alg") + .row(id) + .source(e) })?; let userinfo_signed_response_alg = self .userinfo_signed_response_alg .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseField { - field: "userinfo_signed_response_alg", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("userinfo_signed_response_alg") + .row(id) + .source(e) })?; let token_endpoint_auth_method = self .token_endpoint_auth_method .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseField { - field: "token_endpoint_auth_method", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("token_endpoint_auth_method") + .row(id) + .source(e) })?; let token_endpoint_auth_signing_alg = self .token_endpoint_auth_signing_alg .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseField { - field: "token_endpoint_auth_signing_alg", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("token_endpoint_auth_signing_alg") + .row(id) + .source(e) })?; let initiate_login_uri = self .initiate_login_uri .map(|s| s.parse()) .transpose() - .map_err(|source| ClientFetchError::ParseUrl { - field: "initiate_login_uri", - source, + .map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("initiate_login_uri") + .row(id) + .source(e) })?; let jwks = match (self.jwks, self.jwks_uri) { (None, None) => None, (Some(jwks), None) => { - let jwks = serde_json::from_value(jwks).map_err(ClientFetchError::MalformedJwks)?; + let jwks = serde_json::from_value(jwks).map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("jwks") + .row(id) + .source(e) + })?; Some(JwksOrJwksUri::Jwks(jwks)) } (None, Some(jwks_uri)) => { - let jwks_uri = jwks_uri - .parse() - .map_err(|source| ClientFetchError::ParseUrl { - field: "jwks_uri", - source, - })?; + let jwks_uri = jwks_uri.parse().map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_clients") + .column("jwks_uri") + .row(id) + .source(e) + })?; Some(JwksOrJwksUri::JwksUri(jwks_uri)) } - _ => return Err(ClientFetchError::BothJwksAndJwksUri), + _ => { + return Err(DatabaseInconsistencyError2::on("oauth2_clients") + .column("jwks(_uri)") + .row(id)) + } }; - let id = Ulid::from(self.oauth2_client_id); Ok(Client { id, client_id: id.to_string(), @@ -253,7 +238,7 @@ impl TryInto for OAuth2ClientLookup { pub async fn lookup_clients( executor: impl PgExecutor<'_>, ids: impl IntoIterator + Send, -) -> Result, ClientFetchError> { +) -> Result, DatabaseError> { let ids: Vec = ids.into_iter().map(Uuid::from).collect(); let res = sqlx::query_as!( OAuth2ClientLookup, @@ -289,12 +274,13 @@ pub async fn lookup_clients( .fetch_all(executor) .await?; - let clients: Result, _> = res - .into_iter() - .map(|r| r.try_into().map(|c: Client| (c.id, c))) - .collect(); - - clients + res.into_iter() + .map(|r| { + r.try_into() + .map(|c: Client| (c.id, c)) + .map_err(DatabaseError::from) + }) + .collect() } #[tracing::instrument( @@ -305,7 +291,7 @@ pub async fn lookup_clients( pub async fn lookup_client( executor: impl PgExecutor<'_>, id: Ulid, -) -> Result { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2ClientLookup, r#" @@ -338,11 +324,12 @@ pub async fn lookup_client( Uuid::from(id), ) .fetch_one(executor) - .await?; + .await + .to_option()?; - let client = res.try_into()?; + let Some(res) = res else { return Ok(None) }; - Ok(client) + Ok(Some(res.try_into()?)) } #[tracing::instrument( @@ -353,8 +340,8 @@ pub async fn lookup_client( pub async fn lookup_client_by_client_id( executor: impl PgExecutor<'_>, client_id: &str, -) -> Result { - let id: Ulid = client_id.parse()?; +) -> Result, DatabaseError> { + let Ok(id) = client_id.parse() else { return Ok(None) }; lookup_client(executor, id).await } diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 60f41d14..6a97c2f6 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -19,12 +19,11 @@ use mas_data_model::{ }; use rand::Rng; use sqlx::{PgConnection, PgExecutor}; -use thiserror::Error; use ulid::Ulid; use uuid::Uuid; -use super::client::{lookup_client, ClientFetchError}; -use crate::{Clock, DatabaseInconsistencyError, LookupError}; +use super::client::lookup_client; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError2}; #[tracing::instrument( skip_all, @@ -98,26 +97,12 @@ struct OAuth2RefreshTokenLookup { user_email_confirmed_at: Option>, } -#[derive(Error, Debug)] -#[error("could not lookup refresh token")] -pub enum RefreshTokenLookupError { - Fetch(#[from] sqlx::Error), - ClientFetch(#[from] ClientFetchError), - Conversion(#[from] DatabaseInconsistencyError), -} - -impl LookupError for RefreshTokenLookupError { - fn not_found(&self) -> bool { - matches!(self, Self::Fetch(sqlx::Error::RowNotFound)) - } -} - #[tracing::instrument(skip_all, err)] #[allow(clippy::too_many_lines)] pub async fn lookup_active_refresh_token( conn: &mut PgConnection, token: &str, -) -> Result<(RefreshToken, Session), RefreshTokenLookupError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( OAuth2RefreshTokenLookup, r#" @@ -187,7 +172,7 @@ pub async fn lookup_active_refresh_token( expires_at, }) } - _ => return Err(DatabaseInconsistencyError.into()), + _ => return Err(DatabaseInconsistencyError2::on("oauth2_access_tokens").into()), }; let refresh_token = RefreshToken { @@ -197,8 +182,16 @@ pub async fn lookup_active_refresh_token( access_token, }; - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?; + let session_id = res.oauth2_session_id.into(); + let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) + .await? + .ok_or_else(|| { + DatabaseInconsistencyError2::on("oauth2_sessions") + .column("client_id") + .row(session_id) + })?; + let user_id = Ulid::from(res.user_id); let primary_email = match ( res.user_email_id, res.user_email, @@ -212,14 +205,18 @@ pub async fn lookup_active_refresh_token( confirmed_at, }), (None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError.into()), + _ => { + return Err(DatabaseInconsistencyError2::on("users") + .column("primary_user_email_id") + .row(user_id) + .into()) + } }; - let id = Ulid::from(res.user_id); let user = User { - id, + id: user_id, username: res.user_username, - sub: id.to_string(), + sub: user_id.to_string(), primary_email, }; @@ -232,7 +229,7 @@ pub async fn lookup_active_refresh_token( id: id.into(), created_at, }), - _ => return Err(DatabaseInconsistencyError.into()), + _ => return Err(DatabaseInconsistencyError2::on("user_session_authentications").into()), }; let browser_session = BrowserSession { @@ -242,19 +239,21 @@ pub async fn lookup_active_refresh_token( last_authentication, }; - let scope = res - .oauth2_session_scope - .parse() - .map_err(|_e| DatabaseInconsistencyError)?; + let scope = res.oauth2_session_scope.parse().map_err(|e| { + DatabaseInconsistencyError2::on("oauth2_sessions") + .column("scope") + .row(session_id) + .source(e) + })?; let session = Session { - id: res.oauth2_session_id.into(), + id: session_id, client, browser_session, scope, }; - Ok((refresh_token, session)) + Ok(Some((refresh_token, session))) } #[tracing::instrument( @@ -268,7 +267,7 @@ pub async fn consume_refresh_token( executor: impl PgExecutor<'_>, clock: &Clock, refresh_token: &RefreshToken, -) -> Result<(), anyhow::Error> { +) -> Result<(), DatabaseError> { let consumed_at = clock.now(); let res = sqlx::query!( r#" @@ -280,14 +279,7 @@ pub async fn consume_refresh_token( consumed_at, ) .execute(executor) - .await - .context("failed to update oauth2 refresh token")?; + .await?; - if res.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow::anyhow!( - "no row were affected when updating refresh token" - )) - } + DatabaseError::ensure_affected_rows(&res, 1) } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 503a9df1..4b1d517a 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -24,6 +24,5 @@ pub use self::{ provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, - SessionLookupError, }, }; diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 651f5a3a..a4a0d8a7 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -16,24 +16,12 @@ use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider}; use rand::Rng; use sqlx::PgExecutor; -use thiserror::Error; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseInconsistencyError, GenericLookupError, LookupError}; - -#[derive(Debug, Error)] -#[error("Failed to lookup upstream OAuth 2.0 authorization session")] -pub enum SessionLookupError { - Driver(#[from] sqlx::Error), - Inconcistency(#[from] DatabaseInconsistencyError), -} - -impl LookupError for SessionLookupError { - fn not_found(&self) -> bool { - matches!(self, Self::Driver(sqlx::Error::RowNotFound)) - } -} +use crate::{ + Clock, DatabaseError, DatabaseInconsistencyError2, GenericLookupError, LookupResultExt, +}; struct SessionAndProviderLookup { upstream_oauth_authorization_session_id: Uuid, @@ -64,7 +52,7 @@ struct SessionAndProviderLookup { pub async fn lookup_session( executor: impl PgExecutor<'_>, id: Ulid, -) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> { +) -> Result, DatabaseError> { let res = sqlx::query_as!( SessionAndProviderLookup, r#" @@ -94,29 +82,41 @@ pub async fn lookup_session( Uuid::from(id), ) .fetch_one(executor) - .await?; + .await + .to_option()?; + let Some(res) = res else { return Ok(None) }; + + let id = res.upstream_oauth_provider_id.into(); let provider = UpstreamOAuthProvider { - id: res.upstream_oauth_provider_id.into(), - issuer: res - .provider_issuer - .parse() - .map_err(|_| DatabaseInconsistencyError)?, - scope: res - .provider_scope - .parse() - .map_err(|_| DatabaseInconsistencyError)?, + id, + issuer: res.provider_issuer, + scope: res.provider_scope.parse().map_err(|e| { + DatabaseInconsistencyError2::on("upstream_oauth_providers") + .column("scope") + .row(id) + .source(e) + })?, client_id: res.provider_client_id, encrypted_client_secret: res.provider_encrypted_client_secret, - token_endpoint_auth_method: res - .provider_token_endpoint_auth_method - .parse() - .map_err(|_| DatabaseInconsistencyError)?, + token_endpoint_auth_method: res.provider_token_endpoint_auth_method.parse().map_err( + |e| { + DatabaseInconsistencyError2::on("upstream_oauth_providers") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + }, + )?, token_endpoint_signing_alg: res .provider_token_endpoint_signing_alg .map(|x| x.parse()) .transpose() - .map_err(|_| DatabaseInconsistencyError)?, + .map_err(|e| { + DatabaseInconsistencyError2::on("upstream_oauth_providers") + .column("token_endpoint_signing_alg") + .row(id) + .source(e) + })?, created_at: res.provider_created_at, }; @@ -133,7 +133,7 @@ pub async fn lookup_session( consumed_at: res.consumed_at, }; - Ok((provider, session)) + Ok(Some((provider, session))) } /// Add a session to the database