1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-28 11:02:02 +03:00

Database refactoring

This commit is contained in:
Quentin Gliech
2022-10-21 11:25:38 +02:00
parent 0571c36da9
commit e2142f9cd4
79 changed files with 3070 additions and 3833 deletions

View File

@ -19,28 +19,28 @@ use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
Device, User, UserEmail,
};
use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres};
use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error;
use tokio::task;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{
user::lookup_user_by_username, DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend,
};
use crate::{user::lookup_user_by_username, DatabaseInconsistencyError, PostgresqlBackend};
struct CompatAccessTokenLookup {
compat_access_token_id: i64,
compat_access_token_id: Uuid,
compat_access_token: String,
compat_access_token_created_at: DateTime<Utc>,
compat_access_token_expires_at: Option<DateTime<Utc>>,
compat_session_id: i64,
compat_session_id: Uuid,
compat_session_created_at: DateTime<Utc>,
compat_session_deleted_at: Option<DateTime<Utc>>,
compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: String,
user_id: i64,
user_id: Uuid,
user_username: String,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -49,6 +49,7 @@ struct CompatAccessTokenLookup {
#[derive(Debug, Error)]
#[error("failed to lookup compat access token")]
pub enum CompatAccessTokenLookupError {
Expired { when: DateTime<Utc> },
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
@ -56,7 +57,10 @@ pub enum CompatAccessTokenLookupError {
impl CompatAccessTokenLookupError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
matches!(
self,
Self::Database(sqlx::Error::RowNotFound) | Self::Expired { .. }
)
}
}
@ -75,41 +79,48 @@ pub async fn lookup_active_compat_access_token(
CompatAccessTokenLookup,
r#"
SELECT
ct.id AS "compat_access_token_id",
ct.token AS "compat_access_token",
ct.compat_access_token_id,
ct.access_token AS "compat_access_token",
ct.created_at AS "compat_access_token_created_at",
ct.expires_at AS "compat_access_token_expires_at",
cs.id AS "compat_session_id",
cs.compat_session_id,
cs.created_at AS "compat_session_created_at",
cs.deleted_at AS "compat_session_deleted_at",
cs.finished_at AS "compat_session_finished_at",
cs.device_id AS "compat_session_device_id",
u.id AS "user_id!",
u.user_id AS "user_id!",
u.username AS "user_username!",
ue.id AS "user_email_id?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM compat_access_tokens ct
INNER JOIN compat_sessions cs
ON cs.id = ct.compat_session_id
USING (compat_session_id)
INNER JOIN users u
ON u.id = cs.user_id
USING (user_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
ON ue.user_email_id = u.primary_user_email_id
WHERE ct.token = $1
WHERE ct.access_token = $1
AND (ct.expires_at IS NULL OR ct.expires_at > NOW())
AND cs.deleted_at IS NULL
"#,
AND cs.finished_at IS NULL
"#,
token,
)
.fetch_one(executor)
.instrument(info_span!("Fetch compat access token"))
.await?;
// Check for token expiration
if let Some(expires_at) = res.compat_access_token_expires_at {
if expires_at < Utc::now() {
return Err(CompatAccessTokenLookupError::Expired { when: expires_at });
}
}
let token = CompatAccessToken {
data: res.compat_access_token_id,
data: res.compat_access_token_id.into(),
token: res.compat_access_token,
created_at: res.compat_access_token_created_at,
expires_at: res.compat_access_token_expires_at,
@ -122,7 +133,7 @@ pub async fn lookup_active_compat_access_token(
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -131,41 +142,42 @@ pub async fn lookup_active_compat_access_token(
_ => return Err(DatabaseInconsistencyError.into()),
};
let id = Ulid::from(res.user_id);
let user = User {
data: res.user_id,
data: id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
sub: id.to_string(),
primary_email,
};
let device = Device::try_from(res.compat_session_device_id).unwrap();
let session = CompatSession {
data: res.compat_session_id,
data: res.compat_session_id.into(),
user,
device,
created_at: res.compat_session_created_at,
deleted_at: res.compat_session_deleted_at,
finished_at: res.compat_session_finished_at,
};
Ok((token, session))
}
pub struct CompatRefreshTokenLookup {
compat_refresh_token_id: i64,
compat_refresh_token_id: Uuid,
compat_refresh_token: String,
compat_refresh_token_created_at: DateTime<Utc>,
compat_access_token_id: i64,
compat_access_token_id: Uuid,
compat_access_token: String,
compat_access_token_created_at: DateTime<Utc>,
compat_access_token_expires_at: Option<DateTime<Utc>>,
compat_session_id: i64,
compat_session_id: Uuid,
compat_session_created_at: DateTime<Utc>,
compat_session_deleted_at: Option<DateTime<Utc>>,
compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: String,
user_id: i64,
user_id: Uuid,
user_username: String,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -202,37 +214,37 @@ pub async fn lookup_active_compat_refresh_token(
CompatRefreshTokenLookup,
r#"
SELECT
cr.id AS "compat_refresh_token_id",
cr.token AS "compat_refresh_token",
cr.compat_refresh_token_id,
cr.refresh_token AS "compat_refresh_token",
cr.created_at AS "compat_refresh_token_created_at",
ct.id AS "compat_access_token_id",
ct.token AS "compat_access_token",
ct.compat_access_token_id,
ct.access_token AS "compat_access_token",
ct.created_at AS "compat_access_token_created_at",
ct.expires_at AS "compat_access_token_expires_at",
cs.id AS "compat_session_id",
cs.compat_session_id,
cs.created_at AS "compat_session_created_at",
cs.deleted_at AS "compat_session_deleted_at",
cs.finished_at AS "compat_session_finished_at",
cs.device_id AS "compat_session_device_id",
u.id AS "user_id!",
u.user_id,
u.username AS "user_username!",
ue.id AS "user_email_id?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM compat_refresh_tokens cr
INNER JOIN compat_access_tokens ct
ON ct.id = cr.compat_access_token_id
INNER JOIN compat_sessions cs
ON cs.id = cr.compat_session_id
USING (compat_session_id)
INNER JOIN compat_access_tokens ct
USING (compat_access_token_id)
INNER JOIN users u
ON u.id = cs.user_id
USING (user_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
ON ue.user_email_id = u.primary_user_email_id
WHERE cr.token = $1
AND cr.next_token_id IS NULL
AND cs.deleted_at IS NULL
WHERE cr.refresh_token = $1
AND cr.consumed_at IS NULL
AND cs.finished_at IS NULL
"#,
token,
)
@ -241,13 +253,13 @@ pub async fn lookup_active_compat_refresh_token(
.await?;
let refresh_token = CompatRefreshToken {
data: res.compat_refresh_token_id,
data: res.compat_refresh_token_id.into(),
token: res.compat_refresh_token,
created_at: res.compat_refresh_token_created_at,
};
let access_token = CompatAccessToken {
data: res.compat_access_token_id,
data: res.compat_access_token_id.into(),
token: res.compat_access_token,
created_at: res.compat_access_token_created_at,
expires_at: res.compat_access_token_expires_at,
@ -260,7 +272,7 @@ pub async fn lookup_active_compat_refresh_token(
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -269,21 +281,22 @@ pub async fn lookup_active_compat_refresh_token(
_ => return Err(DatabaseInconsistencyError.into()),
};
let id = Ulid::from(res.user_id);
let user = User {
data: res.user_id,
data: id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
sub: id.to_string(),
primary_email,
};
let device = Device::try_from(res.compat_session_device_id).unwrap();
let session = CompatSession {
data: res.compat_session_id,
data: res.compat_session_id.into(),
user,
device,
created_at: res.compat_session_created_at,
deleted_at: res.compat_session_deleted_at,
finished_at: res.compat_session_finished_at,
};
Ok((refresh_token, access_token, session))
@ -310,7 +323,7 @@ pub async fn compat_login(
ORDER BY up.created_at DESC
LIMIT 1
"#,
user.data,
Uuid::from(user.data),
)
.fetch_one(&mut txn)
.instrument(tracing::info_span!("Lookup hashed password"))
@ -327,27 +340,30 @@ pub async fn compat_login(
.instrument(tracing::info_span!("Verify hashed password"))
.await??;
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO compat_sessions (user_id, device_id)
VALUES ($1, $2)
RETURNING id, created_at
INSERT INTO compat_sessions
(compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
user.data,
Uuid::from(id),
Uuid::from(user.data),
device.as_str(),
created_at,
)
.fetch_one(&mut txn)
.execute(&mut txn)
.instrument(tracing::info_span!("Insert compat session"))
.await
.context("could not insert compat session")?;
let session = CompatSession {
data: res.id,
data: id,
user,
device,
created_at: res.created_at,
deleted_at: None,
created_at,
finished_at: None,
};
txn.commit().await.context("could not commit transaction")?;
@ -361,70 +377,48 @@ pub async fn add_compat_access_token(
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken<PostgresqlBackend>, anyhow::Error> {
if let Some(expires_after) = expires_after {
// For some reason, we need to convert the type first
let pg_expires_after = PgInterval::try_from(expires_after)
// For some reason, this error type does not let me to just bubble up the error here
.map_err(|e| anyhow::anyhow!("failed to encode duration: {}", e))?;
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO compat_access_tokens (compat_session_id, token, created_at, expires_at)
VALUES ($1, $2, NOW(), NOW() + $3)
RETURNING id, created_at
"#,
session.data,
token,
pg_expires_after,
)
.fetch_one(executor)
.instrument(tracing::info_span!("Insert compat access token"))
.await
.context("could not insert compat access token")?;
sqlx::query!(
r#"
INSERT INTO compat_access_tokens
(compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.data),
token,
created_at,
expires_at,
)
.execute(executor)
.instrument(tracing::info_span!("Insert compat access token"))
.await
.context("could not insert compat access token")?;
Ok(CompatAccessToken {
data: res.id,
token,
created_at: res.created_at,
expires_at: Some(res.created_at + expires_after),
})
} else {
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO compat_access_tokens (compat_session_id, token)
VALUES ($1, $2)
RETURNING id, created_at
"#,
session.data,
token,
)
.fetch_one(executor)
.instrument(tracing::info_span!("Insert compat access token"))
.await
.context("could not insert compat access token")?;
Ok(CompatAccessToken {
data: res.id,
token,
created_at: res.created_at,
expires_at: None,
})
}
Ok(CompatAccessToken {
data: id,
token,
created_at,
expires_at,
})
}
pub async fn expire_compat_access_token(
executor: impl PgExecutor<'_>,
access_token: CompatAccessToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let expires_at = Utc::now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
SET expires_at = NOW()
WHERE id = $1
SET expires_at = $2
WHERE compat_access_token_id = $1
"#,
access_token.data,
Uuid::from(access_token.data),
expires_at,
)
.execute(executor)
.await
@ -445,26 +439,30 @@ pub async fn add_compat_refresh_token(
access_token: &CompatAccessToken<PostgresqlBackend>,
token: String,
) -> Result<CompatRefreshToken<PostgresqlBackend>, anyhow::Error> {
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO compat_refresh_tokens (compat_session_id, compat_access_token_id, token)
VALUES ($1, $2, $3)
RETURNING id, created_at
INSERT INTO compat_refresh_tokens
(compat_refresh_token_id, compat_session_id,
compat_access_token_id, refresh_token, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
session.data,
access_token.data,
Uuid::from(id),
Uuid::from(session.data),
Uuid::from(access_token.data),
token,
created_at,
)
.fetch_one(executor)
.execute(executor)
.instrument(tracing::info_span!("Insert compat refresh token"))
.await
.context("could not insert compat refresh token")?;
Ok(CompatRefreshToken {
data: res.id,
data: id,
token,
created_at: res.created_at,
created_at,
})
}
@ -473,16 +471,19 @@ pub async fn compat_logout(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<(), anyhow::Error> {
let finished_at = Utc::now();
// TODO: this does not check for token expiration
let res = sqlx::query!(
r#"
UPDATE compat_sessions
SET deleted_at = NOW()
FROM compat_access_tokens
WHERE compat_access_tokens.token = $1
AND compat_sessions.id = compat_access_tokens.id
AND compat_sessions.deleted_at IS NULL
UPDATE compat_sessions cs
SET finished_at = $2
FROM compat_access_tokens ca
WHERE ca.access_token = $1
AND ca.compat_session_id = cs.compat_session_id
AND cs.finished_at IS NULL
"#,
token,
finished_at,
)
.execute(executor)
.await
@ -495,19 +496,19 @@ pub async fn compat_logout(
}
}
pub async fn replace_compat_refresh_token(
pub async fn consume_compat_refresh_token(
executor: impl PgExecutor<'_>,
refresh_token: &CompatRefreshToken<PostgresqlBackend>,
next_refresh_token: &CompatRefreshToken<PostgresqlBackend>,
refresh_token: CompatRefreshToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let consumed_at = Utc::now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
SET next_token_id = $2
WHERE id = $1
SET consumed_at = $2
WHERE compat_refresh_token_id = $1
"#,
refresh_token.data,
next_refresh_token.data
Uuid::from(refresh_token.data),
consumed_at,
)
.execute(executor)
.await
@ -524,47 +525,50 @@ pub async fn replace_compat_refresh_token(
pub async fn insert_compat_sso_login(
executor: impl PgExecutor<'_>,
token: String,
login_token: String,
redirect_uri: Url,
) -> anyhow::Result<CompatSsoLogin<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO compat_sso_logins (token, redirect_uri)
VALUES ($1, $2)
RETURNING id, created_at
INSERT INTO compat_sso_logins
(compat_sso_login_id, login_token, redirect_uri, created_at)
VALUES ($1, $2, $3, $4)
"#,
&token,
Uuid::from(id),
&login_token,
redirect_uri.as_str(),
created_at,
)
.fetch_one(executor)
.execute(executor)
.instrument(tracing::info_span!("Insert compat SSO login"))
.await
.context("could not insert compat SSO login")?;
Ok(CompatSsoLogin {
data: res.id,
token,
data: id,
login_token,
redirect_uri,
created_at: res.created_at,
created_at,
state: CompatSsoLoginState::Pending,
})
}
struct CompatSsoLoginLookup {
compat_sso_login_id: i64,
compat_sso_login_id: Uuid,
compat_sso_login_token: String,
compat_sso_login_redirect_uri: String,
compat_sso_login_created_at: DateTime<Utc>,
compat_sso_login_fullfilled_at: Option<DateTime<Utc>>,
compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
compat_session_id: Option<i64>,
compat_session_id: Option<Uuid>,
compat_session_created_at: Option<DateTime<Utc>>,
compat_session_deleted_at: Option<DateTime<Utc>>,
compat_session_finished_at: Option<DateTime<Utc>>,
compat_session_device_id: Option<String>,
user_id: Option<i64>,
user_id: Option<Uuid>,
user_username: Option<String>,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -584,7 +588,7 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -594,12 +598,16 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
};
let user = match (res.user_id, res.user_username, primary_email) {
(Some(id), Some(username), primary_email) => Some(User {
data: id,
username,
sub: format!("fake-sub-{}", id),
primary_email,
}),
(Some(id), Some(username), primary_email) => {
let id = Ulid::from(id);
Some(User {
data: id,
username,
sub: id.to_string(),
primary_email,
})
}
(None, None, None) => None,
_ => return Err(DatabaseInconsistencyError),
};
@ -608,17 +616,17 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
res.compat_session_id,
res.compat_session_device_id,
res.compat_session_created_at,
res.compat_session_deleted_at,
res.compat_session_finished_at,
user,
) {
(Some(id), Some(device_id), Some(created_at), deleted_at, Some(user)) => {
(Some(id), Some(device_id), Some(created_at), finished_at, Some(user)) => {
let device = Device::try_from(device_id).map_err(|_| DatabaseInconsistencyError)?;
Some(CompatSession {
data: id,
data: id.into(),
user,
device,
created_at,
deleted_at,
finished_at,
})
}
(None, None, None, None, None) => None,
@ -626,18 +634,18 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
};
let state = match (
res.compat_sso_login_fullfilled_at,
res.compat_sso_login_fulfilled_at,
res.compat_sso_login_exchanged_at,
session,
) {
(None, None, None) => CompatSsoLoginState::Pending,
(Some(fullfilled_at), None, Some(session)) => CompatSsoLoginState::Fullfilled {
fullfilled_at,
(Some(fulfilled_at), None, Some(session)) => CompatSsoLoginState::Fulfilled {
fulfilled_at,
session,
},
(Some(fullfilled_at), Some(exchanged_at), Some(session)) => {
(Some(fulfilled_at), Some(exchanged_at), Some(session)) => {
CompatSsoLoginState::Exchanged {
fullfilled_at,
fulfilled_at,
exchanged_at,
session,
}
@ -646,8 +654,8 @@ impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin<PostgresqlBackend> {
};
Ok(CompatSsoLogin {
data: res.compat_sso_login_id,
token: res.compat_sso_login_token,
data: res.compat_sso_login_id.into(),
login_token: res.compat_sso_login_token,
redirect_uri,
created_at: res.compat_sso_login_created_at,
state,
@ -673,38 +681,38 @@ impl CompatSsoLoginLookupError {
#[tracing::instrument(skip(executor), err)]
pub async fn get_compat_sso_login_by_id(
executor: impl PgExecutor<'_>,
id: i64,
id: Ulid,
) -> Result<CompatSsoLogin<PostgresqlBackend>, CompatSsoLoginLookupError> {
let res = sqlx::query_as!(
CompatSsoLoginLookup,
r#"
SELECT
cl.id AS "compat_sso_login_id",
cl.token AS "compat_sso_login_token",
cl.compat_sso_login_id,
cl.login_token AS "compat_sso_login_token",
cl.redirect_uri AS "compat_sso_login_redirect_uri",
cl.created_at AS "compat_sso_login_created_at",
cl.fullfilled_at AS "compat_sso_login_fullfilled_at",
cl.fulfilled_at AS "compat_sso_login_fulfilled_at",
cl.exchanged_at AS "compat_sso_login_exchanged_at",
cs.id AS "compat_session_id?",
cs.compat_session_id AS "compat_session_id?",
cs.created_at AS "compat_session_created_at?",
cs.deleted_at AS "compat_session_deleted_at?",
cs.finished_at AS "compat_session_finished_at?",
cs.device_id AS "compat_session_device_id?",
u.id AS "user_id?",
u.user_id AS "user_id?",
u.username AS "user_username?",
ue.id AS "user_email_id?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs
ON cs.id = cl.compat_session_id
USING (compat_session_id)
LEFT JOIN users u
ON u.id = cs.user_id
USING (user_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
WHERE cl.id = $1
ON ue.user_email_id = u.primary_user_email_id
WHERE cl.compat_sso_login_id = $1
"#,
id,
Uuid::from(id),
)
.fetch_one(executor)
.instrument(tracing::info_span!("Lookup compat SSO login"))
@ -723,30 +731,30 @@ pub async fn get_compat_sso_login_by_token(
CompatSsoLoginLookup,
r#"
SELECT
cl.id AS "compat_sso_login_id",
cl.token AS "compat_sso_login_token",
cl.compat_sso_login_id,
cl.login_token AS "compat_sso_login_token",
cl.redirect_uri AS "compat_sso_login_redirect_uri",
cl.created_at AS "compat_sso_login_created_at",
cl.fullfilled_at AS "compat_sso_login_fullfilled_at",
cl.fulfilled_at AS "compat_sso_login_fulfilled_at",
cl.exchanged_at AS "compat_sso_login_exchanged_at",
cs.id AS "compat_session_id?",
cs.compat_session_id AS "compat_session_id?",
cs.created_at AS "compat_session_created_at?",
cs.deleted_at AS "compat_session_deleted_at?",
cs.finished_at AS "compat_session_finished_at?",
cs.device_id AS "compat_session_device_id?",
u.id AS "user_id?",
u.user_id AS "user_id?",
u.username AS "user_username?",
ue.id AS "user_email_id?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM compat_sso_logins cl
LEFT JOIN compat_sessions cs
ON cs.id = cl.compat_session_id
USING (compat_session_id)
LEFT JOIN users u
ON u.id = cs.user_id
USING (user_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
WHERE cl.token = $1
ON ue.user_email_id = u.primary_user_email_id
WHERE cl.login_token = $1
"#,
token,
)
@ -769,49 +777,52 @@ pub async fn fullfill_compat_sso_login(
let mut txn = conn.begin().await.context("could not start transaction")?;
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO compat_sessions (user_id, device_id)
VALUES ($1, $2)
RETURNING id, created_at
INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)
VALUES ($1, $2, $3, $4)
"#,
user.data,
Uuid::from(id),
Uuid::from(user.data),
device.as_str(),
created_at,
)
.fetch_one(&mut txn)
.execute(&mut txn)
.instrument(tracing::info_span!("Insert compat session"))
.await
.context("could not insert compat session")?;
let session = CompatSession {
data: res.id,
data: id,
user,
device,
created_at: res.created_at,
deleted_at: None,
created_at,
finished_at: None,
};
let res = sqlx::query_scalar!(
let fulfilled_at = Utc::now();
sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
fullfilled_at = NOW(),
compat_session_id = $2
compat_session_id = $2,
fulfilled_at = $3
WHERE
id = $1
RETURNING fullfilled_at AS "fullfilled_at!"
compat_sso_login_id = $1
"#,
login.data,
session.data,
Uuid::from(login.data),
Uuid::from(session.data),
fulfilled_at,
)
.fetch_one(&mut txn)
.execute(&mut txn)
.instrument(tracing::info_span!("Update compat SSO login"))
.await
.context("could not update compat SSO login")?;
let state = CompatSsoLoginState::Fullfilled {
fullfilled_at: res,
let state = CompatSsoLoginState::Fulfilled {
fulfilled_at,
session,
};
@ -826,33 +837,34 @@ pub async fn mark_compat_sso_login_as_exchanged(
executor: impl PgExecutor<'_>,
mut login: CompatSsoLogin<PostgresqlBackend>,
) -> anyhow::Result<CompatSsoLogin<PostgresqlBackend>> {
let (fullfilled_at, session) = match login.state {
CompatSsoLoginState::Fullfilled {
fullfilled_at,
let (fulfilled_at, session) = match login.state {
CompatSsoLoginState::Fulfilled {
fulfilled_at,
session,
} => (fullfilled_at, session),
} => (fulfilled_at, session),
_ => bail!("sso login in wrong state"),
};
let res = sqlx::query_scalar!(
let exchanged_at = Utc::now();
sqlx::query!(
r#"
UPDATE compat_sso_logins
SET
exchanged_at = NOW()
exchanged_at = $2
WHERE
id = $1
RETURNING exchanged_at AS "exchanged_at!"
compat_sso_login_id = $1
"#,
login.data,
Uuid::from(login.data),
exchanged_at,
)
.fetch_one(executor)
.execute(executor)
.instrument(tracing::info_span!("Update compat SSO login"))
.await
.context("could not update compat SSO login")?;
let state = CompatSsoLoginState::Exchanged {
fullfilled_at,
exchanged_at: res,
fulfilled_at,
exchanged_at,
session,
};
login.state = state;

View File

@ -23,11 +23,11 @@
clippy::module_name_repetitions
)]
use chrono::{DateTime, Utc};
use mas_data_model::{StorageBackend, StorageBackendMarker};
use serde::Serialize;
use sqlx::migrate::Migrator;
use thiserror::Error;
use ulid::Ulid;
#[derive(Debug, Error)]
#[error("database query returned an inconsistent state")]
@ -37,29 +37,24 @@ pub struct DatabaseInconsistencyError;
pub struct PostgresqlBackend;
impl StorageBackend for PostgresqlBackend {
type AccessTokenData = i64;
type AuthenticationData = i64;
type AuthorizationGrantData = i64;
type BrowserSessionData = i64;
type ClientData = i64;
type CompatAccessTokenData = i64;
type CompatRefreshTokenData = i64;
type CompatSessionData = i64;
type CompatSsoLoginData = i64;
type RefreshTokenData = i64;
type SessionData = i64;
type UserData = i64;
type UserEmailData = i64;
type UserEmailVerificationData = i64;
type AccessTokenData = Ulid;
type AuthenticationData = Ulid;
type AuthorizationGrantData = Ulid;
type BrowserSessionData = Ulid;
type ClientData = Ulid;
type CompatAccessTokenData = Ulid;
type CompatRefreshTokenData = Ulid;
type CompatSessionData = Ulid;
type CompatSsoLoginData = Ulid;
type RefreshTokenData = Ulid;
type SessionData = Ulid;
type UserData = Ulid;
type UserEmailData = Ulid;
type UserEmailVerificationData = Ulid;
}
impl StorageBackendMarker for PostgresqlBackend {}
struct IdAndCreationTime {
id: i64,
created_at: DateTime<Utc>,
}
pub mod compat;
pub mod oauth2;
pub mod user;

View File

@ -17,62 +17,76 @@ use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument(
skip_all,
fields(
session.id = %session.data,
client.id = %session.client.data,
user.id = %session.browser_session.user.data,
access_token.id,
),
err(Debug),
)]
pub async fn add_access_token(
executor: impl PgExecutor<'_>,
session: &Session<PostgresqlBackend>,
token: &str,
access_token: String,
expires_after: Duration,
) -> anyhow::Result<AccessToken<PostgresqlBackend>> {
// Checked convertion of duration to i32, maxing at i32::MAX
let expires_after_seconds = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
let created_at = Utc::now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime(created_at.into());
let res = sqlx::query_as!(
IdAndCreationTime,
tracing::Span::current().record("access_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_session_id, token, expires_after)
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3)
RETURNING
id, created_at
($1, $2, $3, $4, $5)
"#,
session.data,
token,
expires_after_seconds,
Uuid::from(id),
Uuid::from(session.data),
&access_token,
created_at,
expires_at,
)
.fetch_one(executor)
.execute(executor)
.await
.context("could not insert oauth2 access token")?;
Ok(AccessToken {
data: res.id,
expires_after,
token: token.to_owned(),
jti: format!("{}", res.id),
created_at: res.created_at,
data: id,
access_token,
jti: id.to_string(),
created_at,
expires_at,
})
}
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
access_token_id: i64,
access_token: String,
access_token_expires_after: i32,
access_token_created_at: DateTime<Utc>,
session_id: i64,
oauth2_client_id: i64,
oauth2_access_token_id: Uuid,
oauth2_access_token: String,
oauth2_access_token_created_at: DateTime<Utc>,
oauth2_access_token_expires_at: DateTime<Utc>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
user_session_id: i64,
user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_id: i64,
user_id: Uuid,
user_username: String,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -114,40 +128,39 @@ where
OAuth2AccessTokenLookup,
r#"
SELECT
at.id AS "access_token_id",
at.token AS "access_token",
at.expires_after AS "access_token_expires_after",
at.created_at AS "access_token_created_at",
os.id AS "session_id!",
at.oauth2_access_token_id,
at.access_token AS "oauth2_access_token",
at.created_at AS "oauth2_access_token_created_at",
at.expires_at AS "oauth2_access_token_expires_at",
os.oauth2_session_id AS "oauth2_session_id!",
os.oauth2_client_id AS "oauth2_client_id!",
os.scope AS "scope!",
us.id AS "user_session_id!",
us.user_session_id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.id AS "user_id!",
u.user_id AS "user_id!",
u.username AS "user_username!",
usa.id AS "user_session_last_authentication_id?",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.id AS "user_email_id?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
ON os.id = at.oauth2_session_id
USING (oauth2_session_id)
INNER JOIN user_sessions us
ON us.id = os.user_session_id
USING (user_session_id)
INNER JOIN users u
ON u.id = us.user_id
USING (user_id)
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
ON ue.user_email_id = u.primary_user_email_id
WHERE at.token = $1
AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()
AND us.active
AND os.ended_at IS NULL
WHERE at.access_token = $1
AND at.revoked_at IS NULL
AND os.finished_at IS NULL
ORDER BY usa.created_at DESC
LIMIT 1
@ -158,14 +171,14 @@ where
.await?;
let access_token = AccessToken {
data: res.access_token_id,
jti: format!("{}", res.access_token_id),
token: res.access_token,
created_at: res.access_token_created_at,
expires_after: Duration::seconds(res.access_token_expires_after.into()),
data: res.oauth2_access_token_id.into(),
jti: res.oauth2_access_token_id.to_string(),
access_token: res.oauth2_access_token,
created_at: res.oauth2_access_token_created_at,
expires_at: res.oauth2_access_token_expires_at,
};
let client = lookup_client(&mut *conn, res.oauth2_client_id).await?;
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
let primary_email = match (
res.user_email_id,
@ -174,7 +187,7 @@ where
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -183,10 +196,11 @@ where
_ => return Err(DatabaseInconsistencyError.into()),
};
let id = Ulid::from(res.user_id);
let user = User {
data: res.user_id,
data: id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
sub: id.to_string(),
primary_email,
};
@ -196,14 +210,14 @@ where
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
data: id.into(),
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = BrowserSession {
data: res.user_session_id,
data: res.user_session_id.into(),
created_at: res.user_session_created_at,
user,
last_authentication,
@ -212,7 +226,7 @@ where
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let session = Session {
data: res.session_id,
data: res.oauth2_session_id.into(),
client,
browser_session,
scope,
@ -222,16 +236,24 @@ where
}
}
#[tracing::instrument(
skip_all,
fields(access_token.id = %access_token.data),
err(Debug),
)]
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
access_token: &AccessToken<PostgresqlBackend>,
access_token: AccessToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let revoked_at = Utc::now();
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE id = $1
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
access_token.data,
Uuid::from(access_token.data),
revoked_at,
)
.execute(executor)
.await
@ -245,11 +267,14 @@ pub async fn revoke_access_token(
}
pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result<u64> {
// Cleanup token which expired more than 15 minutes ago
let threshold = Utc::now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()
WHERE expires_at < $1
"#,
threshold,
)
.execute(executor)
.await

View File

@ -25,11 +25,21 @@ use mas_data_model::{
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope};
use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use super::client::lookup_client;
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument(
skip_all,
fields(
client.id = %client.data,
grant.id,
),
err(Debug),
)]
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>,
@ -40,7 +50,7 @@ pub async fn new_authorization_grant(
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
acr_values: Option<String>,
_acr_values: Option<String>,
response_mode: ResponseMode,
response_type_id_token: bool,
requires_consent: bool,
@ -53,26 +63,43 @@ pub async fn new_authorization_grant(
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
// TODO: this conversion is a bit ugly
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
tracing::Span::current().record("grant.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_authorization_grants
(oauth2_client_id, redirect_uri, scope, state, nonce, max_age,
acr_values, response_mode, code_challenge, code_challenge_method,
response_type_code, response_type_id_token, code, requires_consent)
INSERT INTO oauth2_authorization_grants (
oauth2_authorization_grant_id,
oauth2_client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_mode,
code_challenge,
code_challenge_method,
response_type_code,
response_type_id_token,
authorization_code,
requires_consent,
created_at
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id, created_at
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
"#,
&client.data,
Uuid::from(id),
Uuid::from(client.data),
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
// TODO: this conversion is a bit ugly
max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)),
acr_values,
max_age_i32,
response_mode.to_string(),
code_challenge,
code_challenge_method,
@ -80,13 +107,14 @@ pub async fn new_authorization_grant(
response_type_id_token,
code_str,
requires_consent,
created_at,
)
.fetch_one(executor)
.execute(executor)
.await
.context("could not insert oauth2 authorization grant")?;
Ok(AuthorizationGrant {
data: res.id,
data: id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
@ -95,9 +123,8 @@ pub async fn new_authorization_grant(
state,
nonce,
max_age,
acr_values,
response_mode,
created_at: res.created_at,
created_at,
response_type_id_token,
requires_consent,
})
@ -105,33 +132,32 @@ pub async fn new_authorization_grant(
#[allow(clippy::struct_excessive_bools)]
struct GrantLookup {
grant_id: i64,
grant_created_at: DateTime<Utc>,
grant_cancelled_at: Option<DateTime<Utc>>,
grant_fulfilled_at: Option<DateTime<Utc>>,
grant_exchanged_at: Option<DateTime<Utc>>,
grant_scope: String,
grant_state: Option<String>,
grant_redirect_uri: String,
grant_response_mode: String,
grant_nonce: Option<String>,
grant_max_age: Option<i32>,
grant_acr_values: Option<String>,
grant_response_type_code: bool,
grant_response_type_id_token: bool,
grant_code: Option<String>,
grant_code_challenge: Option<String>,
grant_code_challenge_method: Option<String>,
grant_requires_consent: bool,
oauth2_client_id: i64,
session_id: Option<i64>,
user_session_id: Option<i64>,
oauth2_authorization_grant_id: Uuid,
oauth2_authorization_grant_created_at: DateTime<Utc>,
oauth2_authorization_grant_cancelled_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_fulfilled_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_exchanged_at: Option<DateTime<Utc>>,
oauth2_authorization_grant_scope: String,
oauth2_authorization_grant_state: Option<String>,
oauth2_authorization_grant_nonce: Option<String>,
oauth2_authorization_grant_redirect_uri: String,
oauth2_authorization_grant_response_mode: String,
oauth2_authorization_grant_max_age: Option<i32>,
oauth2_authorization_grant_response_type_code: bool,
oauth2_authorization_grant_response_type_id_token: bool,
oauth2_authorization_grant_code: Option<String>,
oauth2_authorization_grant_code_challenge: Option<String>,
oauth2_authorization_grant_code_challenge_method: Option<String>,
oauth2_authorization_grant_requires_consent: bool,
oauth2_client_id: Uuid,
oauth2_session_id: Option<Uuid>,
user_session_id: Option<Uuid>,
user_session_created_at: Option<DateTime<Utc>>,
user_id: Option<i64>,
user_id: Option<Uuid>,
user_username: Option<String>,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -144,12 +170,12 @@ impl GrantLookup {
executor: impl PgExecutor<'_>,
) -> Result<AuthorizationGrant<PostgresqlBackend>, DatabaseInconsistencyError> {
let scope: Scope = self
.grant_scope
.oauth2_authorization_grant_scope
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
// TODO: don't unwrap
let client = lookup_client(executor, self.oauth2_client_id)
let client = lookup_client(executor, self.oauth2_client_id.into())
.await
.unwrap();
@ -158,7 +184,7 @@ impl GrantLookup {
self.user_session_last_authentication_created_at,
) {
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
data: id.into(),
created_at,
}),
(None, None) => None,
@ -172,7 +198,7 @@ impl GrantLookup {
self.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -182,7 +208,7 @@ impl GrantLookup {
};
let session = match (
self.session_id,
self.oauth2_session_id,
self.user_session_id,
self.user_session_created_at,
self.user_id,
@ -199,15 +225,16 @@ impl GrantLookup {
last_authentication,
primary_email,
) => {
let user_id = Ulid::from(user_id);
let user = User {
data: user_id,
username: user_username,
sub: format!("fake-sub-{}", user_id),
sub: user_id.to_string(),
primary_email,
};
let browser_session = BrowserSession {
data: user_session_id,
data: user_session_id.into(),
user,
created_at: user_session_created_at,
last_authentication,
@ -217,7 +244,7 @@ impl GrantLookup {
let scope = scope.clone();
let session = Session {
data: session_id,
data: session_id.into(),
client,
browser_session,
scope,
@ -230,9 +257,9 @@ impl GrantLookup {
};
let stage = match (
self.grant_fulfilled_at,
self.grant_exchanged_at,
self.grant_cancelled_at,
self.oauth2_authorization_grant_fulfilled_at,
self.oauth2_authorization_grant_exchanged_at,
self.oauth2_authorization_grant_cancelled_at,
session,
) {
(None, None, None, None) => AuthorizationGrantStage::Pending,
@ -255,7 +282,10 @@ impl GrantLookup {
}
};
let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) {
let pkce = match (
self.oauth2_authorization_grant_code_challenge,
self.oauth2_authorization_grant_code_challenge_method,
) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: PkceCodeChallengeMethod::Plain,
@ -272,27 +302,30 @@ impl GrantLookup {
}
};
let code: Option<AuthorizationCode> =
match (self.grant_response_type_code, self.grant_code, pkce) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(DatabaseInconsistencyError);
}
};
let code: Option<AuthorizationCode> = match (
self.oauth2_authorization_grant_response_type_code,
self.oauth2_authorization_grant_code,
pkce,
) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(DatabaseInconsistencyError);
}
};
let redirect_uri = self
.grant_redirect_uri
.oauth2_authorization_grant_redirect_uri
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let response_mode = self
.grant_response_mode
.oauth2_authorization_grant_response_mode
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let max_age = self
.grant_max_age
.oauth2_authorization_grant_max_age
.map(u32::try_from)
.transpose()
.map_err(|_e| DatabaseInconsistencyError)?
@ -301,82 +334,85 @@ impl GrantLookup {
.map_err(|_e| DatabaseInconsistencyError)?;
Ok(AuthorizationGrant {
data: self.grant_id,
data: self.oauth2_authorization_grant_id.into(),
stage,
client,
code,
acr_values: self.grant_acr_values,
scope,
state: self.grant_state,
nonce: self.grant_nonce,
state: self.oauth2_authorization_grant_state,
nonce: self.oauth2_authorization_grant_nonce,
max_age, // TODO
response_mode,
redirect_uri,
created_at: self.grant_created_at,
response_type_id_token: self.grant_response_type_id_token,
requires_consent: self.grant_requires_consent,
created_at: self.oauth2_authorization_grant_created_at,
response_type_id_token: self.oauth2_authorization_grant_response_type_id_token,
requires_consent: self.oauth2_authorization_grant_requires_consent,
})
}
}
#[tracing::instrument(
skip_all,
fields(grant.id = %id),
err(Debug),
)]
pub async fn get_grant_by_id(
conn: &mut PgConnection,
id: i64,
id: Ulid,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT
og.id AS grant_id,
og.created_at AS grant_created_at,
og.cancelled_at AS grant_cancelled_at,
og.fulfilled_at AS grant_fulfilled_at,
og.exchanged_at AS grant_exchanged_at,
og.scope AS grant_scope,
og.state AS grant_state,
og.redirect_uri AS grant_redirect_uri,
og.response_mode AS grant_response_mode,
og.nonce AS grant_nonce,
og.max_age AS grant_max_age,
og.acr_values AS grant_acr_values,
og.oauth2_client_id AS oauth2_client_id,
og.code AS grant_code,
og.response_type_code AS grant_response_type_code,
og.response_type_id_token AS grant_response_type_id_token,
og.code_challenge AS grant_code_challenge,
og.code_challenge_method AS grant_code_challenge_method,
og.requires_consent AS grant_requires_consent,
os.id AS "session_id?",
us.id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.id AS "user_id?",
u.username AS "user_username?",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
og.oauth2_authorization_grant_id,
og.created_at AS oauth2_authorization_grant_created_at,
og.cancelled_at AS oauth2_authorization_grant_cancelled_at,
og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,
og.exchanged_at AS oauth2_authorization_grant_exchanged_at,
og.scope AS oauth2_authorization_grant_scope,
og.state AS oauth2_authorization_grant_state,
og.redirect_uri AS oauth2_authorization_grant_redirect_uri,
og.response_mode AS oauth2_authorization_grant_response_mode,
og.nonce AS oauth2_authorization_grant_nonce,
og.max_age AS oauth2_authorization_grant_max_age,
og.oauth2_client_id AS oauth2_client_id,
og.authorization_code AS oauth2_authorization_grant_code,
og.response_type_code AS oauth2_authorization_grant_response_type_code,
og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,
og.code_challenge AS oauth2_authorization_grant_code_challenge,
og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,
og.requires_consent AS oauth2_authorization_grant_requires_consent,
os.oauth2_session_id AS "oauth2_session_id?",
us.user_session_id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.user_id AS "user_id?",
u.username AS "user_username?",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM
oauth2_authorization_grants og
LEFT JOIN oauth2_sessions os
ON os.id = og.oauth2_session_id
USING (oauth2_session_id)
LEFT JOIN user_sessions us
ON us.id = os.user_session_id
USING (user_session_id)
LEFT JOIN users u
ON u.id = us.user_id
USING (user_id)
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
ON ue.user_email_id = u.primary_user_email_id
WHERE og.id = $1
WHERE og.oauth2_authorization_grant_id = $1
ORDER BY usa.created_at DESC
LIMIT 1
"#,
id,
Uuid::from(id),
)
.fetch_one(&mut *conn)
.await
@ -387,6 +423,7 @@ pub async fn get_grant_by_id(
Ok(grant)
}
#[tracing::instrument(skip_all, err(Debug))]
pub async fn lookup_grant_by_code(
conn: &mut PgConnection,
code: &str,
@ -396,50 +433,49 @@ pub async fn lookup_grant_by_code(
GrantLookup,
r#"
SELECT
og.id AS grant_id,
og.created_at AS grant_created_at,
og.cancelled_at AS grant_cancelled_at,
og.fulfilled_at AS grant_fulfilled_at,
og.exchanged_at AS grant_exchanged_at,
og.scope AS grant_scope,
og.state AS grant_state,
og.redirect_uri AS grant_redirect_uri,
og.response_mode AS grant_response_mode,
og.nonce AS grant_nonce,
og.max_age AS grant_max_age,
og.acr_values AS grant_acr_values,
og.oauth2_client_id AS oauth2_client_id,
og.code AS grant_code,
og.response_type_code AS grant_response_type_code,
og.response_type_id_token AS grant_response_type_id_token,
og.code_challenge AS grant_code_challenge,
og.code_challenge_method AS grant_code_challenge_method,
og.requires_consent AS grant_requires_consent,
os.id AS "session_id?",
us.id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.id AS "user_id?",
u.username AS "user_username?",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
og.oauth2_authorization_grant_id,
og.created_at AS oauth2_authorization_grant_created_at,
og.cancelled_at AS oauth2_authorization_grant_cancelled_at,
og.fulfilled_at AS oauth2_authorization_grant_fulfilled_at,
og.exchanged_at AS oauth2_authorization_grant_exchanged_at,
og.scope AS oauth2_authorization_grant_scope,
og.state AS oauth2_authorization_grant_state,
og.redirect_uri AS oauth2_authorization_grant_redirect_uri,
og.response_mode AS oauth2_authorization_grant_response_mode,
og.nonce AS oauth2_authorization_grant_nonce,
og.max_age AS oauth2_authorization_grant_max_age,
og.oauth2_client_id AS oauth2_client_id,
og.authorization_code AS oauth2_authorization_grant_code,
og.response_type_code AS oauth2_authorization_grant_response_type_code,
og.response_type_id_token AS oauth2_authorization_grant_response_type_id_token,
og.code_challenge AS oauth2_authorization_grant_code_challenge,
og.code_challenge_method AS oauth2_authorization_grant_code_challenge_method,
og.requires_consent AS oauth2_authorization_grant_requires_consent,
os.oauth2_session_id AS "oauth2_session_id?",
us.user_session_id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.user_id AS "user_id?",
u.username AS "user_username?",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM
oauth2_authorization_grants og
LEFT JOIN oauth2_sessions os
ON os.id = og.oauth2_session_id
USING (oauth2_session_id)
LEFT JOIN user_sessions us
ON us.id = os.user_session_id
USING (user_session_id)
LEFT JOIN users u
ON u.id = us.user_id
USING (user_id)
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
ON ue.user_email_id = u.primary_user_email_id
WHERE og.code = $1
WHERE og.authorization_code = $1
ORDER BY usa.created_at DESC
LIMIT 1
@ -455,41 +491,69 @@ pub async fn lookup_grant_by_code(
Ok(grant)
}
#[tracing::instrument(
skip_all,
fields(
grant.id = %grant.data,
client.id = %grant.client.data,
session.id,
user_session.id = %browser_session.data,
user.id = %browser_session.user.data,
),
err(Debug),
)]
pub async fn derive_session(
executor: impl PgExecutor<'_>,
grant: &AuthorizationGrant<PostgresqlBackend>,
browser_session: BrowserSession<PostgresqlBackend>,
) -> anyhow::Result<Session<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
tracing::Span::current().record("session.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_sessions
(user_session_id, oauth2_client_id, scope)
(oauth2_session_id, user_session_id, oauth2_client_id, scope, created_at)
SELECT
$1,
$2,
og.oauth2_client_id,
og.scope
og.scope,
$3
FROM
oauth2_authorization_grants og
WHERE
og.id = $2
RETURNING id, created_at
og.oauth2_authorization_grant_id = $4
"#,
browser_session.data,
grant.data,
Uuid::from(id),
Uuid::from(browser_session.data),
created_at,
Uuid::from(grant.data),
)
.fetch_one(executor)
.execute(executor)
.await
.context("could not insert oauth2 session")?;
Ok(Session {
data: res.id,
data: id,
browser_session,
client: grant.client.clone(),
scope: grant.scope.clone(),
})
}
#[tracing::instrument(
skip_all,
fields(
grant.id = %grant.data,
client.id = %grant.client.data,
session.id = %session.data,
user_session.id = %session.browser_session.data,
user.id = %session.browser_session.user.data,
),
err(Debug),
)]
pub async fn fulfill_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>,
@ -499,15 +563,16 @@ pub async fn fulfill_grant(
r#"
UPDATE oauth2_authorization_grants AS og
SET
oauth2_session_id = os.id,
oauth2_session_id = os.oauth2_session_id,
fulfilled_at = os.created_at
FROM oauth2_sessions os
WHERE
og.id = $1 AND os.id = $2
og.oauth2_authorization_grant_id = $1
AND os.oauth2_session_id = $2
RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>"
"#,
grant.data,
session.data,
Uuid::from(grant.data),
Uuid::from(session.data),
)
.fetch_one(executor)
.await
@ -518,6 +583,14 @@ pub async fn fulfill_grant(
Ok(grant)
}
#[tracing::instrument(
skip_all,
fields(
grant.id = %grant.data,
client.id = %grant.client.data,
),
err(Debug),
)]
pub async fn give_consent_to_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>,
@ -528,9 +601,9 @@ pub async fn give_consent_to_grant(
SET
requires_consent = 'f'
WHERE
og.id = $1
og.oauth2_authorization_grant_id = $1
"#,
grant.data,
Uuid::from(grant.data),
)
.execute(executor)
.await?;
@ -540,22 +613,29 @@ pub async fn give_consent_to_grant(
Ok(grant)
}
#[tracing::instrument(
skip_all,
fields(
grant.id = %grant.data,
client.id = %grant.client.data,
),
err(Debug),
)]
pub async fn exchange_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
let exchanged_at = sqlx::query_scalar!(
let exchanged_at = Utc::now();
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants
SET
exchanged_at = NOW()
WHERE
id = $1
RETURNING exchanged_at AS "exchanged_at!: DateTime<Utc>"
SET exchanged_at = $2
WHERE oauth2_authorization_grant_id = $1
"#,
grant.data,
Uuid::from(grant.data),
exchanged_at,
)
.fetch_one(executor)
.execute(executor)
.await
.context("could not mark grant as exchanged")?;

View File

@ -20,23 +20,25 @@ use mas_iana::{
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::jwk::PublicJsonWebKeySet;
use oauth2_types::{requests::GrantType, response_type::ResponseType};
use oauth2_types::requests::GrantType;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::PostgresqlBackend;
// XXX: response_types & contacts
#[derive(Debug)]
pub struct OAuth2ClientLookup {
id: i64,
client_id: String,
oauth2_client_id: Uuid,
encrypted_client_secret: Option<String>,
redirect_uris: Vec<String>,
response_types: Vec<String>,
// response_types: Vec<String>,
grant_type_authorization_code: bool,
grant_type_refresh_token: bool,
contacts: Vec<String>,
// contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<String>,
client_uri: Option<String>,
@ -53,6 +55,9 @@ pub struct OAuth2ClientLookup {
#[derive(Debug, Error)]
pub enum ClientFetchError {
#[error("invalid client ID")]
InvalidClientId(#[from] ulid::DecodeError),
#[error("malformed jwks column")]
MalformedJwks(#[source] serde_json::Error),
@ -78,7 +83,10 @@ pub enum ClientFetchError {
impl ClientFetchError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
matches!(
self,
Self::Database(sqlx::Error::RowNotFound) | Self::InvalidClientId(_)
)
}
}
@ -94,12 +102,19 @@ impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
source,
})?;
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code,
OAuthAuthorizationEndpointResponseType::IdToken,
OAuthAuthorizationEndpointResponseType::None,
];
/* XXX
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
self.response_types.iter().map(|s| s.parse()).collect();
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
field: "response_types",
source,
})?;
*/
let mut grant_types = Vec::new();
if self.grant_type_authorization_code {
@ -210,13 +225,14 @@ impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
};
Ok(Client {
data: self.id,
client_id: self.client_id,
data: self.oauth2_client_id.into(),
client_id: self.oauth2_client_id.to_string(),
encrypted_client_secret: self.encrypted_client_secret,
redirect_uris,
response_types,
grant_types,
contacts: self.contacts,
// contacts: self.contacts,
contacts: vec![],
client_name: self.client_name,
logo_uri,
client_uri,
@ -234,20 +250,21 @@ impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
pub async fn lookup_client(
executor: impl PgExecutor<'_>,
id: i64,
id: Ulid,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT
c.id,
c.client_id,
c.oauth2_client_id,
c.encrypted_client_secret,
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
c.response_types,
ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!",
c.grant_type_authorization_code,
c.grant_type_refresh_token,
c.contacts,
c.client_name,
c.logo_uri,
c.client_uri,
@ -262,9 +279,9 @@ pub async fn lookup_client(
c.initiate_login_uri
FROM oauth2_clients c
WHERE c.id = $1
WHERE c.oauth2_client_id = $1
"#,
id,
Uuid::from(id),
)
.fetch_one(executor)
.await?;
@ -278,53 +295,18 @@ pub async fn lookup_client_by_client_id(
executor: impl PgExecutor<'_>,
client_id: &str,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT
c.id,
c.client_id,
c.encrypted_client_secret,
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
c.response_types,
c.grant_type_authorization_code,
c.grant_type_refresh_token,
c.contacts,
c.client_name,
c.logo_uri,
c.client_uri,
c.policy_uri,
c.tos_uri,
c.jwks_uri,
c.jwks,
c.id_token_signed_response_alg,
c.userinfo_signed_response_alg,
c.token_endpoint_auth_method,
c.token_endpoint_auth_signing_alg,
c.initiate_login_uri
FROM oauth2_clients c
WHERE c.client_id = $1
"#,
client_id,
)
.fetch_one(executor)
.await?;
let client = res.try_into()?;
Ok(client)
let id: Ulid = client_id.parse()?;
lookup_client(executor, id).await
}
#[allow(clippy::too_many_arguments)]
pub async fn insert_client(
conn: &mut PgConnection,
client_id: &str,
client_id: Ulid,
redirect_uris: &[Url],
encrypted_client_secret: Option<&str>,
response_types: &[ResponseType],
grant_types: &[GrantType],
contacts: &[String],
_contacts: &[String],
client_name: Option<&str>,
logo_uri: Option<&Url>,
client_uri: Option<&Url>,
@ -338,7 +320,6 @@ pub async fn insert_client(
token_endpoint_auth_signing_alg: Option<&JsonWebSignatureAlg>,
initiate_login_uri: Option<&Url>,
) -> Result<(), sqlx::Error> {
let response_types: Vec<String> = response_types.iter().map(ToString::to_string).collect();
let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode);
let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken);
let logo_uri = logo_uri.map(Url::as_str);
@ -353,15 +334,13 @@ pub async fn insert_client(
let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(ToString::to_string);
let initiate_login_uri = initiate_login_uri.map(Url::as_str);
let id = sqlx::query_scalar!(
sqlx::query!(
r#"
INSERT INTO oauth2_clients
(client_id,
(oauth2_client_id,
encrypted_client_secret,
response_types,
grant_type_authorization_code,
grant_type_refresh_token,
contacts,
client_name,
logo_uri,
client_uri,
@ -375,15 +354,12 @@ pub async fn insert_client(
token_endpoint_auth_signing_alg,
initiate_login_uri)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)
RETURNING id
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
"#,
client_id,
Uuid::from(client_id),
encrypted_client_secret,
&response_types,
grant_type_authorization_code,
grant_type_refresh_token,
contacts,
client_name,
logo_uri,
client_uri,
@ -397,96 +373,87 @@ pub async fn insert_client(
token_endpoint_auth_signing_alg,
initiate_login_uri,
)
.fetch_one(&mut *conn)
.await?;
let redirect_uris: Vec<String> = redirect_uris.iter().map(ToString::to_string).collect();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)
SELECT $1, uri FROM UNNEST($2::text[]) uri
"#,
id,
&redirect_uris,
)
.execute(&mut *conn)
.await?;
for redirect_uri in redirect_uris {
let id = Ulid::new();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(client_id),
redirect_uri.as_str(),
)
.execute(&mut *conn)
.await?;
}
Ok(())
}
pub async fn insert_client_from_config(
conn: &mut PgConnection,
client_id: &str,
client_id: Ulid,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<&str>,
jwks: Option<&PublicJsonWebKeySet>,
jwks_uri: Option<&Url>,
redirect_uris: &[Url],
) -> anyhow::Result<()> {
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code.to_string(),
OAuthAuthorizationEndpointResponseType::CodeIdToken.to_string(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.to_string(),
OAuthAuthorizationEndpointResponseType::CodeToken.to_string(),
OAuthAuthorizationEndpointResponseType::IdToken.to_string(),
OAuthAuthorizationEndpointResponseType::IdTokenToken.to_string(),
OAuthAuthorizationEndpointResponseType::None.to_string(),
OAuthAuthorizationEndpointResponseType::Token.to_string(),
];
let jwks = jwks.map(serde_json::to_value).transpose()?;
let jwks_uri = jwks_uri.map(Url::as_str);
let client_auth_method = client_auth_method.to_string();
let id = sqlx::query_scalar!(
sqlx::query!(
r#"
INSERT INTO oauth2_clients
(client_id,
(oauth2_client_id,
encrypted_client_secret,
response_types,
grant_type_authorization_code,
grant_type_refresh_token,
token_endpoint_auth_method,
jwks,
jwks_uri,
contacts)
jwks_uri)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, '{}')
RETURNING id
($1, $2, $3, $4, $5, $6, $7)
"#,
client_id,
Uuid::from(client_id),
encrypted_client_secret,
&response_types,
true,
true,
client_auth_method,
jwks,
jwks_uri,
)
.fetch_one(&mut *conn)
.await?;
let redirect_uris: Vec<String> = redirect_uris.iter().map(ToString::to_string).collect();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)
SELECT $1, uri FROM UNNEST($2::text[]) uri
"#,
id,
&redirect_uris,
)
.execute(&mut *conn)
.await?;
for redirect_uri in redirect_uris {
let id = Ulid::new();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
Uuid::from(client_id),
redirect_uri.as_str(),
)
.execute(&mut *conn)
.await?;
}
Ok(())
}
pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> anyhow::Result<()> {
sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients RESTART IDENTITY CASCADE")
sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients CASCADE")
.execute(executor)
.await?;
Ok(())

View File

@ -14,9 +14,12 @@
use std::str::FromStr;
use chrono::Utc;
use mas_data_model::{Client, User};
use oauth2_types::scope::{Scope, ScopeToken};
use sqlx::PgExecutor;
use ulid::Ulid;
use uuid::Uuid;
use crate::PostgresqlBackend;
@ -31,8 +34,8 @@ pub async fn fetch_client_consent(
FROM oauth2_consents
WHERE user_id = $1 AND oauth2_client_id = $2
"#,
user.data,
client.data,
Uuid::from(user.data),
Uuid::from(client.data),
)
.fetch_all(executor)
.await?;
@ -51,17 +54,29 @@ pub async fn insert_client_consent(
client: &Client<PostgresqlBackend>,
scope: &Scope,
) -> anyhow::Result<()> {
let tokens: Vec<String> = scope.iter().map(ToString::to_string).collect();
let now = Utc::now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime(now.into())),
)
})
.unzip();
sqlx::query!(
r#"
INSERT INTO oauth2_consents (user_id, oauth2_client_id, scope_token)
SELECT $1, $2, scope_token FROM UNNEST($3::text[]) scope_token
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET updated_at = NOW()
INSERT INTO oauth2_consents
(oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
"#,
user.data,
client.data,
&ids,
Uuid::from(user.data),
Uuid::from(client.data),
&tokens,
now,
)
.execute(executor)
.await?;

View File

@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::Utc;
use mas_data_model::Session;
use sqlx::PgExecutor;
use uuid::Uuid;
use crate::PostgresqlBackend;
@ -27,13 +29,15 @@ pub async fn end_oauth_session(
executor: impl PgExecutor<'_>,
session: Session<PostgresqlBackend>,
) -> anyhow::Result<()> {
let finished_at = Utc::now();
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions
SET ended_at = NOW()
WHERE id = $1
SET finished_at = $2
WHERE oauth2_session_id = $1
"#,
session.data,
Uuid::from(session.data),
finished_at,
)
.execute(executor)
.await?;

View File

@ -13,66 +13,71 @@
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use chrono::{DateTime, Utc};
use mas_data_model::{
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
};
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
pub async fn add_refresh_token(
executor: impl PgExecutor<'_>,
session: &Session<PostgresqlBackend>,
access_token: AccessToken<PostgresqlBackend>,
token: &str,
refresh_token: String,
) -> anyhow::Result<RefreshToken<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_session_id, oauth2_access_token_id, token)
(oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
refresh_token, created_at)
VALUES
($1, $2, $3)
RETURNING
id, created_at
($1, $2, $3, $4, $5)
"#,
session.data,
access_token.data,
token,
Uuid::from(id),
Uuid::from(session.data),
Uuid::from(access_token.data),
refresh_token,
created_at,
)
.fetch_one(executor)
.execute(executor)
.await
.context("could not insert oauth2 refresh token")?;
Ok(RefreshToken {
data: res.id,
token: token.to_owned(),
data: id,
refresh_token,
access_token: Some(access_token),
created_at: res.created_at,
created_at,
})
}
struct OAuth2RefreshTokenLookup {
refresh_token_id: i64,
refresh_token: String,
refresh_token_created_at: DateTime<Utc>,
access_token_id: Option<i64>,
access_token: Option<String>,
access_token_expires_after: Option<i32>,
access_token_created_at: Option<DateTime<Utc>>,
session_id: i64,
oauth2_client_id: i64,
scope: String,
user_session_id: i64,
oauth2_refresh_token_id: Uuid,
oauth2_refresh_token: String,
oauth2_refresh_token_created_at: DateTime<Utc>,
oauth2_access_token_id: Option<Uuid>,
oauth2_access_token: Option<String>,
oauth2_access_token_created_at: Option<DateTime<Utc>>,
oauth2_access_token_expires_at: Option<DateTime<Utc>>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
oauth2_session_scope: String,
user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_id: i64,
user_id: Uuid,
user_username: String,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -103,44 +108,45 @@ pub async fn lookup_active_refresh_token(
OAuth2RefreshTokenLookup,
r#"
SELECT
rt.id AS refresh_token_id,
rt.token AS refresh_token,
rt.created_at AS refresh_token_created_at,
at.id AS "access_token_id?",
at.token AS "access_token?",
at.expires_after AS "access_token_expires_after?",
at.created_at AS "access_token_created_at?",
os.id AS "session_id!",
os.oauth2_client_id AS "oauth2_client_id!",
os.scope AS "scope!",
us.id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.id AS "user_id!",
u.username AS "user_username!",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
rt.oauth2_refresh_token_id,
rt.refresh_token AS oauth2_refresh_token,
rt.created_at AS oauth2_refresh_token_created_at,
at.oauth2_access_token_id AS "oauth2_access_token_id?",
at.access_token AS "oauth2_access_token?",
at.created_at AS "oauth2_access_token_created_at?",
at.expires_at AS "oauth2_access_token_expires_at?",
os.oauth2_session_id AS "oauth2_session_id!",
os.oauth2_client_id AS "oauth2_client_id!",
os.scope AS "oauth2_session_scope!",
us.user_session_id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.user_id AS "user_id!",
u.username AS "user_username!",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM oauth2_refresh_tokens rt
LEFT JOIN oauth2_access_tokens at
ON at.id = rt.oauth2_access_token_id
INNER JOIN oauth2_sessions os
ON os.id = rt.oauth2_session_id
USING (oauth2_session_id)
LEFT JOIN oauth2_access_tokens at
USING (oauth2_access_token_id)
INNER JOIN user_sessions us
ON us.id = os.user_session_id
USING (user_session_id)
INNER JOIN users u
ON u.id = us.user_id
USING (user_id)
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
ON ue.user_email_id = u.primary_user_email_id
WHERE rt.token = $1
AND rt.next_token_id IS NULL
AND us.active
AND os.ended_at IS NULL
WHERE rt.refresh_token = $1
AND rt.consumed_at IS NULL
AND rt.revoked_at IS NULL
AND us.finished_at IS NULL
AND os.finished_at IS NULL
ORDER BY usa.created_at DESC
LIMIT 1
@ -151,30 +157,31 @@ pub async fn lookup_active_refresh_token(
.await?;
let access_token = match (
res.access_token_id,
res.access_token,
res.access_token_created_at,
res.access_token_expires_after,
res.oauth2_access_token_id,
res.oauth2_access_token,
res.oauth2_access_token_created_at,
res.oauth2_access_token_expires_at,
) {
(None, None, None, None) => None,
(Some(id), Some(token), Some(created_at), Some(expires_after)) => Some(AccessToken {
data: id,
jti: format!("{}", id),
token,
(Some(id), Some(access_token), Some(created_at), Some(expires_at)) => Some(AccessToken {
data: id.into(),
// XXX: are we doing that everywhere?
jti: Ulid::from(id).to_string(),
access_token,
created_at,
expires_after: Duration::seconds(expires_after.into()),
expires_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let refresh_token = RefreshToken {
data: res.refresh_token_id,
token: res.refresh_token,
created_at: res.refresh_token_created_at,
data: res.oauth2_refresh_token_id.into(),
refresh_token: res.oauth2_refresh_token,
created_at: res.oauth2_refresh_token_created_at,
access_token,
};
let client = lookup_client(&mut *conn, res.oauth2_client_id).await?;
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
let primary_email = match (
res.user_email_id,
@ -183,7 +190,7 @@ pub async fn lookup_active_refresh_token(
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -192,10 +199,11 @@ pub async fn lookup_active_refresh_token(
_ => return Err(DatabaseInconsistencyError.into()),
};
let id = Ulid::from(res.user_id);
let user = User {
data: res.user_id,
data: id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
sub: id.to_string(),
primary_email,
};
@ -205,23 +213,26 @@ pub async fn lookup_active_refresh_token(
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
data: id.into(),
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = BrowserSession {
data: res.user_session_id,
data: res.user_session_id.into(),
created_at: res.user_session_created_at,
user,
last_authentication,
};
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let scope = res
.oauth2_session_scope
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let session = Session {
data: res.session_id,
data: res.oauth2_session_id.into(),
client,
browser_session,
scope,
@ -230,19 +241,19 @@ pub async fn lookup_active_refresh_token(
Ok((refresh_token, session))
}
pub async fn replace_refresh_token(
pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>,
refresh_token: &RefreshToken<PostgresqlBackend>,
next_refresh_token: &RefreshToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let consumed_at = Utc::now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET next_token_id = $2
WHERE id = $1
SET consumed_at = $2
WHERE oauth2_refresh_token_id = $1
"#,
refresh_token.data,
next_refresh_token.data
Uuid::from(refresh_token.data),
consumed_at,
)
.execute(executor)
.await

View File

@ -22,20 +22,21 @@ use mas_data_model::{
UserEmailVerificationState,
};
use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::rngs::OsRng;
use sqlx::{postgres::types::PgInterval, Acquire, PgExecutor, Postgres, Transaction};
use rand::thread_rng;
use sqlx::{Acquire, PgExecutor, Postgres, Transaction};
use thiserror::Error;
use tokio::task;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::IdAndCreationTime;
#[derive(Debug, Clone)]
struct UserLookup {
user_id: i64,
user_id: Uuid,
user_username: String,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -114,13 +115,13 @@ impl ActiveSessionLookupError {
}
struct SessionLookup {
id: i64,
user_id: i64,
user_session_id: Uuid,
user_id: Uuid,
username: String,
created_at: DateTime<Utc>,
last_authentication_id: Option<i64>,
last_authentication_id: Option<Uuid>,
last_authd_at: Option<DateTime<Utc>>,
user_email_id: Option<i64>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -137,7 +138,7 @@ impl TryInto<BrowserSession<PostgresqlBackend>> for SessionLookup {
self.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -146,16 +147,17 @@ impl TryInto<BrowserSession<PostgresqlBackend>> for SessionLookup {
_ => return Err(DatabaseInconsistencyError),
};
let id = Ulid::from(self.user_id);
let user = User {
data: self.user_id,
data: id,
username: self.username,
sub: format!("fake-sub-{}", self.user_id),
sub: id.to_string(),
primary_email,
};
let last_authentication = match (self.last_authentication_id, self.last_authd_at) {
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
data: id.into(),
created_at,
}),
(None, None) => None,
@ -163,7 +165,7 @@ impl TryInto<BrowserSession<PostgresqlBackend>> for SessionLookup {
};
Ok(BrowserSession {
data: self.id,
data: self.user_session_id.into(),
user,
created_at: self.created_at,
last_authentication,
@ -171,37 +173,37 @@ impl TryInto<BrowserSession<PostgresqlBackend>> for SessionLookup {
}
}
#[tracing::instrument(skip_all, fields(session.id = id))]
#[tracing::instrument(skip_all, fields(session.id = %id))]
pub async fn lookup_active_session(
executor: impl PgExecutor<'_>,
id: i64,
id: Ulid,
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
s.id,
u.id AS user_id,
s.user_session_id,
u.user_id,
u.username,
s.created_at,
a.id AS "last_authentication_id?",
a.created_at AS "last_authd_at?",
ue.id AS "user_email_id?",
a.user_session_authentication_id AS "last_authentication_id?",
a.created_at AS "last_authd_at?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM user_sessions s
INNER JOIN users u
ON s.user_id = u.id
USING (user_id)
LEFT JOIN user_session_authentications a
ON a.session_id = s.id
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
WHERE s.id = $1 AND s.active
ON ue.user_email_id = u.primary_user_email_id
WHERE s.user_session_id = $1 AND s.finished_at IS NULL
ORDER BY a.created_at DESC
LIMIT 1
"#,
id,
Uuid::from(id),
)
.fetch_one(executor)
.await?
@ -210,35 +212,37 @@ pub async fn lookup_active_session(
Ok(res)
}
#[tracing::instrument(skip_all, fields(user.id = user.data))]
#[tracing::instrument(skip_all, fields(user.id = %user.data))]
pub async fn start_session(
executor: impl PgExecutor<'_>,
user: User<PostgresqlBackend>,
) -> anyhow::Result<BrowserSession<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO user_sessions (user_id)
VALUES ($1)
RETURNING id, created_at
INSERT INTO user_sessions (user_session_id, user_id, created_at)
VALUES ($1, $2, $3)
"#,
user.data,
Uuid::from(id),
Uuid::from(user.data),
created_at,
)
.fetch_one(executor)
.execute(executor)
.await
.context("could not create session")?;
let session = BrowserSession {
data: res.id,
data: id,
user,
created_at: res.created_at,
created_at,
last_authentication: None,
};
Ok(session)
}
#[tracing::instrument(skip_all, fields(user.id = user.data))]
#[tracing::instrument(skip_all, fields(user.id = %user.data))]
pub async fn count_active_sessions(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
@ -247,9 +251,9 @@ pub async fn count_active_sessions(
r#"
SELECT COUNT(*) as "count!"
FROM user_sessions s
WHERE s.user_id = $1 AND s.active
WHERE s.user_id = $1 AND s.finished_at IS NULL
"#,
user.data,
Uuid::from(user.data),
)
.fetch_one(executor)
.await?
@ -273,7 +277,7 @@ pub enum AuthenticationError {
Internal(#[from] tokio::task::JoinError),
}
#[tracing::instrument(skip_all, fields(session.id = session.data, user.id = session.user.data))]
#[tracing::instrument(skip_all, fields(session.id = %session.data, user.id = %session.user.data))]
pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>,
session: &mut BrowserSession<PostgresqlBackend>,
@ -288,7 +292,7 @@ pub async fn authenticate_session(
ORDER BY up.created_at DESC
LIMIT 1
"#,
session.user.data,
Uuid::from(session.user.data),
)
.fetch_one(txn.borrow_mut())
.instrument(tracing::info_span!("Lookup hashed password"))
@ -309,44 +313,50 @@ pub async fn authenticate_session(
.await??;
// That went well, let's insert the auth info
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO user_session_authentications (session_id)
VALUES ($1)
RETURNING id, created_at
INSERT INTO user_session_authentications
(user_session_authentication_id, user_session_id, created_at)
VALUES ($1, $2, $3)
"#,
session.data,
Uuid::from(id),
Uuid::from(session.data),
created_at,
)
.fetch_one(txn.borrow_mut())
.execute(txn.borrow_mut())
.instrument(tracing::info_span!("Save authentication"))
.await
.map_err(AuthenticationError::Save)?;
session.last_authentication = Some(Authentication {
data: res.id,
created_at: res.created_at,
data: id,
created_at,
});
Ok(())
}
#[tracing::instrument(skip(txn, phf, password))]
#[tracing::instrument(skip(txn, phf, password), err)]
pub async fn register_user(
txn: &mut Transaction<'_, Postgres>,
phf: impl PasswordHasher,
username: &str,
password: &str,
) -> anyhow::Result<User<PostgresqlBackend>> {
let id: i64 = sqlx::query_scalar!(
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO users (username)
VALUES ($1)
RETURNING id
INSERT INTO users (user_id, username, created_at)
VALUES ($1, $2, $3)
"#,
Uuid::from(id),
username,
created_at,
)
.fetch_one(txn.borrow_mut())
.execute(txn.borrow_mut())
.instrument(info_span!("Register user"))
.await
.context("could not insert user")?;
@ -354,7 +364,7 @@ pub async fn register_user(
let user = User {
data: id,
username: username.to_owned(),
sub: format!("fake-sub-{}", id),
sub: id.to_string(),
primary_email: None,
};
@ -363,23 +373,28 @@ pub async fn register_user(
Ok(user)
}
#[tracing::instrument(skip_all, fields(user.id = user.data))]
#[tracing::instrument(skip_all, fields(user.id = %user.data))]
pub async fn set_password(
executor: impl PgExecutor<'_>,
phf: impl PasswordHasher,
user: &User<PostgresqlBackend>,
password: &str,
) -> anyhow::Result<()> {
let salt = SaltString::generate(&mut OsRng);
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let salt = SaltString::generate(thread_rng());
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
sqlx::query_scalar!(
r#"
INSERT INTO user_passwords (user_id, hashed_password)
VALUES ($1, $2)
INSERT INTO user_passwords (user_password_id, user_id, hashed_password, created_at)
VALUES ($1, $2, $3, $4)
"#,
user.data,
Uuid::from(id),
Uuid::from(user.data),
hashed_password.to_string(),
created_at,
)
.execute(executor)
.instrument(info_span!("Save user credentials"))
@ -389,14 +404,20 @@ pub async fn set_password(
Ok(())
}
#[tracing::instrument(skip_all, fields(session.id = session.data))]
#[tracing::instrument(skip_all, fields(session.id = %session.data))]
pub async fn end_session(
executor: impl PgExecutor<'_>,
session: &BrowserSession<PostgresqlBackend>,
) -> anyhow::Result<()> {
let now = Utc::now();
let res = sqlx::query!(
"UPDATE user_sessions SET active = FALSE WHERE id = $1",
session.data,
r#"
UPDATE user_sessions
SET finished_at = $1
WHERE user_session_id = $2
"#,
now,
Uuid::from(session.data),
)
.execute(executor)
.instrument(info_span!("End session"))
@ -433,16 +454,16 @@ pub async fn lookup_user_by_username(
UserLookup,
r#"
SELECT
u.id AS user_id,
u.username AS user_username,
ue.id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
u.user_id,
u.username AS user_username,
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM users u
LEFT JOIN user_emails ue
ON ue.id = u.primary_email_id
USING (user_id)
WHERE u.username = $1
"#,
@ -459,7 +480,7 @@ pub async fn lookup_user_by_username(
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id,
data: id.into(),
email,
created_at,
confirmed_at,
@ -468,10 +489,11 @@ pub async fn lookup_user_by_username(
_ => return Err(DatabaseInconsistencyError.into()),
};
let id = Ulid::from(res.user_id);
Ok(User {
data: res.user_id,
data: id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
sub: id.to_string(),
primary_email,
})
}
@ -494,7 +516,7 @@ pub async fn username_exists(
#[derive(Debug, Clone)]
struct UserEmailLookup {
user_email_id: i64,
user_email_id: Uuid,
user_email: String,
user_email_created_at: DateTime<Utc>,
user_email_confirmed_at: Option<DateTime<Utc>>,
@ -503,7 +525,7 @@ struct UserEmailLookup {
impl From<UserEmailLookup> for UserEmail<PostgresqlBackend> {
fn from(e: UserEmailLookup) -> UserEmail<PostgresqlBackend> {
UserEmail {
data: e.user_email_id,
data: e.user_email_id.into(),
email: e.user_email,
created_at: e.user_email_created_at,
confirmed_at: e.user_email_confirmed_at,
@ -511,7 +533,7 @@ impl From<UserEmailLookup> for UserEmail<PostgresqlBackend> {
}
}
#[tracing::instrument(skip_all, fields(user.id = user.data, %user.username))]
#[tracing::instrument(skip_all, fields(user.id = %user.data, %user.username))]
pub async fn get_user_emails(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
@ -520,7 +542,7 @@ pub async fn get_user_emails(
UserEmailLookup,
r#"
SELECT
ue.id AS "user_email_id",
ue.user_email_id,
ue.email AS "user_email",
ue.created_at AS "user_email_created_at",
ue.confirmed_at AS "user_email_confirmed_at"
@ -530,7 +552,7 @@ pub async fn get_user_emails(
ORDER BY ue.email ASC
"#,
user.data,
Uuid::from(user.data),
)
.fetch_all(executor)
.instrument(info_span!("Fetch user emails"))
@ -539,27 +561,27 @@ pub async fn get_user_emails(
Ok(res.into_iter().map(Into::into).collect())
}
#[tracing::instrument(skip_all, fields(user.id = user.data, %user.username, email.id = id))]
#[tracing::instrument(skip_all, fields(user.id = %user.data, %user.username, email.id = %id))]
pub async fn get_user_email(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
id: i64,
id: Ulid,
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT
ue.id AS "user_email_id",
ue.user_email_id,
ue.email AS "user_email",
ue.created_at AS "user_email_created_at",
ue.confirmed_at AS "user_email_confirmed_at"
FROM user_emails ue
WHERE ue.user_id = $1
AND ue.id = $2
AND ue.user_email_id = $2
"#,
user.data,
id,
Uuid::from(user.data),
Uuid::from(id),
)
.fetch_one(executor)
.instrument(info_span!("Fetch user emails"))
@ -568,32 +590,35 @@ pub async fn get_user_email(
Ok(res.into())
}
#[tracing::instrument(skip(executor, user), fields(user.id = user.data, %user.username))]
#[tracing::instrument(skip(executor, user), fields(user.id = %user.data, %user.username))]
pub async fn add_user_email(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
email: &str,
email: String,
) -> anyhow::Result<UserEmail<PostgresqlBackend>> {
let res = sqlx::query_as!(
UserEmailLookup,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
sqlx::query!(
r#"
INSERT INTO user_emails (user_id, email)
VALUES ($1, $2)
RETURNING
id AS user_email_id,
email AS user_email,
created_at AS user_email_created_at,
confirmed_at AS user_email_confirmed_at
INSERT INTO user_emails (user_email_id, user_id, email, created_at)
VALUES ($1, $2, $3, $4)
"#,
user.data,
email,
Uuid::from(id),
Uuid::from(user.data),
&email,
created_at,
)
.fetch_one(executor)
.execute(executor)
.instrument(info_span!("Add user email"))
.await
.context("could not insert user email")?;
Ok(res.into())
Ok(UserEmail {
data: id,
email,
created_at,
confirmed_at: None,
})
}
#[tracing::instrument(skip(executor))]
@ -604,12 +629,12 @@ pub async fn set_user_email_as_primary(
sqlx::query!(
r#"
UPDATE users
SET primary_email_id = user_emails.id
SET primary_user_email_id = user_emails.user_email_id
FROM user_emails
WHERE user_emails.id = $1
AND users.id = user_emails.user_id
WHERE user_emails.user_email_id = $1
AND users.user_id = user_emails.user_id
"#,
email.data,
Uuid::from(email.data),
)
.execute(executor)
.instrument(info_span!("Add user email"))
@ -627,9 +652,9 @@ pub async fn remove_user_email(
sqlx::query!(
r#"
DELETE FROM user_emails
WHERE user_emails.id = $1
WHERE user_emails.user_email_id = $1
"#,
email.data,
Uuid::from(email.data),
)
.execute(executor)
.instrument(info_span!("Remove user email"))
@ -649,7 +674,7 @@ pub async fn lookup_user_email(
UserEmailLookup,
r#"
SELECT
ue.id AS "user_email_id",
ue.user_email_id,
ue.email AS "user_email",
ue.created_at AS "user_email_created_at",
ue.confirmed_at AS "user_email_confirmed_at"
@ -658,7 +683,7 @@ pub async fn lookup_user_email(
WHERE ue.user_id = $1
AND ue.email = $2
"#,
user.data,
Uuid::from(user.data),
email,
)
.fetch_one(executor)
@ -673,23 +698,23 @@ pub async fn lookup_user_email(
pub async fn lookup_user_email_by_id(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
id: i64,
id: Ulid,
) -> anyhow::Result<UserEmail<PostgresqlBackend>> {
let res = sqlx::query_as!(
UserEmailLookup,
r#"
SELECT
ue.id AS "user_email_id",
ue.user_email_id,
ue.email AS "user_email",
ue.created_at AS "user_email_created_at",
ue.confirmed_at AS "user_email_confirmed_at"
FROM user_emails ue
WHERE ue.user_id = $1
AND ue.id = $2
AND ue.user_email_id = $2
"#,
user.data,
id,
Uuid::from(user.data),
Uuid::from(id),
)
.fetch_one(executor)
.instrument(info_span!("Lookup user email"))
@ -704,31 +729,32 @@ pub async fn mark_user_email_as_verified(
executor: impl PgExecutor<'_>,
mut email: UserEmail<PostgresqlBackend>,
) -> anyhow::Result<UserEmail<PostgresqlBackend>> {
let confirmed_at = sqlx::query_scalar!(
let confirmed_at = Utc::now();
sqlx::query!(
r#"
UPDATE user_emails
SET confirmed_at = NOW()
WHERE id = $1
RETURNING confirmed_at
SET confirmed_at = $2
WHERE user_email_id = $1
"#,
email.data,
Uuid::from(email.data),
confirmed_at,
)
.fetch_one(executor)
.execute(executor)
.instrument(info_span!("Confirm user email"))
.await
.context("could not update user email")?;
email.confirmed_at = confirmed_at;
email.confirmed_at = Some(confirmed_at);
Ok(email)
}
struct UserEmailVerificationLookup {
verification_id: i64,
verification_code: String,
verification_expired: bool,
verification_created_at: DateTime<Utc>,
verification_consumed_at: Option<DateTime<Utc>>,
struct UserEmailConfirmationCodeLookup {
user_email_confirmation_code_id: Uuid,
code: String,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
consumed_at: Option<DateTime<Utc>>,
}
#[tracing::instrument(skip(executor))]
@ -736,49 +762,46 @@ pub async fn lookup_user_email_verification_code(
executor: impl PgExecutor<'_>,
email: UserEmail<PostgresqlBackend>,
code: &str,
max_age: chrono::Duration,
) -> anyhow::Result<UserEmailVerification<PostgresqlBackend>> {
// For some reason, we need to convert the type first
let max_age = PgInterval::try_from(max_age)
// For some reason, this error type does not let me to just bubble up the error here
.map_err(|e| anyhow::anyhow!("failed to encode duration: {}", e))?;
let now = Utc::now();
let res = sqlx::query_as!(
UserEmailVerificationLookup,
UserEmailConfirmationCodeLookup,
r#"
SELECT
ev.id AS "verification_id",
ev.code AS "verification_code",
(ev.created_at + $3 < NOW()) AS "verification_expired!",
ev.created_at AS "verification_created_at",
ev.consumed_at AS "verification_consumed_at"
FROM user_email_verifications ev
WHERE ev.code = $1
AND ev.user_email_id = $2
ec.user_email_confirmation_code_id,
ec.code,
ec.created_at,
ec.expires_at,
ec.consumed_at
FROM user_email_confirmation_codes ec
WHERE ec.code = $1
AND ec.user_email_id = $2
"#,
code,
email.data,
max_age,
Uuid::from(email.data),
)
.fetch_one(executor)
.instrument(info_span!("Lookup user email verification"))
.await
.context("could not lookup user email verification")?;
let state = if res.verification_expired {
UserEmailVerificationState::Expired
} else if let Some(when) = res.verification_consumed_at {
let state = if let Some(when) = res.consumed_at {
UserEmailVerificationState::AlreadyUsed { when }
} else if res.expires_at < now {
UserEmailVerificationState::Expired {
when: res.expires_at,
}
} else {
UserEmailVerificationState::Valid
};
Ok(UserEmailVerification {
data: res.verification_id,
code: res.verification_code,
data: res.user_email_confirmation_code_id.into(),
code: res.code,
email,
state,
created_at: res.verification_created_at,
created_at: res.created_at,
})
}
@ -791,16 +814,18 @@ pub async fn consume_email_verification(
bail!("user email verification in wrong state");
}
let consumed_at = sqlx::query_scalar!(
let consumed_at = Utc::now();
sqlx::query!(
r#"
UPDATE user_email_verifications
SET consumed_at = NOW()
WHERE id = $1
RETURNING consumed_at AS "consumed_at!"
UPDATE user_email_confirmation_codes
SET consumed_at = $2
WHERE user_email_confirmation_code_id = $1
"#,
verification.data,
Uuid::from(verification.data),
consumed_at
)
.fetch_one(executor)
.execute(executor)
.instrument(info_span!("Consume user email verification"))
.await
.context("could not update user email verification")?;
@ -810,32 +835,39 @@ pub async fn consume_email_verification(
Ok(verification)
}
#[tracing::instrument(skip(executor, email), fields(email.id = email.data, %email.email))]
#[tracing::instrument(skip(executor, email), fields(email.id = %email.data, %email.email))]
pub async fn add_user_email_verification_code(
executor: impl PgExecutor<'_>,
email: UserEmail<PostgresqlBackend>,
max_age: chrono::Duration,
code: String,
) -> anyhow::Result<UserEmailVerification<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let expires_at = created_at + max_age;
sqlx::query!(
r#"
INSERT INTO user_email_verifications (user_email_id, code)
VALUES ($1, $2)
RETURNING id, created_at
INSERT INTO user_email_confirmation_codes
(user_email_confirmation_code_id, user_email_id, code, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5)
"#,
email.data,
Uuid::from(id),
Uuid::from(email.data),
code,
created_at,
expires_at,
)
.fetch_one(executor)
.execute(executor)
.instrument(info_span!("Add user email verification code"))
.await
.context("could not insert user email verification code")?;
let verification = UserEmailVerification {
data: res.id,
data: id,
email,
code,
created_at: res.created_at,
created_at,
state: UserEmailVerificationState::Valid,
};