1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: unify user operations errors

This commit is contained in:
Quentin Gliech
2022-12-07 19:07:53 +01:00
parent f7f65e314b
commit b7cad48bbd
11 changed files with 165 additions and 206 deletions

View File

@ -14,7 +14,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
use mas_storage::user::{lookup_active_session, ActiveSessionLookupError}; use mas_storage::{user::lookup_active_session, DatabaseError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{Executor, Postgres}; use sqlx::{Executor, Postgres};
use ulid::Ulid; use ulid::Ulid;
@ -47,7 +47,7 @@ impl SessionInfo {
pub async fn load_session( pub async fn load_session(
&self, &self,
executor: impl Executor<'_, Database = Postgres>, executor: impl Executor<'_, Database = Postgres>,
) -> Result<Option<BrowserSession>, ActiveSessionLookupError> { ) -> Result<Option<BrowserSession>, DatabaseError> {
let session_id = if let Some(id) = self.current { let session_id = if let Some(id) = self.current {
id id
} else { } else {
@ -55,7 +55,7 @@ impl SessionInfo {
}; };
let res = lookup_active_session(executor, session_id).await?; let res = lookup_active_session(executor, session_id).await?;
Ok(Some(res)) Ok(res)
} }
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anyhow::Context;
use argon2::Argon2; use argon2::Argon2;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use mas_config::{DatabaseConfig, RootConfig}; use mas_config::{DatabaseConfig, RootConfig};
@ -214,7 +215,9 @@ impl Options {
let pool = config.connect().await?; let pool = config.connect().await?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let hasher = Argon2::default(); let hasher = Argon2::default();
let user = lookup_user_by_username(&mut txn, username).await?; let user = lookup_user_by_username(&mut txn, username)
.await?
.context("User not found")?;
set_password(&mut txn, &mut rng, &clock, hasher, &user, password).await?; set_password(&mut txn, &mut rng, &clock, hasher, &user, password).await?;
info!(%user.id, %user.username, "Password changed"); info!(%user.id, %user.username, "Password changed");
@ -228,8 +231,12 @@ impl Options {
let pool = config.connect().await?; let pool = config.connect().await?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let user = lookup_user_by_username(&mut txn, username).await?; let user = lookup_user_by_username(&mut txn, username)
let email = lookup_user_email(&mut txn, &user, email).await?; .await?
.context("User not found")?;
let email = lookup_user_email(&mut txn, &user, email)
.await?
.context("Email not found")?;
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?; let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -136,9 +136,7 @@ impl RootQuery {
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id) let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?;
.await
.to_option()?;
let ret = browser_session.and_then(|browser_session| { let ret = browser_session.and_then(|browser_session| {
if browser_session.user.id == current_user.id { if browser_session.user.id == current_user.id {
@ -166,9 +164,8 @@ impl RootQuery {
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
let user_email = mas_storage::user::lookup_user_email_by_id(&mut conn, &current_user, id) let user_email =
.await mas_storage::user::lookup_user_email_by_id(&mut conn, &current_user, id).await?;
.to_option()?;
Ok(user_email.map(UserEmail)) Ok(user_email.map(UserEmail))
} }

View File

@ -59,13 +59,16 @@ mod views;
/// errors. /// errors.
#[macro_export] #[macro_export]
macro_rules! impl_from_error_for_route { macro_rules! impl_from_error_for_route {
($error:ty) => { ($route_error:ty : $error:ty) => {
impl From<$error> for self::RouteError { impl From<$error> for $route_error {
fn from(e: $error) -> Self { fn from(e: $error) -> Self {
Self::Internal(Box::new(e)) Self::Internal(Box::new(e))
} }
} }
}; };
($error:ty) => {
impl_from_error_for_route!(self::RouteError: $error);
};
} }
pub use mas_axum_utils::http_client_factory::HttpClientFactory; pub use mas_axum_utils::http_client_factory::HttpClientFactory;

View File

@ -26,12 +26,9 @@ use mas_data_model::{AuthorizationGrant, BrowserSession};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::oauth2::{
oauth2::{ authorization_grant::{derive_session, fulfill_grant, get_grant_by_id},
authorization_grant::{derive_session, fulfill_grant, get_grant_by_id}, consent::fetch_client_consent,
consent::fetch_client_consent,
},
user::ActiveSessionLookupError,
}; };
use mas_templates::Templates; use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse}; use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
@ -39,9 +36,8 @@ use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use super::callback::{ use super::callback::CallbackDestination;
CallbackDestination, CallbackDestinationError, IntoCallbackDestinationError, use crate::impl_from_error_for_route;
};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum RouteError { pub enum RouteError {
@ -49,7 +45,7 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>), Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)] #[error(transparent)]
Anyhow(anyhow::Error), Anyhow(#[from] anyhow::Error),
#[error("authorization grant is not in a pending state")] #[error("authorization grant is not in a pending state")]
NotPending, NotPending,
@ -74,35 +70,10 @@ impl IntoResponse for RouteError {
} }
} }
impl From<anyhow::Error> for RouteError { impl_from_error_for_route!(sqlx::Error);
fn from(e: anyhow::Error) -> Self { impl_from_error_for_route!(mas_storage::DatabaseError);
Self::Anyhow(e) impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
} impl_from_error_for_route!(super::callback::CallbackDestinationError);
}
impl From<sqlx::Error> for RouteError {
fn from(e: sqlx::Error) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<ActiveSessionLookupError> for RouteError {
fn from(e: ActiveSessionLookupError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<IntoCallbackDestinationError> for RouteError {
fn from(e: IntoCallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<CallbackDestinationError> for RouteError {
fn from(e: CallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
pub(crate) async fn get( pub(crate) async fn get(
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
@ -171,17 +142,8 @@ pub enum GrantCompletionError {
PolicyViolation, PolicyViolation,
} }
impl From<sqlx::Error> for GrantCompletionError { impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
fn from(e: sqlx::Error) -> Self { impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
Self::Internal(Box::new(e))
}
}
impl From<IntoCallbackDestinationError> for GrantCompletionError {
fn from(e: IntoCallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
pub(crate) async fn complete( pub(crate) async fn complete(
grant: AuthorizationGrant, grant: AuthorizationGrant,

View File

@ -75,8 +75,6 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(sqlx::Error); impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::GenericLookupError); impl_from_error_for_route!(mas_storage::GenericLookupError);
impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError);
impl_from_error_for_route!(mas_storage::user::UserLookupError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::DatabaseError); impl_from_error_for_route!(mas_storage::DatabaseError);

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anyhow::Context;
use axum::{ use axum::{
extract::{Form, Path, Query, State}, extract::{Form, Path, Query, State},
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
@ -64,7 +65,9 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id).await?; let user_email = lookup_user_email_by_id(&mut conn, &session.user, id)
.await?
.context("Could not find user email")?;
if user_email.confirmed_at.is_some() { if user_email.confirmed_at.is_some() {
// This email was already verified, skip // This email was already verified, skip
@ -103,15 +106,18 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let email = lookup_user_email_by_id(&mut txn, &session.user, id).await?; let email = lookup_user_email_by_id(&mut txn, &session.user, id)
.await?
.context("Could not find user email")?;
if session.user.primary_email.is_none() { if session.user.primary_email.is_none() {
set_user_email_as_primary(&mut txn, &email).await?; set_user_email_as_primary(&mut txn, &email).await?;
} }
// TODO: make those 8 hours configurable // TODO: make those 8 hours configurable
let verification = let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code)
lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?; .await?
.context("Invalid code")?;
// TODO: display nice errors if the code was already consumed or expired // TODO: display nice errors if the code was already consumed or expired
let verification = consume_email_verification(&mut txn, &clock, verification).await?; let verification = consume_email_verification(&mut txn, &clock, verification).await?;

View File

@ -306,7 +306,9 @@ pub async fn compat_login(
let mut txn = conn.begin().await.context("could not start transaction")?; let mut txn = conn.begin().await.context("could not start transaction")?;
// First, lookup the user // First, lookup the user
let user = lookup_user_by_username(&mut txn, username).await?; let user = lookup_user_by_username(&mut txn, username)
.await?
.context("Could not lookup username")?;
tracing::Span::current().record("user.id", tracing::field::display(user.id)); tracing::Span::current().record("user.id", tracing::field::display(user.id));
// Now, fetch the hashed password from the user associated with that session // Now, fetch the hashed password from the user associated with that session

View File

@ -137,7 +137,9 @@ pub async fn get_paginated_user_oauth_sessions(
// ideal // ideal
let mut browser_sessions: HashMap<Ulid, BrowserSession> = HashMap::new(); let mut browser_sessions: HashMap<Ulid, BrowserSession> = HashMap::new();
for id in browser_session_ids { for id in browser_session_ids {
let v = lookup_active_session(&mut *conn, id).await?; let v = lookup_active_session(&mut *conn, id)
.await?
.context("Failed to load active session")?;
browser_sessions.insert(id, v); browser_sessions.insert(id, v);
} }

View File

@ -14,7 +14,7 @@
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use anyhow::{bail, Context}; use anyhow::Context;
use argon2::Argon2; use argon2::Argon2;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{
@ -30,10 +30,9 @@ use tracing::{info_span, Instrument};
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::DatabaseInconsistencyError;
use crate::{ use crate::{
pagination::{process_page, QueryBuilderExt}, pagination::{process_page, QueryBuilderExt},
Clock, GenericLookupError, LookupError, Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -49,11 +48,7 @@ struct UserLookup {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum LoginError { pub enum LoginError {
#[error("could not find user {username:?}")] #[error("could not find user {username:?}")]
NotFound { NotFound { username: String },
username: String,
#[source]
source: UserLookupError,
},
#[error("authentication failed for {username:?}")] #[error("authentication failed for {username:?}")]
Authentication { Authentication {
@ -81,18 +76,16 @@ pub async fn login(
let mut txn = conn.begin().await.context("could not start transaction")?; let mut txn = conn.begin().await.context("could not start transaction")?;
let user = lookup_user_by_username(&mut txn, username) let user = lookup_user_by_username(&mut txn, username)
.await .await
.map_err(|source| { .context("Could not find user by username")?;
if source.not_found() {
LoginError::NotFound { let Some(user) = user else {
username: username.to_owned(), return Err(LoginError::NotFound { username: username.to_owned() });
source, };
}
} else { let mut session = start_session(&mut txn, &mut rng, clock, user)
LoginError::Other(source.into()) .await
} .context("Could not start session")?;
})?;
let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
authenticate_session(&mut txn, &mut rng, clock, &mut session, password) authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
.await .await
.map_err(|source| { .map_err(|source| {
@ -110,19 +103,6 @@ pub async fn login(
Ok(session) Ok(session)
} }
#[derive(Debug, Error)]
#[error("could not fetch session")]
pub enum ActiveSessionLookupError {
Fetch(#[from] sqlx::Error),
Conversion(#[from] DatabaseInconsistencyError),
}
impl LookupError for ActiveSessionLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Fetch(sqlx::Error::RowNotFound))
}
}
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
struct SessionLookup { struct SessionLookup {
user_session_id: Uuid, user_session_id: Uuid,
@ -138,9 +118,10 @@ struct SessionLookup {
} }
impl TryInto<BrowserSession> for SessionLookup { impl TryInto<BrowserSession> for SessionLookup {
type Error = DatabaseInconsistencyError; type Error = DatabaseInconsistencyError2;
fn try_into(self) -> Result<BrowserSession, Self::Error> { fn try_into(self) -> Result<BrowserSession, Self::Error> {
let id = Ulid::from(self.user_id);
let primary_email = match ( let primary_email = match (
self.user_email_id, self.user_email_id,
self.user_email, self.user_email,
@ -154,10 +135,13 @@ impl TryInto<BrowserSession> for SessionLookup {
confirmed_at, confirmed_at,
}), }),
(None, None, None, None) => None, (None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError), _ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id))
}
}; };
let id = Ulid::from(self.user_id);
let user = User { let user = User {
id, id,
username: self.username, username: self.username,
@ -171,7 +155,11 @@ impl TryInto<BrowserSession> for SessionLookup {
created_at, created_at,
}), }),
(None, None) => None, (None, None) => None,
_ => return Err(DatabaseInconsistencyError), _ => {
return Err(DatabaseInconsistencyError2::on(
"user_session_authentications",
))
}
}; };
Ok(BrowserSession { Ok(BrowserSession {
@ -191,7 +179,7 @@ impl TryInto<BrowserSession> for SessionLookup {
pub async fn lookup_active_session( pub async fn lookup_active_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
id: Ulid, id: Ulid,
) -> Result<BrowserSession, ActiveSessionLookupError> { ) -> Result<Option<BrowserSession>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
SessionLookup, SessionLookup,
r#" r#"
@ -220,10 +208,12 @@ pub async fn lookup_active_session(
Uuid::from(id), Uuid::from(id),
) )
.fetch_one(executor) .fetch_one(executor)
.await? .await
.try_into()?; .to_option()?;
Ok(res) let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
} }
#[tracing::instrument( #[tracing::instrument(
@ -232,7 +222,7 @@ pub async fn lookup_active_session(
%user.id, %user.id,
%user.username, %user.username,
), ),
err(Display), err,
)] )]
pub async fn get_paginated_user_sessions( pub async fn get_paginated_user_sessions(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
@ -241,7 +231,7 @@ pub async fn get_paginated_user_sessions(
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<(bool, bool, Vec<BrowserSession>), anyhow::Error> { ) -> Result<(bool, bool, Vec<BrowserSession>), DatabaseError> {
let mut query = QueryBuilder::new( let mut query = QueryBuilder::new(
r#" r#"
SELECT SELECT
@ -289,14 +279,14 @@ pub async fn get_paginated_user_sessions(
%user.id, %user.id,
user_session.id, user_session.id,
), ),
err(Display), err,
)] )]
pub async fn start_session( pub async fn start_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send, mut rng: impl Rng + Send,
clock: &Clock, clock: &Clock,
user: User, user: User,
) -> Result<BrowserSession, anyhow::Error> { ) -> Result<BrowserSession, sqlx::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id)); tracing::Span::current().record("user_session.id", tracing::field::display(id));
@ -311,8 +301,7 @@ pub async fn start_session(
created_at, created_at,
) )
.execute(executor) .execute(executor)
.await .await?;
.context("could not create session")?;
let session = BrowserSession { let session = BrowserSession {
id, id,
@ -327,12 +316,12 @@ pub async fn start_session(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user.id), fields(%user.id),
err(Display), err,
)] )]
pub async fn count_active_sessions( pub async fn count_active_sessions(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
) -> Result<usize, anyhow::Error> { ) -> Result<i64, DatabaseError> {
let res = sqlx::query_scalar!( let res = sqlx::query_scalar!(
r#" r#"
SELECT COUNT(*) as "count!" SELECT COUNT(*) as "count!"
@ -342,8 +331,7 @@ pub async fn count_active_sessions(
Uuid::from(user.id), Uuid::from(user.id),
) )
.fetch_one(executor) .fetch_one(executor)
.await? .await?;
.try_into()?;
Ok(res) Ok(res)
} }
@ -485,7 +473,7 @@ pub async fn authenticate_session_with_upstream(
user.username = username, user.username = username,
user.id, user.id,
), ),
err(Display), err(Debug),
)] )]
pub async fn register_user( pub async fn register_user(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
@ -569,7 +557,7 @@ pub async fn register_passwordless_user(
%user.id, %user.id,
user_password.id, user_password.id,
), ),
err(Display), err(Debug),
)] )]
pub async fn set_password( pub async fn set_password(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
@ -607,13 +595,13 @@ pub async fn set_password(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user_session.id), fields(%user_session.id),
err(Display), err,
)] )]
pub async fn end_session( pub async fn end_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
user_session: &BrowserSession, user_session: &BrowserSession,
) -> Result<(), anyhow::Error> { ) -> Result<(), DatabaseError> {
let now = clock.now(); let now = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
@ -626,27 +614,9 @@ pub async fn end_session(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("End session")) .instrument(info_span!("End session"))
.await .await?;
.context("could not end session")?;
match res.rows_affected() { DatabaseError::ensure_affected_rows(&res, 1)
1 => Ok(()),
0 => Err(anyhow::anyhow!("no row affected")),
_ => Err(anyhow::anyhow!("too many row affected")),
}
}
#[derive(Debug, Error)]
#[error("failed to lookup user")]
pub enum UserLookupError {
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl LookupError for UserLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
}
} }
#[tracing::instrument( #[tracing::instrument(
@ -657,7 +627,7 @@ impl LookupError for UserLookupError {
pub async fn lookup_user_by_username( pub async fn lookup_user_by_username(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
username: &str, username: &str,
) -> Result<User, UserLookupError> { ) -> Result<Option<User>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserLookup, UserLookup,
r#" r#"
@ -679,8 +649,12 @@ pub async fn lookup_user_by_username(
) )
.fetch_one(executor) .fetch_one(executor)
.instrument(info_span!("Fetch user")) .instrument(info_span!("Fetch user"))
.await?; .await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = Ulid::from(res.user_id);
let primary_email = match ( let primary_email = match (
res.user_email_id, res.user_email_id,
res.user_email, res.user_email,
@ -694,16 +668,20 @@ pub async fn lookup_user_by_username(
confirmed_at, confirmed_at,
}), }),
(None, None, None, None) => None, (None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()), _ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id)
.into())
}
}; };
let id = Ulid::from(res.user_id); Ok(Some(User {
Ok(User {
id, id,
username: res.user_username, username: res.user_username,
sub: id.to_string(), sub: id.to_string(),
primary_email, primary_email,
}) }))
} }
#[tracing::instrument( #[tracing::instrument(
@ -711,7 +689,7 @@ pub async fn lookup_user_by_username(
fields(user.id = %id), fields(user.id = %id),
err, err,
)] )]
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, UserLookupError> { pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserLookup, UserLookup,
r#" r#"
@ -735,6 +713,7 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
.instrument(info_span!("Fetch user")) .instrument(info_span!("Fetch user"))
.await?; .await?;
let id = Ulid::from(res.user_id);
let primary_email = match ( let primary_email = match (
res.user_email_id, res.user_email_id,
res.user_email, res.user_email,
@ -748,10 +727,14 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
confirmed_at, confirmed_at,
}), }),
(None, None, None, None) => None, (None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()), _ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id)
.into())
}
}; };
let id = Ulid::from(res.user_id);
Ok(User { Ok(User {
id, id,
username: res.user_username, username: res.user_username,
@ -803,12 +786,12 @@ impl From<UserEmailLookup> for UserEmail {
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user.id, %user.username), fields(%user.id, %user.username),
err(Display), err,
)] )]
pub async fn get_user_emails( pub async fn get_user_emails(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
) -> Result<Vec<UserEmail>, anyhow::Error> { ) -> Result<Vec<UserEmail>, sqlx::Error> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserEmailLookup, UserEmailLookup,
r#" r#"
@ -835,12 +818,12 @@ pub async fn get_user_emails(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user.id, %user.username), fields(%user.id, %user.username),
err(Display), err,
)] )]
pub async fn count_user_emails( pub async fn count_user_emails(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
) -> Result<i64, anyhow::Error> { ) -> Result<i64, sqlx::Error> {
let res = sqlx::query_scalar!( let res = sqlx::query_scalar!(
r#" r#"
SELECT COUNT(*) SELECT COUNT(*)
@ -859,7 +842,7 @@ pub async fn count_user_emails(
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user.id, %user.username), fields(%user.id, %user.username),
err(Display), err,
)] )]
pub async fn get_paginated_user_emails( pub async fn get_paginated_user_emails(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
@ -868,7 +851,7 @@ pub async fn get_paginated_user_emails(
after: Option<Ulid>, after: Option<Ulid>,
first: Option<usize>, first: Option<usize>,
last: Option<usize>, last: Option<usize>,
) -> Result<(bool, bool, Vec<UserEmail>), anyhow::Error> { ) -> Result<(bool, bool, Vec<UserEmail>), DatabaseError> {
let mut query = QueryBuilder::new( let mut query = QueryBuilder::new(
r#" r#"
SELECT SELECT
@ -908,13 +891,13 @@ pub async fn get_paginated_user_emails(
%user.username, %user.username,
user_email.id = %id, user_email.id = %id,
), ),
err(Display), err,
)] )]
pub async fn get_user_email( pub async fn get_user_email(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
id: Ulid, id: Ulid,
) -> Result<UserEmail, anyhow::Error> { ) -> Result<UserEmail, sqlx::Error> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserEmailLookup, UserEmailLookup,
r#" r#"
@ -946,7 +929,7 @@ pub async fn get_user_email(
user_email.id, user_email.id,
user_email.email = %email, user_email.email = %email,
), ),
err(Display), err,
)] )]
pub async fn add_user_email( pub async fn add_user_email(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
@ -954,7 +937,7 @@ pub async fn add_user_email(
clock: &Clock, clock: &Clock,
user: &User, user: &User,
email: String, email: String,
) -> Result<UserEmail, anyhow::Error> { ) -> Result<UserEmail, sqlx::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id)); tracing::Span::current().record("user_email.id", tracing::field::display(id));
@ -971,8 +954,7 @@ pub async fn add_user_email(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("Add user email")) .instrument(info_span!("Add user email"))
.await .await?;
.context("could not insert user email")?;
Ok(UserEmail { Ok(UserEmail {
id, id,
@ -993,7 +975,7 @@ pub async fn add_user_email(
pub async fn set_user_email_as_primary( pub async fn set_user_email_as_primary(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user_email: &UserEmail, user_email: &UserEmail,
) -> Result<(), anyhow::Error> { ) -> Result<(), sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
UPDATE users UPDATE users
@ -1006,8 +988,7 @@ pub async fn set_user_email_as_primary(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("Add user email")) .instrument(info_span!("Add user email"))
.await .await?;
.context("could not set user email as primary")?;
Ok(()) Ok(())
} }
@ -1018,12 +999,12 @@ pub async fn set_user_email_as_primary(
%user_email.id, %user_email.id,
%user_email.email, %user_email.email,
), ),
err(Display), err,
)] )]
pub async fn remove_user_email( pub async fn remove_user_email(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user_email: UserEmail, user_email: UserEmail,
) -> Result<(), anyhow::Error> { ) -> Result<(), sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
DELETE FROM user_emails DELETE FROM user_emails
@ -1033,8 +1014,7 @@ pub async fn remove_user_email(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("Remove user email")) .instrument(info_span!("Remove user email"))
.await .await?;
.context("could not remove user email")?;
Ok(()) Ok(())
} }
@ -1045,13 +1025,13 @@ pub async fn remove_user_email(
%user.id, %user.id,
user_email.email = email, user_email.email = email,
), ),
err(Display), err,
)] )]
pub async fn lookup_user_email( pub async fn lookup_user_email(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
email: &str, email: &str,
) -> Result<UserEmail, anyhow::Error> { ) -> Result<Option<UserEmail>, sqlx::Error> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserEmailLookup, UserEmailLookup,
r#" r#"
@ -1071,9 +1051,11 @@ pub async fn lookup_user_email(
.fetch_one(executor) .fetch_one(executor)
.instrument(info_span!("Lookup user email")) .instrument(info_span!("Lookup user email"))
.await .await
.context("could not lookup user email")?; .to_option()?;
Ok(res.into()) let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
} }
#[tracing::instrument( #[tracing::instrument(
@ -1088,7 +1070,7 @@ pub async fn lookup_user_email_by_id(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
user: &User, user: &User,
id: Ulid, id: Ulid,
) -> Result<UserEmail, GenericLookupError> { ) -> Result<Option<UserEmail>, DatabaseError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserEmailLookup, UserEmailLookup,
r#" r#"
@ -1108,21 +1090,23 @@ pub async fn lookup_user_email_by_id(
.fetch_one(executor) .fetch_one(executor)
.instrument(info_span!("Lookup user email")) .instrument(info_span!("Lookup user email"))
.await .await
.map_err(GenericLookupError::what("user email"))?; .to_option()?;
Ok(res.into()) let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
} }
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user_email.id), fields(%user_email.id),
err(Display), err,
)] )]
pub async fn mark_user_email_as_verified( pub async fn mark_user_email_as_verified(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
mut user_email: UserEmail, mut user_email: UserEmail,
) -> Result<UserEmail, anyhow::Error> { ) -> Result<UserEmail, sqlx::Error> {
let confirmed_at = clock.now(); let confirmed_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
@ -1135,8 +1119,7 @@ pub async fn mark_user_email_as_verified(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("Confirm user email")) .instrument(info_span!("Confirm user email"))
.await .await?;
.context("could not update user email")?;
user_email.confirmed_at = Some(confirmed_at); user_email.confirmed_at = Some(confirmed_at);
@ -1154,14 +1137,14 @@ struct UserEmailConfirmationCodeLookup {
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(%user_email.id), fields(%user_email.id),
err(Display), err,
)] )]
pub async fn lookup_user_email_verification_code( pub async fn lookup_user_email_verification_code(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
user_email: UserEmail, user_email: UserEmail,
code: &str, code: &str,
) -> Result<UserEmailVerification, anyhow::Error> { ) -> Result<Option<UserEmailVerification>, DatabaseError> {
let now = clock.now(); let now = clock.now();
let res = sqlx::query_as!( let res = sqlx::query_as!(
@ -1183,7 +1166,9 @@ pub async fn lookup_user_email_verification_code(
.fetch_one(executor) .fetch_one(executor)
.instrument(info_span!("Lookup user email verification")) .instrument(info_span!("Lookup user email verification"))
.await .await
.context("could not lookup user email verification")?; .to_option()?;
let Some(res) = res else { return Ok(None) };
let state = if let Some(when) = res.consumed_at { let state = if let Some(when) = res.consumed_at {
UserEmailVerificationState::AlreadyUsed { when } UserEmailVerificationState::AlreadyUsed { when }
@ -1195,13 +1180,13 @@ pub async fn lookup_user_email_verification_code(
UserEmailVerificationState::Valid UserEmailVerificationState::Valid
}; };
Ok(UserEmailVerification { Ok(Some(UserEmailVerification {
id: res.user_email_confirmation_code_id.into(), id: res.user_email_confirmation_code_id.into(),
code: res.code, code: res.code,
email: user_email, email: user_email,
state, state,
created_at: res.created_at, created_at: res.created_at,
}) }))
} }
#[tracing::instrument( #[tracing::instrument(
@ -1209,18 +1194,18 @@ pub async fn lookup_user_email_verification_code(
fields( fields(
%user_email_verification.id, %user_email_verification.id,
), ),
err(Display), err,
)] )]
pub async fn consume_email_verification( pub async fn consume_email_verification(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock, clock: &Clock,
mut user_email_verification: UserEmailVerification, mut user_email_verification: UserEmailVerification,
) -> Result<UserEmailVerification, anyhow::Error> { ) -> Result<UserEmailVerification, DatabaseError> {
if !matches!( if !matches!(
user_email_verification.state, user_email_verification.state,
UserEmailVerificationState::Valid UserEmailVerificationState::Valid
) { ) {
bail!("user email verification in wrong state"); return Err(DatabaseError::InvalidOperation);
} }
let consumed_at = clock.now(); let consumed_at = clock.now();
@ -1236,8 +1221,7 @@ pub async fn consume_email_verification(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("Consume user email verification")) .instrument(info_span!("Consume user email verification"))
.await .await?;
.context("could not update user email verification")?;
user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at }; user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at };
@ -1252,7 +1236,7 @@ pub async fn consume_email_verification(
user_email_confirmation.id, user_email_confirmation.id,
user_email_confirmation.code = code, user_email_confirmation.code = code,
), ),
err(Display), err,
)] )]
pub async fn add_user_email_verification_code( pub async fn add_user_email_verification_code(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
@ -1261,7 +1245,7 @@ pub async fn add_user_email_verification_code(
user_email: UserEmail, user_email: UserEmail,
max_age: chrono::Duration, max_age: chrono::Duration,
code: String, code: String,
) -> Result<UserEmailVerification, anyhow::Error> { ) -> Result<UserEmailVerification, sqlx::Error> {
let created_at = clock.now(); let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
@ -1281,8 +1265,7 @@ pub async fn add_user_email_verification_code(
) )
.execute(executor) .execute(executor)
.instrument(info_span!("Add user email verification code")) .instrument(info_span!("Add user email verification code"))
.await .await?;
.context("could not insert user email verification code")?;
let verification = UserEmailVerification { let verification = UserEmailVerification {
id, id,
@ -1320,7 +1303,9 @@ mod tests {
let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?; let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
assert_eq!(session.user.id, user.id); assert_eq!(session.user.id, user.id);
let user2 = lookup_user_by_username(&mut txn, "john").await?; let user2 = lookup_user_by_username(&mut txn, "john")
.await?
.context("Could not find user")?;
assert_eq!(user.id, user2.id); assert_eq!(user.id, user2.id);
txn.commit().await?; txn.commit().await?;

View File

@ -535,20 +535,17 @@ where {
/// Context used by the `account/index.html` template /// Context used by the `account/index.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct AccountContext { pub struct AccountContext {
active_sessions: usize, active_sessions: i64,
emails: Vec<UserEmail>, emails: Vec<UserEmail>,
} }
impl AccountContext { impl AccountContext {
/// Constructs a context for the "my account" page /// Constructs a context for the "my account" page
#[must_use] #[must_use]
pub fn new<T>(active_sessions: usize, emails: Vec<T>) -> Self pub fn new(active_sessions: i64, emails: Vec<UserEmail>) -> Self {
where
T: Into<UserEmail>,
{
Self { Self {
active_sessions, active_sessions,
emails: emails.into_iter().map(Into::into).collect(), emails,
} }
} }
} }