You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: unify the compat login errors
This commit is contained in:
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use anyhow::Context;
|
||||
use argon2::{Argon2, PasswordHash};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{
|
||||
@ -21,7 +21,6 @@ use mas_data_model::{
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{Acquire, PgExecutor, Postgres, QueryBuilder};
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
@ -31,7 +30,7 @@ use uuid::Uuid;
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
user::lookup_user_by_username,
|
||||
Clock, DatabaseInconsistencyError, LookupError,
|
||||
Clock, DatabaseError, DatabaseInconsistencyError2, LookupResultExt,
|
||||
};
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
@ -51,29 +50,12 @@ struct CompatAccessTokenLookup {
|
||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to lookup compat access token")]
|
||||
pub enum CompatAccessTokenLookupError {
|
||||
Expired { when: DateTime<Utc> },
|
||||
Database(#[from] sqlx::Error),
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for CompatAccessTokenLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Database(sqlx::Error::RowNotFound) | Self::Expired { .. }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn lookup_active_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<(CompatAccessToken, CompatSession), CompatAccessTokenLookupError> {
|
||||
) -> Result<Option<(CompatAccessToken, CompatSession)>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatAccessTokenLookup,
|
||||
r#"
|
||||
@ -101,20 +83,19 @@ pub async fn lookup_active_compat_access_token(
|
||||
LEFT JOIN user_emails ue
|
||||
ON ue.user_email_id = u.primary_user_email_id
|
||||
|
||||
WHERE ct.access_token = $1 AND cs.finished_at IS NULL
|
||||
WHERE ct.access_token = $1
|
||||
AND ct.expires_at < $2
|
||||
AND cs.finished_at IS NULL
|
||||
"#,
|
||||
token,
|
||||
clock.now(),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch compat access token"))
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
// Check for token expiration
|
||||
if let Some(expires_at) = res.compat_access_token_expires_at {
|
||||
if expires_at < clock.now() {
|
||||
return Err(CompatAccessTokenLookupError::Expired { when: expires_at });
|
||||
}
|
||||
}
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
let token = CompatAccessToken {
|
||||
id: res.compat_access_token_id.into(),
|
||||
@ -123,6 +104,7 @@ pub async fn lookup_active_compat_access_token(
|
||||
expires_at: res.compat_access_token_expires_at,
|
||||
};
|
||||
|
||||
let user_id = Ulid::from(res.user_id);
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
res.user_email,
|
||||
@ -136,28 +118,38 @@ pub async fn lookup_active_compat_access_token(
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("compat_sessions")
|
||||
.column("user_id")
|
||||
.row(user_id)
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
let user = User {
|
||||
id,
|
||||
id: user_id,
|
||||
username: res.user_username,
|
||||
sub: id.to_string(),
|
||||
sub: user_id.to_string(),
|
||||
primary_email,
|
||||
};
|
||||
|
||||
let device = Device::try_from(res.compat_session_device_id).unwrap();
|
||||
let id = res.compat_session_id.into();
|
||||
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = CompatSession {
|
||||
id: res.compat_session_id.into(),
|
||||
id,
|
||||
user,
|
||||
device,
|
||||
created_at: res.compat_session_created_at,
|
||||
finished_at: res.compat_session_finished_at,
|
||||
};
|
||||
|
||||
Ok((token, session))
|
||||
Ok(Some((token, session)))
|
||||
}
|
||||
|
||||
pub struct CompatRefreshTokenLookup {
|
||||
@ -180,25 +172,12 @@ pub struct CompatRefreshTokenLookup {
|
||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to lookup compat refresh token")]
|
||||
pub enum CompatRefreshTokenLookupError {
|
||||
Database(#[from] sqlx::Error),
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for CompatRefreshTokenLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn lookup_active_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<(CompatRefreshToken, CompatAccessToken, CompatSession), CompatRefreshTokenLookupError> {
|
||||
) -> Result<Option<(CompatRefreshToken, CompatAccessToken, CompatSession)>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatRefreshTokenLookup,
|
||||
r#"
|
||||
@ -239,7 +218,10 @@ pub async fn lookup_active_compat_refresh_token(
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch compat refresh token"))
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
let Some(res) = res else { return Ok(None); };
|
||||
|
||||
let refresh_token = CompatRefreshToken {
|
||||
id: res.compat_refresh_token_id.into(),
|
||||
@ -254,6 +236,7 @@ pub async fn lookup_active_compat_refresh_token(
|
||||
expires_at: res.compat_access_token_expires_at,
|
||||
};
|
||||
|
||||
let user_id = Ulid::from(res.user_id);
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
res.user_email,
|
||||
@ -267,28 +250,38 @@ pub async fn lookup_active_compat_refresh_token(
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError.into()),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users")
|
||||
.column("primary_user_email_id")
|
||||
.row(user_id)
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let id = Ulid::from(res.user_id);
|
||||
let user = User {
|
||||
id,
|
||||
id: user_id,
|
||||
username: res.user_username,
|
||||
sub: id.to_string(),
|
||||
sub: user_id.to_string(),
|
||||
primary_email,
|
||||
};
|
||||
|
||||
let device = Device::try_from(res.compat_session_device_id).unwrap();
|
||||
let session_id = res.compat_session_id.into();
|
||||
let device = Device::try_from(res.compat_session_device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("compat_sessions")
|
||||
.column("device_id")
|
||||
.row(session_id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let session = CompatSession {
|
||||
id: res.compat_session_id.into(),
|
||||
id: session_id,
|
||||
user,
|
||||
device,
|
||||
created_at: res.compat_session_created_at,
|
||||
finished_at: res.compat_session_finished_at,
|
||||
};
|
||||
|
||||
Ok((refresh_token, access_token, session))
|
||||
Ok(Some((refresh_token, access_token, session)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -299,7 +292,7 @@ pub async fn lookup_active_compat_refresh_token(
|
||||
compat_session.id,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err(Display),
|
||||
err(Debug),
|
||||
)]
|
||||
pub async fn compat_login(
|
||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||
@ -309,6 +302,7 @@ pub async fn compat_login(
|
||||
password: &str,
|
||||
device: Device,
|
||||
) -> Result<CompatSession, anyhow::Error> {
|
||||
// TODO: that should be split and not verify the password hash here
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
|
||||
// First, lookup the user
|
||||
@ -381,7 +375,7 @@ pub async fn compat_login(
|
||||
compat_access_token.id,
|
||||
user.id = %session.user.id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -390,7 +384,7 @@ pub async fn add_compat_access_token(
|
||||
session: &CompatSession,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken, anyhow::Error> {
|
||||
) -> Result<CompatAccessToken, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||
@ -411,8 +405,7 @@ pub async fn add_compat_access_token(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Insert compat access token"))
|
||||
.await
|
||||
.context("could not insert compat access token")?;
|
||||
.await?;
|
||||
|
||||
Ok(CompatAccessToken {
|
||||
id,
|
||||
@ -427,13 +420,13 @@ pub async fn add_compat_access_token(
|
||||
fields(
|
||||
compat_access_token.id = %access_token.id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn expire_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token: CompatAccessToken,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), DatabaseError> {
|
||||
let expires_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -445,16 +438,9 @@ pub async fn expire_compat_access_token(
|
||||
expires_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("failed to update compat access token")?;
|
||||
.await?;
|
||||
|
||||
if res.rows_affected() == 1 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"no row were affected when updating access token"
|
||||
))
|
||||
}
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -466,7 +452,7 @@ pub async fn expire_compat_access_token(
|
||||
compat_refresh_token.id,
|
||||
user.id = %session.user.id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn add_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -475,7 +461,7 @@ pub async fn add_compat_refresh_token(
|
||||
session: &CompatSession,
|
||||
access_token: &CompatAccessToken,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken, anyhow::Error> {
|
||||
) -> Result<CompatRefreshToken, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||
@ -495,8 +481,7 @@ pub async fn add_compat_refresh_token(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Insert compat refresh token"))
|
||||
.await
|
||||
.context("could not insert compat refresh token")?;
|
||||
.await?;
|
||||
|
||||
Ok(CompatRefreshToken {
|
||||
id,
|
||||
@ -508,13 +493,13 @@ pub async fn add_compat_refresh_token(
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(compat_session.id),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn compat_logout(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let finished_at = clock.now();
|
||||
// TODO: this does not check for token expiration
|
||||
let compat_session_id = sqlx::query_scalar!(
|
||||
@ -531,8 +516,7 @@ pub async fn compat_logout(
|
||||
finished_at,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not update compat access token")?;
|
||||
.await?;
|
||||
|
||||
tracing::Span::current().record(
|
||||
"compat_session.id",
|
||||
@ -547,13 +531,13 @@ pub async fn compat_logout(
|
||||
fields(
|
||||
compat_refresh_token.id = %refresh_token.id,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn consume_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: CompatRefreshToken,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), DatabaseError> {
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
@ -565,16 +549,9 @@ pub async fn consume_compat_refresh_token(
|
||||
consumed_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("failed to update compat refresh token")?;
|
||||
.await?;
|
||||
|
||||
if res.rows_affected() == 1 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"no row were affected when updating refresh token"
|
||||
))
|
||||
}
|
||||
DatabaseError::ensure_affected_rows(&res, 1)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -583,7 +560,7 @@ pub async fn consume_compat_refresh_token(
|
||||
compat_sso_login.id,
|
||||
compat_sso_login.redirect_uri = %redirect_uri,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn insert_compat_sso_login(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -591,7 +568,7 @@ pub async fn insert_compat_sso_login(
|
||||
clock: &Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin, anyhow::Error> {
|
||||
) -> Result<CompatSsoLogin, sqlx::Error> {
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||
@ -609,8 +586,7 @@ pub async fn insert_compat_sso_login(
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Insert compat SSO login"))
|
||||
.await
|
||||
.context("could not insert compat SSO login")?;
|
||||
.await?;
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id,
|
||||
@ -642,11 +618,16 @@ struct CompatSsoLoginLookup {
|
||||
}
|
||||
|
||||
impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
type Error = DatabaseInconsistencyError2;
|
||||
|
||||
fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
|
||||
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri)
|
||||
.map_err(|_| DatabaseInconsistencyError)?;
|
||||
let id = res.compat_sso_login_id.into();
|
||||
let redirect_uri = Url::parse(&res.compat_sso_login_redirect_uri).map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("compat_sso_logins")
|
||||
.column("redirect_uri")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
@ -661,7 +642,9 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
confirmed_at,
|
||||
}),
|
||||
(None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("users").column("primary_user_email_id"))
|
||||
}
|
||||
};
|
||||
|
||||
let user = match (res.user_id, res.user_username, primary_email) {
|
||||
@ -676,7 +659,7 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
}
|
||||
|
||||
(None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError),
|
||||
_ => return Err(DatabaseInconsistencyError2::on("compat_sessions").column("user_id")),
|
||||
};
|
||||
|
||||
let session = match (
|
||||
@ -687,9 +670,15 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
user,
|
||||
) {
|
||||
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => {
|
||||
let device = Device::try_from(device_id).map_err(|_| DatabaseInconsistencyError)?;
|
||||
let id = id.into();
|
||||
let device = Device::try_from(device_id).map_err(|e| {
|
||||
DatabaseInconsistencyError2::on("compat_sessions")
|
||||
.column("device")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
Some(CompatSession {
|
||||
id: id.into(),
|
||||
id,
|
||||
user,
|
||||
device,
|
||||
created_at,
|
||||
@ -697,7 +686,11 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
})
|
||||
}
|
||||
(None, None, None, None, None) => None,
|
||||
_ => return Err(DatabaseInconsistencyError),
|
||||
_ => {
|
||||
return Err(DatabaseInconsistencyError2::on("compat_sso_logins")
|
||||
.column("compat_session_id")
|
||||
.row(id))
|
||||
}
|
||||
};
|
||||
|
||||
let state = match (
|
||||
@ -717,11 +710,11 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
session,
|
||||
}
|
||||
}
|
||||
_ => return Err(DatabaseInconsistencyError),
|
||||
_ => return Err(DatabaseInconsistencyError2::on("compat_sso_logins").row(id)),
|
||||
};
|
||||
|
||||
Ok(CompatSsoLogin {
|
||||
id: res.compat_sso_login_id.into(),
|
||||
id,
|
||||
login_token: res.compat_sso_login_token,
|
||||
redirect_uri,
|
||||
created_at: res.compat_sso_login_created_at,
|
||||
@ -730,19 +723,6 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("failed to lookup compat SSO login")]
|
||||
pub enum CompatSsoLoginLookupError {
|
||||
Database(#[from] sqlx::Error),
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
impl LookupError for CompatSsoLoginLookupError {
|
||||
fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
@ -753,7 +733,7 @@ impl LookupError for CompatSsoLoginLookupError {
|
||||
pub async fn get_compat_sso_login_by_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<CompatSsoLogin, CompatSsoLoginLookupError> {
|
||||
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
@ -787,9 +767,12 @@ pub async fn get_compat_sso_login_by_id(
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
Ok(res.try_into()?)
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
@ -798,7 +781,7 @@ pub async fn get_compat_sso_login_by_id(
|
||||
%user.id,
|
||||
%user.username,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn get_paginated_user_compat_sso_logins(
|
||||
executor: impl PgExecutor<'_>,
|
||||
@ -807,7 +790,7 @@ pub async fn get_paginated_user_compat_sso_logins(
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(bool, bool, Vec<CompatSsoLogin>), anyhow::Error> {
|
||||
) -> Result<(bool, bool, Vec<CompatSsoLogin>), DatabaseError> {
|
||||
// TODO: this queries too much (like user info) which we probably don't need
|
||||
// because we already have them
|
||||
let mut query = QueryBuilder::new(
|
||||
@ -864,7 +847,7 @@ pub async fn get_paginated_user_compat_sso_logins(
|
||||
pub async fn get_compat_sso_login_by_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
token: &str,
|
||||
) -> Result<CompatSsoLogin, CompatSsoLoginLookupError> {
|
||||
) -> Result<Option<CompatSsoLogin>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
CompatSsoLoginLookup,
|
||||
r#"
|
||||
@ -898,35 +881,38 @@ pub async fn get_compat_sso_login_by_token(
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(tracing::info_span!("Lookup compat SSO login"))
|
||||
.await?;
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
Ok(res.try_into()?)
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
%user.id,
|
||||
compat_sso_login.id = %login.id,
|
||||
compat_sso_login.redirect_uri = %login.redirect_uri,
|
||||
%compat_sso_login.id,
|
||||
%compat_sso_login.redirect_uri,
|
||||
compat_session.id,
|
||||
compat_session.device.id = device.as_str(),
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn fullfill_compat_sso_login(
|
||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: User,
|
||||
mut login: CompatSsoLogin,
|
||||
mut compat_sso_login: CompatSsoLogin,
|
||||
device: Device,
|
||||
) -> Result<CompatSsoLogin, anyhow::Error> {
|
||||
if !matches!(login.state, CompatSsoLoginState::Pending) {
|
||||
bail!("sso login in wrong state");
|
||||
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||
if !matches!(compat_sso_login.state, CompatSsoLoginState::Pending) {
|
||||
return Err(DatabaseError::InvalidOperation);
|
||||
};
|
||||
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
let mut txn = conn.begin().await?;
|
||||
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
@ -944,8 +930,7 @@ pub async fn fullfill_compat_sso_login(
|
||||
)
|
||||
.execute(&mut txn)
|
||||
.instrument(tracing::info_span!("Insert compat session"))
|
||||
.await
|
||||
.context("could not insert compat session")?;
|
||||
.await?;
|
||||
|
||||
let session = CompatSession {
|
||||
id,
|
||||
@ -965,46 +950,41 @@ pub async fn fullfill_compat_sso_login(
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(login.id),
|
||||
Uuid::from(compat_sso_login.id),
|
||||
Uuid::from(session.id),
|
||||
fulfilled_at,
|
||||
)
|
||||
.execute(&mut txn)
|
||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||
.await
|
||||
.context("could not update compat SSO login")?;
|
||||
.await?;
|
||||
|
||||
let state = CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session,
|
||||
};
|
||||
|
||||
login.state = state;
|
||||
compat_sso_login.state = state;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(login)
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
compat_sso_login.id = %login.id,
|
||||
compat_sso_login.redirect_uri = %login.redirect_uri,
|
||||
%compat_sso_login.id,
|
||||
%compat_sso_login.redirect_uri,
|
||||
),
|
||||
err(Display),
|
||||
err,
|
||||
)]
|
||||
pub async fn mark_compat_sso_login_as_exchanged(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, anyhow::Error> {
|
||||
let (fulfilled_at, session) = match login.state {
|
||||
CompatSsoLoginState::Fulfilled {
|
||||
fulfilled_at,
|
||||
session,
|
||||
} => (fulfilled_at, session),
|
||||
_ => bail!("sso login in wrong state"),
|
||||
mut compat_sso_login: CompatSsoLogin,
|
||||
) -> Result<CompatSsoLogin, DatabaseError> {
|
||||
let CompatSsoLoginState::Fulfilled { fulfilled_at, session } = compat_sso_login.state else {
|
||||
return Err(DatabaseError::InvalidOperation);
|
||||
};
|
||||
|
||||
let exchanged_at = clock.now();
|
||||
@ -1016,19 +996,18 @@ pub async fn mark_compat_sso_login_as_exchanged(
|
||||
WHERE
|
||||
compat_sso_login_id = $1
|
||||
"#,
|
||||
Uuid::from(login.id),
|
||||
Uuid::from(compat_sso_login.id),
|
||||
exchanged_at,
|
||||
)
|
||||
.execute(executor)
|
||||
.instrument(tracing::info_span!("Update compat SSO login"))
|
||||
.await
|
||||
.context("could not update compat SSO login")?;
|
||||
.await?;
|
||||
|
||||
let state = CompatSsoLoginState::Exchanged {
|
||||
fulfilled_at,
|
||||
exchanged_at,
|
||||
session,
|
||||
};
|
||||
login.state = state;
|
||||
Ok(login)
|
||||
compat_sso_login.state = state;
|
||||
Ok(compat_sso_login)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use pagination::InvalidPagination;
|
||||
use sqlx::migrate::Migrator;
|
||||
use sqlx::{migrate::Migrator, postgres::PgQueryResult};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@ -100,6 +100,30 @@ pub enum DatabaseError {
|
||||
|
||||
/// An error which occured while generating the paginated query
|
||||
Pagination(#[from] InvalidPagination),
|
||||
|
||||
/// An error which happened because the requested database operation is
|
||||
/// invalid
|
||||
#[error("Invalid database operation")]
|
||||
InvalidOperation,
|
||||
|
||||
/// An error which happens when an operation affects not enough or too many
|
||||
/// rows
|
||||
#[error("Expected {expected} rows to be affected, but {actual} rows were affected")]
|
||||
RowsAffected { expected: u64, actual: u64 },
|
||||
}
|
||||
|
||||
impl DatabaseError {
|
||||
pub(crate) fn ensure_affected_rows(
|
||||
result: &PgQueryResult,
|
||||
expected: u64,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let actual = result.rows_affected();
|
||||
if actual == expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(DatabaseError::RowsAffected { expected, actual })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
Reference in New Issue
Block a user