You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: start unifying database errors
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2788,6 +2788,7 @@ dependencies = [
|
|||||||
name = "mas-graphql"
|
name = "mas-graphql"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"async-graphql",
|
"async-graphql",
|
||||||
"chrono",
|
"chrono",
|
||||||
"mas-axum-utils",
|
"mas-axum-utils",
|
||||||
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0.66"
|
||||||
async-graphql = { version = "5.0.2", features = ["chrono", "url"] }
|
async-graphql = { version = "5.0.2", features = ["chrono", "url"] }
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.23"
|
||||||
serde = { version = "1.0.149", features = ["derive"] }
|
serde = { version = "1.0.149", features = ["derive"] }
|
||||||
|
@ -188,9 +188,7 @@ impl RootQuery {
|
|||||||
let Some(session) = session else { return Ok(None) };
|
let Some(session) = session else { return Ok(None) };
|
||||||
let current_user = session.user;
|
let current_user = session.user;
|
||||||
|
|
||||||
let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id)
|
let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?;
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
// Ensure that the link belongs to the current user
|
// Ensure that the link belongs to the current user
|
||||||
let link = link.filter(|link| link.user_id == Some(current_user.id));
|
let link = link.filter(|link| link.user_id == Some(current_user.id));
|
||||||
@ -208,9 +206,7 @@ impl RootQuery {
|
|||||||
let database = ctx.data::<PgPool>()?;
|
let database = ctx.data::<PgPool>()?;
|
||||||
let mut conn = database.acquire().await?;
|
let mut conn = database.acquire().await?;
|
||||||
|
|
||||||
let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id)
|
let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?;
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
Ok(provider.map(UpstreamOAuth2Provider::new))
|
Ok(provider.map(UpstreamOAuth2Provider::new))
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use anyhow::Context as _;
|
||||||
use async_graphql::{Context, Object, ID};
|
use async_graphql::{Context, Object, ID};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
@ -100,7 +101,9 @@ impl UpstreamOAuth2Link {
|
|||||||
// Fetch on-the-fly
|
// Fetch on-the-fly
|
||||||
let database = ctx.data::<PgPool>()?;
|
let database = ctx.data::<PgPool>()?;
|
||||||
let mut conn = database.acquire().await?;
|
let mut conn = database.acquire().await?;
|
||||||
mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id).await?
|
mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id)
|
||||||
|
.await?
|
||||||
|
.context("Upstream OAuth 2.0 provider not found")?
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(UpstreamOAuth2Provider::new(provider))
|
Ok(UpstreamOAuth2Provider::new(provider))
|
||||||
|
@ -22,7 +22,7 @@ use mas_axum_utils::http_client_factory::HttpClientFactory;
|
|||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt};
|
use mas_storage::upstream_oauth2::lookup_provider;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
@ -46,7 +46,7 @@ impl_from_error_for_route!(sqlx::Error);
|
|||||||
impl_from_error_for_route!(mas_http::ClientInitError);
|
impl_from_error_for_route!(mas_http::ClientInitError);
|
||||||
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
||||||
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
|
impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
|
||||||
impl_from_error_for_route!(mas_storage::upstream_oauth2::ProviderLookupError);
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
|
|
||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
@ -75,8 +75,7 @@ pub(crate) async fn get(
|
|||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
let provider = lookup_provider(&mut txn, provider_id)
|
let provider = lookup_provider(&mut txn, provider_id)
|
||||||
.await
|
.await?
|
||||||
.to_option()?
|
|
||||||
.ok_or(RouteError::ProviderNotFound)?;
|
.ok_or(RouteError::ProviderNotFound)?;
|
||||||
|
|
||||||
let http_service = http_client_factory
|
let http_service = http_client_factory
|
||||||
|
@ -96,6 +96,7 @@ pub(crate) enum RouteError {
|
|||||||
Anyhow(#[from] anyhow::Error),
|
Anyhow(#[from] anyhow::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
impl_from_error_for_route!(mas_storage::GenericLookupError);
|
impl_from_error_for_route!(mas_storage::GenericLookupError);
|
||||||
impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError);
|
impl_from_error_for_route!(mas_storage::upstream_oauth2::SessionLookupError);
|
||||||
impl_from_error_for_route!(mas_http::ClientInitError);
|
impl_from_error_for_route!(mas_http::ClientInitError);
|
||||||
@ -242,9 +243,7 @@ pub(crate) async fn get(
|
|||||||
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
|
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
|
||||||
|
|
||||||
// Look for an existing link
|
// Look for an existing link
|
||||||
let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject)
|
let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?;
|
||||||
.await
|
|
||||||
.to_option()?;
|
|
||||||
|
|
||||||
let link = if let Some(link) = maybe_link {
|
let link = if let Some(link) = maybe_link {
|
||||||
link
|
link
|
||||||
|
@ -79,6 +79,7 @@ impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError);
|
|||||||
impl_from_error_for_route!(mas_storage::user::UserLookupError);
|
impl_from_error_for_route!(mas_storage::user::UserLookupError);
|
||||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||||
|
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||||
|
|
||||||
impl IntoResponse for RouteError {
|
impl IntoResponse for RouteError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
@ -118,8 +119,7 @@ pub(crate) async fn get(
|
|||||||
.map_err(|_| RouteError::MissingCookie)?;
|
.map_err(|_| RouteError::MissingCookie)?;
|
||||||
|
|
||||||
let link = lookup_link(&mut txn, link_id)
|
let link = lookup_link(&mut txn, link_id)
|
||||||
.await
|
.await?
|
||||||
.to_option()?
|
|
||||||
.ok_or(RouteError::LinkNotFound)?;
|
.ok_or(RouteError::LinkNotFound)?;
|
||||||
|
|
||||||
// This checks that we're in a browser session which is allowed to consume this
|
// This checks that we're in a browser session which is allowed to consume this
|
||||||
@ -221,8 +221,7 @@ pub(crate) async fn post(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let link = lookup_link(&mut txn, link_id)
|
let link = lookup_link(&mut txn, link_id)
|
||||||
.await
|
.await?
|
||||||
.to_option()?
|
|
||||||
.ok_or(RouteError::LinkNotFound)?;
|
.ok_or(RouteError::LinkNotFound)?;
|
||||||
|
|
||||||
// This checks that we're in a browser session which is allowed to consume this
|
// This checks that we're in a browser session which is allowed to consume this
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
use mas_router::{PostAuthAction, Route};
|
use mas_router::{PostAuthAction, Route};
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
|
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
|
||||||
@ -58,11 +59,14 @@ impl OptionalPostAuthAction {
|
|||||||
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
|
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
|
||||||
|
|
||||||
PostAuthAction::LinkUpstream { id } => {
|
PostAuthAction::LinkUpstream { id } => {
|
||||||
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id).await?;
|
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id)
|
||||||
|
.await?
|
||||||
|
.context("Failed to load upstream OAuth 2.0 link")?;
|
||||||
|
|
||||||
let provider =
|
let provider =
|
||||||
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
|
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
|
||||||
.await?;
|
.await?
|
||||||
|
.context("Failed to load upstream OAuth 2.0 provider")?;
|
||||||
|
|
||||||
let provider = Box::new(provider);
|
let provider = Box::new(provider);
|
||||||
let link = Box::new(link);
|
let link = Box::new(link);
|
||||||
|
@ -29,8 +29,10 @@
|
|||||||
)]
|
)]
|
||||||
|
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
use pagination::InvalidPagination;
|
||||||
use sqlx::migrate::Migrator;
|
use sqlx::migrate::Migrator;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
#[error("failed to lookup {what}")]
|
#[error("failed to lookup {what}")]
|
||||||
@ -52,6 +54,12 @@ impl LookupError for GenericLookupError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LookupError for sqlx::Error {
|
||||||
|
fn not_found(&self) -> bool {
|
||||||
|
matches!(self, sqlx::Error::RowNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait LookupError {
|
pub trait LookupError {
|
||||||
fn not_found(&self) -> bool;
|
fn not_found(&self) -> bool;
|
||||||
}
|
}
|
||||||
@ -80,6 +88,76 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generic error when interacting with the database
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error(transparent)]
|
||||||
|
pub enum DatabaseError {
|
||||||
|
/// An error which came from the database itself
|
||||||
|
Driver(#[from] sqlx::Error),
|
||||||
|
|
||||||
|
/// An error which occured while converting the data from the database
|
||||||
|
Inconsistency(#[from] DatabaseInconsistencyError2),
|
||||||
|
|
||||||
|
/// An error which occured while generating the paginated query
|
||||||
|
Pagination(#[from] InvalidPagination),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub struct DatabaseInconsistencyError2 {
|
||||||
|
table: &'static str,
|
||||||
|
column: Option<&'static str>,
|
||||||
|
row: Option<Ulid>,
|
||||||
|
|
||||||
|
#[source]
|
||||||
|
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for DatabaseInconsistencyError2 {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Database inconsistency on table {}", self.table)?;
|
||||||
|
if let Some(column) = self.column {
|
||||||
|
write!(f, " column {column}")?;
|
||||||
|
}
|
||||||
|
if let Some(row) = self.row {
|
||||||
|
write!(f, " row {row}")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseInconsistencyError2 {
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) const fn on(table: &'static str) -> Self {
|
||||||
|
Self {
|
||||||
|
table,
|
||||||
|
column: None,
|
||||||
|
row: None,
|
||||||
|
source: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) const fn column(mut self, column: &'static str) -> Self {
|
||||||
|
self.column = Some(column);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) const fn row(mut self, row: Ulid) -> Self {
|
||||||
|
self.row = Some(row);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn source<E: std::error::Error + Send + Sync + 'static>(
|
||||||
|
mut self,
|
||||||
|
source: E,
|
||||||
|
) -> Self {
|
||||||
|
self.source = Some(Box::new(source));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, Copy)]
|
#[derive(Default, Debug, Clone, Copy)]
|
||||||
pub struct Clock {
|
pub struct Clock {
|
||||||
_private: (),
|
_private: (),
|
||||||
|
@ -119,7 +119,7 @@ pub trait QueryBuilderExt {
|
|||||||
after: Option<Ulid>,
|
after: Option<Ulid>,
|
||||||
first: Option<usize>,
|
first: Option<usize>,
|
||||||
last: Option<usize>,
|
last: Option<usize>,
|
||||||
) -> Result<&mut Self, anyhow::Error>;
|
) -> Result<&mut Self, InvalidPagination>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
||||||
@ -135,7 +135,7 @@ where
|
|||||||
after: Option<Ulid>,
|
after: Option<Ulid>,
|
||||||
first: Option<usize>,
|
first: Option<usize>,
|
||||||
last: Option<usize>,
|
last: Option<usize>,
|
||||||
) -> Result<&mut Self, anyhow::Error> {
|
) -> Result<&mut Self, InvalidPagination> {
|
||||||
generate_pagination(self, id_field, before, after, first, last)?;
|
generate_pagination(self, id_field, before, after, first, last)?;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
pagination::{process_page, QueryBuilderExt},
|
pagination::{process_page, QueryBuilderExt},
|
||||||
Clock, GenericLookupError,
|
Clock, DatabaseError, LookupResultExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
#[derive(sqlx::FromRow)]
|
||||||
@ -54,7 +54,7 @@ impl From<LinkLookup> for UpstreamOAuthLink {
|
|||||||
pub async fn lookup_link(
|
pub async fn lookup_link(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
id: Ulid,
|
id: Ulid,
|
||||||
) -> Result<UpstreamOAuthLink, GenericLookupError> {
|
) -> Result<Option<UpstreamOAuthLink>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
LinkLookup,
|
LinkLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -71,9 +71,10 @@ pub async fn lookup_link(
|
|||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await
|
.await
|
||||||
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
.to_option()?
|
||||||
|
.map(Into::into);
|
||||||
|
|
||||||
Ok(res.into())
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -90,7 +91,7 @@ pub async fn lookup_link_by_subject(
|
|||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
subject: &str,
|
subject: &str,
|
||||||
) -> Result<UpstreamOAuthLink, GenericLookupError> {
|
) -> Result<Option<UpstreamOAuthLink>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
LinkLookup,
|
LinkLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -109,9 +110,10 @@ pub async fn lookup_link_by_subject(
|
|||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await
|
.await
|
||||||
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
.to_option()?
|
||||||
|
.map(Into::into);
|
||||||
|
|
||||||
Ok(res.into())
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@ -131,7 +133,7 @@ pub async fn add_link(
|
|||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
upstream_oauth_provider: &UpstreamOAuthProvider,
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
subject: String,
|
subject: String,
|
||||||
) -> Result<UpstreamOAuthLink, sqlx::Error> {
|
) -> Result<UpstreamOAuthLink, DatabaseError> {
|
||||||
let created_at = clock.now();
|
let created_at = clock.now();
|
||||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
||||||
@ -205,7 +207,7 @@ pub async fn get_paginated_user_links(
|
|||||||
after: Option<Ulid>,
|
after: Option<Ulid>,
|
||||||
first: Option<usize>,
|
first: Option<usize>,
|
||||||
last: Option<usize>,
|
last: Option<usize>,
|
||||||
) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), anyhow::Error> {
|
) -> Result<(bool, bool, Vec<UpstreamOAuthLink>), DatabaseError> {
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
r#"
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -21,9 +21,7 @@ pub use self::{
|
|||||||
add_link, associate_link_to_user, get_paginated_user_links, lookup_link,
|
add_link, associate_link_to_user, get_paginated_user_links, lookup_link,
|
||||||
lookup_link_by_subject,
|
lookup_link_by_subject,
|
||||||
},
|
},
|
||||||
provider::{
|
provider::{add_provider, get_paginated_providers, get_providers, lookup_provider},
|
||||||
add_provider, get_paginated_providers, get_providers, lookup_provider, ProviderLookupError,
|
|
||||||
},
|
|
||||||
session::{
|
session::{
|
||||||
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,
|
||||||
SessionLookupError,
|
SessionLookupError,
|
||||||
|
@ -18,29 +18,15 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod
|
|||||||
use oauth2_types::scope::Scope;
|
use oauth2_types::scope::Scope;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::{PgExecutor, QueryBuilder};
|
use sqlx::{PgExecutor, QueryBuilder};
|
||||||
use thiserror::Error;
|
|
||||||
use tracing::{info_span, Instrument};
|
use tracing::{info_span, Instrument};
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
pagination::{process_page, QueryBuilderExt},
|
pagination::{process_page, QueryBuilderExt},
|
||||||
Clock, DatabaseInconsistencyError, LookupError,
|
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
#[error("Failed to lookup upstream OAuth 2.0 provider")]
|
|
||||||
pub enum ProviderLookupError {
|
|
||||||
Driver(#[from] sqlx::Error),
|
|
||||||
Inconcistency(#[from] DatabaseInconsistencyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LookupError for ProviderLookupError {
|
|
||||||
fn not_found(&self) -> bool {
|
|
||||||
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(sqlx::FromRow)]
|
#[derive(sqlx::FromRow)]
|
||||||
struct ProviderLookup {
|
struct ProviderLookup {
|
||||||
upstream_oauth_provider_id: Uuid,
|
upstream_oauth_provider_id: Uuid,
|
||||||
@ -54,22 +40,31 @@ struct ProviderLookup {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||||
type Error = DatabaseInconsistencyError;
|
type Error = DatabaseInconsistencyError2;
|
||||||
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
|
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
|
||||||
let id = value.upstream_oauth_provider_id.into();
|
let id = value.upstream_oauth_provider_id.into();
|
||||||
let scope = value
|
let scope = value.scope.parse().map_err(|e| {
|
||||||
.scope
|
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||||
.parse()
|
.column("scope")
|
||||||
.map_err(|_| DatabaseInconsistencyError)?;
|
.row(id)
|
||||||
let token_endpoint_auth_method = value
|
.source(e)
|
||||||
.token_endpoint_auth_method
|
})?;
|
||||||
.parse()
|
let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
|
||||||
.map_err(|_| DatabaseInconsistencyError)?;
|
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||||
|
.column("token_endpoint_auth_method")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
let token_endpoint_signing_alg = value
|
let token_endpoint_signing_alg = value
|
||||||
.token_endpoint_signing_alg
|
.token_endpoint_signing_alg
|
||||||
.map(|x| x.parse())
|
.map(|x| x.parse())
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|_| DatabaseInconsistencyError)?;
|
.map_err(|e| {
|
||||||
|
DatabaseInconsistencyError2::on("upstream_oauth_providers")
|
||||||
|
.column("token_endpoint_signing_alg")
|
||||||
|
.row(id)
|
||||||
|
.source(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(UpstreamOAuthProvider {
|
Ok(UpstreamOAuthProvider {
|
||||||
id,
|
id,
|
||||||
@ -92,7 +87,7 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
|||||||
pub async fn lookup_provider(
|
pub async fn lookup_provider(
|
||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
id: Ulid,
|
id: Ulid,
|
||||||
) -> Result<UpstreamOAuthProvider, ProviderLookupError> {
|
) -> Result<Option<UpstreamOAuthProvider>, DatabaseError> {
|
||||||
let res = sqlx::query_as!(
|
let res = sqlx::query_as!(
|
||||||
ProviderLookup,
|
ProviderLookup,
|
||||||
r#"
|
r#"
|
||||||
@ -111,9 +106,15 @@ pub async fn lookup_provider(
|
|||||||
Uuid::from(id),
|
Uuid::from(id),
|
||||||
)
|
)
|
||||||
.fetch_one(executor)
|
.fetch_one(executor)
|
||||||
.await?;
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
Ok(res.try_into()?)
|
let res = res
|
||||||
|
.map(UpstreamOAuthProvider::try_from)
|
||||||
|
.transpose()
|
||||||
|
.map_err(DatabaseError::from)?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
|
Reference in New Issue
Block a user