You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-28 11:02:02 +03:00
storage: start unifying database errors
This commit is contained in:
@ -22,7 +22,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
Clock, GenericLookupError,
|
||||
Clock, DatabaseError, LookupResultExt,
|
||||
};
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
@ -54,7 +54,7 @@ impl From<LinkLookup> for UpstreamOAuthLink {
|
||||
pub async fn lookup_link(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<UpstreamOAuthLink, GenericLookupError> {
|
||||
) -> Result<Option<UpstreamOAuthLink>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
@ -71,9 +71,10 @@ pub async fn lookup_link(
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
||||
.to_option()?
|
||||
.map(Into::into);
|
||||
|
||||
Ok(res.into())
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -90,7 +91,7 @@ pub async fn lookup_link_by_subject(
|
||||
executor: impl PgExecutor<'_>,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: &str,
|
||||
) -> Result<UpstreamOAuthLink, GenericLookupError> {
|
||||
) -> Result<Option<UpstreamOAuthLink>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
LinkLookup,
|
||||
r#"
|
||||
@ -109,9 +110,10 @@ pub async fn lookup_link_by_subject(
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
||||
.to_option()?
|
||||
.map(Into::into);
|
||||
|
||||
Ok(res.into())
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -131,7 +133,7 @@ pub async fn add_link(
|
||||
clock: &Clock,
|
||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||
subject: String,
|
||||
) -> Result<UpstreamOAuthLink, sqlx::Error> {
|
||||
) -> Result<UpstreamOAuthLink, DatabaseError> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
||||
@ -205,7 +207,7 @@ pub async fn get_paginated_user_links(
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), anyhow::Error> {
|
||||
) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), DatabaseError> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
|
@ -21,9 +21,7 @@ pub use self::{
|
||||
add_link, associate_link_to_user, get_paginated_user_links, lookup_link,
|
||||
lookup_link_by_subject,
|
||||
},
|
||||
provider::{
|
||||
add_provider, get_paginated_providers, get_providers, lookup_provider, ProviderLookupError,
|
||||
},
|
||||
provider::{add_provider, get_paginated_providers, get_providers, lookup_provider},
|
||||
session::{
|
||||
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
||||
SessionLookupError,
|
||||
|
@ -18,29 +18,15 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::Rng;
|
||||
use sqlx::{PgExecutor, QueryBuilder};
|
||||
use thiserror::Error;
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
Clock, DatabaseInconsistencyError, LookupError,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Failed to lookup upstream OAuth 2.0 provider")]
|
||||
pub enum ProviderLookupError {
|
||||
Driver(#[from] sqlx::Error),
|
||||
Inconcistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for ProviderLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct ProviderLookup {
|
||||
upstream_oauth_provider_id: Uuid,
|
||||
@ -54,22 +40,31 @@ struct ProviderLookup {
|
||||
}
|
||||
|
||||
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
type Error = DatabaseInconsistencyError2;
|
||||
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_provider_id.into();
|
||||
let scope = value
|
||||
.scope
|
||||
.parse()
|
||||
.map_err(|_| DatabaseInconsistencyError)?;
|
||||
let token_endpoint_auth_method = value
|
||||
.token_endpoint_auth_method
|
||||
.parse()
|
||||
.map_err(|_| DatabaseInconsistencyError)?;
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||
.column("scope")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_auth_method")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
let token_endpoint_signing_alg = value
|
||||
.token_endpoint_signing_alg
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|_| DatabaseInconsistencyError)?;
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||
.column("token_endpoint_signing_alg")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(UpstreamOAuthProvider {
|
||||
id,
|
||||
@ -92,7 +87,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
pub async fn lookup_provider(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<UpstreamOAuthProvider, ProviderLookupError> {
|
||||
) -> Result<Option<UpstreamOAuthProvider>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
ProviderLookup,
|
||||
r#"
|
||||
@ -111,9 +106,15 @@ pub async fn lookup_provider(
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
Ok(res.try_into()?)
|
||||
let res = res
|
||||
.map(UpstreamOAuthProvider::try_from)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::from)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
|
Reference in New Issue
Block a user