1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: unify user operations errors

This commit is contained in:
Quentin Gliech
2022-12-07 19:07:53 +01:00
parent f7f65e314b
commit b7cad48bbd
11 changed files with 165 additions and 206 deletions

View File

@ -306,7 +306,9 @@ pub async fn compat_login(
let mut txn = conn.begin().await.context("could not start transaction")?;
// First, lookup the user
let user = lookup_user_by_username(&mut txn, username).await?;
let user = lookup_user_by_username(&mut txn, username)
.await?
.context("Could not lookup username")?;
tracing::Span::current().record("user.id", tracing::field::display(user.id));
// Now, fetch the hashed password from the user associated with that session

View File

@ -137,7 +137,9 @@ pub async fn get_paginated_user_oauth_sessions(
// ideal
let mut browser_sessions: HashMap<Ulid, BrowserSession> = HashMap::new();
for id in browser_session_ids {
let v = lookup_active_session(&mut *conn, id).await?;
let v = lookup_active_session(&mut *conn, id)
.await?
.context("Failed to load active session")?;
browser_sessions.insert(id, v);
}

View File

@ -14,7 +14,7 @@
use std::borrow::BorrowMut;
use anyhow::{bail, Context};
use anyhow::Context;
use argon2::Argon2;
use chrono::{DateTime, Utc};
use mas_data_model::{
@ -30,10 +30,9 @@ use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use super::DatabaseInconsistencyError;
use crate::{
pagination::{process_page, QueryBuilderExt},
Clock, GenericLookupError, LookupError,
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
};
#[derive(Debug, Clone)]
@ -49,11 +48,7 @@ struct UserLookup {
#[derive(Debug, Error)]
pub enum LoginError {
#[error("could not find user {username:?}")]
NotFound {
username: String,
#[source]
source: UserLookupError,
},
NotFound { username: String },
#[error("authentication failed for {username:?}")]
Authentication {
@ -81,18 +76,16 @@ pub async fn login(
let mut txn = conn.begin().await.context("could not start transaction")?;
let user = lookup_user_by_username(&mut txn, username)
.await
.map_err(|source| {
if source.not_found() {
LoginError::NotFound {
username: username.to_owned(),
source,
}
} else {
LoginError::Other(source.into())
}
})?;
.context("Could not find user by username")?;
let Some(user) = user else {
return Err(LoginError::NotFound { username: username.to_owned() });
};
let mut session = start_session(&mut txn, &mut rng, clock, user)
.await
.context("Could not start session")?;
let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
.await
.map_err(|source| {
@ -110,19 +103,6 @@ pub async fn login(
Ok(session)
}
#[derive(Debug, Error)]
#[error("could not fetch session")]
pub enum ActiveSessionLookupError {
Fetch(#[from] sqlx::Error),
Conversion(#[from] DatabaseInconsistencyError),
}
impl LookupError for ActiveSessionLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Fetch(sqlx::Error::RowNotFound))
}
}
#[derive(sqlx::FromRow)]
struct SessionLookup {
user_session_id: Uuid,
@ -138,9 +118,10 @@ struct SessionLookup {
}
impl TryInto<BrowserSession> for SessionLookup {
type Error = DatabaseInconsistencyError;
type Error = DatabaseInconsistencyError2;
fn try_into(self) -> Result<BrowserSession, Self::Error> {
let id = Ulid::from(self.user_id);
let primary_email = match (
self.user_email_id,
self.user_email,
@ -154,10 +135,13 @@ impl TryInto<BrowserSession> for SessionLookup {
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id))
}
};
let id = Ulid::from(self.user_id);
let user = User {
id,
username: self.username,
@ -171,7 +155,11 @@ impl TryInto<BrowserSession> for SessionLookup {
created_at,
}),
(None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(DatabaseInconsistencyError2::on(
"user_session_authentications",
))
}
};
Ok(BrowserSession {
@ -191,7 +179,7 @@ impl TryInto<BrowserSession> for SessionLookup {
pub async fn lookup_active_session(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<BrowserSession, ActiveSessionLookupError> {
) -> Result<Option<BrowserSession>, DatabaseError> {
let res = sqlx::query_as!(
SessionLookup,
r#"
@ -220,10 +208,12 @@ pub async fn lookup_active_session(
Uuid::from(id),
)
.fetch_one(executor)
.await?
.try_into()?;
.await
.to_option()?;
Ok(res)
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
@ -232,7 +222,7 @@ pub async fn lookup_active_session(
%user.id,
%user.username,
),
err(Display),
err,
)]
pub async fn get_paginated_user_sessions(
executor: impl PgExecutor<'_>,
@ -241,7 +231,7 @@ pub async fn get_paginated_user_sessions(
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<BrowserSession>), anyhow::Error> {
) -> Result<(bool, bool, Vec<BrowserSession>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
@ -289,14 +279,14 @@ pub async fn get_paginated_user_sessions(
%user.id,
user_session.id,
),
err(Display),
err,
)]
pub async fn start_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: User,
) -> Result<BrowserSession, anyhow::Error> {
) -> Result<BrowserSession, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id));
@ -311,8 +301,7 @@ pub async fn start_session(
created_at,
)
.execute(executor)
.await
.context("could not create session")?;
.await?;
let session = BrowserSession {
id,
@ -327,12 +316,12 @@ pub async fn start_session(
#[tracing::instrument(
skip_all,
fields(%user.id),
err(Display),
err,
)]
pub async fn count_active_sessions(
executor: impl PgExecutor<'_>,
user: &User,
) -> Result<usize, anyhow::Error> {
) -> Result<i64, DatabaseError> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
@ -342,8 +331,7 @@ pub async fn count_active_sessions(
Uuid::from(user.id),
)
.fetch_one(executor)
.await?
.try_into()?;
.await?;
Ok(res)
}
@ -485,7 +473,7 @@ pub async fn authenticate_session_with_upstream(
user.username = username,
user.id,
),
err(Display),
err(Debug),
)]
pub async fn register_user(
txn: &mut Transaction<'_, Postgres>,
@ -569,7 +557,7 @@ pub async fn register_passwordless_user(
%user.id,
user_password.id,
),
err(Display),
err(Debug),
)]
pub async fn set_password(
executor: impl PgExecutor<'_>,
@ -607,13 +595,13 @@ pub async fn set_password(
#[tracing::instrument(
skip_all,
fields(%user_session.id),
err(Display),
err,
)]
pub async fn end_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
user_session: &BrowserSession,
) -> Result<(), anyhow::Error> {
) -> Result<(), DatabaseError> {
let now = clock.now();
let res = sqlx::query!(
r#"
@ -626,27 +614,9 @@ pub async fn end_session(
)
.execute(executor)
.instrument(info_span!("End session"))
.await
.context("could not end session")?;
.await?;
match res.rows_affected() {
1 => Ok(()),
0 => Err(anyhow::anyhow!("no row affected")),
_ => Err(anyhow::anyhow!("too many row affected")),
}
}
#[derive(Debug, Error)]
#[error("failed to lookup user")]
pub enum UserLookupError {
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl LookupError for UserLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
}
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
@ -657,7 +627,7 @@ impl LookupError for UserLookupError {
pub async fn lookup_user_by_username(
executor: impl PgExecutor<'_>,
username: &str,
) -> Result<User, UserLookupError> {
) -> Result<Option<User>, DatabaseError> {
let res = sqlx::query_as!(
UserLookup,
r#"
@ -679,8 +649,12 @@ pub async fn lookup_user_by_username(
)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await?;
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = Ulid::from(res.user_id);
let primary_email = match (
res.user_email_id,
res.user_email,
@ -694,16 +668,20 @@ pub async fn lookup_user_by_username(
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id)
.into())
}
};
let id = Ulid::from(res.user_id);
Ok(User {
Ok(Some(User {
id,
username: res.user_username,
sub: id.to_string(),
primary_email,
})
}))
}
#[tracing::instrument(
@ -711,7 +689,7 @@ pub async fn lookup_user_by_username(
fields(user.id = %id),
err,
)]
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, UserLookupError> {
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, DatabaseError> {
let res = sqlx::query_as!(
UserLookup,
r#"
@ -735,6 +713,7 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
.instrument(info_span!("Fetch user"))
.await?;
let id = Ulid::from(res.user_id);
let primary_email = match (
res.user_email_id,
res.user_email,
@ -748,10 +727,14 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id)
.into())
}
};
let id = Ulid::from(res.user_id);
Ok(User {
id,
username: res.user_username,
@ -803,12 +786,12 @@ impl From<UserEmailLookup> for UserEmail {
#[tracing::instrument(
skip_all,
fields(%user.id, %user.username),
err(Display),
err,
)]
pub async fn get_user_emails(
executor: impl PgExecutor<'_>,
user: &User,
) -> Result<Vec<UserEmail>, anyhow::Error> {
) -> Result<Vec<UserEmail>, sqlx::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -835,12 +818,12 @@ pub async fn get_user_emails(
#[tracing::instrument(
skip_all,
fields(%user.id, %user.username),
err(Display),
err,
)]
pub async fn count_user_emails(
executor: impl PgExecutor<'_>,
user: &User,
) -> Result<i64, anyhow::Error> {
) -> Result<i64, sqlx::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
@ -859,7 +842,7 @@ pub async fn count_user_emails(
#[tracing::instrument(
skip_all,
fields(%user.id, %user.username),
err(Display),
err,
)]
pub async fn get_paginated_user_emails(
executor: impl PgExecutor<'_>,
@ -868,7 +851,7 @@ pub async fn get_paginated_user_emails(
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UserEmail>), anyhow::Error> {
) -> Result<(bool, bool, Vec<UserEmail>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
@ -908,13 +891,13 @@ pub async fn get_paginated_user_emails(
%user.username,
user_email.id = %id,
),
err(Display),
err,
)]
pub async fn get_user_email(
executor: impl PgExecutor<'_>,
user: &User,
id: Ulid,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<UserEmail, sqlx::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -946,7 +929,7 @@ pub async fn get_user_email(
user_email.id,
user_email.email = %email,
),
err(Display),
err,
)]
pub async fn add_user_email(
executor: impl PgExecutor<'_>,
@ -954,7 +937,7 @@ pub async fn add_user_email(
clock: &Clock,
user: &User,
email: String,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<UserEmail, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id));
@ -971,8 +954,7 @@ pub async fn add_user_email(
)
.execute(executor)
.instrument(info_span!("Add user email"))
.await
.context("could not insert user email")?;
.await?;
Ok(UserEmail {
id,
@ -993,7 +975,7 @@ pub async fn add_user_email(
pub async fn set_user_email_as_primary(
executor: impl PgExecutor<'_>,
user_email: &UserEmail,
) -> Result<(), anyhow::Error> {
) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"
UPDATE users
@ -1006,8 +988,7 @@ pub async fn set_user_email_as_primary(
)
.execute(executor)
.instrument(info_span!("Add user email"))
.await
.context("could not set user email as primary")?;
.await?;
Ok(())
}
@ -1018,12 +999,12 @@ pub async fn set_user_email_as_primary(
%user_email.id,
%user_email.email,
),
err(Display),
err,
)]
pub async fn remove_user_email(
executor: impl PgExecutor<'_>,
user_email: UserEmail,
) -> Result<(), anyhow::Error> {
) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"
DELETE FROM user_emails
@ -1033,8 +1014,7 @@ pub async fn remove_user_email(
)
.execute(executor)
.instrument(info_span!("Remove user email"))
.await
.context("could not remove user email")?;
.await?;
Ok(())
}
@ -1045,13 +1025,13 @@ pub async fn remove_user_email(
%user.id,
user_email.email = email,
),
err(Display),
err,
)]
pub async fn lookup_user_email(
executor: impl PgExecutor<'_>,
user: &User,
email: &str,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<Option<UserEmail>, sqlx::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -1071,9 +1051,11 @@ pub async fn lookup_user_email(
.fetch_one(executor)
.instrument(info_span!("Lookup user email"))
.await
.context("could not lookup user email")?;
.to_option()?;
Ok(res.into())
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
@ -1088,7 +1070,7 @@ pub async fn lookup_user_email_by_id(
executor: impl PgExecutor<'_>,
user: &User,
id: Ulid,
) -> Result<UserEmail, GenericLookupError> {
) -> Result<Option<UserEmail>, DatabaseError> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -1108,21 +1090,23 @@ pub async fn lookup_user_email_by_id(
.fetch_one(executor)
.instrument(info_span!("Lookup user email"))
.await
.map_err(GenericLookupError::what("user email"))?;
.to_option()?;
Ok(res.into())
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(%user_email.id),
err(Display),
err,
)]
pub async fn mark_user_email_as_verified(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut user_email: UserEmail,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<UserEmail, sqlx::Error> {
let confirmed_at = clock.now();
sqlx::query!(
r#"
@ -1135,8 +1119,7 @@ pub async fn mark_user_email_as_verified(
)
.execute(executor)
.instrument(info_span!("Confirm user email"))
.await
.context("could not update user email")?;
.await?;
user_email.confirmed_at = Some(confirmed_at);
@ -1154,14 +1137,14 @@ struct UserEmailConfirmationCodeLookup {
#[tracing::instrument(
skip_all,
fields(%user_email.id),
err(Display),
err,
)]
pub async fn lookup_user_email_verification_code(
executor: impl PgExecutor<'_>,
clock: &Clock,
user_email: UserEmail,
code: &str,
) -> Result<UserEmailVerification, anyhow::Error> {
) -> Result<Option<UserEmailVerification>, DatabaseError> {
let now = clock.now();
let res = sqlx::query_as!(
@ -1183,7 +1166,9 @@ pub async fn lookup_user_email_verification_code(
.fetch_one(executor)
.instrument(info_span!("Lookup user email verification"))
.await
.context("could not lookup user email verification")?;
.to_option()?;
let Some(res) = res else { return Ok(None) };
let state = if let Some(when) = res.consumed_at {
UserEmailVerificationState::AlreadyUsed { when }
@ -1195,13 +1180,13 @@ pub async fn lookup_user_email_verification_code(
UserEmailVerificationState::Valid
};
Ok(UserEmailVerification {
Ok(Some(UserEmailVerification {
id: res.user_email_confirmation_code_id.into(),
code: res.code,
email: user_email,
state,
created_at: res.created_at,
})
}))
}
#[tracing::instrument(
@ -1209,18 +1194,18 @@ pub async fn lookup_user_email_verification_code(
fields(
%user_email_verification.id,
),
err(Display),
err,
)]
pub async fn consume_email_verification(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut user_email_verification: UserEmailVerification,
) -> Result<UserEmailVerification, anyhow::Error> {
) -> Result<UserEmailVerification, DatabaseError> {
if !matches!(
user_email_verification.state,
UserEmailVerificationState::Valid
) {
bail!("user email verification in wrong state");
return Err(DatabaseError::InvalidOperation);
}
let consumed_at = clock.now();
@ -1236,8 +1221,7 @@ pub async fn consume_email_verification(
)
.execute(executor)
.instrument(info_span!("Consume user email verification"))
.await
.context("could not update user email verification")?;
.await?;
user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at };
@ -1252,7 +1236,7 @@ pub async fn consume_email_verification(
user_email_confirmation.id,
user_email_confirmation.code = code,
),
err(Display),
err,
)]
pub async fn add_user_email_verification_code(
executor: impl PgExecutor<'_>,
@ -1261,7 +1245,7 @@ pub async fn add_user_email_verification_code(
user_email: UserEmail,
max_age: chrono::Duration,
code: String,
) -> Result<UserEmailVerification, anyhow::Error> {
) -> Result<UserEmailVerification, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
@ -1281,8 +1265,7 @@ pub async fn add_user_email_verification_code(
)
.execute(executor)
.instrument(info_span!("Add user email verification code"))
.await
.context("could not insert user email verification code")?;
.await?;
let verification = UserEmailVerification {
id,
@ -1320,7 +1303,9 @@ mod tests {
let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
assert_eq!(session.user.id, user.id);
let user2 = lookup_user_by_username(&mut txn, "john").await?;
let user2 = lookup_user_by_username(&mut txn, "john")
.await?
.context("Could not find user")?;
assert_eq!(user.id, user2.id);
txn.commit().await?;