1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

storage: unify user operations errors

This commit is contained in:
Quentin Gliech
2022-12-07 19:07:53 +01:00
parent f7f65e314b
commit b7cad48bbd
11 changed files with 165 additions and 206 deletions

View File

@ -14,7 +14,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession;
use mas_storage::user::{lookup_active_session, ActiveSessionLookupError};
use mas_storage::{user::lookup_active_session, DatabaseError};
use serde::{Deserialize, Serialize};
use sqlx::{Executor, Postgres};
use ulid::Ulid;
@ -47,7 +47,7 @@ impl SessionInfo {
pub async fn load_session(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> Result<Option<BrowserSession>, ActiveSessionLookupError> {
) -> Result<Option<BrowserSession>, DatabaseError> {
let session_id = if let Some(id) = self.current {
id
} else {
@ -55,7 +55,7 @@ impl SessionInfo {
};
let res = lookup_active_session(executor, session_id).await?;
Ok(Some(res))
Ok(res)
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use argon2::Argon2;
use clap::{Parser, ValueEnum};
use mas_config::{DatabaseConfig, RootConfig};
@ -214,7 +215,9 @@ impl Options {
let pool = config.connect().await?;
let mut txn = pool.begin().await?;
let hasher = Argon2::default();
let user = lookup_user_by_username(&mut txn, username).await?;
let user = lookup_user_by_username(&mut txn, username)
.await?
.context("User not found")?;
set_password(&mut txn, &mut rng, &clock, hasher, &user, password).await?;
info!(%user.id, %user.username, "Password changed");
@ -228,8 +231,12 @@ impl Options {
let pool = config.connect().await?;
let mut txn = pool.begin().await?;
let user = lookup_user_by_username(&mut txn, username).await?;
let email = lookup_user_email(&mut txn, &user, email).await?;
let user = lookup_user_by_username(&mut txn, username)
.await?
.context("User not found")?;
let email = lookup_user_email(&mut txn, &user, email)
.await?
.context("Email not found")?;
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
txn.commit().await?;

View File

@ -136,9 +136,7 @@ impl RootQuery {
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id)
.await
.to_option()?;
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?;
let ret = browser_session.and_then(|browser_session| {
if browser_session.user.id == current_user.id {
@ -166,9 +164,8 @@ impl RootQuery {
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
let user_email = mas_storage::user::lookup_user_email_by_id(&mut conn, &current_user, id)
.await
.to_option()?;
let user_email =
mas_storage::user::lookup_user_email_by_id(&mut conn, &current_user, id).await?;
Ok(user_email.map(UserEmail))
}

View File

@ -59,13 +59,16 @@ mod views;
/// errors.
#[macro_export]
macro_rules! impl_from_error_for_route {
($error:ty) => {
impl From<$error> for self::RouteError {
($route_error:ty : $error:ty) => {
impl From<$error> for $route_error {
fn from(e: $error) -> Self {
Self::Internal(Box::new(e))
}
}
};
($error:ty) => {
impl_from_error_for_route!(self::RouteError: $error);
};
}
pub use mas_axum_utils::http_client_factory::HttpClientFactory;

View File

@ -26,12 +26,9 @@ use mas_data_model::{AuthorizationGrant, BrowserSession};
use mas_keystore::Encrypter;
use mas_policy::PolicyFactory;
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{
use mas_storage::oauth2::{
authorization_grant::{derive_session, fulfill_grant, get_grant_by_id},
consent::fetch_client_consent,
},
user::ActiveSessionLookupError,
};
use mas_templates::Templates;
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
@ -39,9 +36,8 @@ use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error;
use ulid::Ulid;
use super::callback::{
CallbackDestination, CallbackDestinationError, IntoCallbackDestinationError,
};
use super::callback::CallbackDestination;
use crate::impl_from_error_for_route;
#[derive(Debug, Error)]
pub enum RouteError {
@ -49,7 +45,7 @@ pub enum RouteError {
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
Anyhow(anyhow::Error),
Anyhow(#[from] anyhow::Error),
#[error("authorization grant is not in a pending state")]
NotPending,
@ -74,35 +70,10 @@ impl IntoResponse for RouteError {
}
}
impl From<anyhow::Error> for RouteError {
fn from(e: anyhow::Error) -> Self {
Self::Anyhow(e)
}
}
impl From<sqlx::Error> for RouteError {
fn from(e: sqlx::Error) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<ActiveSessionLookupError> for RouteError {
fn from(e: ActiveSessionLookupError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<IntoCallbackDestinationError> for RouteError {
fn from(e: IntoCallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<CallbackDestinationError> for RouteError {
fn from(e: CallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_storage::DatabaseError);
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(super::callback::CallbackDestinationError);
pub(crate) async fn get(
State(policy_factory): State<Arc<PolicyFactory>>,
@ -171,17 +142,8 @@ pub enum GrantCompletionError {
PolicyViolation,
}
impl From<sqlx::Error> for GrantCompletionError {
fn from(e: sqlx::Error) -> Self {
Self::Internal(Box::new(e))
}
}
impl From<IntoCallbackDestinationError> for GrantCompletionError {
fn from(e: IntoCallbackDestinationError) -> Self {
Self::Internal(Box::new(e))
}
}
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
pub(crate) async fn complete(
grant: AuthorizationGrant,

View File

@ -75,8 +75,6 @@ pub(crate) enum RouteError {
impl_from_error_for_route!(sqlx::Error);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::GenericLookupError);
impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError);
impl_from_error_for_route!(mas_storage::user::UserLookupError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::DatabaseError);

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use axum::{
extract::{Form, Path, Query, State},
response::{Html, IntoResponse, Response},
@ -64,7 +65,9 @@ pub(crate) async fn get(
return Ok((cookie_jar, login.go()).into_response());
};
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id).await?;
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id)
.await?
.context("Could not find user email")?;
if user_email.confirmed_at.is_some() {
// This email was already verified, skip
@ -103,15 +106,18 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
let email = lookup_user_email_by_id(&mut txn, &session.user, id).await?;
let email = lookup_user_email_by_id(&mut txn, &session.user, id)
.await?
.context("Could not find user email")?;
if session.user.primary_email.is_none() {
set_user_email_as_primary(&mut txn, &email).await?;
}
// TODO: make those 8 hours configurable
let verification =
lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?;
let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code)
.await?
.context("Invalid code")?;
// TODO: display nice errors if the code was already consumed or expired
let verification = consume_email_verification(&mut txn, &clock, verification).await?;

View File

@ -306,7 +306,9 @@ pub async fn compat_login(
let mut txn = conn.begin().await.context("could not start transaction")?;
// First, lookup the user
let user = lookup_user_by_username(&mut txn, username).await?;
let user = lookup_user_by_username(&mut txn, username)
.await?
.context("Could not lookup username")?;
tracing::Span::current().record("user.id", tracing::field::display(user.id));
// Now, fetch the hashed password from the user associated with that session

View File

@ -137,7 +137,9 @@ pub async fn get_paginated_user_oauth_sessions(
// ideal
let mut browser_sessions: HashMap<Ulid, BrowserSession> = HashMap::new();
for id in browser_session_ids {
let v = lookup_active_session(&mut *conn, id).await?;
let v = lookup_active_session(&mut *conn, id)
.await?
.context("Failed to load active session")?;
browser_sessions.insert(id, v);
}

View File

@ -14,7 +14,7 @@
use std::borrow::BorrowMut;
use anyhow::{bail, Context};
use anyhow::Context;
use argon2::Argon2;
use chrono::{DateTime, Utc};
use mas_data_model::{
@ -30,10 +30,9 @@ use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use super::DatabaseInconsistencyError;
use crate::{
pagination::{process_page, QueryBuilderExt},
Clock, GenericLookupError, LookupError,
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
};
#[derive(Debug, Clone)]
@ -49,11 +48,7 @@ struct UserLookup {
#[derive(Debug, Error)]
pub enum LoginError {
#[error("could not find user {username:?}")]
NotFound {
username: String,
#[source]
source: UserLookupError,
},
NotFound { username: String },
#[error("authentication failed for {username:?}")]
Authentication {
@ -81,18 +76,16 @@ pub async fn login(
let mut txn = conn.begin().await.context("could not start transaction")?;
let user = lookup_user_by_username(&mut txn, username)
.await
.map_err(|source| {
if source.not_found() {
LoginError::NotFound {
username: username.to_owned(),
source,
}
} else {
LoginError::Other(source.into())
}
})?;
.context("Could not find user by username")?;
let Some(user) = user else {
return Err(LoginError::NotFound { username: username.to_owned() });
};
let mut session = start_session(&mut txn, &mut rng, clock, user)
.await
.context("Could not start session")?;
let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
.await
.map_err(|source| {
@ -110,19 +103,6 @@ pub async fn login(
Ok(session)
}
#[derive(Debug, Error)]
#[error("could not fetch session")]
pub enum ActiveSessionLookupError {
Fetch(#[from] sqlx::Error),
Conversion(#[from] DatabaseInconsistencyError),
}
impl LookupError for ActiveSessionLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Fetch(sqlx::Error::RowNotFound))
}
}
#[derive(sqlx::FromRow)]
struct SessionLookup {
user_session_id: Uuid,
@ -138,9 +118,10 @@ struct SessionLookup {
}
impl TryInto<BrowserSession> for SessionLookup {
type Error = DatabaseInconsistencyError;
type Error = DatabaseInconsistencyError2;
fn try_into(self) -> Result<BrowserSession, Self::Error> {
let id = Ulid::from(self.user_id);
let primary_email = match (
self.user_email_id,
self.user_email,
@ -154,10 +135,13 @@ impl TryInto<BrowserSession> for SessionLookup {
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id))
}
};
let id = Ulid::from(self.user_id);
let user = User {
id,
username: self.username,
@ -171,7 +155,11 @@ impl TryInto<BrowserSession> for SessionLookup {
created_at,
}),
(None, None) => None,
_ => return Err(DatabaseInconsistencyError),
_ => {
return Err(DatabaseInconsistencyError2::on(
"user_session_authentications",
))
}
};
Ok(BrowserSession {
@ -191,7 +179,7 @@ impl TryInto<BrowserSession> for SessionLookup {
pub async fn lookup_active_session(
executor: impl PgExecutor<'_>,
id: Ulid,
) -> Result<BrowserSession, ActiveSessionLookupError> {
) -> Result<Option<BrowserSession>, DatabaseError> {
let res = sqlx::query_as!(
SessionLookup,
r#"
@ -220,10 +208,12 @@ pub async fn lookup_active_session(
Uuid::from(id),
)
.fetch_one(executor)
.await?
.try_into()?;
.await
.to_option()?;
Ok(res)
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
}
#[tracing::instrument(
@ -232,7 +222,7 @@ pub async fn lookup_active_session(
%user.id,
%user.username,
),
err(Display),
err,
)]
pub async fn get_paginated_user_sessions(
executor: impl PgExecutor<'_>,
@ -241,7 +231,7 @@ pub async fn get_paginated_user_sessions(
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<BrowserSession>), anyhow::Error> {
) -> Result<(bool, bool, Vec<BrowserSession>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
@ -289,14 +279,14 @@ pub async fn get_paginated_user_sessions(
%user.id,
user_session.id,
),
err(Display),
err,
)]
pub async fn start_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: User,
) -> Result<BrowserSession, anyhow::Error> {
) -> Result<BrowserSession, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id));
@ -311,8 +301,7 @@ pub async fn start_session(
created_at,
)
.execute(executor)
.await
.context("could not create session")?;
.await?;
let session = BrowserSession {
id,
@ -327,12 +316,12 @@ pub async fn start_session(
#[tracing::instrument(
skip_all,
fields(%user.id),
err(Display),
err,
)]
pub async fn count_active_sessions(
executor: impl PgExecutor<'_>,
user: &User,
) -> Result<usize, anyhow::Error> {
) -> Result<i64, DatabaseError> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
@ -342,8 +331,7 @@ pub async fn count_active_sessions(
Uuid::from(user.id),
)
.fetch_one(executor)
.await?
.try_into()?;
.await?;
Ok(res)
}
@ -485,7 +473,7 @@ pub async fn authenticate_session_with_upstream(
user.username = username,
user.id,
),
err(Display),
err(Debug),
)]
pub async fn register_user(
txn: &mut Transaction<'_, Postgres>,
@ -569,7 +557,7 @@ pub async fn register_passwordless_user(
%user.id,
user_password.id,
),
err(Display),
err(Debug),
)]
pub async fn set_password(
executor: impl PgExecutor<'_>,
@ -607,13 +595,13 @@ pub async fn set_password(
#[tracing::instrument(
skip_all,
fields(%user_session.id),
err(Display),
err,
)]
pub async fn end_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
user_session: &BrowserSession,
) -> Result<(), anyhow::Error> {
) -> Result<(), DatabaseError> {
let now = clock.now();
let res = sqlx::query!(
r#"
@ -626,27 +614,9 @@ pub async fn end_session(
)
.execute(executor)
.instrument(info_span!("End session"))
.await
.context("could not end session")?;
.await?;
match res.rows_affected() {
1 => Ok(()),
0 => Err(anyhow::anyhow!("no row affected")),
_ => Err(anyhow::anyhow!("too many row affected")),
}
}
#[derive(Debug, Error)]
#[error("failed to lookup user")]
pub enum UserLookupError {
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl LookupError for UserLookupError {
fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
}
DatabaseError::ensure_affected_rows(&res, 1)
}
#[tracing::instrument(
@ -657,7 +627,7 @@ impl LookupError for UserLookupError {
pub async fn lookup_user_by_username(
executor: impl PgExecutor<'_>,
username: &str,
) -> Result<User, UserLookupError> {
) -> Result<Option<User>, DatabaseError> {
let res = sqlx::query_as!(
UserLookup,
r#"
@ -679,8 +649,12 @@ pub async fn lookup_user_by_username(
)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await?;
.await
.to_option()?;
let Some(res) = res else { return Ok(None) };
let id = Ulid::from(res.user_id);
let primary_email = match (
res.user_email_id,
res.user_email,
@ -694,16 +668,20 @@ pub async fn lookup_user_by_username(
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id)
.into())
}
};
let id = Ulid::from(res.user_id);
Ok(User {
Ok(Some(User {
id,
username: res.user_username,
sub: id.to_string(),
primary_email,
})
}))
}
#[tracing::instrument(
@ -711,7 +689,7 @@ pub async fn lookup_user_by_username(
fields(user.id = %id),
err,
)]
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, UserLookupError> {
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, DatabaseError> {
let res = sqlx::query_as!(
UserLookup,
r#"
@ -735,6 +713,7 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
.instrument(info_span!("Fetch user"))
.await?;
let id = Ulid::from(res.user_id);
let primary_email = match (
res.user_email_id,
res.user_email,
@ -748,10 +727,14 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()),
_ => {
return Err(DatabaseInconsistencyError2::on("users")
.column("primary_user_email_id")
.row(id)
.into())
}
};
let id = Ulid::from(res.user_id);
Ok(User {
id,
username: res.user_username,
@ -803,12 +786,12 @@ impl From<UserEmailLookup> for UserEmail {
#[tracing::instrument(
skip_all,
fields(%user.id, %user.username),
err(Display),
err,
)]
pub async fn get_user_emails(
executor: impl PgExecutor<'_>,
user: &User,
) -> Result<Vec<UserEmail>, anyhow::Error> {
) -> Result<Vec<UserEmail>, sqlx::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -835,12 +818,12 @@ pub async fn get_user_emails(
#[tracing::instrument(
skip_all,
fields(%user.id, %user.username),
err(Display),
err,
)]
pub async fn count_user_emails(
executor: impl PgExecutor<'_>,
user: &User,
) -> Result<i64, anyhow::Error> {
) -> Result<i64, sqlx::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
@ -859,7 +842,7 @@ pub async fn count_user_emails(
#[tracing::instrument(
skip_all,
fields(%user.id, %user.username),
err(Display),
err,
)]
pub async fn get_paginated_user_emails(
executor: impl PgExecutor<'_>,
@ -868,7 +851,7 @@ pub async fn get_paginated_user_emails(
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(bool, bool, Vec<UserEmail>), anyhow::Error> {
) -> Result<(bool, bool, Vec<UserEmail>), DatabaseError> {
let mut query = QueryBuilder::new(
r#"
SELECT
@ -908,13 +891,13 @@ pub async fn get_paginated_user_emails(
%user.username,
user_email.id = %id,
),
err(Display),
err,
)]
pub async fn get_user_email(
executor: impl PgExecutor<'_>,
user: &User,
id: Ulid,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<UserEmail, sqlx::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -946,7 +929,7 @@ pub async fn get_user_email(
user_email.id,
user_email.email = %email,
),
err(Display),
err,
)]
pub async fn add_user_email(
executor: impl PgExecutor<'_>,
@ -954,7 +937,7 @@ pub async fn add_user_email(
clock: &Clock,
user: &User,
email: String,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<UserEmail, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id));
@ -971,8 +954,7 @@ pub async fn add_user_email(
)
.execute(executor)
.instrument(info_span!("Add user email"))
.await
.context("could not insert user email")?;
.await?;
Ok(UserEmail {
id,
@ -993,7 +975,7 @@ pub async fn add_user_email(
pub async fn set_user_email_as_primary(
executor: impl PgExecutor<'_>,
user_email: &UserEmail,
) -> Result<(), anyhow::Error> {
) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"
UPDATE users
@ -1006,8 +988,7 @@ pub async fn set_user_email_as_primary(
)
.execute(executor)
.instrument(info_span!("Add user email"))
.await
.context("could not set user email as primary")?;
.await?;
Ok(())
}
@ -1018,12 +999,12 @@ pub async fn set_user_email_as_primary(
%user_email.id,
%user_email.email,
),
err(Display),
err,
)]
pub async fn remove_user_email(
executor: impl PgExecutor<'_>,
user_email: UserEmail,
) -> Result<(), anyhow::Error> {
) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"
DELETE FROM user_emails
@ -1033,8 +1014,7 @@ pub async fn remove_user_email(
)
.execute(executor)
.instrument(info_span!("Remove user email"))
.await
.context("could not remove user email")?;
.await?;
Ok(())
}
@ -1045,13 +1025,13 @@ pub async fn remove_user_email(
%user.id,
user_email.email = email,
),
err(Display),
err,
)]
pub async fn lookup_user_email(
executor: impl PgExecutor<'_>,
user: &User,
email: &str,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<Option<UserEmail>, sqlx::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -1071,9 +1051,11 @@ pub async fn lookup_user_email(
.fetch_one(executor)
.instrument(info_span!("Lookup user email"))
.await
.context("could not lookup user email")?;
.to_option()?;
Ok(res.into())
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
@ -1088,7 +1070,7 @@ pub async fn lookup_user_email_by_id(
executor: impl PgExecutor<'_>,
user: &User,
id: Ulid,
) -> Result<UserEmail, GenericLookupError> {
) -> Result<Option<UserEmail>, DatabaseError> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
@ -1108,21 +1090,23 @@ pub async fn lookup_user_email_by_id(
.fetch_one(executor)
.instrument(info_span!("Lookup user email"))
.await
.map_err(GenericLookupError::what("user email"))?;
.to_option()?;
Ok(res.into())
let Some(res) = res else { return Ok(None) };
Ok(Some(res.into()))
}
#[tracing::instrument(
skip_all,
fields(%user_email.id),
err(Display),
err,
)]
pub async fn mark_user_email_as_verified(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut user_email: UserEmail,
) -> Result<UserEmail, anyhow::Error> {
) -> Result<UserEmail, sqlx::Error> {
let confirmed_at = clock.now();
sqlx::query!(
r#"
@ -1135,8 +1119,7 @@ pub async fn mark_user_email_as_verified(
)
.execute(executor)
.instrument(info_span!("Confirm user email"))
.await
.context("could not update user email")?;
.await?;
user_email.confirmed_at = Some(confirmed_at);
@ -1154,14 +1137,14 @@ struct UserEmailConfirmationCodeLookup {
#[tracing::instrument(
skip_all,
fields(%user_email.id),
err(Display),
err,
)]
pub async fn lookup_user_email_verification_code(
executor: impl PgExecutor<'_>,
clock: &Clock,
user_email: UserEmail,
code: &str,
) -> Result<UserEmailVerification, anyhow::Error> {
) -> Result<Option<UserEmailVerification>, DatabaseError> {
let now = clock.now();
let res = sqlx::query_as!(
@ -1183,7 +1166,9 @@ pub async fn lookup_user_email_verification_code(
.fetch_one(executor)
.instrument(info_span!("Lookup user email verification"))
.await
.context("could not lookup user email verification")?;
.to_option()?;
let Some(res) = res else { return Ok(None) };
let state = if let Some(when) = res.consumed_at {
UserEmailVerificationState::AlreadyUsed { when }
@ -1195,13 +1180,13 @@ pub async fn lookup_user_email_verification_code(
UserEmailVerificationState::Valid
};
Ok(UserEmailVerification {
Ok(Some(UserEmailVerification {
id: res.user_email_confirmation_code_id.into(),
code: res.code,
email: user_email,
state,
created_at: res.created_at,
})
}))
}
#[tracing::instrument(
@ -1209,18 +1194,18 @@ pub async fn lookup_user_email_verification_code(
fields(
%user_email_verification.id,
),
err(Display),
err,
)]
pub async fn consume_email_verification(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut user_email_verification: UserEmailVerification,
) -> Result<UserEmailVerification, anyhow::Error> {
) -> Result<UserEmailVerification, DatabaseError> {
if !matches!(
user_email_verification.state,
UserEmailVerificationState::Valid
) {
bail!("user email verification in wrong state");
return Err(DatabaseError::InvalidOperation);
}
let consumed_at = clock.now();
@ -1236,8 +1221,7 @@ pub async fn consume_email_verification(
)
.execute(executor)
.instrument(info_span!("Consume user email verification"))
.await
.context("could not update user email verification")?;
.await?;
user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at };
@ -1252,7 +1236,7 @@ pub async fn consume_email_verification(
user_email_confirmation.id,
user_email_confirmation.code = code,
),
err(Display),
err,
)]
pub async fn add_user_email_verification_code(
executor: impl PgExecutor<'_>,
@ -1261,7 +1245,7 @@ pub async fn add_user_email_verification_code(
user_email: UserEmail,
max_age: chrono::Duration,
code: String,
) -> Result<UserEmailVerification, anyhow::Error> {
) -> Result<UserEmailVerification, sqlx::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
@ -1281,8 +1265,7 @@ pub async fn add_user_email_verification_code(
)
.execute(executor)
.instrument(info_span!("Add user email verification code"))
.await
.context("could not insert user email verification code")?;
.await?;
let verification = UserEmailVerification {
id,
@ -1320,7 +1303,9 @@ mod tests {
let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
assert_eq!(session.user.id, user.id);
let user2 = lookup_user_by_username(&mut txn, "john").await?;
let user2 = lookup_user_by_username(&mut txn, "john")
.await?
.context("Could not find user")?;
assert_eq!(user.id, user2.id);
txn.commit().await?;

View File

@ -535,20 +535,17 @@ where {
/// Context used by the `account/index.html` template
#[derive(Serialize)]
pub struct AccountContext {
active_sessions: usize,
active_sessions: i64,
emails: Vec<UserEmail>,
}
impl AccountContext {
/// Constructs a context for the "my account" page
#[must_use]
pub fn new<T>(active_sessions: usize, emails: Vec<T>) -> Self
where
T: Into<UserEmail>,
{
pub fn new(active_sessions: i64, emails: Vec<UserEmail>) -> Self {
Self {
active_sessions,
emails: emails.into_iter().map(Into::into).collect(),
emails,
}
}
}