1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Pass the rng and clock around

This commit is contained in:
Quentin Gliech
2022-10-21 18:50:06 +02:00
parent 5c7e66a9b2
commit 559181c2c3
40 changed files with 504 additions and 218 deletions

8
Cargo.lock generated
View File

@ -2468,6 +2468,8 @@ dependencies = [
"opentelemetry-semantic-conventions", "opentelemetry-semantic-conventions",
"opentelemetry-zipkin", "opentelemetry-zipkin",
"prometheus", "prometheus",
"rand",
"rand_chacha",
"rustls", "rustls",
"serde_json", "serde_json",
"serde_yaml", "serde_yaml",
@ -2497,6 +2499,7 @@ dependencies = [
"mas-keystore", "mas-keystore",
"pem-rfc7468", "pem-rfc7468",
"rand", "rand",
"rand_chacha",
"rustls-pemfile", "rustls-pemfile",
"schemars", "schemars",
"serde", "serde",
@ -2567,6 +2570,7 @@ dependencies = [
"mime", "mime",
"oauth2-types", "oauth2-types",
"rand", "rand",
"rand_chacha",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
@ -2779,6 +2783,7 @@ dependencies = [
"oauth2-types", "oauth2-types",
"password-hash", "password-hash",
"rand", "rand",
"rand_chacha",
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
@ -3068,7 +3073,7 @@ checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
[[package]] [[package]]
name = "opa-wasm" name = "opa-wasm"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/matrix-org/rust-opa-wasm.git#325071ee8a2a7d18cc611365edd3235945a8cdf4" source = "git+https://github.com/matrix-org/rust-opa-wasm.git#f838595670747b0644b6bfd9829fca5d63bbee66"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64", "base64",
@ -3079,6 +3084,7 @@ dependencies = [
"md-5", "md-5",
"parse-size", "parse-size",
"rand", "rand",
"rayon-core",
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -1,2 +1,11 @@
msrv = "1.61.0" msrv = "1.61.0"
doc-valid-idents = ["OpenID", "OAuth", ".."] doc-valid-idents = ["OpenID", "OAuth", ".."]
disallowed-methods = [
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },
{ path = "chrono::Utc::now", reason = "source the current time from the clock instead" },
]
disallowed-types = [
"rand::OsRng",
]

View File

@ -15,6 +15,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD}; use data_encoding::{DecodeError, BASE64URL_NOPAD};
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds}; use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error; use thiserror::Error;
@ -56,20 +57,20 @@ pub struct CsrfToken {
impl CsrfToken { impl CsrfToken {
/// Create a new token from a defined value valid for a specified duration /// Create a new token from a defined value valid for a specified duration
fn new(token: [u8; 32], ttl: Duration) -> Self { fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
let expiration = Utc::now() + ttl; let expiration = now + ttl;
Self { expiration, token } Self { expiration, token }
} }
/// Generate a new random token valid for a specified duration /// Generate a new random token valid for a specified duration
fn generate(ttl: Duration) -> Self { fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
let token = rand::random(); let token = rng.gen();
Self::new(token, ttl) Self::new(token, now, ttl)
} }
/// Generate a new token with the same value but an up to date expiration /// Generate a new token with the same value but an up to date expiration
fn refresh(self, ttl: Duration) -> Self { fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
Self::new(self.token, ttl) Self::new(self.token, now, ttl)
} }
/// Get the value to include in HTML forms /// Get the value to include in HTML forms
@ -88,8 +89,8 @@ impl CsrfToken {
} }
} }
fn verify_expiration(self) -> Result<Self, CsrfError> { fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
if Utc::now() < self.expiration { if now < self.expiration {
Ok(self) Ok(self)
} else { } else {
Err(CsrfError::Expired) Err(CsrfError::Expired)
@ -118,12 +119,18 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
cookie.set_path("/"); cookie.set_path("/");
cookie.set_http_only(true); cookie.set_http_only(true);
// XXX: the rng source and clock should come from somewhere else
#[allow(clippy::disallowed_methods)]
let now = Utc::now();
#[allow(clippy::disallowed_methods)]
let rng = thread_rng();
let new_token = cookie let new_token = cookie
.decode() .decode()
.ok() .ok()
.and_then(|token: CsrfToken| token.verify_expiration().ok()) .and_then(|token: CsrfToken| token.verify_expiration(now).ok())
.unwrap_or_else(|| CsrfToken::generate(Duration::hours(1))) .unwrap_or_else(|| CsrfToken::generate(now, rng, Duration::hours(1)))
.refresh(Duration::hours(1)); .refresh(now, Duration::hours(1));
let cookie = cookie.encode(&new_token); let cookie = cookie.encode(&new_token);
let jar = jar.add(cookie); let jar = jar.add(cookie);
@ -131,9 +138,13 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
} }
fn verify_form<T>(&self, form: ProtectedForm<T>) -> Result<T, CsrfError> { fn verify_form<T>(&self, form: ProtectedForm<T>) -> Result<T, CsrfError> {
// XXX: the clock should come from somewhere else
#[allow(clippy::disallowed_methods)]
let now = Utc::now();
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?; let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
let token: CsrfToken = cookie.decode()?; let token: CsrfToken = cookie.decode()?;
let token = token.verify_expiration()?; let token = token.verify_expiration(now)?;
token.verify_form_value(&form.csrf)?; token.verify_form_value(&form.csrf)?;
Ok(form.inner) Ok(form.inner)
} }

View File

@ -6,23 +6,25 @@ edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
[dependencies] [dependencies]
axum = "0.6.0-rc.2"
tokio = { version = "1.21.2", features = ["full"] }
futures-util = "0.3.25"
anyhow = "1.0.66" anyhow = "1.0.66"
argon2 = { version = "0.4.1", features = ["password-hash"] }
atty = "0.2.14"
axum = "0.6.0-rc.2"
clap = { version = "4.0.18", features = ["derive"] } clap = { version = "4.0.18", features = ["derive"] }
dotenv = "0.15.0" dotenv = "0.15.0"
tower = { version = "0.4.13", features = ["full"] } futures-util = "0.3.25"
hyper = { version = "0.14.22", features = ["full"] } hyper = { version = "0.14.22", features = ["full"] }
serde_yaml = "0.9.14"
serde_json = "1.0.87"
url = "2.3.1"
argon2 = { version = "0.4.1", features = ["password-hash"] }
watchman_client = "0.8.0"
atty = "0.2.14"
listenfd = "1.0.0"
rustls = "0.20.7"
itertools = "0.10.5" itertools = "0.10.5"
listenfd = "1.0.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
rustls = "0.20.7"
serde_json = "1.0.87"
serde_yaml = "0.9.14"
tokio = { version = "1.21.2", features = ["full"] }
tower = { version = "0.4.13", features = ["full"] }
url = "2.3.1"
watchman_client = "0.8.0"
tracing = "0.1.37" tracing = "0.1.37"
tracing-appender = "0.2.2" tracing-appender = "0.2.2"

View File

@ -20,7 +20,9 @@ use mas_storage::{
user::{ user::{
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
}, },
Clock,
}; };
use rand::SeedableRng;
use tracing::{info, warn}; use tracing::{info, warn};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -51,14 +53,17 @@ enum Subcommand {
impl Options { impl Options {
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC; use Subcommand as SC;
let clock = Clock::default();
match &self.subcommand { match &self.subcommand {
SC::Register { username, password } => { SC::Register { username, password } => {
let config: DatabaseConfig = root.load_config()?; let config: DatabaseConfig = root.load_config()?;
let pool = config.connect().await?; let pool = config.connect().await?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let hasher = Argon2::default(); let hasher = Argon2::default();
let rng = rand_chacha::ChaChaRng::from_entropy();
let user = register_user(&mut txn, hasher, username, password).await?; let user = register_user(&mut txn, rng, &clock, hasher, username, password).await?;
txn.commit().await?; txn.commit().await?;
info!(?user, "User registered"); info!(?user, "User registered");
@ -76,7 +81,7 @@ impl Options {
let user = lookup_user_by_username(&mut txn, username).await?; let user = lookup_user_by_username(&mut txn, username).await?;
let email = lookup_user_email(&mut txn, &user, email).await?; let email = lookup_user_email(&mut txn, &user, email).await?;
let email = mark_user_email_as_verified(&mut txn, email).await?; let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
txn.commit().await?; txn.commit().await?;
info!(?email, "Email marked as verified"); info!(?email, "Email marked as verified");

View File

@ -28,6 +28,7 @@ lettre = { version = "0.10.1", default-features = false, features = ["serde", "b
pem-rfc7468 = "0.6.0" pem-rfc7468 = "0.6.0"
rustls-pemfile = "1.0.1" rustls-pemfile = "1.0.1"
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1"
indoc = "1.0.7" indoc = "1.0.7"

View File

@ -20,7 +20,7 @@ use mas_jose::jwk::{JsonWebKey, JsonWebKeySet};
use mas_keystore::{Encrypter, Keystore, PrivateKey}; use mas_keystore::{Encrypter, Keystore, PrivateKey};
use rand::{ use rand::{
distributions::{Alphanumeric, DistString}, distributions::{Alphanumeric, DistString},
thread_rng, thread_rng, SeedableRng,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -139,64 +139,72 @@ impl ConfigurationSection<'_> for SecretsConfig {
#[tracing::instrument] #[tracing::instrument]
async fn generate() -> anyhow::Result<Self> { async fn generate() -> anyhow::Result<Self> {
// XXX: that RNG should come from somewhere else
#[allow(clippy::disallowed_methods)]
let mut rng = rand_chacha::ChaChaRng::from_rng(thread_rng())?;
info!("Generating keys..."); info!("Generating keys...");
let span = tracing::info_span!("rsa"); let span = tracing::info_span!("rsa");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let rsa_key = task::spawn_blocking(move || { let rsa_key = task::spawn_blocking(move || {
let _entered = span.enter(); let _entered = span.enter();
let ret = PrivateKey::generate_rsa(thread_rng()).unwrap(); let ret = PrivateKey::generate_rsa(key_rng).unwrap();
info!("Done generating RSA key"); info!("Done generating RSA key");
ret ret
}) })
.await .await
.context("could not join blocking task")?; .context("could not join blocking task")?;
let rsa_key = KeyConfig { let rsa_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10), kid: Alphanumeric.sample_string(&mut rng, 10),
password: None, password: None,
key: KeyOrFile::Key(rsa_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), key: KeyOrFile::Key(rsa_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
}; };
let span = tracing::info_span!("ec_p256"); let span = tracing::info_span!("ec_p256");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_p256_key = task::spawn_blocking(move || { let ec_p256_key = task::spawn_blocking(move || {
let _entered = span.enter(); let _entered = span.enter();
let ret = PrivateKey::generate_ec_p256(thread_rng()); let ret = PrivateKey::generate_ec_p256(key_rng);
info!("Done generating EC P-256 key"); info!("Done generating EC P-256 key");
ret ret
}) })
.await .await
.context("could not join blocking task")?; .context("could not join blocking task")?;
let ec_p256_key = KeyConfig { let ec_p256_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10), kid: Alphanumeric.sample_string(&mut rng, 10),
password: None, password: None,
key: KeyOrFile::Key(ec_p256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), key: KeyOrFile::Key(ec_p256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
}; };
let span = tracing::info_span!("ec_p384"); let span = tracing::info_span!("ec_p384");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_p384_key = task::spawn_blocking(move || { let ec_p384_key = task::spawn_blocking(move || {
let _entered = span.enter(); let _entered = span.enter();
let ret = PrivateKey::generate_ec_p384(thread_rng()); let ret = PrivateKey::generate_ec_p384(key_rng);
info!("Done generating EC P-256 key"); info!("Done generating EC P-256 key");
ret ret
}) })
.await .await
.context("could not join blocking task")?; .context("could not join blocking task")?;
let ec_p384_key = KeyConfig { let ec_p384_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10), kid: Alphanumeric.sample_string(&mut rng, 10),
password: None, password: None,
key: KeyOrFile::Key(ec_p384_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), key: KeyOrFile::Key(ec_p384_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
}; };
let span = tracing::info_span!("ec_k256"); let span = tracing::info_span!("ec_k256");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_k256_key = task::spawn_blocking(move || { let ec_k256_key = task::spawn_blocking(move || {
let _entered = span.enter(); let _entered = span.enter();
let ret = PrivateKey::generate_ec_k256(thread_rng()); let ret = PrivateKey::generate_ec_k256(key_rng);
info!("Done generating EC secp256k1 key"); info!("Done generating EC secp256k1 key");
ret ret
}) })
.await .await
.context("could not join blocking task")?; .context("could not join blocking task")?;
let ec_k256_key = KeyConfig { let ec_k256_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10), kid: Alphanumeric.sample_string(&mut rng, 10),
password: None, password: None,
key: KeyOrFile::Key(ec_k256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()), key: KeyOrFile::Key(ec_k256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
}; };

View File

@ -263,6 +263,8 @@ mod tests {
#[test] #[test]
fn test_generate_and_check() { fn test_generate_and_check() {
const COUNT: usize = 500; // Generate 500 of each token type const COUNT: usize = 500; // Generate 500 of each token type
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng(); let mut rng = thread_rng();
for t in [ for t in [

View File

@ -44,6 +44,7 @@ chrono = { version = "0.4.22", features = ["serde"] }
url = { version = "2.3.1", features = ["serde"] } url = { version = "2.3.1", features = ["serde"] }
mime = "0.3.16" mime = "0.3.16"
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1"
headers = "0.3.8" headers = "0.3.8"
ulid = "1.0.0" ulid = "1.0.0"

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use chrono::{Duration, Utc}; use chrono::Duration;
use hyper::StatusCode; use hyper::StatusCode;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType}; use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
use mas_storage::{ use mas_storage::{
@ -22,9 +22,8 @@ use mas_storage::{
get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged, get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged,
CompatSsoLoginLookupError, CompatSsoLoginLookupError,
}, },
PostgresqlBackend, Clock, PostgresqlBackend,
}; };
use rand::thread_rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds}; use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::{PgPool, Postgres, Transaction};
@ -201,6 +200,7 @@ pub(crate) async fn post(
State(homeserver): State<MatrixHomeserver>, State(homeserver): State<MatrixHomeserver>,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let session = match input.credentials { let session = match input.credentials {
Credentials::Password { Credentials::Password {
@ -208,7 +208,7 @@ pub(crate) async fn post(
password, password,
} => user_password_login(&mut txn, user, password).await?, } => user_password_login(&mut txn, user, password).await?,
Credentials::Token { token } => token_login(&mut txn, &token).await?, Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?,
_ => { _ => {
return Err(RouteError::Unsupported); return Err(RouteError::Unsupported);
@ -225,14 +225,28 @@ pub(crate) async fn post(
None None
}; };
let access_token = TokenType::CompatAccessToken.generate(&mut thread_rng()); let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = let access_token = add_compat_access_token(
add_compat_access_token(&mut txn, &session, access_token, expires_in).await?; &mut txn,
&mut rng,
&clock,
&session,
access_token,
expires_in,
)
.await?;
let refresh_token = if input.refresh_token { let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut thread_rng()); let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = let refresh_token = add_compat_refresh_token(
add_compat_refresh_token(&mut txn, &session, &access_token, refresh_token).await?; &mut txn,
&mut rng,
&clock,
&session,
&access_token,
refresh_token,
)
.await?;
Some(refresh_token.token) Some(refresh_token.token)
} else { } else {
None None
@ -251,11 +265,12 @@ pub(crate) async fn post(
async fn token_login( async fn token_login(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
clock: &Clock,
token: &str, token: &str,
) -> Result<CompatSession<PostgresqlBackend>, RouteError> { ) -> Result<CompatSession<PostgresqlBackend>, RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token).await?; let login = get_compat_sso_login_by_token(&mut *txn, token).await?;
let now = Utc::now(); let now = clock.now();
match login.state { match login.state {
CompatSsoLoginState::Pending => { CompatSsoLoginState::Pending => {
tracing::error!( tracing::error!(
@ -285,7 +300,7 @@ async fn token_login(
} }
} }
let login = mark_compat_sso_login_as_exchanged(&mut *txn, login).await?; let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
match login.state { match login.state {
CompatSsoLoginState::Exchanged { session, .. } => Ok(session), CompatSsoLoginState::Exchanged { session, .. } => Ok(session),
@ -298,8 +313,10 @@ async fn user_password_login(
username: String, username: String,
password: String, password: String,
) -> Result<CompatSession<PostgresqlBackend>, RouteError> { ) -> Result<CompatSession<PostgresqlBackend>, RouteError> {
let device = Device::generate(&mut thread_rng()); let (clock, mut rng) = crate::rng_and_clock()?;
let session = compat_login(txn, &username, &password, device)
let device = Device::generate(&mut rng);
let session = compat_login(txn, &mut rng, &clock, &username, &password, device)
.await .await
.map_err(|_| RouteError::LoginFailed)?; .map_err(|_| RouteError::LoginFailed)?;

View File

@ -20,7 +20,7 @@ use axum::{
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use chrono::{Duration, Utc}; use chrono::Duration;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
FancyError, SessionInfoExt, FancyError, SessionInfoExt,
@ -28,9 +28,11 @@ use mas_axum_utils::{
use mas_data_model::Device; use mas_data_model::Device;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route}; use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; use mas_storage::{
compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id},
Clock,
};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use rand::thread_rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::PgPool;
use ulid::Ulid; use ulid::Ulid;
@ -56,6 +58,7 @@ pub async fn get(
Path(id): Path<Ulid>, Path(id): Path<Ulid>,
Query(params): Query<Params>, Query(params): Query<Params>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
@ -95,7 +98,7 @@ pub async fn get(
let login = get_compat_sso_login_by_id(&mut conn, id).await?; let login = get_compat_sso_login_by_id(&mut conn, id).await?;
// Bail out if that login session is more than 30min old // Bail out if that login session is more than 30min old
if Utc::now() > login.created_at + Duration::minutes(30) { if clock.now() > login.created_at + Duration::minutes(30) {
let ctx = ErrorContext::new() let ctx = ErrorContext::new()
.with_code("compat_sso_login_expired") .with_code("compat_sso_login_expired")
.with_description("This login session expired.".to_owned()); .with_description("This login session expired.".to_owned());
@ -121,6 +124,7 @@ pub async fn post(
Query(params): Query<Params>, Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
@ -160,7 +164,7 @@ pub async fn post(
let login = get_compat_sso_login_by_id(&mut txn, id).await?; let login = get_compat_sso_login_by_id(&mut txn, id).await?;
// Bail out if that login session is more than 30min old // Bail out if that login session is more than 30min old
if Utc::now() > login.created_at + Duration::minutes(30) { if clock.now() > login.created_at + Duration::minutes(30) {
let ctx = ErrorContext::new() let ctx = ErrorContext::new()
.with_code("compat_sso_login_expired") .with_code("compat_sso_login_expired")
.with_description("This login session expired.".to_owned()); .with_description("This login session expired.".to_owned());
@ -186,8 +190,9 @@ pub async fn post(
redirect_uri redirect_uri
}; };
let device = Device::generate(&mut thread_rng()); let device = Device::generate(&mut rng);
let _login = fullfill_compat_sso_login(&mut txn, session.user, login, device).await?; let _login =
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -20,10 +20,7 @@ use axum::{
use hyper::StatusCode; use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::compat::insert_compat_sso_login; use mas_storage::compat::insert_compat_sso_login;
use rand::{ use rand::distributions::{Alphanumeric, DistString};
distributions::{Alphanumeric, DistString},
thread_rng,
};
use serde::Deserialize; use serde::Deserialize;
use serde_with::serde; use serde_with::serde;
use sqlx::PgPool; use sqlx::PgPool;
@ -70,6 +67,8 @@ pub async fn get(
State(url_builder): State<UrlBuilder>, State(url_builder): State<UrlBuilder>,
Query(params): Query<Params>, Query(params): Query<Params>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
// Check the redirectUrl parameter // Check the redirectUrl parameter
let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?; let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?;
let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?; let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?;
@ -84,9 +83,9 @@ pub async fn get(
return Err(RouteError::InvalidRedirectUrl); return Err(RouteError::InvalidRedirectUrl);
} }
let token = Alphanumeric.sample_string(&mut thread_rng(), 32); let token = Alphanumeric.sample_string(&mut rng, 32);
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?; let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action))) Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action)))
} }

View File

@ -16,7 +16,7 @@ use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
use headers::{authorization::Bearer, Authorization}; use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode; use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType}; use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::compat::compat_logout; use mas_storage::{compat::compat_logout, Clock};
use sqlx::PgPool; use sqlx::PgPool;
use super::MatrixError; use super::MatrixError;
@ -67,6 +67,7 @@ pub(crate) async fn post(
State(pool): State<PgPool>, State(pool): State<PgPool>,
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>, maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
@ -78,7 +79,7 @@ pub(crate) async fn post(
return Err(RouteError::InvalidAuthorization); return Err(RouteError::InvalidAuthorization);
} }
compat_logout(&mut conn, token) compat_logout(&mut conn, &clock, token)
.await .await
.map_err(|_| RouteError::LogoutFailed)?; .map_err(|_| RouteError::LogoutFailed)?;

View File

@ -20,7 +20,6 @@ use mas_storage::compat::{
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token, add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
expire_compat_access_token, lookup_active_compat_refresh_token, CompatRefreshTokenLookupError, expire_compat_access_token, lookup_active_compat_refresh_token, CompatRefreshTokenLookupError,
}; };
use rand::thread_rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds}; use serde_with::{serde_as, DurationMilliSeconds};
use sqlx::PgPool; use sqlx::PgPool;
@ -98,6 +97,7 @@ pub(crate) async fn post(
State(pool): State<PgPool>, State(pool): State<PgPool>,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let token_type = TokenType::check(&input.refresh_token)?; let token_type = TokenType::check(&input.refresh_token)?;
@ -109,23 +109,31 @@ pub(crate) async fn post(
let (refresh_token, access_token, session) = let (refresh_token, access_token, session) =
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?; lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?;
let (new_refresh_token_str, new_access_token_str) = { let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
let mut rng = thread_rng(); let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
(
TokenType::CompatRefreshToken.generate(&mut rng),
TokenType::CompatAccessToken.generate(&mut rng),
)
};
let expires_in = Duration::minutes(5); let expires_in = Duration::minutes(5);
let new_access_token = let new_access_token = add_compat_access_token(
add_compat_access_token(&mut txn, &session, new_access_token_str, Some(expires_in)).await?; &mut txn,
let new_refresh_token = &mut rng,
add_compat_refresh_token(&mut txn, &session, &new_access_token, new_refresh_token_str) &clock,
.await?; &session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
consume_compat_refresh_token(&mut txn, refresh_token).await?; consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
expire_compat_access_token(&mut txn, access_token).await?; expire_compat_access_token(&mut txn, &clock, access_token).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -21,6 +21,7 @@
use std::{convert::Infallible, sync::Arc, time::Duration}; use std::{convert::Infallible, sync::Arc, time::Duration};
use anyhow::Context;
use axum::{ use axum::{
body::HttpBody, body::HttpBody,
extract::FromRef, extract::FromRef,
@ -36,6 +37,7 @@ use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder}; use mas_router::{Route, UrlBuilder};
use mas_templates::{ErrorContext, Templates}; use mas_templates::{ErrorContext, Templates};
use rand::SeedableRng;
use sqlx::PgPool; use sqlx::PgPool;
use tower::util::AndThenLayer; use tower::util::AndThenLayer;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
@ -356,3 +358,15 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
policy_factory, policy_factory,
})) }))
} }
// XXX: that should be moved somewhere else
fn rng_and_clock() -> Result<(mas_storage::Clock, rand_chacha::ChaChaRng), anyhow::Error> {
let clock = mas_storage::Clock::default();
// This rng is used to source the local rng
#[allow(clippy::disallowed_methods)]
let rng = rand::thread_rng();
let rng = rand_chacha::ChaChaRng::from_rng(rng).context("Failed to seed RNG")?;
Ok((clock, rng))
}

View File

@ -190,6 +190,8 @@ pub(crate) async fn complete(
policy_factory: &PolicyFactory, policy_factory: &PolicyFactory,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> { ) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
let (clock, mut rng) = crate::rng_and_clock()?;
// Verify that the grant is in a pending stage // Verify that the grant is in a pending stage
if !grant.stage.is_pending() { if !grant.stage.is_pending() {
return Err(GrantCompletionError::NotPending); return Err(GrantCompletionError::NotPending);
@ -226,7 +228,7 @@ pub(crate) async fn complete(
} }
// All good, let's start the session // All good, let's start the session
let session = derive_session(&mut txn, &grant, browser_session).await?; let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?;
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?; let grant = fulfill_grant(&mut txn, grant, session.clone()).await?;

View File

@ -37,7 +37,7 @@ use oauth2_types::{
requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode}, requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
response_type::ResponseType, response_type::ResponseType,
}; };
use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rand::{distributions::Alphanumeric, Rng};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
@ -159,6 +159,7 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(params): Form<Params>, Form(params): Form<Params>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
// First, figure out what client it is // First, figure out what client it is
@ -265,7 +266,7 @@ pub(crate) async fn get(
} }
// 32 random alphanumeric characters, about 190bit of entropy // 32 random alphanumeric characters, about 190bit of entropy
let code: String = thread_rng() let code: String = (&mut rng)
.sample_iter(&Alphanumeric) .sample_iter(&Alphanumeric)
.take(32) .take(32)
.map(char::from) .map(char::from)
@ -296,6 +297,8 @@ pub(crate) async fn get(
let grant = new_authorization_grant( let grant = new_authorization_grant(
&mut txn, &mut txn,
&mut rng,
&clock,
client, client,
redirect_uri.clone(), redirect_uri.clone(),
params.auth.scope, params.auth.scope,

View File

@ -119,6 +119,7 @@ pub(crate) async fn post(
Path(grant_id): Path<Ulid>, Path(grant_id): Path<Ulid>,
Form(form): Form<ProtectedForm<()>>, Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, RouteError> { ) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool let mut txn = pool
.begin() .begin()
.await .await
@ -163,6 +164,8 @@ pub(crate) async fn post(
.collect(); .collect();
insert_client_consent( insert_client_consent(
&mut txn, &mut txn,
&mut rng,
&clock,
&session.user, &session.user,
&grant.client, &grant.client,
&scope_without_device, &scope_without_device,

View File

@ -28,6 +28,7 @@ use mas_storage::{
client::ClientFetchError, client::ClientFetchError,
refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError}, refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError},
}, },
Clock,
}; };
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
use sqlx::PgPool; use sqlx::PgPool;
@ -158,6 +159,7 @@ pub(crate) async fn post(
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<IntrospectionRequest>, client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let client = client_authorization.credentials.fetch(&mut conn).await?; let client = client_authorization.credentials.fetch(&mut conn).await?;
@ -227,7 +229,8 @@ pub(crate) async fn post(
} }
} }
TokenType::CompatAccessToken => { TokenType::CompatAccessToken => {
let (token, session) = lookup_active_compat_access_token(&mut conn, token).await?; let (token, session) =
lookup_active_compat_access_token(&mut conn, &clock, token).await?;
let device_scope = session.device.to_scope_token(); let device_scope = session.device.to_scope_token();
let scope = [device_scope].into_iter().collect(); let scope = [device_scope].into_iter().collect();

View File

@ -50,7 +50,6 @@ use oauth2_types::{
}, },
scope, scope,
}; };
use rand::thread_rng;
use serde::Serialize; use serde::Serialize;
use serde_with::{serde_as, skip_serializing_none}; use serde_with::{serde_as, skip_serializing_none};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::{PgPool, Postgres, Transaction};
@ -235,12 +234,13 @@ async fn authorization_code_grant(
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
// TODO: there is a bunch of unnecessary cloning here // TODO: there is a bunch of unnecessary cloning here
// TODO: handle "not found" cases // TODO: handle "not found" cases
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code).await?; let authz_grant = lookup_grant_by_code(&mut txn, &grant.code).await?;
// TODO: that's not a timestamp from the DB. Let's assume they are in sync let now = clock.now();
let now = Utc::now();
let session = match authz_grant.stage { let session = match authz_grant.stage {
AuthorizationGrantStage::Cancelled { cancelled_at } => { AuthorizationGrantStage::Cancelled { cancelled_at } => {
@ -257,7 +257,7 @@ async fn authorization_code_grant(
// Ending the session if the token was already exchanged more than 20s ago // Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::seconds(20) { if now - exchanged_at > Duration::seconds(20) {
debug!("Ending potentially compromised session"); debug!("Ending potentially compromised session");
end_oauth_session(&mut txn, session).await?; end_oauth_session(&mut txn, &clock, session).await?;
txn.commit().await?; txn.commit().await?;
} }
@ -303,22 +303,32 @@ async fn authorization_code_grant(
let browser_session = &session.browser_session; let browser_session = &session.browser_session;
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let (access_token_str, refresh_token_str) = { let access_token_str = TokenType::AccessToken.generate(&mut rng);
let mut rng = thread_rng(); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
(
TokenType::AccessToken.generate(&mut rng),
TokenType::RefreshToken.generate(&mut rng),
)
};
let access_token = add_access_token(&mut txn, session, access_token_str.clone(), ttl).await?; let access_token = add_access_token(
&mut txn,
&mut rng,
&clock,
session,
access_token_str.clone(),
ttl,
)
.await?;
let _refresh_token = let _refresh_token = add_refresh_token(
add_refresh_token(&mut txn, session, access_token, refresh_token_str.clone()).await?; &mut txn,
&mut rng,
&clock,
session,
access_token,
refresh_token_str.clone(),
)
.await?;
let id_token = if session.scope.contains(&scope::OPENID) { let id_token = if session.scope.contains(&scope::OPENID) {
let mut claims = HashMap::new(); let mut claims = HashMap::new();
let now = Utc::now(); let now = clock.now();
claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?; claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?;
claims::SUB.insert(&mut claims, &browser_session.user.sub)?; claims::SUB.insert(&mut claims, &browser_session.user.sub)?;
claims::AUD.insert(&mut claims, client.client_id.clone())?; claims::AUD.insert(&mut claims, client.client_id.clone())?;
@ -346,7 +356,7 @@ async fn authorization_code_grant(
let signer = key.params().signing_key_for_alg(&alg)?; let signer = key.params().signing_key_for_alg(&alg)?;
let header = JsonWebSignatureHeader::new(alg) let header = JsonWebSignatureHeader::new(alg)
.with_kid(key.kid().context("key has no `kid` for some reason")?); .with_kid(key.kid().context("key has no `kid` for some reason")?);
let id_token = Jwt::sign(header, claims, &signer)?; let id_token = Jwt::sign_with_rng(&mut rng, header, claims, &signer)?;
Some(id_token.as_str().to_owned()) Some(id_token.as_str().to_owned())
} else { } else {
@ -362,7 +372,7 @@ async fn authorization_code_grant(
params = params.with_id_token(id_token); params = params.with_id_token(id_token);
} }
exchange_grant(&mut txn, authz_grant).await?; exchange_grant(&mut txn, &clock, authz_grant).await?;
txn.commit().await?; txn.commit().await?;
@ -374,6 +384,8 @@ async fn refresh_token_grant(
client: &Client<PostgresqlBackend>, client: &Client<PostgresqlBackend>,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> { ) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let (refresh_token, session) = let (refresh_token, session) =
lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?; lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?;
@ -383,24 +395,33 @@ async fn refresh_token_grant(
} }
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let (access_token_str, refresh_token_str) = { let access_token_str = TokenType::AccessToken.generate(&mut rng);
let mut rng = thread_rng(); let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
(
TokenType::AccessToken.generate(&mut rng),
TokenType::RefreshToken.generate(&mut rng),
)
};
let new_access_token = let new_access_token = add_access_token(
add_access_token(&mut txn, &session, access_token_str.clone(), ttl).await?; &mut txn,
&mut rng,
&clock,
&session,
access_token_str.clone(),
ttl,
)
.await?;
let new_refresh_token = let new_refresh_token = add_refresh_token(
add_refresh_token(&mut txn, &session, new_access_token, refresh_token_str).await?; &mut txn,
&mut rng,
&clock,
&session,
new_access_token,
refresh_token_str,
)
.await?;
consume_refresh_token(&mut txn, &refresh_token).await?; consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
if let Some(access_token) = refresh_token.access_token { if let Some(access_token) = refresh_token.access_token {
revoke_access_token(&mut txn, access_token).await?; revoke_access_token(&mut txn, &clock, access_token).await?;
} }
let params = AccessTokenResponse::new(access_token_str) let params = AccessTokenResponse::new(access_token_str)

View File

@ -54,6 +54,7 @@ pub async fn get(
user_authorization: UserAuthorization, user_authorization: UserAuthorization,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
// TODO: error handling // TODO: error handling
let (_clock, mut rng) = crate::rng_and_clock()?;
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let session = user_authorization.protected(&mut conn).await?; let session = user_authorization.protected(&mut conn).await?;
@ -88,7 +89,7 @@ pub async fn get(
user_info, user_info,
}; };
let token = Jwt::sign(header, user_info, &signer)?; let token = Jwt::sign_with_rng(&mut rng, header, user_info, &signer)?;
Ok(JwtResponse(token).into_response()) Ok(JwtResponse(token).into_response())
} else { } else {
Ok(Json(user_info).into_response()) Ok(Json(user_info).into_response())

View File

@ -72,6 +72,7 @@ pub(crate) async fn post(
Query(query): Query<OptionalPostAuthAction>, Query(query): Query<OptionalPostAuthAction>,
Form(form): Form<ProtectedForm<EmailForm>>, Form(form): Form<ProtectedForm<EmailForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -86,14 +87,22 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let user_email = add_user_email(&mut txn, &session.user, form.email).await?; let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.data); let next = mas_router::AccountVerifyEmail::new(user_email.data);
let next = if let Some(action) = query.post_auth_action { let next = if let Some(action) = query.post_auth_action {
next.and_then(action) next.and_then(action)
} else { } else {
next next
}; };
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?; start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?; txn.commit().await?;

View File

@ -32,10 +32,10 @@ use mas_storage::{
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails, add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
remove_user_email, set_user_email_as_primary, remove_user_email, set_user_email_as_primary,
}, },
PostgresqlBackend, Clock, PostgresqlBackend,
}; };
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates}; use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, thread_rng, Rng}; use rand::{distributions::Uniform, Rng};
use serde::Deserialize; use serde::Deserialize;
use sqlx::{PgExecutor, PgPool}; use sqlx::{PgExecutor, PgPool};
use tracing::info; use tracing::info;
@ -93,17 +93,26 @@ async fn render(
async fn start_email_verification( async fn start_email_verification(
mailer: &Mailer, mailer: &Mailer,
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User<PostgresqlBackend>, user: &User<PostgresqlBackend>,
user_email: UserEmail<PostgresqlBackend>, user_email: UserEmail<PostgresqlBackend>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// First, generate a code // First, generate a code
let range = Uniform::<u32>::from(0..1_000_000); let range = Uniform::<u32>::from(0..1_000_000);
let code = thread_rng().sample(range).to_string(); let code = rng.sample(range).to_string();
let address: Address = user_email.email.parse()?; let address: Address = user_email.email.parse()?;
let verification = let verification = add_user_email_verification_code(
add_user_email_verification_code(executor, user_email, Duration::hours(8), code).await?; executor,
&mut rng,
clock,
user_email,
Duration::hours(8),
code,
)
.await?;
// And send the verification email // And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address); let mailbox = Mailbox::new(Some(user.username.clone()), address);
@ -126,6 +135,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>, Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
@ -143,9 +153,18 @@ pub(crate) async fn post(
match form { match form {
ManagementForm::Add { email } => { ManagementForm::Add { email } => {
let user_email = add_user_email(&mut txn, &session.user, email).await?; let user_email =
add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.data); let next = mas_router::AccountVerifyEmail::new(user_email.data);
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?; start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?; txn.commit().await?;
return Ok((cookie_jar, next.go()).into_response()); return Ok((cookie_jar, next.go()).into_response());
} }
@ -154,7 +173,15 @@ pub(crate) async fn post(
let user_email = get_user_email(&mut txn, &session.user, id).await?; let user_email = get_user_email(&mut txn, &session.user, id).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.data); let next = mas_router::AccountVerifyEmail::new(user_email.data);
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?; start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?; txn.commit().await?;
return Ok((cookie_jar, next.go()).into_response()); return Ok((cookie_jar, next.go()).into_response());
} }

View File

@ -23,9 +23,12 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::Route; use mas_router::Route;
use mas_storage::user::{ use mas_storage::{
consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code, user::{
mark_user_email_as_verified, set_user_email_as_primary, consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code,
mark_user_email_as_verified, set_user_email_as_primary,
},
Clock,
}; };
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates}; use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
@ -84,6 +87,7 @@ pub(crate) async fn post(
Path(id): Path<Ulid>, Path(id): Path<Ulid>,
Form(form): Form<ProtectedForm<CodeForm>>, Form(form): Form<ProtectedForm<CodeForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let clock = Clock::default();
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -105,12 +109,13 @@ pub(crate) async fn post(
} }
// TODO: make those 8 hours configurable // TODO: make those 8 hours configurable
let verification = lookup_user_email_verification_code(&mut txn, email, &form.code).await?; let verification =
lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?;
// TODO: display nice errors if the code was already consumed or expired // TODO: display nice errors if the code was already consumed or expired
let verification = consume_email_verification(&mut txn, verification).await?; let verification = consume_email_verification(&mut txn, &clock, verification).await?;
let _email = mark_user_email_as_verified(&mut txn, verification.email).await?; let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -81,6 +81,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ChangeForm>>, Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -96,7 +97,14 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
authenticate_session(&mut txn, &mut session, &form.current_password).await?; authenticate_session(
&mut txn,
&mut rng,
&clock,
&mut session,
&form.current_password,
)
.await?;
// TODO: display nice form errors // TODO: display nice form errors
if form.new_password != form.new_password_confirm { if form.new_password != form.new_password_confirm {
@ -104,7 +112,15 @@ pub(crate) async fn post(
} }
let phf = Argon2::default(); let phf = Argon2::default();
set_password(&mut txn, phf, &session.user, &form.new_password).await?; set_password(
&mut txn,
&mut rng,
&clock,
phf,
&session.user,
&form.new_password,
)
.await?;
let reply = render(templates.clone(), session, cookie_jar).await?; let reply = render(templates.clone(), session, cookie_jar).await?;

View File

@ -80,6 +80,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut conn = pool.acquire().await?; let mut conn = pool.acquire().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -114,7 +115,7 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response()); return Ok((cookie_jar, Html(content)).into_response());
} }
match login(&mut conn, &form.username, &form.password).await { match login(&mut conn, &mut rng, &clock, &form.username, &form.password).await {
Ok(session_info) => { Ok(session_info) => {
let cookie_jar = cookie_jar.set_session(&session_info); let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.go_next(); let reply = query.go_next();

View File

@ -23,7 +23,7 @@ use mas_axum_utils::{
}; };
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::user::end_session; use mas_storage::{user::end_session, Clock};
use sqlx::PgPool; use sqlx::PgPool;
pub(crate) async fn post( pub(crate) async fn post(
@ -31,6 +31,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>, Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let clock = Clock::default();
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -40,7 +41,7 @@ pub(crate) async fn post(
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut txn).await?;
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
end_session(&mut txn, &session).await?; end_session(&mut txn, &clock, &session).await?;
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
} }

View File

@ -80,6 +80,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>, Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -98,7 +99,7 @@ pub(crate) async fn post(
}; };
// TODO: recover from errors here // TODO: recover from errors here
authenticate_session(&mut txn, &mut session, &form.password).await?; authenticate_session(&mut txn, &mut rng, &clock, &mut session, &form.password).await?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?; txn.commit().await?;

View File

@ -39,7 +39,7 @@ use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField, EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
TemplateContext, Templates, ToFormState, TemplateContext, Templates, ToFormState,
}; };
use rand::{distributions::Uniform, thread_rng, Rng}; use rand::{distributions::Uniform, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool}; use sqlx::{PgConnection, PgPool};
@ -87,6 +87,7 @@ pub(crate) async fn get(
} }
} }
#[allow(clippy::too_many_lines)]
pub(crate) async fn post( pub(crate) async fn post(
State(mailer): State<Mailer>, State(mailer): State<Mailer>,
State(policy_factory): State<Arc<PolicyFactory>>, State(policy_factory): State<Arc<PolicyFactory>>,
@ -96,6 +97,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?; let form = cookie_jar.verify_form(form)?;
@ -180,18 +182,34 @@ pub(crate) async fn post(
} }
let pfh = Argon2::default(); let pfh = Argon2::default();
let user = register_user(&mut txn, pfh, &form.username, &form.password).await?; let user = register_user(
&mut txn,
&mut rng,
&clock,
pfh,
&form.username,
&form.password,
)
.await?;
let user_email = add_user_email(&mut txn, &user, form.email).await?; let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?;
// First, generate a code // First, generate a code
let range = Uniform::<u32>::from(0..1_000_000); let range = Uniform::<u32>::from(0..1_000_000);
let code = thread_rng().sample(range).to_string(); let code = rng.sample(range);
let code = format!("{code:06}");
let address: Address = user_email.email.parse()?; let address: Address = user_email.email.parse()?;
let verification = let verification = add_user_email_verification_code(
add_user_email_verification_code(&mut txn, user_email, Duration::hours(8), code).await?; &mut txn,
&mut rng,
&clock,
user_email,
Duration::hours(8),
code,
)
.await?;
// And send the verification email // And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address); let mailbox = Mailbox::new(Some(user.username.clone()), address);
@ -203,7 +221,7 @@ pub(crate) async fn post(
let next = mas_router::AccountVerifyEmail::new(verification.email.data) let next = mas_router::AccountVerifyEmail::new(verification.email.data)
.and_maybe(query.post_auth_action); .and_maybe(query.post_auth_action);
let session = start_session(&mut txn, user).await?; let session = start_session(&mut txn, &mut rng, &clock, user).await?;
txn.commit().await?; txn.commit().await?;

View File

@ -309,6 +309,7 @@ impl<T> Jwt<'static, T> {
S: Signature, S: Signature,
T: Serialize, T: Serialize,
{ {
#[allow(clippy::disallowed_methods)]
Self::sign_with_rng(thread_rng(), header, payload, key) Self::sign_with_rng(thread_rng(), header, payload, key)
} }
@ -357,6 +358,7 @@ impl<T> Jwt<'static, T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#![allow(clippy::disallowed_methods)]
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use rand::thread_rng; use rand::thread_rng;

View File

@ -19,6 +19,7 @@ tracing = "0.1.37"
argon2 = { version = "0.4.1", features = ["password-hash"] } argon2 = { version = "0.4.1", features = ["password-hash"] }
password-hash = { version = "0.4.2", features = ["std"] } password-hash = { version = "0.4.2", features = ["std"] }
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] } url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.1" uuid = "1.2.1"
ulid = { version = "1.0.0", features = ["uuid", "serde"] } ulid = { version = "1.0.0", features = ["uuid", "serde"] }

View File

@ -19,6 +19,7 @@ use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState, CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
Device, User, UserEmail, Device, User, UserEmail,
}; };
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres}; use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
@ -27,7 +28,7 @@ use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::{user::lookup_user_by_username, DatabaseInconsistencyError, PostgresqlBackend}; use crate::{user::lookup_user_by_username, Clock, DatabaseInconsistencyError, PostgresqlBackend};
struct CompatAccessTokenLookup { struct CompatAccessTokenLookup {
compat_access_token_id: Uuid, compat_access_token_id: Uuid,
@ -67,6 +68,7 @@ impl CompatAccessTokenLookupError {
#[tracing::instrument(skip_all, err)] #[tracing::instrument(skip_all, err)]
pub async fn lookup_active_compat_access_token( pub async fn lookup_active_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
token: &str, token: &str,
) -> Result< ) -> Result<
( (
@ -112,7 +114,7 @@ pub async fn lookup_active_compat_access_token(
// Check for token expiration // Check for token expiration
if let Some(expires_at) = res.compat_access_token_expires_at { if let Some(expires_at) = res.compat_access_token_expires_at {
if expires_at < Utc::now() { if expires_at < clock.now() {
return Err(CompatAccessTokenLookupError::Expired { when: expires_at }); return Err(CompatAccessTokenLookupError::Expired { when: expires_at });
} }
} }
@ -311,7 +313,9 @@ pub async fn lookup_active_compat_refresh_token(
err(Display), err(Display),
)] )]
pub async fn compat_login( pub async fn compat_login(
conn: impl Acquire<'_, Database = Postgres>, conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
username: &str, username: &str,
password: &str, password: &str,
device: Device, device: Device,
@ -348,8 +352,8 @@ pub async fn compat_login(
.instrument(tracing::info_span!("Verify hashed password")) .instrument(tracing::info_span!("Verify hashed password"))
.await??; .await??;
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id)); tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -392,12 +396,14 @@ pub async fn compat_login(
)] )]
pub async fn add_compat_access_token( pub async fn add_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession<PostgresqlBackend>, session: &CompatSession<PostgresqlBackend>,
token: String, token: String,
expires_after: Option<Duration>, expires_after: Option<Duration>,
) -> Result<CompatAccessToken<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatAccessToken<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id)); tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after); let expires_at = expires_after.map(|expires_after| created_at + expires_after);
@ -436,9 +442,10 @@ pub async fn add_compat_access_token(
)] )]
pub async fn expire_compat_access_token( pub async fn expire_compat_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: CompatAccessToken<PostgresqlBackend>, access_token: CompatAccessToken<PostgresqlBackend>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let expires_at = Utc::now(); let expires_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
UPDATE compat_access_tokens UPDATE compat_access_tokens
@ -474,12 +481,14 @@ pub async fn expire_compat_access_token(
)] )]
pub async fn add_compat_refresh_token( pub async fn add_compat_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession<PostgresqlBackend>, session: &CompatSession<PostgresqlBackend>,
access_token: &CompatAccessToken<PostgresqlBackend>, access_token: &CompatAccessToken<PostgresqlBackend>,
token: String, token: String,
) -> Result<CompatRefreshToken<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatRefreshToken<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id)); tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -514,9 +523,10 @@ pub async fn add_compat_refresh_token(
)] )]
pub async fn compat_logout( pub async fn compat_logout(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
token: &str, token: &str,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let finished_at = Utc::now(); let finished_at = clock.now();
// TODO: this does not check for token expiration // TODO: this does not check for token expiration
let compat_session_id = sqlx::query_scalar!( let compat_session_id = sqlx::query_scalar!(
r#" r#"
@ -552,9 +562,10 @@ pub async fn compat_logout(
)] )]
pub async fn consume_compat_refresh_token( pub async fn consume_compat_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: CompatRefreshToken<PostgresqlBackend>, refresh_token: CompatRefreshToken<PostgresqlBackend>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let consumed_at = Utc::now(); let consumed_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
UPDATE compat_refresh_tokens UPDATE compat_refresh_tokens
@ -587,11 +598,13 @@ pub async fn consume_compat_refresh_token(
)] )]
pub async fn insert_compat_sso_login( pub async fn insert_compat_sso_login(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
login_token: String, login_token: String,
redirect_uri: Url, redirect_uri: Url,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id)); tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -845,7 +858,9 @@ pub async fn get_compat_sso_login_by_token(
err(Display), err(Display),
)] )]
pub async fn fullfill_compat_sso_login( pub async fn fullfill_compat_sso_login(
conn: impl Acquire<'_, Database = Postgres>, conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
user: User<PostgresqlBackend>, user: User<PostgresqlBackend>,
mut login: CompatSsoLogin<PostgresqlBackend>, mut login: CompatSsoLogin<PostgresqlBackend>,
device: Device, device: Device,
@ -856,8 +871,8 @@ pub async fn fullfill_compat_sso_login(
let mut txn = conn.begin().await.context("could not start transaction")?; let mut txn = conn.begin().await.context("could not start transaction")?;
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user.id", tracing::field::display(user.data)); tracing::Span::current().record("user.id", tracing::field::display(user.data));
sqlx::query!( sqlx::query!(
@ -883,7 +898,7 @@ pub async fn fullfill_compat_sso_login(
finished_at: None, finished_at: None,
}; };
let fulfilled_at = Utc::now(); let fulfilled_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
UPDATE compat_sso_logins UPDATE compat_sso_logins
@ -924,6 +939,7 @@ pub async fn fullfill_compat_sso_login(
)] )]
pub async fn mark_compat_sso_login_as_exchanged( pub async fn mark_compat_sso_login_as_exchanged(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
mut login: CompatSsoLogin<PostgresqlBackend>, mut login: CompatSsoLogin<PostgresqlBackend>,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> { ) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> {
let (fulfilled_at, session) = match login.state { let (fulfilled_at, session) = match login.state {
@ -934,7 +950,7 @@ pub async fn mark_compat_sso_login_as_exchanged(
_ => bail!("sso login in wrong state"), _ => bail!("sso login in wrong state"),
}; };
let exchanged_at = Utc::now(); let exchanged_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
UPDATE compat_sso_logins UPDATE compat_sso_logins

View File

@ -15,7 +15,12 @@
//! Interactions with the database //! Interactions with the database
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)] #![deny(
clippy::all,
clippy::str_to_string,
clippy::future_not_send,
rustdoc::broken_intra_doc_links
)]
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow( #![allow(
clippy::missing_errors_doc, clippy::missing_errors_doc,
@ -23,12 +28,27 @@
clippy::module_name_repetitions clippy::module_name_repetitions
)] )]
use chrono::{DateTime, Utc};
use mas_data_model::{StorageBackend, StorageBackendMarker}; use mas_data_model::{StorageBackend, StorageBackendMarker};
use serde::Serialize; use serde::Serialize;
use sqlx::migrate::Migrator; use sqlx::migrate::Migrator;
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
#[derive(Default, Debug, Clone, Copy)]
pub struct Clock {
_private: (),
}
impl Clock {
#[must_use]
pub fn now(&self) -> DateTime<Utc> {
// This is the clock used elsewhere, it's fine to call Utc::now here
#[allow(clippy::disallowed_methods)]
Utc::now()
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("database query returned an inconsistent state")] #[error("database query returned an inconsistent state")]
pub struct DatabaseInconsistencyError; pub struct DatabaseInconsistencyError;

View File

@ -15,13 +15,14 @@
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail}; use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres}; use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError}; use super::client::{lookup_client, ClientFetchError};
use crate::{DatabaseInconsistencyError, PostgresqlBackend}; use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@ -35,13 +36,15 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
)] )]
pub async fn add_access_token( pub async fn add_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session<PostgresqlBackend>, session: &Session<PostgresqlBackend>,
access_token: String, access_token: String,
expires_after: Duration, expires_after: Duration,
) -> Result<AccessToken<PostgresqlBackend>, anyhow::Error> { ) -> Result<AccessToken<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let expires_at = created_at + expires_after; let expires_at = created_at + expires_after;
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id)); tracing::Span::current().record("access_token.id", tracing::field::display(id));
@ -243,9 +246,10 @@ where
)] )]
pub async fn revoke_access_token( pub async fn revoke_access_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: AccessToken<PostgresqlBackend>, access_token: AccessToken<PostgresqlBackend>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let revoked_at = Utc::now(); let revoked_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
UPDATE oauth2_access_tokens UPDATE oauth2_access_tokens
@ -266,9 +270,9 @@ pub async fn revoke_access_token(
} }
} }
pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result<u64> { pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result<u64> {
// Cleanup token which expired more than 15 minutes ago // Cleanup token which expired more than 15 minutes ago
let threshold = Utc::now() - Duration::minutes(15); let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
DELETE FROM oauth2_access_tokens DELETE FROM oauth2_access_tokens

View File

@ -24,13 +24,14 @@ use mas_data_model::{
}; };
use mas_iana::oauth::PkceCodeChallengeMethod; use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope}; use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use super::client::lookup_client; use super::client::lookup_client;
use crate::{DatabaseInconsistencyError, PostgresqlBackend}; use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@ -43,6 +44,8 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant( pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
client: Client<PostgresqlBackend>, client: Client<PostgresqlBackend>,
redirect_uri: Url, redirect_uri: Url,
scope: Scope, scope: Scope,
@ -67,8 +70,8 @@ pub async fn new_authorization_grant(
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)); let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code); let code_str = code.as_ref().map(|c| &c.code);
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("grant.id", tracing::field::display(id)); tracing::Span::current().record("grant.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -504,11 +507,13 @@ pub async fn lookup_grant_by_code(
)] )]
pub async fn derive_session( pub async fn derive_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
grant: &AuthorizationGrant<PostgresqlBackend>, grant: &AuthorizationGrant<PostgresqlBackend>,
browser_session: BrowserSession<PostgresqlBackend>, browser_session: BrowserSession<PostgresqlBackend>,
) -> Result<Session<PostgresqlBackend>, anyhow::Error> { ) -> Result<Session<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("session.id", tracing::field::display(id)); tracing::Span::current().record("session.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -623,9 +628,10 @@ pub async fn give_consent_to_grant(
)] )]
pub async fn exchange_grant( pub async fn exchange_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
mut grant: AuthorizationGrant<PostgresqlBackend>, mut grant: AuthorizationGrant<PostgresqlBackend>,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> { ) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> {
let exchanged_at = Utc::now(); let exchanged_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
UPDATE oauth2_authorization_grants UPDATE oauth2_authorization_grants

View File

@ -14,14 +14,14 @@
use std::str::FromStr; use std::str::FromStr;
use chrono::Utc;
use mas_data_model::{Client, User}; use mas_data_model::{Client, User};
use oauth2_types::scope::{Scope, ScopeToken}; use oauth2_types::scope::{Scope, ScopeToken};
use rand::Rng;
use sqlx::PgExecutor; use sqlx::PgExecutor;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::PostgresqlBackend; use crate::{Clock, PostgresqlBackend};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@ -67,17 +67,19 @@ pub async fn fetch_client_consent(
)] )]
pub async fn insert_client_consent( pub async fn insert_client_consent(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User<PostgresqlBackend>, user: &User<PostgresqlBackend>,
client: &Client<PostgresqlBackend>, client: &Client<PostgresqlBackend>,
scope: &Scope, scope: &Scope,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let now = Utc::now(); let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter() .iter()
.map(|token| { .map(|token| {
( (
token.to_string(), token.to_string(),
Uuid::from(Ulid::from_datetime(now.into())), Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
) )
}) })
.unzip(); .unzip();

View File

@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use chrono::Utc;
use mas_data_model::Session; use mas_data_model::Session;
use sqlx::PgExecutor; use sqlx::PgExecutor;
use uuid::Uuid; use uuid::Uuid;
use crate::PostgresqlBackend; use crate::{Clock, PostgresqlBackend};
pub mod access_token; pub mod access_token;
pub mod authorization_grant; pub mod authorization_grant;
@ -37,9 +36,10 @@ pub mod refresh_token;
)] )]
pub async fn end_oauth_session( pub async fn end_oauth_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
session: Session<PostgresqlBackend>, session: Session<PostgresqlBackend>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let finished_at = Utc::now(); let finished_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
UPDATE oauth2_sessions UPDATE oauth2_sessions

View File

@ -17,13 +17,14 @@ use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail, AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
}; };
use rand::Rng;
use sqlx::{PgConnection, PgExecutor}; use sqlx::{PgConnection, PgExecutor};
use thiserror::Error; use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError}; use super::client::{lookup_client, ClientFetchError};
use crate::{DatabaseInconsistencyError, PostgresqlBackend}; use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
@ -38,12 +39,14 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
)] )]
pub async fn add_refresh_token( pub async fn add_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session<PostgresqlBackend>, session: &Session<PostgresqlBackend>,
access_token: AccessToken<PostgresqlBackend>, access_token: AccessToken<PostgresqlBackend>,
refresh_token: String, refresh_token: String,
) -> anyhow::Result<RefreshToken<PostgresqlBackend>> { ) -> anyhow::Result<RefreshToken<PostgresqlBackend>> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id)); tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -263,9 +266,10 @@ pub async fn lookup_active_refresh_token(
)] )]
pub async fn consume_refresh_token( pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: &RefreshToken<PostgresqlBackend>, refresh_token: &RefreshToken<PostgresqlBackend>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let consumed_at = Utc::now(); let consumed_at = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
UPDATE oauth2_refresh_tokens UPDATE oauth2_refresh_tokens

View File

@ -22,7 +22,7 @@ use mas_data_model::{
UserEmailVerificationState, UserEmailVerificationState,
}; };
use password_hash::{PasswordHash, PasswordHasher, SaltString}; use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::thread_rng; use rand::{CryptoRng, Rng};
use sqlx::{Acquire, PgExecutor, Postgres, Transaction}; use sqlx::{Acquire, PgExecutor, Postgres, Transaction};
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
@ -31,6 +31,7 @@ use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use super::{DatabaseInconsistencyError, PostgresqlBackend}; use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::Clock;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct UserLookup { struct UserLookup {
@ -68,7 +69,9 @@ pub enum LoginError {
err, err,
)] )]
pub async fn login( pub async fn login(
conn: impl Acquire<'_, Database = Postgres>, conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
username: &str, username: &str,
password: &str, password: &str,
) -> Result<BrowserSession<PostgresqlBackend>, LoginError> { ) -> Result<BrowserSession<PostgresqlBackend>, LoginError> {
@ -86,8 +89,8 @@ pub async fn login(
} }
})?; })?;
let mut session = start_session(&mut txn, user).await?; let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
authenticate_session(&mut txn, &mut session, password) authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
.await .await
.map_err(|source| { .map_err(|source| {
if matches!(source, AuthenticationError::Password { .. }) { if matches!(source, AuthenticationError::Password { .. }) {
@ -230,10 +233,12 @@ pub async fn lookup_active_session(
)] )]
pub async fn start_session( pub async fn start_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: User<PostgresqlBackend>, user: User<PostgresqlBackend>,
) -> Result<BrowserSession<PostgresqlBackend>, anyhow::Error> { ) -> Result<BrowserSession<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id)); tracing::Span::current().record("user_session.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -301,13 +306,16 @@ pub enum AuthenticationError {
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields( fields(
session.id = %session.data, user.id = %session.user.data,
user.id = %session.user.data user_session.id = %session.data,
user_session_authentication.id,
), ),
err, err,
)] )]
pub async fn authenticate_session( pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &mut BrowserSession<PostgresqlBackend>, session: &mut BrowserSession<PostgresqlBackend>,
password: &str, password: &str,
) -> Result<(), AuthenticationError> { ) -> Result<(), AuthenticationError> {
@ -341,8 +349,13 @@ pub async fn authenticate_session(
.await??; .await??;
// That went well, let's insert the auth info // That went well, let's insert the auth info
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record(
"user_session_authentication.id",
tracing::field::display(id),
);
sqlx::query!( sqlx::query!(
r#" r#"
INSERT INTO user_session_authentications INSERT INTO user_session_authentications
@ -376,12 +389,14 @@ pub async fn authenticate_session(
)] )]
pub async fn register_user( pub async fn register_user(
txn: &mut Transaction<'_, Postgres>, txn: &mut Transaction<'_, Postgres>,
phf: impl PasswordHasher, mut rng: impl CryptoRng + Rng + Send,
clock: &Clock,
phf: impl PasswordHasher + Send,
username: &str, username: &str,
password: &str, password: &str,
) -> Result<User<PostgresqlBackend>, anyhow::Error> { ) -> Result<User<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user.id", tracing::field::display(id)); tracing::Span::current().record("user.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -405,7 +420,7 @@ pub async fn register_user(
primary_email: None, primary_email: None,
}; };
set_password(txn.borrow_mut(), phf, &user, password).await?; set_password(txn.borrow_mut(), &mut rng, clock, phf, &user, password).await?;
Ok(user) Ok(user)
} }
@ -420,15 +435,17 @@ pub async fn register_user(
)] )]
pub async fn set_password( pub async fn set_password(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
phf: impl PasswordHasher, mut rng: impl CryptoRng + Rng + Send,
clock: &Clock,
phf: impl PasswordHasher + Send,
user: &User<PostgresqlBackend>, user: &User<PostgresqlBackend>,
password: &str, password: &str,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime(created_at.into());
tracing::Span::current().record("user_password.id", tracing::field::display(id)); tracing::Span::current().record("user_password.id", tracing::field::display(id));
let salt = SaltString::generate(thread_rng()); let salt = SaltString::generate(&mut rng);
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?; let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
sqlx::query_scalar!( sqlx::query_scalar!(
@ -456,9 +473,10 @@ pub async fn set_password(
)] )]
pub async fn end_session( pub async fn end_session(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
session: &BrowserSession<PostgresqlBackend>, session: &BrowserSession<PostgresqlBackend>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let now = Utc::now(); let now = clock.now();
let res = sqlx::query!( let res = sqlx::query!(
r#" r#"
UPDATE user_sessions UPDATE user_sessions
@ -672,11 +690,13 @@ pub async fn get_user_email(
)] )]
pub async fn add_user_email( pub async fn add_user_email(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User<PostgresqlBackend>, user: &User<PostgresqlBackend>,
email: String, email: String,
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> { ) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id)); tracing::Span::current().record("user_email.id", tracing::field::display(id));
sqlx::query!( sqlx::query!(
@ -842,9 +862,10 @@ pub async fn lookup_user_email_by_id(
)] )]
pub async fn mark_user_email_as_verified( pub async fn mark_user_email_as_verified(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
mut email: UserEmail<PostgresqlBackend>, mut email: UserEmail<PostgresqlBackend>,
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> { ) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
let confirmed_at = Utc::now(); let confirmed_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
UPDATE user_emails UPDATE user_emails
@ -881,10 +902,11 @@ struct UserEmailConfirmationCodeLookup {
)] )]
pub async fn lookup_user_email_verification_code( pub async fn lookup_user_email_verification_code(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
email: UserEmail<PostgresqlBackend>, email: UserEmail<PostgresqlBackend>,
code: &str, code: &str,
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> { ) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
let now = Utc::now(); let now = clock.now();
let res = sqlx::query_as!( let res = sqlx::query_as!(
UserEmailConfirmationCodeLookup, UserEmailConfirmationCodeLookup,
@ -935,13 +957,14 @@ pub async fn lookup_user_email_verification_code(
)] )]
pub async fn consume_email_verification( pub async fn consume_email_verification(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
clock: &Clock,
mut verification: UserEmailVerification<PostgresqlBackend>, mut verification: UserEmailVerification<PostgresqlBackend>,
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> { ) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
if !matches!(verification.state, UserEmailVerificationState::Valid) { if !matches!(verification.state, UserEmailVerificationState::Valid) {
bail!("user email verification in wrong state"); bail!("user email verification in wrong state");
} }
let consumed_at = Utc::now(); let consumed_at = clock.now();
sqlx::query!( sqlx::query!(
r#" r#"
@ -974,12 +997,14 @@ pub async fn consume_email_verification(
)] )]
pub async fn add_user_email_verification_code( pub async fn add_user_email_verification_code(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
email: UserEmail<PostgresqlBackend>, email: UserEmail<PostgresqlBackend>,
max_age: chrono::Duration, max_age: chrono::Duration,
code: String, code: String,
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> { ) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now(); let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into()); let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id)); tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
let expires_at = created_at + max_age; let expires_at = created_at + max_age;
@ -1013,23 +1038,27 @@ pub async fn add_user_email_verification_code(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use rand::SeedableRng;
use super::*; use super::*;
#[sqlx::test(migrator = "crate::MIGRATOR")] #[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_user_registration_and_login(pool: sqlx::PgPool) -> anyhow::Result<()> { async fn test_user_registration_and_login(pool: sqlx::PgPool) -> anyhow::Result<()> {
let clock = Clock::default();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let exists = username_exists(&mut txn, "john").await?; let exists = username_exists(&mut txn, "john").await?;
assert!(!exists); assert!(!exists);
let hasher = Argon2::default(); let hasher = Argon2::default();
let user = register_user(&mut txn, hasher, "john", "hunter2").await?; let user = register_user(&mut txn, &mut rng, &clock, hasher, "john", "hunter2").await?;
assert_eq!(user.username, "john"); assert_eq!(user.username, "john");
let exists = username_exists(&mut txn, "john").await?; let exists = username_exists(&mut txn, "john").await?;
assert!(exists); assert!(exists);
let session = login(&mut txn, "john", "hunter2").await?; let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
assert_eq!(session.user.data, user.data); assert_eq!(session.user.data, user.data);
let user2 = lookup_user_by_username(&mut txn, "john").await?; let user2 = lookup_user_by_username(&mut txn, "john").await?;

View File

@ -14,13 +14,14 @@
//! Database-related tasks //! Database-related tasks
use mas_storage::Clock;
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
use super::Task; use super::Task;
#[derive(Clone)] #[derive(Clone)]
struct CleanupExpired(Pool<Postgres>); struct CleanupExpired(Pool<Postgres>, Clock);
impl std::fmt::Debug for CleanupExpired { impl std::fmt::Debug for CleanupExpired {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -31,7 +32,7 @@ impl std::fmt::Debug for CleanupExpired {
#[async_trait::async_trait] #[async_trait::async_trait]
impl Task for CleanupExpired { impl Task for CleanupExpired {
async fn run(&self) { async fn run(&self) {
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0).await; let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await;
match res { match res {
Ok(0) => { Ok(0) => {
debug!("no token to clean up"); debug!("no token to clean up");
@ -49,5 +50,6 @@ impl Task for CleanupExpired {
/// Cleanup expired tokens /// Cleanup expired tokens
#[must_use] #[must_use]
pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone { pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone {
CleanupExpired(pool.clone()) // XXX: the clock should come from somewhere else
CleanupExpired(pool.clone(), Clock::default())
} }