You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: unify user operations errors
This commit is contained in:
@ -14,7 +14,7 @@
|
||||
|
||||
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::user::{lookup_active_session, ActiveSessionLookupError};
|
||||
use mas_storage::{user::lookup_active_session, DatabaseError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Executor, Postgres};
|
||||
use ulid::Ulid;
|
||||
@ -47,7 +47,7 @@ impl SessionInfo {
|
||||
pub async fn load_session(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> Result<Option<BrowserSession>, ActiveSessionLookupError> {
|
||||
) -> Result<Option<BrowserSession>, DatabaseError> {
|
||||
let session_id = if let Some(id) = self.current {
|
||||
id
|
||||
} else {
|
||||
@ -55,7 +55,7 @@ impl SessionInfo {
|
||||
};
|
||||
|
||||
let res = lookup_active_session(executor, session_id).await?;
|
||||
Ok(Some(res))
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use argon2::Argon2;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use mas_config::{DatabaseConfig, RootConfig};
|
||||
@ -214,7 +215,9 @@ impl Options {
|
||||
let pool = config.connect().await?;
|
||||
let mut txn = pool.begin().await?;
|
||||
let hasher = Argon2::default();
|
||||
let user = lookup_user_by_username(&mut txn, username).await?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
set_password(&mut txn, &mut rng, &clock, hasher, &user, password).await?;
|
||||
info!(%user.id, %user.username, "Password changed");
|
||||
@ -228,8 +231,12 @@ impl Options {
|
||||
let pool = config.connect().await?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let user = lookup_user_by_username(&mut txn, username).await?;
|
||||
let email = lookup_user_email(&mut txn, &user, email).await?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let email = lookup_user_email(&mut txn, &user, email)
|
||||
.await?
|
||||
.context("Email not found")?;
|
||||
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
@ -136,9 +136,7 @@ impl RootQuery {
|
||||
let Some(session) = session else { return Ok(None) };
|
||||
let current_user = session.user;
|
||||
|
||||
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id)
|
||||
.await
|
||||
.to_option()?;
|
||||
let browser_session = mas_storage::user::lookup_active_session(&mut conn, id).await?;
|
||||
|
||||
let ret = browser_session.and_then(|browser_session| {
|
||||
if browser_session.user.id == current_user.id {
|
||||
@ -166,9 +164,8 @@ impl RootQuery {
|
||||
let Some(session) = session else { return Ok(None) };
|
||||
let current_user = session.user;
|
||||
|
||||
let user_email = mas_storage::user::lookup_user_email_by_id(&mut conn, ¤t_user, id)
|
||||
.await
|
||||
.to_option()?;
|
||||
let user_email =
|
||||
mas_storage::user::lookup_user_email_by_id(&mut conn, ¤t_user, id).await?;
|
||||
|
||||
Ok(user_email.map(UserEmail))
|
||||
}
|
||||
|
@ -59,13 +59,16 @@ mod views;
|
||||
/// errors.
|
||||
#[macro_export]
|
||||
macro_rules! impl_from_error_for_route {
|
||||
($error:ty) => {
|
||||
impl From<$error> for self::RouteError {
|
||||
($route_error:ty : $error:ty) => {
|
||||
impl From<$error> for $route_error {
|
||||
fn from(e: $error) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
};
|
||||
($error:ty) => {
|
||||
impl_from_error_for_route!(self::RouteError: $error);
|
||||
};
|
||||
}
|
||||
|
||||
pub use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||
|
@ -26,12 +26,9 @@ use mas_data_model::{AuthorizationGrant, BrowserSession};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::{
|
||||
oauth2::{
|
||||
use mas_storage::oauth2::{
|
||||
authorization_grant::{derive_session, fulfill_grant, get_grant_by_id},
|
||||
consent::fetch_client_consent,
|
||||
},
|
||||
user::ActiveSessionLookupError,
|
||||
};
|
||||
use mas_templates::Templates;
|
||||
use oauth2_types::requests::{AccessTokenResponse, AuthorizationResponse};
|
||||
@ -39,9 +36,8 @@ use sqlx::{PgPool, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::callback::{
|
||||
CallbackDestination, CallbackDestinationError, IntoCallbackDestinationError,
|
||||
};
|
||||
use super::callback::CallbackDestination;
|
||||
use crate::impl_from_error_for_route;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RouteError {
|
||||
@ -49,7 +45,7 @@ pub enum RouteError {
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
#[error(transparent)]
|
||||
Anyhow(anyhow::Error),
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
|
||||
#[error("authorization grant is not in a pending state")]
|
||||
NotPending,
|
||||
@ -74,35 +70,10 @@ impl IntoResponse for RouteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for RouteError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self::Anyhow(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for RouteError {
|
||||
fn from(e: sqlx::Error) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ActiveSessionLookupError> for RouteError {
|
||||
fn from(e: ActiveSessionLookupError) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IntoCallbackDestinationError> for RouteError {
|
||||
fn from(e: IntoCallbackDestinationError) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CallbackDestinationError> for RouteError {
|
||||
fn from(e: CallbackDestinationError) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
|
||||
impl_from_error_for_route!(super::callback::CallbackDestinationError);
|
||||
|
||||
pub(crate) async fn get(
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
@ -171,17 +142,8 @@ pub enum GrantCompletionError {
|
||||
PolicyViolation,
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for GrantCompletionError {
|
||||
fn from(e: sqlx::Error) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IntoCallbackDestinationError> for GrantCompletionError {
|
||||
fn from(e: IntoCallbackDestinationError) -> Self {
|
||||
Self::Internal(Box::new(e))
|
||||
}
|
||||
}
|
||||
impl_from_error_for_route!(GrantCompletionError: sqlx::Error);
|
||||
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
|
||||
|
||||
pub(crate) async fn complete(
|
||||
grant: AuthorizationGrant,
|
||||
|
@ -75,8 +75,6 @@ pub(crate) enum RouteError {
|
||||
impl_from_error_for_route!(sqlx::Error);
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_storage::GenericLookupError);
|
||||
impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError);
|
||||
impl_from_error_for_route!(mas_storage::user::UserLookupError);
|
||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl_from_error_for_route!(mas_storage::DatabaseError);
|
||||
|
@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{Form, Path, Query, State},
|
||||
response::{Html, IntoResponse, Response},
|
||||
@ -64,7 +65,9 @@ pub(crate) async fn get(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id).await?;
|
||||
let user_email = lookup_user_email_by_id(&mut conn, &session.user, id)
|
||||
.await?
|
||||
.context("Could not find user email")?;
|
||||
|
||||
if user_email.confirmed_at.is_some() {
|
||||
// This email was already verified, skip
|
||||
@ -103,15 +106,18 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let email = lookup_user_email_by_id(&mut txn, &session.user, id).await?;
|
||||
let email = lookup_user_email_by_id(&mut txn, &session.user, id)
|
||||
.await?
|
||||
.context("Could not find user email")?;
|
||||
|
||||
if session.user.primary_email.is_none() {
|
||||
set_user_email_as_primary(&mut txn, &email).await?;
|
||||
}
|
||||
|
||||
// TODO: make those 8 hours configurable
|
||||
let verification =
|
||||
lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?;
|
||||
let verification = lookup_user_email_verification_code(&mut txn, &clock, email, &form.code)
|
||||
.await?
|
||||
.context("Invalid code")?;
|
||||
|
||||
// TODO: display nice errors if the code was already consumed or expired
|
||||
let verification = consume_email_verification(&mut txn, &clock, verification).await?;
|
||||
|
@ -306,7 +306,9 @@ pub async fn compat_login(
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
|
||||
// First, lookup the user
|
||||
let user = lookup_user_by_username(&mut txn, username).await?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
.await?
|
||||
.context("Could not lookup username")?;
|
||||
tracing::Span::current().record("user.id", tracing::field::display(user.id));
|
||||
|
||||
// Now, fetch the hashed password from the user associated with that session
|
||||
|
@ -137,7 +137,9 @@ pub async fn get_paginated_user_oauth_sessions(
|
||||
// ideal
|
||||
let mut browser_sessions: HashMap<Ulid, BrowserSession> = HashMap::new();
|
||||
for id in browser_session_ids {
|
||||
let v = lookup_active_session(&mut *conn, id).await?;
|
||||
let v = lookup_active_session(&mut *conn, id)
|
||||
.await?
|
||||
.context("Failed to load active session")?;
|
||||
browser_sessions.insert(id, v);
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
use std::borrow::BorrowMut;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use anyhow::Context;
|
||||
use argon2::Argon2;
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
@ -30,10 +30,9 @@ use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::DatabaseInconsistencyError;
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
Clock, GenericLookupError, LookupError,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -49,11 +48,7 @@ struct UserLookup {
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LoginError {
|
||||
#[error("could not find user {username:?}")]
|
||||
NotFound {
|
||||
username: String,
|
||||
#[source]
|
||||
source: UserLookupError,
|
||||
},
|
||||
NotFound { username: String },
|
||||
|
||||
#[error("authentication failed for {username:?}")]
|
||||
Authentication {
|
||||
@ -81,18 +76,16 @@ pub async fn login(
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
.await
|
||||
.map_err(|source| {
|
||||
if source.not_found() {
|
||||
LoginError::NotFound {
|
||||
username: username.to_owned(),
|
||||
source,
|
||||
}
|
||||
} else {
|
||||
LoginError::Other(source.into())
|
||||
}
|
||||
})?;
|
||||
.context("Could not find user by username")?;
|
||||
|
||||
let Some(user) = user else {
|
||||
return Err(LoginError::NotFound { username: username.to_owned() });
|
||||
};
|
||||
|
||||
let mut session = start_session(&mut txn, &mut rng, clock, user)
|
||||
.await
|
||||
.context("Could not start session")?;
|
||||
|
||||
let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
|
||||
authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
|
||||
.await
|
||||
.map_err(|source| {
|
||||
@ -110,19 +103,6 @@ pub async fn login(
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not fetch session")]
|
||||
pub enum ActiveSessionLookupError {
|
||||
Fetch(#[from] sqlx::Error),
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for ActiveSessionLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Fetch(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct SessionLookup {
|
||||
user_session_id: Uuid,
|
||||
@ -138,9 +118,10 @@ struct SessionLookup {
|
||||
}
|
||||
|
||||
impl TryInto<BrowserSession> for SessionLookup {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
type Error = DatabaseInconsistencyError2;
|
||||
|
||||
fn try_into(self) -> Result<BrowserSession, Self::Error> {
|
||||
let id = Ulid::from(self.user_id);
|
||||
let primary_email = match (
|
||||
self.user_email_id,
|
||||
self.user_email,
|
||||
@ -154,10 +135,13 @@ impl TryInto<BrowserSession> for SessionLookup {
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(self.user_id);
|
||||
let user = User {
|
||||
id,
|
||||
username: self.username,
|
||||
@ -171,7 +155,11 @@ impl TryInto<BrowserSession> for SessionLookup {
|
||||
created_at,
|
||||
}),
|
||||
(None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on(
|
||||
"user_session_authentications",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(BrowserSession {
|
||||
@ -191,7 +179,7 @@ impl TryInto<BrowserSession> for SessionLookup {
|
||||
pub async fn lookup_active_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<BrowserSession, ActiveSessionLookupError> {
|
||||
) -> Result<Option<BrowserSession>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionLookup,
|
||||
r#"
|
||||
@ -220,10 +208,12 @@ pub async fn lookup_active_session(
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?
|
||||
.try_into()?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
Ok(res)
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -232,7 +222,7 @@ pub async fn lookup_active_session(
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_paginated_user_sessions(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -241,7 +231,7 @@ pub async fn get_paginated_user_sessions(
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<BrowserSession>), anyhow::Error> {
|
||||
) -> Result<(bool, bool, Vec<BrowserSession>), DatabaseError> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
@ -289,14 +279,14 @@ pub async fn get_paginated_user_sessions(
|
||||
%user.id,
|
||||
user_session.id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn start_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: User,
|
||||
) -> Result<BrowserSession, anyhow::Error> {
|
||||
) -> Result<BrowserSession, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user_session.id", tracing::field::display(id));
|
||||
@ -311,8 +301,7 @@ pub async fn start_session(
|
||||
created_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not create session")?;
|
||||
.await?;
|
||||
|
||||
let session = BrowserSession {
|
||||
id,
|
||||
@ -327,12 +316,12 @@ pub async fn start_session(
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user.id),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn count_active_sessions(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
) -> Result<usize, anyhow::Error> {
|
||||
) -> Result<i64, DatabaseError> {
|
||||
let res = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*) as "count!"
|
||||
@ -342,8 +331,7 @@ pub async fn count_active_sessions(
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?
|
||||
.try_into()?;
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
@ -485,7 +473,7 @@ pub async fn authenticate_session_with_upstream(
|
||||
user.username = username,
|
||||
user.id,
|
||||
),
|
||||
err(Display),
|
||||
err(Debug),
|
||||
)]
|
||||
pub async fn register_user(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
@ -569,7 +557,7 @@ pub async fn register_passwordless_user(
|
||||
%user.id,
|
||||
user_password.id,
|
||||
),
|
||||
err(Display),
|
||||
err(Debug),
|
||||
)]
|
||||
pub async fn set_password(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -607,13 +595,13 @@ pub async fn set_password(
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user_session.id),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn end_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
user_session: &BrowserSession,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), DatabaseError> {
|
||||
let now = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -626,27 +614,9 @@ pub async fn end_session(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("End session"))
|
||||
.await
|
||||
.context("could not end session")?;
|
||||
.await?;
|
||||
|
||||
match res.rows_affected() {
|
||||
1 => Ok(()),
|
||||
0 => Err(anyhow::anyhow!("no row affected")),
|
||||
_ => Err(anyhow::anyhow!("too many row affected")),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to lookup user")]
|
||||
pub enum UserLookupError {
|
||||
Database(#[from] sqlx::Error),
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for UserLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
||||
}
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -657,7 +627,7 @@ impl LookupError for UserLookupError {
|
||||
pub async fn lookup_user_by_username(
|
||||
executor: impl PgExecutor<'_>,
|
||||
username: &str,
|
||||
) -> Result<User, UserLookupError> {
|
||||
) -> Result<Option<User>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
UserLookup,
|
||||
r#"
|
||||
@ -679,8 +649,12 @@ pub async fn lookup_user_by_username(
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch user"))
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
res.user_email,
|
||||
@ -694,16 +668,20 @@ pub async fn lookup_user_by_username(
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(id)
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
Ok(User {
|
||||
Ok(Some(User {
|
||||
id,
|
||||
username: res.user_username,
|
||||
sub: id.to_string(),
|
||||
primary_email,
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -711,7 +689,7 @@ pub async fn lookup_user_by_username(
|
||||
fields(user.id = %id),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, UserLookupError> {
|
||||
pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
UserLookup,
|
||||
r#"
|
||||
@ -735,6 +713,7 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
|
||||
.instrument(info_span!("Fetch user"))
|
||||
.await?;
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
res.user_email,
|
||||
@ -748,10 +727,14 @@ pub async fn lookup_user(executor: impl PgExecutor<'_>, id: Ulid) -> Result<User
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(id)
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
Ok(User {
|
||||
id,
|
||||
username: res.user_username,
|
||||
@ -803,12 +786,12 @@ impl From<UserEmailLookup> for UserEmail {
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user.id, %user.username),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_user_emails(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
) -> Result<Vec<UserEmail>, anyhow::Error> {
|
||||
) -> Result<Vec<UserEmail>, sqlx::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
@ -835,12 +818,12 @@ pub async fn get_user_emails(
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user.id, %user.username),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn count_user_emails(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
) -> Result<i64, anyhow::Error> {
|
||||
) -> Result<i64, sqlx::Error> {
|
||||
let res = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
@ -859,7 +842,7 @@ pub async fn count_user_emails(
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user.id, %user.username),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_paginated_user_emails(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -868,7 +851,7 @@ pub async fn get_paginated_user_emails(
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<UserEmail>), anyhow::Error> {
|
||||
) -> Result<(bool, bool, Vec<UserEmail>), DatabaseError> {
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT
|
||||
@ -908,13 +891,13 @@ pub async fn get_paginated_user_emails(
|
||||
%user.username,
|
||||
user_email.id = %id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_user_email(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
id: Ulid,
|
||||
) -> Result<UserEmail, anyhow::Error> {
|
||||
) -> Result<UserEmail, sqlx::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
@ -946,7 +929,7 @@ pub async fn get_user_email(
|
||||
user_email.id,
|
||||
user_email.email = %email,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_user_email(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -954,7 +937,7 @@ pub async fn add_user_email(
|
||||
clock: &Clock,
|
||||
user: &User,
|
||||
email: String,
|
||||
) -> Result<UserEmail, anyhow::Error> {
|
||||
) -> Result<UserEmail, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user_email.id", tracing::field::display(id));
|
||||
@ -971,8 +954,7 @@ pub async fn add_user_email(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("Add user email"))
|
||||
.await
|
||||
.context("could not insert user email")?;
|
||||
.await?;
|
||||
|
||||
Ok(UserEmail {
|
||||
id,
|
||||
@ -993,7 +975,7 @@ pub async fn add_user_email(
|
||||
pub async fn set_user_email_as_primary(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user_email: &UserEmail,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE users
|
||||
@ -1006,8 +988,7 @@ pub async fn set_user_email_as_primary(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("Add user email"))
|
||||
.await
|
||||
.context("could not set user email as primary")?;
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -1018,12 +999,12 @@ pub async fn set_user_email_as_primary(
|
||||
%user_email.id,
|
||||
%user_email.email,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn remove_user_email(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user_email: UserEmail,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM user_emails
|
||||
@ -1033,8 +1014,7 @@ pub async fn remove_user_email(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("Remove user email"))
|
||||
.await
|
||||
.context("could not remove user email")?;
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -1045,13 +1025,13 @@ pub async fn remove_user_email(
|
||||
%user.id,
|
||||
user_email.email = email,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_user_email(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
email: &str,
|
||||
) -> Result<UserEmail, anyhow::Error> {
|
||||
) -> Result<Option<UserEmail>, sqlx::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
@ -1071,9 +1051,11 @@ pub async fn lookup_user_email(
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Lookup user email"))
|
||||
.await
|
||||
.context("could not lookup user email")?;
|
||||
.to_option()?;
|
||||
|
||||
Ok(res.into())
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -1088,7 +1070,7 @@ pub async fn lookup_user_email_by_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
user: &User,
|
||||
id: Ulid,
|
||||
) -> Result<UserEmail, GenericLookupError> {
|
||||
) -> Result<Option<UserEmail>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailLookup,
|
||||
r#"
|
||||
@ -1108,21 +1090,23 @@ pub async fn lookup_user_email_by_id(
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Lookup user email"))
|
||||
.await
|
||||
.map_err(GenericLookupError::what("user email"))?;
|
||||
.to_option()?;
|
||||
|
||||
Ok(res.into())
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.into()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user_email.id),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn mark_user_email_as_verified(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut user_email: UserEmail,
|
||||
) -> Result<UserEmail, anyhow::Error> {
|
||||
) -> Result<UserEmail, sqlx::Error> {
|
||||
let confirmed_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
@ -1135,8 +1119,7 @@ pub async fn mark_user_email_as_verified(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("Confirm user email"))
|
||||
.await
|
||||
.context("could not update user email")?;
|
||||
.await?;
|
||||
|
||||
user_email.confirmed_at = Some(confirmed_at);
|
||||
|
||||
@ -1154,14 +1137,14 @@ struct UserEmailConfirmationCodeLookup {
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(%user_email.id),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_user_email_verification_code(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
user_email: UserEmail,
|
||||
code: &str,
|
||||
) -> Result<UserEmailVerification, anyhow::Error> {
|
||||
) -> Result<Option<UserEmailVerification>, DatabaseError> {
|
||||
let now = clock.now();
|
||||
|
||||
let res = sqlx::query_as!(
|
||||
@ -1183,7 +1166,9 @@ pub async fn lookup_user_email_verification_code(
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Lookup user email verification"))
|
||||
.await
|
||||
.context("could not lookup user email verification")?;
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let state = if let Some(when) = res.consumed_at {
|
||||
UserEmailVerificationState::AlreadyUsed { when }
|
||||
@ -1195,13 +1180,13 @@ pub async fn lookup_user_email_verification_code(
|
||||
UserEmailVerificationState::Valid
|
||||
};
|
||||
|
||||
Ok(UserEmailVerification {
|
||||
Ok(Some(UserEmailVerification {
|
||||
id: res.user_email_confirmation_code_id.into(),
|
||||
code: res.code,
|
||||
email: user_email,
|
||||
state,
|
||||
created_at: res.created_at,
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -1209,18 +1194,18 @@ pub async fn lookup_user_email_verification_code(
|
||||
fields(
|
||||
%user_email_verification.id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn consume_email_verification(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut user_email_verification: UserEmailVerification,
|
||||
) -> Result<UserEmailVerification, anyhow::Error> {
|
||||
) -> Result<UserEmailVerification, DatabaseError> {
|
||||
if !matches!(
|
||||
user_email_verification.state,
|
||||
UserEmailVerificationState::Valid
|
||||
) {
|
||||
bail!("user email verification in wrong state");
|
||||
return Err(DatabaseError::InvalidOperation);
|
||||
}
|
||||
|
||||
let consumed_at = clock.now();
|
||||
@ -1236,8 +1221,7 @@ pub async fn consume_email_verification(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("Consume user email verification"))
|
||||
.await
|
||||
.context("could not update user email verification")?;
|
||||
.await?;
|
||||
|
||||
user_email_verification.state = UserEmailVerificationState::AlreadyUsed { when: consumed_at };
|
||||
|
||||
@ -1252,7 +1236,7 @@ pub async fn consume_email_verification(
|
||||
user_email_confirmation.id,
|
||||
user_email_confirmation.code = code,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_user_email_verification_code(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -1261,7 +1245,7 @@ pub async fn add_user_email_verification_code(
|
||||
user_email: UserEmail,
|
||||
max_age: chrono::Duration,
|
||||
code: String,
|
||||
) -> Result<UserEmailVerification, anyhow::Error> {
|
||||
) -> Result<UserEmailVerification, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
|
||||
@ -1281,8 +1265,7 @@ pub async fn add_user_email_verification_code(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("Add user email verification code"))
|
||||
.await
|
||||
.context("could not insert user email verification code")?;
|
||||
.await?;
|
||||
|
||||
let verification = UserEmailVerification {
|
||||
id,
|
||||
@ -1320,7 +1303,9 @@ mod tests {
|
||||
let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
|
||||
assert_eq!(session.user.id, user.id);
|
||||
|
||||
let user2 = lookup_user_by_username(&mut txn, "john").await?;
|
||||
let user2 = lookup_user_by_username(&mut txn, "john")
|
||||
.await?
|
||||
.context("Could not find user")?;
|
||||
assert_eq!(user.id, user2.id);
|
||||
|
||||
txn.commit().await?;
|
||||
|
@ -535,20 +535,17 @@ where {
|
||||
/// Context used by the `account/index.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct AccountContext {
|
||||
active_sessions: usize,
|
||||
active_sessions: i64,
|
||||
emails: Vec<UserEmail>,
|
||||
}
|
||||
|
||||
impl AccountContext {
|
||||
/// Constructs a context for the "my account" page
|
||||
#[must_use]
|
||||
pub fn new<T>(active_sessions: usize, emails: Vec<T>) -> Self
|
||||
where
|
||||
T: Into<UserEmail>,
|
||||
{
|
||||
pub fn new(active_sessions: i64, emails: Vec<UserEmail>) -> Self {
|
||||
Self {
|
||||
active_sessions,
|
||||
emails: emails.into_iter().map(Into::into).collect(),
|
||||
emails,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user