1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: start unifying database errors

This commit is contained in:
Quentin Gliech
2022-12-07 16:04:46 +01:00
parent 12ce2a3d04
commit 1ddc05ff01
13 changed files with 143 additions and 62 deletions

1
Cargo.lock generated
View File

@ -2788,6 +2788,7 @@ dependencies = [
name = "mas-graphql" name = "mas-graphql"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"async-graphql", "async-graphql",
"chrono", "chrono",
"mas-axum-utils", "mas-axum-utils",

View File

@ -6,6 +6,7 @@ edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
[dependencies] [dependencies]
anyhow = "1.0.66"
async-graphql = { version = "5.0.2", features = ["chrono", "url"] } async-graphql = { version = "5.0.2", features = ["chrono", "url"] }
chrono = "0.4.23" chrono = "0.4.23"
serde = { version = "1.0.149", features = ["derive"] } serde = { version = "1.0.149", features = ["derive"] }

View File

@ -188,9 +188,7 @@ impl RootQuery {
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id) let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?;
.await
.to_option()?;
// Ensure that the link belongs to the current user // Ensure that the link belongs to the current user
let link = link.filter(|link| link.user_id == Some(current_user.id)); let link = link.filter(|link| link.user_id == Some(current_user.id));
@ -208,9 +206,7 @@ impl RootQuery {
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id) let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?;
.await
.to_option()?;
Ok(provider.map(UpstreamOAuth2Provider::new)) Ok(provider.map(UpstreamOAuth2Provider::new))
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anyhow::Context as _;
use async_graphql::{Context, Object, ID}; use async_graphql::{Context, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use sqlx::PgPool; use sqlx::PgPool;
@ -100,7 +101,9 @@ impl UpstreamOAuth2Link {
// Fetch on-the-fly // Fetch on-the-fly
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id).await? mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id)
.await?
.context("Upstream OAuth 2.0 provider not found")?
}; };
Ok(UpstreamOAuth2Provider::new(provider)) Ok(UpstreamOAuth2Provider::new(provider))

View File

@ -22,7 +22,7 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt}; use mas_storage::upstream_oauth2::lookup_provider;
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
@ -46,7 +46,7 @@ impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_http::ClientInitError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError); impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
impl_from_error_for_route!(mas_storage::upstream_oauth2::ProviderLookupError); impl_from_error_for_route!(mas_storage::DatabaseError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -75,8 +75,7 @@ pub(crate) async fn get(
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let provider = lookup_provider(&mut txn, provider_id) let provider = lookup_provider(&mut txn, provider_id)
.await .await?
.to_option()?
.ok_or(RouteError::ProviderNotFound)?; .ok_or(RouteError::ProviderNotFound)?;
let http_service = http_client_factory let http_service = http_client_factory

View File

@ -96,6 +96,7 @@ pub(crate) enum RouteError {
Anyhow(#[from] anyhow::Error), Anyhow(#[from] anyhow::Error),
} }
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(mas_storage::GenericLookupError); impl_from_error_for_route!(mas_storage::GenericLookupError);
impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError); impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError);
impl_from_error_for_route!(mas_http::ClientInitError); impl_from_error_for_route!(mas_http::ClientInitError);
@ -242,9 +243,7 @@ pub(crate) async fn get(
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
// Look for an existing link // Look for an existing link
let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject) let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?;
.await
.to_option()?;
let link = if let Some(link) = maybe_link { let link = if let Some(link) = maybe_link {
link link

View File

@ -79,6 +79,7 @@ impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError);
impl_from_error_for_route!(mas_storage::user::UserLookupError); impl_from_error_for_route!(mas_storage::user::UserLookupError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
@ -118,8 +119,7 @@ pub(crate) async fn get(
.map_err(|_| RouteError::MissingCookie)?; .map_err(|_| RouteError::MissingCookie)?;
let link = lookup_link(&mut txn, link_id) let link = lookup_link(&mut txn, link_id)
.await .await?
.to_option()?
.ok_or(RouteError::LinkNotFound)?; .ok_or(RouteError::LinkNotFound)?;
// This checks that we're in a browser session which is allowed to consume this // This checks that we're in a browser session which is allowed to consume this
@ -221,8 +221,7 @@ pub(crate) async fn post(
}; };
let link = lookup_link(&mut txn, link_id) let link = lookup_link(&mut txn, link_id)
.await .await?
.to_option()?
.ok_or(RouteError::LinkNotFound)?; .ok_or(RouteError::LinkNotFound)?;
// This checks that we're in a browser session which is allowed to consume this // This checks that we're in a browser session which is allowed to consume this

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anyhow::Context;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
@ -58,11 +59,14 @@ impl OptionalPostAuthAction {
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
PostAuthAction::LinkUpstream { id } => { PostAuthAction::LinkUpstream { id } => {
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id).await?; let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id)
.await?
.context("Failed to load upstream OAuth 2.0 link")?;
let provider = let provider =
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
.await?; .await?
.context("Failed to load upstream OAuth 2.0 provider")?;
let provider = Box::new(provider); let provider = Box::new(provider);
let link = Box::new(link); let link = Box::new(link);

View File

@ -29,8 +29,10 @@
)] )]
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use pagination::InvalidPagination;
use sqlx::migrate::Migrator; use sqlx::migrate::Migrator;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid;
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("failed to lookup {what}")] #[error("failed to lookup {what}")]
@ -52,6 +54,12 @@ impl LookupError for GenericLookupError {
} }
} }
impl LookupError for sqlx::Error {
fn not_found(&self) -> bool {
matches!(self, sqlx::Error::RowNotFound)
}
}
pub trait LookupError { pub trait LookupError {
fn not_found(&self) -> bool; fn not_found(&self) -> bool;
} }
@ -80,6 +88,76 @@ where
} }
} }
/// Generic error when interacting with the database
#[derive(Debug, Error)]
#[error(transparent)]
pub enum DatabaseError {
/// An error which came from the database itself
Driver(#[from] sqlx::Error),
/// An error which occured while converting the data from the database
Inconsistency(#[from] DatabaseInconsistencyError2),
/// An error which occured while generating the paginated query
Pagination(#[from] InvalidPagination),
}
#[derive(Debug, Error)]
pub struct DatabaseInconsistencyError2 {
table: &'static str,
column: Option<&'static str>,
row: Option<Ulid>,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
}
impl std::fmt::Display for DatabaseInconsistencyError2 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Database inconsistency on table {}", self.table)?;
if let Some(column) = self.column {
write!(f, " column {column}")?;
}
if let Some(row) = self.row {
write!(f, " row {row}")?;
}
Ok(())
}
}
impl DatabaseInconsistencyError2 {
#[must_use]
pub(crate) const fn on(table: &'static str) -> Self {
Self {
table,
column: None,
row: None,
source: None,
}
}
#[must_use]
pub(crate) const fn column(mut self, column: &'static str) -> Self {
self.column = Some(column);
self
}
#[must_use]
pub(crate) const fn row(mut self, row: Ulid) -> Self {
self.row = Some(row);
self
}
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
mut self,
source: E,
) -> Self {
self.source = Some(Box::new(source));
self
}
}
#[derive(Default, Debug, Clone, Copy)] #[derive(Default, Debug, Clone, Copy)]
pub struct Clock { pub struct Clock {
_private: (), _private: (),

View File

@ -119,7 +119,7 @@ pub trait QueryBuilderExt {
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<&mut Self, anyhow::Error>; ) -> Result<&mut Self, InvalidPagination>;
} }
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
@ -135,7 +135,7 @@ where
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<&mut Self, anyhow::Error> { ) -> Result<&mut Self, InvalidPagination> {
generate_pagination(self, id_field, before, after, first, last)?; generate_pagination(self, id_field, before, after, first, last)?;
Ok(self) Ok(self)
} }

View File

@ -22,7 +22,7 @@ use uuid::Uuid;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, QueryBuilderExt},
Clock, GenericLookupError, Clock, DatabaseError, LookupResultExt,
}; };
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
@ -54,7 +54,7 @@ impl From<LinkLookup> for UpstreamOAuthLink {
pub async fn lookup_link( pub async fn lookup_link(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
id: Ulid, id: Ulid,
) -> Result<UpstreamOAuthLink, GenericLookupError> { ) -> Result<Option<UpstreamOAuthLink>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
LinkLookup, LinkLookup,
r#" r#"
@ -71,9 +71,10 @@ pub async fn lookup_link(
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; .to_option()?
.map(Into::into);
Ok(res.into()) Ok(res)
} }
#[tracing::instrument( #[tracing::instrument(
@ -90,7 +91,7 @@ pub async fn lookup_link_by_subject(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
upstream_oauth_provider: &UpstreamOAuthProvider, upstream_oauth_provider: &UpstreamOAuthProvider,
subject: &str, subject: &str,
) -> Result<UpstreamOAuthLink, GenericLookupError> { ) -> Result<Option<UpstreamOAuthLink>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
LinkLookup, LinkLookup,
r#" r#"
@ -109,9 +110,10 @@ pub async fn lookup_link_by_subject(
) )
.fetch_one(executor) .fetch_one(executor)
.await .await
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; .to_option()?
.map(Into::into);
Ok(res.into()) Ok(res)
} }
#[tracing::instrument( #[tracing::instrument(
@ -131,7 +133,7 @@ pub async fn add_link(
clock: &Clock, clock: &Clock,
upstream_oauth_provider: &UpstreamOAuthProvider, upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String, subject: String,
) -> Result<UpstreamOAuthLink, sqlx::Error> { ) -> Result<UpstreamOAuthLink, DatabaseError> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
@ -205,7 +207,7 @@ pub async fn get_paginated_user_links(
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), anyhow::Error> { ) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), DatabaseError> {
let mut query = QueryBuilder::new( let mut query = QueryBuilder::new(
r#" r#"
SELECT SELECT

View File

@ -21,9 +21,7 @@ pub use self::{
add_link, associate_link_to_user, get_paginated_user_links, lookup_link, add_link, associate_link_to_user, get_paginated_user_links, lookup_link,
lookup_link_by_subject, lookup_link_by_subject,
}, },
provider::{ provider::{add_provider, get_paginated_providers, get_providers, lookup_provider},
add_provider, get_paginated_providers, get_providers, lookup_provider, ProviderLookupError,
},
session::{ session::{
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
SessionLookupError, SessionLookupError,

View File

@ -18,29 +18,15 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod
use oauth2_types::scope::Scope; use oauth2_types::scope::Scope;
use rand::Rng; use rand::Rng;
use sqlx::{PgExecutor, QueryBuilder}; use sqlx::{PgExecutor, QueryBuilder};
use thiserror::Error;
use tracing::{info_span, Instrument}; use tracing::{info_span, Instrument};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, QueryBuilderExt},
Clock, DatabaseInconsistencyError, LookupError, Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
}; };
#[derive(Debug, Error)]
#[error("Failed to lookup upstream OAuth 2.0 provider")]
pub enum ProviderLookupError {
Driver(#[from] sqlx::Error),
Inconcistency(#[from] DatabaseInconsistencyError),
}
impl LookupError for ProviderLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
}
}
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
struct ProviderLookup { struct ProviderLookup {
upstream_oauth_provider_id: Uuid, upstream_oauth_provider_id: Uuid,
@ -54,22 +40,31 @@ struct ProviderLookup {
} }
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider { impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
type Error = DatabaseInconsistencyError; type Error = DatabaseInconsistencyError2;
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> { fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_provider_id.into(); let id = value.upstream_oauth_provider_id.into();
let scope = value let scope = value.scope.parse().map_err(|e| {
.scope DatabaseInconsistencyError2::on("upstream_oauth_providers")
.parse() .column("scope")
.map_err(|_| DatabaseInconsistencyError)?; .row(id)
let token_endpoint_auth_method = value .source(e)
.token_endpoint_auth_method })?;
.parse() let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
.map_err(|_| DatabaseInconsistencyError)?; DatabaseInconsistencyError2::on("upstream_oauth_providers")
.column("token_endpoint_auth_method")
.row(id)
.source(e)
})?;
let token_endpoint_signing_alg = value let token_endpoint_signing_alg = value
.token_endpoint_signing_alg .token_endpoint_signing_alg
.map(|x| x.parse()) .map(|x| x.parse())
.transpose() .transpose()
.map_err(|_| DatabaseInconsistencyError)?; .map_err(|e| {
DatabaseInconsistencyError2::on("upstream_oauth_providers")
.column("token_endpoint_signing_alg")
.row(id)
.source(e)
})?;
Ok(UpstreamOAuthProvider { Ok(UpstreamOAuthProvider {
id, id,
@ -92,7 +87,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
pub async fn lookup_provider( pub async fn lookup_provider(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
id: Ulid, id: Ulid,
) -> Result<UpstreamOAuthProvider, ProviderLookupError> { ) -> Result<Option<UpstreamOAuthProvider>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
ProviderLookup, ProviderLookup,
r#" r#"
@ -111,9 +106,15 @@ pub async fn lookup_provider(
Uuid::from(id), Uuid::from(id),
) )
.fetch_one(executor) .fetch_one(executor)
.await?; .await
.to_option()?;
Ok(res.try_into()?) let res = res
.map(UpstreamOAuthProvider::try_from)
.transpose()
.map_err(DatabaseError::from)?;
Ok(res)
} }
#[tracing::instrument( #[tracing::instrument(