1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

storage: unify most of the remaining errors

This commit is contained in:
Quentin Gliech
2022-12-08 12:19:28 +01:00
parent 102571512e
commit a836cc864a
14 changed files with 238 additions and 133 deletions

View File

@ -911,7 +911,7 @@ pub async fn fullfill_compat_sso_login(
device: Device,
) -> Result<CompatSsoLogin, DatabaseError> {
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
return Err(DatabaseError::InvalidOperation);
return Err(DatabaseError::invalid_operation());
};
let mut txn = conn.begin().await?;
@ -986,7 +986,7 @@ pub async fn mark_compat_sso_login_as_exchanged(
mut compat_sso_login: CompatSsoLogin,
) -> Result<CompatSsoLogin, DatabaseError> {
let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else {
return Err(DatabaseError::InvalidOperation);
return Err(DatabaseError::invalid_operation());
};
let exchanged_at = clock.now();

View File

@ -104,7 +104,10 @@ pub enum DatabaseError {
/// An error which happened because the requested database operation is
/// invalid
#[error("Invalid database operation")]
InvalidOperation,
InvalidOperation {
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
},
/// An error which happens when an operation affects not enough or too many
/// rows
@ -124,6 +127,16 @@ impl DatabaseError {
Err(DatabaseError::RowsAffected { expected, actual })
}
}
pub(crate) fn to_invalid_operation<E: std::error::Error + Send + Sync + 'static>(e: E) -> Self {
Self::InvalidOperation {
source: Some(Box::new(e)),
}
}
pub(crate) const fn invalid_operation() -> Self {
Self::InvalidOperation { source: None }
}
}
#[derive(Debug, Error)]

View File

@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
use rand::Rng;
@ -40,7 +39,7 @@ pub async fn add_access_token(
session: &Session,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken, anyhow::Error> {
) -> Result<AccessToken, sqlx::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
@ -61,8 +60,7 @@ pub async fn add_access_token(
expires_at,
)
.execute(executor)
.await
.context("could not insert oauth2 access token")?;
.await?;
Ok(AccessToken {
id,

View File

@ -12,11 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::unused_async)]
use std::num::NonZeroU32;
use anyhow::Context;
use chrono::{DateTime, Utc};
use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
@ -31,7 +28,7 @@ use url::Url;
use uuid::Uuid;
use super::client::lookup_client;
use crate::{Clock, DatabaseInconsistencyError};
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt};
#[tracing::instrument(
skip_all,
@ -39,7 +36,7 @@ use crate::{Clock, DatabaseInconsistencyError};
%client.id,
grant.id,
),
err(Debug),
err,
)]
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
@ -57,7 +54,7 @@ pub async fn new_authorization_grant(
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
) -> Result<AuthorizationGrant, anyhow::Error> {
) -> Result<AuthorizationGrant, sqlx::Error> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
@ -113,8 +110,7 @@ pub async fn new_authorization_grant(
created_at,
)
.execute(executor)
.await
.context("could not insert oauth2 authorization grant")?;
.await?;
Ok(AuthorizationGrant {
id,
@ -171,17 +167,23 @@ impl GrantLookup {
async fn into_authorization_grant(
self,
executor: impl PgExecutor<'_>,
) -> Result<AuthorizationGrant, DatabaseInconsistencyError> {
let scope: Scope = self
.oauth2_authorization_grant_scope
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
) -> Result<AuthorizationGrant, DatabaseError> {
let id = self.oauth2_authorization_grant_id.into();
let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("scope")
.row(id)
.source(e)
})?;
// TODO: don't unwrap
let client = lookup_client(executor, self.oauth2_client_id.into())
.await
.unwrap()
.unwrap();
.await?
.ok_or_else(|| {
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("client_id")
.row(id)
})?;
let last_authentication = match (
self.user_session_last_authentication_id,
@ -192,7 +194,9 @@ impl GrantLookup {
created_at,
}),
(None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(DatabaseInconsistencyError2::on("user_session_authentications").into())
}
};
let primary_email = match (
@ -208,7 +212,11 @@ impl GrantLookup {
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.into())
}
};
let session = match (
@ -257,7 +265,14 @@ impl GrantLookup {
Some(session)
}
(None, None, None, None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("oauth2_session_id")
.row(id)
.into(),
)
}
};
let stage = match (
@ -282,7 +297,12 @@ impl GrantLookup {
AuthorizationGrantStage::Cancelled { cancelled_at }
}
_ => {
return Err(DatabaseInconsistencyError);
return Err(
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("stage")
.row(id)
.into(),
);
}
};
@ -302,7 +322,12 @@ impl GrantLookup {
}),
(None, None) => None,
_ => {
return Err(DatabaseInconsistencyError);
return Err(
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("code_challenge_method")
.row(id)
.into(),
);
}
};
@ -314,38 +339,63 @@ impl GrantLookup {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(DatabaseInconsistencyError);
return Err(
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("authorization_code")
.row(id)
.into(),
);
}
};
let redirect_uri = self
.oauth2_authorization_grant_redirect_uri
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
.map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("redirect_uri")
.row(id)
.source(e)
})?;
let response_mode = self
.oauth2_authorization_grant_response_mode
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
.map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("response_mode")
.row(id)
.source(e)
})?;
let max_age = self
.oauth2_authorization_grant_max_age
.map(u32::try_from)
.transpose()
.map_err(|_e| DatabaseInconsistencyError)?
.map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?
.map(NonZeroU32::try_from)
.transpose()
.map_err(|_e| DatabaseInconsistencyError)?;
.map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_authorization_grants")
.column("max_age")
.row(id)
.source(e)
})?;
Ok(AuthorizationGrant {
id: self.oauth2_authorization_grant_id.into(),
id,
stage,
client,
code,
scope,
state: self.oauth2_authorization_grant_state,
nonce: self.oauth2_authorization_grant_nonce,
max_age, // TODO
max_age,
response_mode,
redirect_uri,
created_at: self.oauth2_authorization_grant_created_at,
@ -358,13 +408,12 @@ impl GrantLookup {
#[tracing::instrument(
skip_all,
fields(grant.id = %id),
err(Debug),
err,
)]
pub async fn get_grant_by_id(
conn: &mut PgConnection,
id: Ulid,
) -> Result<AuthorizationGrant, anyhow::Error> {
// TODO: handle "not found" cases
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
let res = sqlx::query_as!(
GrantLookup,
r#"
@ -420,19 +469,20 @@ pub async fn get_grant_by_id(
)
.fetch_one(&mut *conn)
.await
.context("failed to get grant by id")?;
.to_option()?;
let Some(res) = res else { return Ok(None) };
let grant = res.into_authorization_grant(&mut *conn).await?;
Ok(grant)
Ok(Some(grant))
}
#[tracing::instrument(skip_all, err(Debug))]
#[tracing::instrument(skip_all, err)]
pub async fn lookup_grant_by_code(
conn: &mut PgConnection,
code: &str,
) -> Result<AuthorizationGrant, anyhow::Error> {
// TODO: handle "not found" cases
) -> Result<Option<AuthorizationGrant>, DatabaseError> {
let res = sqlx::query_as!(
GrantLookup,
r#"
@ -488,11 +538,13 @@ pub async fn lookup_grant_by_code(
)
.fetch_one(&mut *conn)
.await
.context("failed to lookup grant by code")?;
.to_option()?;
let Some(res) = res else { return Ok(None) };
let grant = res.into_authorization_grant(&mut *conn).await?;
Ok(grant)
Ok(Some(grant))
}
#[tracing::instrument(
@ -504,7 +556,7 @@ pub async fn lookup_grant_by_code(
user_session.id = %browser_session.id,
user.id = %browser_session.user.id,
),
err(Debug),
err,
)]
pub async fn derive_session(
executor: impl PgExecutor<'_>,
@ -512,7 +564,7 @@ pub async fn derive_session(
clock: &Clock,
grant: &AuthorizationGrant,
browser_session: BrowserSession,
) -> Result<Session, anyhow::Error> {
) -> Result<Session, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("session.id", tracing::field::display(id));
@ -538,8 +590,7 @@ pub async fn derive_session(
Uuid::from(grant.id),
)
.execute(executor)
.await
.context("could not insert oauth2 session")?;
.await?;
Ok(Session {
id,
@ -558,13 +609,13 @@ pub async fn derive_session(
user_session.id = %session.browser_session.id,
user.id = %session.browser_session.user.id,
),
err(Debug),
err,
)]
pub async fn fulfill_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant,
session: Session,
) -> Result<AuthorizationGrant, anyhow::Error> {
) -> Result<AuthorizationGrant, DatabaseError> {
let fulfilled_at = sqlx::query_scalar!(
r#"
UPDATE oauth2_authorization_grants AS og
@ -581,10 +632,12 @@ pub async fn fulfill_grant(
Uuid::from(session.id),
)
.fetch_one(executor)
.await
.context("could not mark grant as fulfilled")?;
.await?;
grant.stage = grant.stage.fulfill(fulfilled_at, session)?;
grant.stage = grant
.stage
.fulfill(fulfilled_at, session)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}
@ -595,7 +648,7 @@ pub async fn fulfill_grant(
%grant.id,
client.id = %grant.client.id,
),
err(Debug),
err,
)]
pub async fn give_consent_to_grant(
executor: impl PgExecutor<'_>,
@ -625,13 +678,13 @@ pub async fn give_consent_to_grant(
%grant.id,
client.id = %grant.client.id,
),
err(Debug),
err,
)]
pub async fn exchange_grant(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut grant: AuthorizationGrant,
) -> Result<AuthorizationGrant, anyhow::Error> {
) -> Result<AuthorizationGrant, DatabaseError> {
let exchanged_at = clock.now();
sqlx::query!(
r#"
@ -643,10 +696,12 @@ pub async fn exchange_grant(
exchanged_at,
)
.execute(executor)
.await
.context("could not mark grant as exchanged")?;
.await?;
grant.stage = grant.stage.exchange(exchanged_at)?;
grant.stage = grant
.stage
.exchange(exchanged_at)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(grant)
}

View File

@ -468,8 +468,12 @@ pub async fn insert_client_from_config(
jwks: Option<&PublicJsonWebKeySet>,
jwks_uri: Option<&Url>,
redirect_uris: &[Url],
) -> Result<(), anyhow::Error> {
let jwks = jwks.map(serde_json::to_value).transpose()?;
) -> Result<(), DatabaseError> {
let jwks = jwks
.map(serde_json::to_value)
.transpose()
.map_err(DatabaseError::to_invalid_operation)?;
let jwks_uri = jwks_uri.map(Url::as_str);
let client_auth_method = client_auth_method.to_string();
@ -526,7 +530,7 @@ pub async fn insert_client_from_config(
Ok(())
}
pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> Result<(), anyhow::Error> {
pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> Result<(), sqlx::Error> {
sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients CASCADE")
.execute(executor)
.await?;

View File

@ -21,7 +21,7 @@ use sqlx::PgExecutor;
use ulid::Ulid;
use uuid::Uuid;
use crate::Clock;
use crate::{Clock, DatabaseError, DatabaseInconsistencyError2};
#[tracing::instrument(
skip_all,
@ -29,13 +29,13 @@ use crate::Clock;
%user.id,
%client.id,
),
err(Debug),
err,
)]
pub async fn fetch_client_consent(
executor: impl PgExecutor<'_>,
user: &User,
client: &Client,
) -> Result<Scope, anyhow::Error> {
) -> Result<Scope, DatabaseError> {
let scope_tokens: Vec<String> = sqlx::query_scalar!(
r#"
SELECT scope_token
@ -53,7 +53,13 @@ pub async fn fetch_client_consent(
.map(|s| ScopeToken::from_str(&s))
.collect();
Ok(scope?)
let scope = scope.map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_consents")
.column("scope_token")
.source(e)
})?;
Ok(scope)
}
#[tracing::instrument(
@ -63,7 +69,7 @@ pub async fn fetch_client_consent(
%client.id,
%scope,
),
err(Debug),
err,
)]
pub async fn insert_client_consent(
executor: impl PgExecutor<'_>,
@ -72,7 +78,7 @@ pub async fn insert_client_consent(
user: &User,
client: &Client,
scope: &Scope,
) -> Result<(), anyhow::Error> {
) -> Result<(), sqlx::Error> {
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()

View File

@ -14,7 +14,6 @@
use std::collections::{BTreeSet, HashMap};
use anyhow::Context;
use mas_data_model::{BrowserSession, Session, User};
use sqlx::{PgConnection, PgExecutor, QueryBuilder};
use tracing::{info_span, Instrument};
@ -25,7 +24,7 @@ use self::client::lookup_clients;
use crate::{
pagination::{process_page, QueryBuilderExt},
user::lookup_active_session,
Clock,
Clock, DatabaseError, DatabaseInconsistencyError2,
};
pub mod access_token;
@ -42,13 +41,13 @@ pub mod refresh_token;
user_session.id = %session.browser_session.id,
client.id = %session.client.id,
),
err(Debug),
err,
)]
pub async fn end_oauth_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
session: Session,
) -> Result<(), anyhow::Error> {
) -> Result<(), DatabaseError> {
let finished_at = clock.now();
let res = sqlx::query!(
r#"
@ -62,9 +61,7 @@ pub async fn end_oauth_session(
.execute(executor)
.await?;
anyhow::ensure!(res.rows_affected() == 1);
Ok(())
DatabaseError::ensure_affected_rows(&res, 1)
}
#[derive(sqlx::FromRow)]
@ -81,7 +78,7 @@ struct OAuthSessionLookup {
%user.id,
%user.username,
),
err(Display),
err,
)]
pub async fn get_paginated_user_oauth_sessions(
conn: &mut PgConnection,
@ -90,7 +87,7 @@ pub async fn get_paginated_user_oauth_sessions(
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<Session>), anyhow::Error> {
) -> Result<(bool, bool, Vec<Session>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
@ -139,26 +136,42 @@ pub async fn get_paginated_user_oauth_sessions(
for id in browser_session_ids {
let v = lookup_active_session(&mut *conn, id)
.await?
.context("Failed to load active session")?;
.ok_or_else(|| {
DatabaseInconsistencyError2::on("oauth2_sessions").column("user_session_id")
})?;
browser_sessions.insert(id, v);
}
let page: Result<Vec<_>, _> = page
let page: Result<Vec<_>, DatabaseInconsistencyError2> = page
.into_iter()
.map(|item| {
let id = Ulid::from(item.oauth2_session_id);
let client = clients
.get(&Ulid::from(item.oauth2_client_id))
.context("client was not fetched")?
.ok_or_else(|| {
DatabaseInconsistencyError2::on("oauth2_sessions")
.column("oauth2_client_id")
.row(id)
})?
.clone();
let browser_session = browser_sessions
.get(&Ulid::from(item.user_session_id))
.context("browser session was not fetched")?
.ok_or_else(|| {
DatabaseInconsistencyError2::on("oauth2_sessions")
.column("user_session_id")
.row(id)
})?
.clone();
let scope = item.scope.parse()?;
let scope = item.scope.parse().map_err(|e| {
DatabaseInconsistencyError2::on("oauth2_sessions")
.column("scope")
.row(id)
.source(e)
})?;
anyhow::Ok(Session {
Ok(Session {
id: Ulid::from(item.oauth2_session_id),
client,
browser_session,

View File

@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Utc};
use mas_data_model::{
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
@ -43,7 +42,7 @@ pub async fn add_refresh_token(
session: &Session,
access_token: AccessToken,
refresh_token: String,
) -> anyhow::Result<RefreshToken> {
) -> Result<RefreshToken, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
@ -63,8 +62,7 @@ pub async fn add_refresh_token(
created_at,
)
.execute(executor)
.await
.context("could not insert oauth2 refresh token")?;
.await?;
Ok(RefreshToken {
id,

View File

@ -179,14 +179,14 @@ pub async fn add_provider(
})
}
#[tracing::instrument(skip_all, err(Display))]
#[tracing::instrument(skip_all, err)]
pub async fn get_paginated_providers(
executor: impl PgExecutor<'_>,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UpstreamOAuthProvider>), anyhow::Error> {
) -> Result<(bool, bool, Vec<UpstreamOAuthProvider>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
@ -224,7 +224,7 @@ pub async fn get_paginated_providers(
#[tracing::instrument(skip_all, err)]
pub async fn get_providers(
executor: impl PgExecutor<'_>,
) -> Result<Vec<UpstreamOAuthProvider>, anyhow::Error> {
) -> Result<Vec<UpstreamOAuthProvider>, DatabaseError> {
let res = sqlx::query_as!(
ProviderLookup,
r#"

View File

@ -1205,7 +1205,7 @@ pub async fn consume_email_verification(
user_email_verification.state,
UserEmailVerificationState::Valid
) {
return Err(DatabaseError::InvalidOperation);
return Err(DatabaseError::invalid_operation());
}
let consumed_at = clock.now();