diff --git a/Cargo.lock b/Cargo.lock index ffaafd89..73bf436b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2788,6 +2788,7 @@ dependencies = [ name = "mas-graphql" version = "0.1.0" dependencies = [ + "anyhow", "async-graphql", "chrono", "mas-axum-utils", diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index be755d79..f7fe2420 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +anyhow = "1.0.66" async-graphql = { version = "5.0.2", features = ["chrono", "url"] } chrono = "0.4.23" serde = { version = "1.0.149", features = ["derive"] } diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 0d5af1ba..267f1aaa 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -188,9 +188,7 @@ impl RootQuery { let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id) - .await - .to_option()?; + let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?; // Ensure that the link belongs to the current user let link = link.filter(|link| link.user_id == Some(current_user.id)); @@ -208,9 +206,7 @@ impl RootQuery { let database = ctx.data::()?; let mut conn = database.acquire().await?; - let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id) - .await - .to_option()?; + let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?; Ok(provider.map(UpstreamOAuth2Provider::new)) } diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 41790161..87164dd4 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; use sqlx::PgPool; @@ -100,7 +101,9 @@ impl UpstreamOAuth2Link { // Fetch on-the-fly let database = ctx.data::()?; let mut conn = database.acquire().await?; - mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id).await? + mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id) + .await? + .context("Upstream OAuth 2.0 provider not found")? }; Ok(UpstreamOAuth2Provider::new(provider)) diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index eebdccce..ae7dddd1 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -22,7 +22,7 @@ use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::Encrypter; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; -use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt}; +use mas_storage::upstream_oauth2::lookup_provider; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -46,7 +46,7 @@ impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); -impl_from_error_for_route!(mas_storage::upstream_oauth2::ProviderLookupError); +impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -75,8 +75,7 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; let provider = lookup_provider(&mut txn, provider_id) - .await - .to_option()? + .await? .ok_or(RouteError::ProviderNotFound)?; let http_service = http_client_factory diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 18a6a44b..7b01945c 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -96,6 +96,7 @@ pub(crate) enum RouteError { Anyhow(#[from] anyhow::Error), } +impl_from_error_for_route!(mas_storage::DatabaseError); impl_from_error_for_route!(mas_storage::GenericLookupError); impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError); impl_from_error_for_route!(mas_http::ClientInitError); @@ -242,9 +243,7 @@ pub(crate) async fn get( let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; // Look for an existing link - let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject) - .await - .to_option()?; + let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?; let link = if let Some(link) = maybe_link { link diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 329ea6e2..36c1d078 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -79,6 +79,7 @@ impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError); impl_from_error_for_route!(mas_storage::user::UserLookupError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); +impl_from_error_for_route!(mas_storage::DatabaseError); impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { @@ -118,8 +119,7 @@ pub(crate) async fn get( .map_err(|_| RouteError::MissingCookie)?; let link = lookup_link(&mut txn, link_id) - .await - .to_option()? + .await? .ok_or(RouteError::LinkNotFound)?; // This checks that we're in a browser session which is allowed to consume this @@ -221,8 +221,7 @@ pub(crate) async fn post( }; let link = lookup_link(&mut txn, link_id) - .await - .to_option()? + .await? .ok_or(RouteError::LinkNotFound)?; // This checks that we're in a browser session which is allowed to consume this diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index 7165e377..ab759885 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, @@ -58,11 +59,14 @@ impl OptionalPostAuthAction { PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::LinkUpstream { id } => { - let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id).await?; + let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id) + .await? + .context("Failed to load upstream OAuth 2.0 link")?; let provider = mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) - .await?; + .await? + .context("Failed to load upstream OAuth 2.0 provider")?; let provider = Box::new(provider); let link = Box::new(link); diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 6ad9d366..7905eece 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -29,8 +29,10 @@ )] use chrono::{DateTime, Utc}; +use pagination::InvalidPagination; use sqlx::migrate::Migrator; use thiserror::Error; +use ulid::Ulid; #[derive(Debug, Error)] #[error("failed to lookup {what}")] @@ -52,6 +54,12 @@ impl LookupError for GenericLookupError { } } +impl LookupError for sqlx::Error { + fn not_found(&self) -> bool { + matches!(self, sqlx::Error::RowNotFound) + } +} + pub trait LookupError { fn not_found(&self) -> bool; } @@ -80,6 +88,76 @@ where } } +/// Generic error when interacting with the database +#[derive(Debug, Error)] +#[error(transparent)] +pub enum DatabaseError { + /// An error which came from the database itself + Driver(#[from] sqlx::Error), + + /// An error which occured while converting the data from the database + Inconsistency(#[from] DatabaseInconsistencyError2), + + /// An error which occured while generating the paginated query + Pagination(#[from] InvalidPagination), +} + +#[derive(Debug, Error)] +pub struct DatabaseInconsistencyError2 { + table: &'static str, + column: Option<&'static str>, + row: Option, + + #[source] + source: Option>, +} + +impl std::fmt::Display for DatabaseInconsistencyError2 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Database inconsistency on table {}", self.table)?; + if let Some(column) = self.column { + write!(f, " column {column}")?; + } + if let Some(row) = self.row { + write!(f, " row {row}")?; + } + + Ok(()) + } +} + +impl DatabaseInconsistencyError2 { + #[must_use] + pub(crate) const fn on(table: &'static str) -> Self { + Self { + table, + column: None, + row: None, + source: None, + } + } + + #[must_use] + pub(crate) const fn column(mut self, column: &'static str) -> Self { + self.column = Some(column); + self + } + + #[must_use] + pub(crate) const fn row(mut self, row: Ulid) -> Self { + self.row = Some(row); + self + } + + pub(crate) fn source( + mut self, + source: E, + ) -> Self { + self.source = Some(Box::new(source)); + self + } +} + #[derive(Default, Debug, Clone, Copy)] pub struct Clock { _private: (), diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 0c898158..95655675 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -119,7 +119,7 @@ pub trait QueryBuilderExt { after: Option, first: Option, last: Option, - ) -> Result<&mut Self, anyhow::Error>; + ) -> Result<&mut Self, InvalidPagination>; } impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> @@ -135,7 +135,7 @@ where after: Option, first: Option, last: Option, - ) -> Result<&mut Self, anyhow::Error> { + ) -> Result<&mut Self, InvalidPagination> { generate_pagination(self, id_field, before, after, first, last)?; Ok(self) } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 69c0daff..520ed3f0 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -22,7 +22,7 @@ use uuid::Uuid; use crate::{ pagination::{process_page, QueryBuilderExt}, - Clock, GenericLookupError, + Clock, DatabaseError, LookupResultExt, }; #[derive(sqlx::FromRow)] @@ -54,7 +54,7 @@ impl From for UpstreamOAuthLink { pub async fn lookup_link( executor: impl PgExecutor<'_>, id: Ulid, -) -> Result { +) -> Result, DatabaseError> { let res = sqlx::query_as!( LinkLookup, r#" @@ -71,9 +71,10 @@ pub async fn lookup_link( ) .fetch_one(executor) .await - .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; + .to_option()? + .map(Into::into); - Ok(res.into()) + Ok(res) } #[tracing::instrument( @@ -90,7 +91,7 @@ pub async fn lookup_link_by_subject( executor: impl PgExecutor<'_>, upstream_oauth_provider: &UpstreamOAuthProvider, subject: &str, -) -> Result { +) -> Result, DatabaseError> { let res = sqlx::query_as!( LinkLookup, r#" @@ -109,9 +110,10 @@ pub async fn lookup_link_by_subject( ) .fetch_one(executor) .await - .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; + .to_option()? + .map(Into::into); - Ok(res.into()) + Ok(res) } #[tracing::instrument( @@ -131,7 +133,7 @@ pub async fn add_link( clock: &Clock, upstream_oauth_provider: &UpstreamOAuthProvider, subject: String, -) -> Result { +) -> Result { let created_at = clock.now(); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); @@ -205,7 +207,7 @@ pub async fn get_paginated_user_links( after: Option, first: Option, last: Option, -) -> Result<(bool, bool, Vec), anyhow::Error> { +) -> Result<(bool, bool, Vec), DatabaseError> { let mut query = QueryBuilder::new( r#" SELECT diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 20cd34fd..503a9df1 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -21,9 +21,7 @@ pub use self::{ add_link, associate_link_to_user, get_paginated_user_links, lookup_link, lookup_link_by_subject, }, - provider::{ - add_provider, get_paginated_providers, get_providers, lookup_provider, ProviderLookupError, - }, + provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, SessionLookupError, diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 7fcb526e..351fb3f2 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -18,29 +18,15 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use oauth2_types::scope::Scope; use rand::Rng; use sqlx::{PgExecutor, QueryBuilder}; -use thiserror::Error; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ pagination::{process_page, QueryBuilderExt}, - Clock, DatabaseInconsistencyError, LookupError, + Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt, }; -#[derive(Debug, Error)] -#[error("Failed to lookup upstream OAuth 2.0 provider")] -pub enum ProviderLookupError { - Driver(#[from] sqlx::Error), - Inconcistency(#[from] DatabaseInconsistencyError), -} - -impl LookupError for ProviderLookupError { - fn not_found(&self) -> bool { - matches!(self, Self::Driver(sqlx::Error::RowNotFound)) - } -} - #[derive(sqlx::FromRow)] struct ProviderLookup { upstream_oauth_provider_id: Uuid, @@ -54,22 +40,31 @@ struct ProviderLookup { } impl TryFrom for UpstreamOAuthProvider { - type Error = DatabaseInconsistencyError; + type Error = DatabaseInconsistencyError2; fn try_from(value: ProviderLookup) -> Result { let id = value.upstream_oauth_provider_id.into(); - let scope = value - .scope - .parse() - .map_err(|_| DatabaseInconsistencyError)?; - let token_endpoint_auth_method = value - .token_endpoint_auth_method - .parse() - .map_err(|_| DatabaseInconsistencyError)?; + let scope = value.scope.parse().map_err(|e| { + DatabaseInconsistencyError2::on("upstream_oauth_providers") + .column("scope") + .row(id) + .source(e) + })?; + let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| { + DatabaseInconsistencyError2::on("upstream_oauth_providers") + .column("token_endpoint_auth_method") + .row(id) + .source(e) + })?; let token_endpoint_signing_alg = value .token_endpoint_signing_alg .map(|x| x.parse()) .transpose() - .map_err(|_| DatabaseInconsistencyError)?; + .map_err(|e| { + DatabaseInconsistencyError2::on("upstream_oauth_providers") + .column("token_endpoint_signing_alg") + .row(id) + .source(e) + })?; Ok(UpstreamOAuthProvider { id, @@ -92,7 +87,7 @@ impl TryFrom for UpstreamOAuthProvider { pub async fn lookup_provider( executor: impl PgExecutor<'_>, id: Ulid, -) -> Result { +) -> Result, DatabaseError> { let res = sqlx::query_as!( ProviderLookup, r#" @@ -111,9 +106,15 @@ pub async fn lookup_provider( Uuid::from(id), ) .fetch_one(executor) - .await?; + .await + .to_option()?; - Ok(res.try_into()?) + let res = res + .map(UpstreamOAuthProvider::try_from) + .transpose() + .map_err(DatabaseError::from)?; + + Ok(res) } #[tracing::instrument(