1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Pass the rng and clock around

This commit is contained in:
Quentin Gliech
2022-10-21 18:50:06 +02:00
parent 5c7e66a9b2
commit 559181c2c3
40 changed files with 504 additions and 218 deletions

8
Cargo.lock generated
View File

@ -2468,6 +2468,8 @@ dependencies = [
"opentelemetry-semantic-conventions",
"opentelemetry-zipkin",
"prometheus",
"rand",
"rand_chacha",
"rustls",
"serde_json",
"serde_yaml",
@ -2497,6 +2499,7 @@ dependencies = [
"mas-keystore",
"pem-rfc7468",
"rand",
"rand_chacha",
"rustls-pemfile",
"schemars",
"serde",
@ -2567,6 +2570,7 @@ dependencies = [
"mime",
"oauth2-types",
"rand",
"rand_chacha",
"serde",
"serde_json",
"serde_urlencoded",
@ -2779,6 +2783,7 @@ dependencies = [
"oauth2-types",
"password-hash",
"rand",
"rand_chacha",
"serde",
"serde_json",
"sqlx",
@ -3068,7 +3073,7 @@ checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
[[package]]
name = "opa-wasm"
version = "0.1.0"
source = "git+https://github.com/matrix-org/rust-opa-wasm.git#325071ee8a2a7d18cc611365edd3235945a8cdf4"
source = "git+https://github.com/matrix-org/rust-opa-wasm.git#f838595670747b0644b6bfd9829fca5d63bbee66"
dependencies = [
"anyhow",
"base64",
@ -3079,6 +3084,7 @@ dependencies = [
"md-5",
"parse-size",
"rand",
"rayon-core",
"semver",
"serde",
"serde_json",

View File

@ -1,2 +1,11 @@
msrv = "1.61.0"
doc-valid-idents = ["OpenID", "OAuth", ".."]
disallowed-methods = [
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },
{ path = "chrono::Utc::now", reason = "source the current time from the clock instead" },
]
disallowed-types = [
"rand::OsRng",
]

View File

@ -15,6 +15,7 @@
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;
@ -56,20 +57,20 @@ pub struct CsrfToken {
impl CsrfToken {
/// Create a new token from a defined value valid for a specified duration
fn new(token: [u8; 32], ttl: Duration) -> Self {
let expiration = Utc::now() + ttl;
fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
let expiration = now + ttl;
Self { expiration, token }
}
/// Generate a new random token valid for a specified duration
fn generate(ttl: Duration) -> Self {
let token = rand::random();
Self::new(token, ttl)
fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
let token = rng.gen();
Self::new(token, now, ttl)
}
/// Generate a new token with the same value but an up to date expiration
fn refresh(self, ttl: Duration) -> Self {
Self::new(self.token, ttl)
fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
Self::new(self.token, now, ttl)
}
/// Get the value to include in HTML forms
@ -88,8 +89,8 @@ impl CsrfToken {
}
}
fn verify_expiration(self) -> Result<Self, CsrfError> {
if Utc::now() < self.expiration {
fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
if now < self.expiration {
Ok(self)
} else {
Err(CsrfError::Expired)
@ -118,12 +119,18 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
cookie.set_path("/");
cookie.set_http_only(true);
// XXX: the rng source and clock should come from somewhere else
#[allow(clippy::disallowed_methods)]
let now = Utc::now();
#[allow(clippy::disallowed_methods)]
let rng = thread_rng();
let new_token = cookie
.decode()
.ok()
.and_then(|token: CsrfToken| token.verify_expiration().ok())
.unwrap_or_else(|| CsrfToken::generate(Duration::hours(1)))
.refresh(Duration::hours(1));
.and_then(|token: CsrfToken| token.verify_expiration(now).ok())
.unwrap_or_else(|| CsrfToken::generate(now, rng, Duration::hours(1)))
.refresh(now, Duration::hours(1));
let cookie = cookie.encode(&new_token);
let jar = jar.add(cookie);
@ -131,9 +138,13 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
}
fn verify_form<T>(&self, form: ProtectedForm<T>) -> Result<T, CsrfError> {
// XXX: the clock should come from somewhere else
#[allow(clippy::disallowed_methods)]
let now = Utc::now();
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
let token: CsrfToken = cookie.decode()?;
let token = token.verify_expiration()?;
let token = token.verify_expiration(now)?;
token.verify_form_value(&form.csrf)?;
Ok(form.inner)
}

View File

@ -6,23 +6,25 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
axum = "0.6.0-rc.2"
tokio = { version = "1.21.2", features = ["full"] }
futures-util = "0.3.25"
anyhow = "1.0.66"
argon2 = { version = "0.4.1", features = ["password-hash"] }
atty = "0.2.14"
axum = "0.6.0-rc.2"
clap = { version = "4.0.18", features = ["derive"] }
dotenv = "0.15.0"
tower = { version = "0.4.13", features = ["full"] }
futures-util = "0.3.25"
hyper = { version = "0.14.22", features = ["full"] }
serde_yaml = "0.9.14"
serde_json = "1.0.87"
url = "2.3.1"
argon2 = { version = "0.4.1", features = ["password-hash"] }
watchman_client = "0.8.0"
atty = "0.2.14"
listenfd = "1.0.0"
rustls = "0.20.7"
itertools = "0.10.5"
listenfd = "1.0.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
rustls = "0.20.7"
serde_json = "1.0.87"
serde_yaml = "0.9.14"
tokio = { version = "1.21.2", features = ["full"] }
tower = { version = "0.4.13", features = ["full"] }
url = "2.3.1"
watchman_client = "0.8.0"
tracing = "0.1.37"
tracing-appender = "0.2.2"

View File

@ -20,7 +20,9 @@ use mas_storage::{
user::{
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
},
Clock,
};
use rand::SeedableRng;
use tracing::{info, warn};
#[derive(Parser, Debug)]
@ -51,14 +53,17 @@ enum Subcommand {
impl Options {
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
use Subcommand as SC;
let clock = Clock::default();
match &self.subcommand {
SC::Register { username, password } => {
let config: DatabaseConfig = root.load_config()?;
let pool = config.connect().await?;
let mut txn = pool.begin().await?;
let hasher = Argon2::default();
let rng = rand_chacha::ChaChaRng::from_entropy();
let user = register_user(&mut txn, hasher, username, password).await?;
let user = register_user(&mut txn, rng, &clock, hasher, username, password).await?;
txn.commit().await?;
info!(?user, "User registered");
@ -76,7 +81,7 @@ impl Options {
let user = lookup_user_by_username(&mut txn, username).await?;
let email = lookup_user_email(&mut txn, &user, email).await?;
let email = mark_user_email_as_verified(&mut txn, email).await?;
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
txn.commit().await?;
info!(?email, "Email marked as verified");

View File

@ -28,6 +28,7 @@ lettre = { version = "0.10.1", default-features = false, features = ["serde", "b
pem-rfc7468 = "0.6.0"
rustls-pemfile = "1.0.1"
rand = "0.8.5"
rand_chacha = "0.3.1"
indoc = "1.0.7"

View File

@ -20,7 +20,7 @@ use mas_jose::jwk::{JsonWebKey, JsonWebKeySet};
use mas_keystore::{Encrypter, Keystore, PrivateKey};
use rand::{
distributions::{Alphanumeric, DistString},
thread_rng,
thread_rng, SeedableRng,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -139,64 +139,72 @@ impl ConfigurationSection<'_> for SecretsConfig {
#[tracing::instrument]
async fn generate() -> anyhow::Result<Self> {
// XXX: that RNG should come from somewhere else
#[allow(clippy::disallowed_methods)]
let mut rng = rand_chacha::ChaChaRng::from_rng(thread_rng())?;
info!("Generating keys...");
let span = tracing::info_span!("rsa");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let rsa_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_rsa(thread_rng()).unwrap();
let ret = PrivateKey::generate_rsa(key_rng).unwrap();
info!("Done generating RSA key");
ret
})
.await
.context("could not join blocking task")?;
let rsa_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
kid: Alphanumeric.sample_string(&mut rng, 10),
password: None,
key: KeyOrFile::Key(rsa_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
let span = tracing::info_span!("ec_p256");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_p256_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_ec_p256(thread_rng());
let ret = PrivateKey::generate_ec_p256(key_rng);
info!("Done generating EC P-256 key");
ret
})
.await
.context("could not join blocking task")?;
let ec_p256_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
kid: Alphanumeric.sample_string(&mut rng, 10),
password: None,
key: KeyOrFile::Key(ec_p256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
let span = tracing::info_span!("ec_p384");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_p384_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_ec_p384(thread_rng());
let ret = PrivateKey::generate_ec_p384(key_rng);
info!("Done generating EC P-256 key");
ret
})
.await
.context("could not join blocking task")?;
let ec_p384_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
kid: Alphanumeric.sample_string(&mut rng, 10),
password: None,
key: KeyOrFile::Key(ec_p384_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
let span = tracing::info_span!("ec_k256");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_k256_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_ec_k256(thread_rng());
let ret = PrivateKey::generate_ec_k256(key_rng);
info!("Done generating EC secp256k1 key");
ret
})
.await
.context("could not join blocking task")?;
let ec_k256_key = KeyConfig {
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
kid: Alphanumeric.sample_string(&mut rng, 10),
password: None,
key: KeyOrFile::Key(ec_k256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};

View File

@ -263,6 +263,8 @@ mod tests {
#[test]
fn test_generate_and_check() {
const COUNT: usize = 500; // Generate 500 of each token type
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng();
for t in [

View File

@ -44,6 +44,7 @@ chrono = { version = "0.4.22", features = ["serde"] }
url = { version = "2.3.1", features = ["serde"] }
mime = "0.3.16"
rand = "0.8.5"
rand_chacha = "0.3.1"
headers = "0.3.8"
ulid = "1.0.0"

View File

@ -13,7 +13,7 @@
// limitations under the License.
use axum::{extract::State, response::IntoResponse, Json};
use chrono::{Duration, Utc};
use chrono::Duration;
use hyper::StatusCode;
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
use mas_storage::{
@ -22,9 +22,8 @@ use mas_storage::{
get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged,
CompatSsoLoginLookupError,
},
PostgresqlBackend,
Clock, PostgresqlBackend,
};
use rand::thread_rng;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
use sqlx::{PgPool, Postgres, Transaction};
@ -201,6 +200,7 @@ pub(crate) async fn post(
State(homeserver): State<MatrixHomeserver>,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let session = match input.credentials {
Credentials::Password {
@ -208,7 +208,7 @@ pub(crate) async fn post(
password,
} => user_password_login(&mut txn, user, password).await?,
Credentials::Token { token } => token_login(&mut txn, &token).await?,
Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?,
_ => {
return Err(RouteError::Unsupported);
@ -225,14 +225,28 @@ pub(crate) async fn post(
None
};
let access_token = TokenType::CompatAccessToken.generate(&mut thread_rng());
let access_token =
add_compat_access_token(&mut txn, &session, access_token, expires_in).await?;
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
let access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
access_token,
expires_in,
)
.await?;
let refresh_token = if input.refresh_token {
let refresh_token = TokenType::CompatRefreshToken.generate(&mut thread_rng());
let refresh_token =
add_compat_refresh_token(&mut txn, &session, &access_token, refresh_token).await?;
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
let refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&access_token,
refresh_token,
)
.await?;
Some(refresh_token.token)
} else {
None
@ -251,11 +265,12 @@ pub(crate) async fn post(
async fn token_login(
txn: &mut Transaction<'_, Postgres>,
clock: &Clock,
token: &str,
) -> Result<CompatSession<PostgresqlBackend>, RouteError> {
let login = get_compat_sso_login_by_token(&mut *txn, token).await?;
let now = Utc::now();
let now = clock.now();
match login.state {
CompatSsoLoginState::Pending => {
tracing::error!(
@ -285,7 +300,7 @@ async fn token_login(
}
}
let login = mark_compat_sso_login_as_exchanged(&mut *txn, login).await?;
let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
match login.state {
CompatSsoLoginState::Exchanged { session, .. } => Ok(session),
@ -298,8 +313,10 @@ async fn user_password_login(
username: String,
password: String,
) -> Result<CompatSession<PostgresqlBackend>, RouteError> {
let device = Device::generate(&mut thread_rng());
let session = compat_login(txn, &username, &password, device)
let (clock, mut rng) = crate::rng_and_clock()?;
let device = Device::generate(&mut rng);
let session = compat_login(txn, &mut rng, &clock, &username, &password, device)
.await
.map_err(|_| RouteError::LoginFailed)?;

View File

@ -20,7 +20,7 @@ use axum::{
response::{Html, IntoResponse, Redirect, Response},
};
use axum_extra::extract::PrivateCookieJar;
use chrono::{Duration, Utc};
use chrono::Duration;
use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
FancyError, SessionInfoExt,
@ -28,9 +28,11 @@ use mas_axum_utils::{
use mas_data_model::Device;
use mas_keystore::Encrypter;
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
use mas_storage::{
compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id},
Clock,
};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use rand::thread_rng;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use ulid::Ulid;
@ -56,6 +58,7 @@ pub async fn get(
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
@ -95,7 +98,7 @@ pub async fn get(
let login = get_compat_sso_login_by_id(&mut conn, id).await?;
// Bail out if that login session is more than 30min old
if Utc::now() > login.created_at + Duration::minutes(30) {
if clock.now() > login.created_at + Duration::minutes(30) {
let ctx = ErrorContext::new()
.with_code("compat_sso_login_expired")
.with_description("This login session expired.".to_owned());
@ -121,6 +124,7 @@ pub async fn post(
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
@ -160,7 +164,7 @@ pub async fn post(
let login = get_compat_sso_login_by_id(&mut txn, id).await?;
// Bail out if that login session is more than 30min old
if Utc::now() > login.created_at + Duration::minutes(30) {
if clock.now() > login.created_at + Duration::minutes(30) {
let ctx = ErrorContext::new()
.with_code("compat_sso_login_expired")
.with_description("This login session expired.".to_owned());
@ -186,8 +190,9 @@ pub async fn post(
redirect_uri
};
let device = Device::generate(&mut thread_rng());
let _login = fullfill_compat_sso_login(&mut txn, session.user, login, device).await?;
let device = Device::generate(&mut rng);
let _login =
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?;
txn.commit().await?;

View File

@ -20,10 +20,7 @@ use axum::{
use hyper::StatusCode;
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::compat::insert_compat_sso_login;
use rand::{
distributions::{Alphanumeric, DistString},
thread_rng,
};
use rand::distributions::{Alphanumeric, DistString};
use serde::Deserialize;
use serde_with::serde;
use sqlx::PgPool;
@ -70,6 +67,8 @@ pub async fn get(
State(url_builder): State<UrlBuilder>,
Query(params): Query<Params>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
// Check the redirectUrl parameter
let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?;
let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?;
@ -84,9 +83,9 @@ pub async fn get(
return Err(RouteError::InvalidRedirectUrl);
}
let token = Alphanumeric.sample_string(&mut thread_rng(), 32);
let token = Alphanumeric.sample_string(&mut rng, 32);
let mut conn = pool.acquire().await?;
let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?;
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action)))
}

View File

@ -16,7 +16,7 @@ use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
use headers::{authorization::Bearer, Authorization};
use hyper::StatusCode;
use mas_data_model::{TokenFormatError, TokenType};
use mas_storage::compat::compat_logout;
use mas_storage::{compat::compat_logout, Clock};
use sqlx::PgPool;
use super::MatrixError;
@ -67,6 +67,7 @@ pub(crate) async fn post(
State(pool): State<PgPool>,
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
@ -78,7 +79,7 @@ pub(crate) async fn post(
return Err(RouteError::InvalidAuthorization);
}
compat_logout(&mut conn, token)
compat_logout(&mut conn, &clock, token)
.await
.map_err(|_| RouteError::LogoutFailed)?;

View File

@ -20,7 +20,6 @@ use mas_storage::compat::{
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
expire_compat_access_token, lookup_active_compat_refresh_token, CompatRefreshTokenLookupError,
};
use rand::thread_rng;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DurationMilliSeconds};
use sqlx::PgPool;
@ -98,6 +97,7 @@ pub(crate) async fn post(
State(pool): State<PgPool>,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let token_type = TokenType::check(&input.refresh_token)?;
@ -109,23 +109,31 @@ pub(crate) async fn post(
let (refresh_token, access_token, session) =
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?;
let (new_refresh_token_str, new_access_token_str) = {
let mut rng = thread_rng();
(
TokenType::CompatRefreshToken.generate(&mut rng),
TokenType::CompatAccessToken.generate(&mut rng),
)
};
let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
let expires_in = Duration::minutes(5);
let new_access_token =
add_compat_access_token(&mut txn, &session, new_access_token_str, Some(expires_in)).await?;
let new_refresh_token =
add_compat_refresh_token(&mut txn, &session, &new_access_token, new_refresh_token_str)
let new_access_token = add_compat_access_token(
&mut txn,
&mut rng,
&clock,
&session,
new_access_token_str,
Some(expires_in),
)
.await?;
let new_refresh_token = add_compat_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
&new_access_token,
new_refresh_token_str,
)
.await?;
consume_compat_refresh_token(&mut txn, refresh_token).await?;
expire_compat_access_token(&mut txn, access_token).await?;
consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
expire_compat_access_token(&mut txn, &clock, access_token).await?;
txn.commit().await?;

View File

@ -21,6 +21,7 @@
use std::{convert::Infallible, sync::Arc, time::Duration};
use anyhow::Context;
use axum::{
body::HttpBody,
extract::FromRef,
@ -36,6 +37,7 @@ use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_router::{Route, UrlBuilder};
use mas_templates::{ErrorContext, Templates};
use rand::SeedableRng;
use sqlx::PgPool;
use tower::util::AndThenLayer;
use tower_http::cors::{Any, CorsLayer};
@ -356,3 +358,15 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
policy_factory,
}))
}
// XXX: that should be moved somewhere else
fn rng_and_clock() -> Result<(mas_storage::Clock, rand_chacha::ChaChaRng), anyhow::Error> {
let clock = mas_storage::Clock::default();
// This rng is used to source the local rng
#[allow(clippy::disallowed_methods)]
let rng = rand::thread_rng();
let rng = rand_chacha::ChaChaRng::from_rng(rng).context("Failed to seed RNG")?;
Ok((clock, rng))
}

View File

@ -190,6 +190,8 @@ pub(crate) async fn complete(
policy_factory: &PolicyFactory,
mut txn: Transaction<'_, Postgres>,
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
let (clock, mut rng) = crate::rng_and_clock()?;
// Verify that the grant is in a pending stage
if !grant.stage.is_pending() {
return Err(GrantCompletionError::NotPending);
@ -226,7 +228,7 @@ pub(crate) async fn complete(
}
// All good, let's start the session
let session = derive_session(&mut txn, &grant, browser_session).await?;
let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?;
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?;

View File

@ -37,7 +37,7 @@ use oauth2_types::{
requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
response_type::ResponseType,
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rand::{distributions::Alphanumeric, Rng};
use serde::Deserialize;
use sqlx::PgPool;
use thiserror::Error;
@ -159,6 +159,7 @@ pub(crate) async fn get(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(params): Form<Params>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
// First, figure out what client it is
@ -265,7 +266,7 @@ pub(crate) async fn get(
}
// 32 random alphanumeric characters, about 190bit of entropy
let code: String = thread_rng()
let code: String = (&mut rng)
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
@ -296,6 +297,8 @@ pub(crate) async fn get(
let grant = new_authorization_grant(
&mut txn,
&mut rng,
&clock,
client,
redirect_uri.clone(),
params.auth.scope,

View File

@ -119,6 +119,7 @@ pub(crate) async fn post(
Path(grant_id): Path<Ulid>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool
.begin()
.await
@ -163,6 +164,8 @@ pub(crate) async fn post(
.collect();
insert_client_consent(
&mut txn,
&mut rng,
&clock,
&session.user,
&grant.client,
&scope_without_device,

View File

@ -28,6 +28,7 @@ use mas_storage::{
client::ClientFetchError,
refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError},
},
Clock,
};
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
use sqlx::PgPool;
@ -158,6 +159,7 @@ pub(crate) async fn post(
State(encrypter): State<Encrypter>,
client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let clock = Clock::default();
let mut conn = pool.acquire().await?;
let client = client_authorization.credentials.fetch(&mut conn).await?;
@ -227,7 +229,8 @@ pub(crate) async fn post(
}
}
TokenType::CompatAccessToken => {
let (token, session) = lookup_active_compat_access_token(&mut conn, token).await?;
let (token, session) =
lookup_active_compat_access_token(&mut conn, &clock, token).await?;
let device_scope = session.device.to_scope_token();
let scope = [device_scope].into_iter().collect();

View File

@ -50,7 +50,6 @@ use oauth2_types::{
},
scope,
};
use rand::thread_rng;
use serde::Serialize;
use serde_with::{serde_as, skip_serializing_none};
use sqlx::{PgPool, Postgres, Transaction};
@ -235,12 +234,13 @@ async fn authorization_code_grant(
url_builder: &UrlBuilder,
mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
// TODO: there is a bunch of unnecessary cloning here
// TODO: handle "not found" cases
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code).await?;
// TODO: that's not a timestamp from the DB. Let's assume they are in sync
let now = Utc::now();
let now = clock.now();
let session = match authz_grant.stage {
AuthorizationGrantStage::Cancelled { cancelled_at } => {
@ -257,7 +257,7 @@ async fn authorization_code_grant(
// Ending the session if the token was already exchanged more than 20s ago
if now - exchanged_at > Duration::seconds(20) {
debug!("Ending potentially compromised session");
end_oauth_session(&mut txn, session).await?;
end_oauth_session(&mut txn, &clock, session).await?;
txn.commit().await?;
}
@ -303,22 +303,32 @@ async fn authorization_code_grant(
let browser_session = &session.browser_session;
let ttl = Duration::minutes(5);
let (access_token_str, refresh_token_str) = {
let mut rng = thread_rng();
(
TokenType::AccessToken.generate(&mut rng),
TokenType::RefreshToken.generate(&mut rng),
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let access_token = add_access_token(
&mut txn,
&mut rng,
&clock,
session,
access_token_str.clone(),
ttl,
)
};
.await?;
let access_token = add_access_token(&mut txn, session, access_token_str.clone(), ttl).await?;
let _refresh_token =
add_refresh_token(&mut txn, session, access_token, refresh_token_str.clone()).await?;
let _refresh_token = add_refresh_token(
&mut txn,
&mut rng,
&clock,
session,
access_token,
refresh_token_str.clone(),
)
.await?;
let id_token = if session.scope.contains(&scope::OPENID) {
let mut claims = HashMap::new();
let now = Utc::now();
let now = clock.now();
claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?;
claims::SUB.insert(&mut claims, &browser_session.user.sub)?;
claims::AUD.insert(&mut claims, client.client_id.clone())?;
@ -346,7 +356,7 @@ async fn authorization_code_grant(
let signer = key.params().signing_key_for_alg(&alg)?;
let header = JsonWebSignatureHeader::new(alg)
.with_kid(key.kid().context("key has no `kid` for some reason")?);
let id_token = Jwt::sign(header, claims, &signer)?;
let id_token = Jwt::sign_with_rng(&mut rng, header, claims, &signer)?;
Some(id_token.as_str().to_owned())
} else {
@ -362,7 +372,7 @@ async fn authorization_code_grant(
params = params.with_id_token(id_token);
}
exchange_grant(&mut txn, authz_grant).await?;
exchange_grant(&mut txn, &clock, authz_grant).await?;
txn.commit().await?;
@ -374,6 +384,8 @@ async fn refresh_token_grant(
client: &Client<PostgresqlBackend>,
mut txn: Transaction<'_, Postgres>,
) -> Result<AccessTokenResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let (refresh_token, session) =
lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?;
@ -383,24 +395,33 @@ async fn refresh_token_grant(
}
let ttl = Duration::minutes(5);
let (access_token_str, refresh_token_str) = {
let mut rng = thread_rng();
(
TokenType::AccessToken.generate(&mut rng),
TokenType::RefreshToken.generate(&mut rng),
let access_token_str = TokenType::AccessToken.generate(&mut rng);
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
let new_access_token = add_access_token(
&mut txn,
&mut rng,
&clock,
&session,
access_token_str.clone(),
ttl,
)
};
.await?;
let new_access_token =
add_access_token(&mut txn, &session, access_token_str.clone(), ttl).await?;
let new_refresh_token = add_refresh_token(
&mut txn,
&mut rng,
&clock,
&session,
new_access_token,
refresh_token_str,
)
.await?;
let new_refresh_token =
add_refresh_token(&mut txn, &session, new_access_token, refresh_token_str).await?;
consume_refresh_token(&mut txn, &refresh_token).await?;
consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
if let Some(access_token) = refresh_token.access_token {
revoke_access_token(&mut txn, access_token).await?;
revoke_access_token(&mut txn, &clock, access_token).await?;
}
let params = AccessTokenResponse::new(access_token_str)

View File

@ -54,6 +54,7 @@ pub async fn get(
user_authorization: UserAuthorization,
) -> Result<Response, FancyError> {
// TODO: error handling
let (_clock, mut rng) = crate::rng_and_clock()?;
let mut conn = pool.acquire().await?;
let session = user_authorization.protected(&mut conn).await?;
@ -88,7 +89,7 @@ pub async fn get(
user_info,
};
let token = Jwt::sign(header, user_info, &signer)?;
let token = Jwt::sign_with_rng(&mut rng, header, user_info, &signer)?;
Ok(JwtResponse(token).into_response())
} else {
Ok(Json(user_info).into_response())

View File

@ -72,6 +72,7 @@ pub(crate) async fn post(
Query(query): Query<OptionalPostAuthAction>,
Form(form): Form<ProtectedForm<EmailForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?;
@ -86,14 +87,22 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
let user_email = add_user_email(&mut txn, &session.user, form.email).await?;
let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.data);
let next = if let Some(action) = query.post_auth_action {
next.and_then(action)
} else {
next
};
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?;
start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?;

View File

@ -32,10 +32,10 @@ use mas_storage::{
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
remove_user_email, set_user_email_as_primary,
},
PostgresqlBackend,
Clock, PostgresqlBackend,
};
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
use rand::{distributions::Uniform, thread_rng, Rng};
use rand::{distributions::Uniform, Rng};
use serde::Deserialize;
use sqlx::{PgExecutor, PgPool};
use tracing::info;
@ -93,17 +93,26 @@ async fn render(
async fn start_email_verification(
mailer: &Mailer,
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User<PostgresqlBackend>,
user_email: UserEmail<PostgresqlBackend>,
) -> anyhow::Result<()> {
// First, generate a code
let range = Uniform::<u32>::from(0..1_000_000);
let code = thread_rng().sample(range).to_string();
let code = rng.sample(range).to_string();
let address: Address = user_email.email.parse()?;
let verification =
add_user_email_verification_code(executor, user_email, Duration::hours(8), code).await?;
let verification = add_user_email_verification_code(
executor,
&mut rng,
clock,
user_email,
Duration::hours(8),
code,
)
.await?;
// And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address);
@ -126,6 +135,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ManagementForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
@ -143,9 +153,18 @@ pub(crate) async fn post(
match form {
ManagementForm::Add { email } => {
let user_email = add_user_email(&mut txn, &session.user, email).await?;
let user_email =
add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.data);
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?;
start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?;
return Ok((cookie_jar, next.go()).into_response());
}
@ -154,7 +173,15 @@ pub(crate) async fn post(
let user_email = get_user_email(&mut txn, &session.user, id).await?;
let next = mas_router::AccountVerifyEmail::new(user_email.data);
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?;
start_email_verification(
&mailer,
&mut txn,
&mut rng,
&clock,
&session.user,
user_email,
)
.await?;
txn.commit().await?;
return Ok((cookie_jar, next.go()).into_response());
}

View File

@ -23,9 +23,12 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::user::{
use mas_storage::{
user::{
consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code,
mark_user_email_as_verified, set_user_email_as_primary,
},
Clock,
};
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
use serde::Deserialize;
@ -84,6 +87,7 @@ pub(crate) async fn post(
Path(id): Path<Ulid>,
Form(form): Form<ProtectedForm<CodeForm>>,
) -> Result<Response, FancyError> {
let clock = Clock::default();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?;
@ -105,12 +109,13 @@ pub(crate) async fn post(
}
// TODO: make those 8 hours configurable
let verification = lookup_user_email_verification_code(&mut txn, email, &form.code).await?;
let verification =
lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?;
// TODO: display nice errors if the code was already consumed or expired
let verification = consume_email_verification(&mut txn, verification).await?;
let verification = consume_email_verification(&mut txn, &clock, verification).await?;
let _email = mark_user_email_as_verified(&mut txn, verification.email).await?;
let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?;
txn.commit().await?;

View File

@ -81,6 +81,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ChangeForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?;
@ -96,7 +97,14 @@ pub(crate) async fn post(
return Ok((cookie_jar, login.go()).into_response());
};
authenticate_session(&mut txn, &mut session, &form.current_password).await?;
authenticate_session(
&mut txn,
&mut rng,
&clock,
&mut session,
&form.current_password,
)
.await?;
// TODO: display nice form errors
if form.new_password != form.new_password_confirm {
@ -104,7 +112,15 @@ pub(crate) async fn post(
}
let phf = Argon2::default();
set_password(&mut txn, phf, &session.user, &form.new_password).await?;
set_password(
&mut txn,
&mut rng,
&clock,
phf,
&session.user,
&form.new_password,
)
.await?;
let reply = render(templates.clone(), session, cookie_jar).await?;

View File

@ -80,6 +80,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut conn = pool.acquire().await?;
let form = cookie_jar.verify_form(form)?;
@ -114,7 +115,7 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response());
}
match login(&mut conn, &form.username, &form.password).await {
match login(&mut conn, &mut rng, &clock, &form.username, &form.password).await {
Ok(session_info) => {
let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.go_next();

View File

@ -23,7 +23,7 @@ use mas_axum_utils::{
};
use mas_keystore::Encrypter;
use mas_router::{PostAuthAction, Route};
use mas_storage::user::end_session;
use mas_storage::{user::end_session, Clock};
use sqlx::PgPool;
pub(crate) async fn post(
@ -31,6 +31,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
) -> Result<impl IntoResponse, FancyError> {
let clock = Clock::default();
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?;
@ -40,7 +41,7 @@ pub(crate) async fn post(
let maybe_session = session_info.load_session(&mut txn).await?;
if let Some(session) = maybe_session {
end_session(&mut txn, &session).await?;
end_session(&mut txn, &clock, &session).await?;
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
}

View File

@ -80,6 +80,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?;
@ -98,7 +99,7 @@ pub(crate) async fn post(
};
// TODO: recover from errors here
authenticate_session(&mut txn, &mut session, &form.password).await?;
authenticate_session(&mut txn, &mut rng, &clock, &mut session, &form.password).await?;
let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await?;

View File

@ -39,7 +39,7 @@ use mas_templates::{
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
TemplateContext, Templates, ToFormState,
};
use rand::{distributions::Uniform, thread_rng, Rng};
use rand::{distributions::Uniform, Rng};
use serde::{Deserialize, Serialize};
use sqlx::{PgConnection, PgPool};
@ -87,6 +87,7 @@ pub(crate) async fn get(
}
}
#[allow(clippy::too_many_lines)]
pub(crate) async fn post(
State(mailer): State<Mailer>,
State(policy_factory): State<Arc<PolicyFactory>>,
@ -96,6 +97,7 @@ pub(crate) async fn post(
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> {
let (clock, mut rng) = crate::rng_and_clock()?;
let mut txn = pool.begin().await?;
let form = cookie_jar.verify_form(form)?;
@ -180,18 +182,34 @@ pub(crate) async fn post(
}
let pfh = Argon2::default();
let user = register_user(&mut txn, pfh, &form.username, &form.password).await?;
let user = register_user(
&mut txn,
&mut rng,
&clock,
pfh,
&form.username,
&form.password,
)
.await?;
let user_email = add_user_email(&mut txn, &user, form.email).await?;
let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?;
// First, generate a code
let range = Uniform::<u32>::from(0..1_000_000);
let code = thread_rng().sample(range).to_string();
let code = rng.sample(range);
let code = format!("{code:06}");
let address: Address = user_email.email.parse()?;
let verification =
add_user_email_verification_code(&mut txn, user_email, Duration::hours(8), code).await?;
let verification = add_user_email_verification_code(
&mut txn,
&mut rng,
&clock,
user_email,
Duration::hours(8),
code,
)
.await?;
// And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address);
@ -203,7 +221,7 @@ pub(crate) async fn post(
let next = mas_router::AccountVerifyEmail::new(verification.email.data)
.and_maybe(query.post_auth_action);
let session = start_session(&mut txn, user).await?;
let session = start_session(&mut txn, &mut rng, &clock, user).await?;
txn.commit().await?;

View File

@ -309,6 +309,7 @@ impl<T> Jwt<'static, T> {
S: Signature,
T: Serialize,
{
#[allow(clippy::disallowed_methods)]
Self::sign_with_rng(thread_rng(), header, payload, key)
}
@ -357,6 +358,7 @@ impl<T> Jwt<'static, T> {
#[cfg(test)]
mod tests {
#![allow(clippy::disallowed_methods)]
use mas_iana::jose::JsonWebSignatureAlg;
use rand::thread_rng;

View File

@ -19,6 +19,7 @@ tracing = "0.1.37"
argon2 = { version = "0.4.1", features = ["password-hash"] }
password-hash = { version = "0.4.2", features = ["std"] }
rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.1"
ulid = { version = "1.0.0", features = ["uuid", "serde"] }

View File

@ -19,6 +19,7 @@ use mas_data_model::{
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
Device, User, UserEmail,
};
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error;
use tokio::task;
@ -27,7 +28,7 @@ use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{user::lookup_user_by_username, DatabaseInconsistencyError, PostgresqlBackend};
use crate::{user::lookup_user_by_username, Clock, DatabaseInconsistencyError, PostgresqlBackend};
struct CompatAccessTokenLookup {
compat_access_token_id: Uuid,
@ -67,6 +68,7 @@ impl CompatAccessTokenLookupError {
#[tracing::instrument(skip_all, err)]
pub async fn lookup_active_compat_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
token: &str,
) -> Result<
(
@ -112,7 +114,7 @@ pub async fn lookup_active_compat_access_token(
// Check for token expiration
if let Some(expires_at) = res.compat_access_token_expires_at {
if expires_at < Utc::now() {
if expires_at < clock.now() {
return Err(CompatAccessTokenLookupError::Expired { when: expires_at });
}
}
@ -311,7 +313,9 @@ pub async fn lookup_active_compat_refresh_token(
err(Display),
)]
pub async fn compat_login(
conn: impl Acquire<'_, Database = Postgres>,
conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
username: &str,
password: &str,
device: Device,
@ -348,8 +352,8 @@ pub async fn compat_login(
.instrument(tracing::info_span!("Verify hashed password"))
.await??;
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
sqlx::query!(
@ -392,12 +396,14 @@ pub async fn compat_login(
)]
pub async fn add_compat_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession<PostgresqlBackend>,
token: String,
expires_after: Option<Duration>,
) -> Result<CompatAccessToken<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
@ -436,9 +442,10 @@ pub async fn add_compat_access_token(
)]
pub async fn expire_compat_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: CompatAccessToken<PostgresqlBackend>,
) -> Result<(), anyhow::Error> {
let expires_at = Utc::now();
let expires_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_access_tokens
@ -474,12 +481,14 @@ pub async fn expire_compat_access_token(
)]
pub async fn add_compat_refresh_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &CompatSession<PostgresqlBackend>,
access_token: &CompatAccessToken<PostgresqlBackend>,
token: String,
) -> Result<CompatRefreshToken<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
sqlx::query!(
@ -514,9 +523,10 @@ pub async fn add_compat_refresh_token(
)]
pub async fn compat_logout(
executor: impl PgExecutor<'_>,
clock: &Clock,
token: &str,
) -> Result<(), anyhow::Error> {
let finished_at = Utc::now();
let finished_at = clock.now();
// TODO: this does not check for token expiration
let compat_session_id = sqlx::query_scalar!(
r#"
@ -552,9 +562,10 @@ pub async fn compat_logout(
)]
pub async fn consume_compat_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: CompatRefreshToken<PostgresqlBackend>,
) -> Result<(), anyhow::Error> {
let consumed_at = Utc::now();
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE compat_refresh_tokens
@ -587,11 +598,13 @@ pub async fn consume_compat_refresh_token(
)]
pub async fn insert_compat_sso_login(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
login_token: String,
redirect_uri: Url,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
sqlx::query!(
@ -845,7 +858,9 @@ pub async fn get_compat_sso_login_by_token(
err(Display),
)]
pub async fn fullfill_compat_sso_login(
conn: impl Acquire<'_, Database = Postgres>,
conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
user: User<PostgresqlBackend>,
mut login: CompatSsoLogin<PostgresqlBackend>,
device: Device,
@ -856,8 +871,8 @@ pub async fn fullfill_compat_sso_login(
let mut txn = conn.begin().await.context("could not start transaction")?;
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user.id", tracing::field::display(user.data));
sqlx::query!(
@ -883,7 +898,7 @@ pub async fn fullfill_compat_sso_login(
finished_at: None,
};
let fulfilled_at = Utc::now();
let fulfilled_at = clock.now();
sqlx::query!(
r#"
UPDATE compat_sso_logins
@ -924,6 +939,7 @@ pub async fn fullfill_compat_sso_login(
)]
pub async fn mark_compat_sso_login_as_exchanged(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut login: CompatSsoLogin<PostgresqlBackend>,
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> {
let (fulfilled_at, session) = match login.state {
@ -934,7 +950,7 @@ pub async fn mark_compat_sso_login_as_exchanged(
_ => bail!("sso login in wrong state"),
};
let exchanged_at = Utc::now();
let exchanged_at = clock.now();
sqlx::query!(
r#"
UPDATE compat_sso_logins

View File

@ -15,7 +15,12 @@
//! Interactions with the database
#![forbid(unsafe_code)]
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
#![deny(
clippy::all,
clippy::str_to_string,
clippy::future_not_send,
rustdoc::broken_intra_doc_links
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
@ -23,12 +28,27 @@
clippy::module_name_repetitions
)]
use chrono::{DateTime, Utc};
use mas_data_model::{StorageBackend, StorageBackendMarker};
use serde::Serialize;
use sqlx::migrate::Migrator;
use thiserror::Error;
use ulid::Ulid;
#[derive(Default, Debug, Clone, Copy)]
pub struct Clock {
_private: (),
}
impl Clock {
#[must_use]
pub fn now(&self) -> DateTime<Utc> {
// This is the clock used elsewhere, it's fine to call Utc::now here
#[allow(clippy::disallowed_methods)]
Utc::now()
}
}
#[derive(Debug, Error)]
#[error("database query returned an inconsistent state")]
pub struct DatabaseInconsistencyError;

View File

@ -15,13 +15,14 @@
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
use rand::Rng;
use sqlx::{Acquire, PgExecutor, Postgres};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError};
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument(
skip_all,
@ -35,13 +36,15 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
)]
pub async fn add_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session<PostgresqlBackend>,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime(created_at.into());
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id));
@ -243,9 +246,10 @@ where
)]
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: AccessToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let revoked_at = Utc::now();
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
@ -266,9 +270,9 @@ pub async fn revoke_access_token(
}
}
pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result<u64> {
pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result<u64> {
// Cleanup token which expired more than 15 minutes ago
let threshold = Utc::now() - Duration::minutes(15);
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens

View File

@ -24,13 +24,14 @@ use mas_data_model::{
};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use super::client::lookup_client;
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument(
skip_all,
@ -43,6 +44,8 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
client: Client<PostgresqlBackend>,
redirect_uri: Url,
scope: Scope,
@ -67,8 +70,8 @@ pub async fn new_authorization_grant(
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
let code_str = code.as_ref().map(|c| &c.code);
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("grant.id", tracing::field::display(id));
sqlx::query!(
@ -504,11 +507,13 @@ pub async fn lookup_grant_by_code(
)]
pub async fn derive_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
grant: &AuthorizationGrant<PostgresqlBackend>,
browser_session: BrowserSession<PostgresqlBackend>,
) -> Result<Session<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("session.id", tracing::field::display(id));
sqlx::query!(
@ -623,9 +628,10 @@ pub async fn give_consent_to_grant(
)]
pub async fn exchange_grant(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut grant: AuthorizationGrant<PostgresqlBackend>,
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> {
let exchanged_at = Utc::now();
let exchanged_at = clock.now();
sqlx::query!(
r#"
UPDATE oauth2_authorization_grants

View File

@ -14,14 +14,14 @@
use std::str::FromStr;
use chrono::Utc;
use mas_data_model::{Client, User};
use oauth2_types::scope::{Scope, ScopeToken};
use rand::Rng;
use sqlx::PgExecutor;
use ulid::Ulid;
use uuid::Uuid;
use crate::PostgresqlBackend;
use crate::{Clock, PostgresqlBackend};
#[tracing::instrument(
skip_all,
@ -67,17 +67,19 @@ pub async fn fetch_client_consent(
)]
pub async fn insert_client_consent(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User<PostgresqlBackend>,
client: &Client<PostgresqlBackend>,
scope: &Scope,
) -> Result<(), anyhow::Error> {
let now = Utc::now();
let now = clock.now();
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
.iter()
.map(|token| {
(
token.to_string(),
Uuid::from(Ulid::from_datetime(now.into())),
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
)
})
.unzip();

View File

@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::Utc;
use mas_data_model::Session;
use sqlx::PgExecutor;
use uuid::Uuid;
use crate::PostgresqlBackend;
use crate::{Clock, PostgresqlBackend};
pub mod access_token;
pub mod authorization_grant;
@ -37,9 +36,10 @@ pub mod refresh_token;
)]
pub async fn end_oauth_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
session: Session<PostgresqlBackend>,
) -> Result<(), anyhow::Error> {
let finished_at = Utc::now();
let finished_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions

View File

@ -17,13 +17,14 @@ use chrono::{DateTime, Utc};
use mas_data_model::{
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError};
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument(
skip_all,
@ -38,12 +39,14 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
)]
pub async fn add_refresh_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session<PostgresqlBackend>,
access_token: AccessToken<PostgresqlBackend>,
refresh_token: String,
) -> anyhow::Result<RefreshToken<PostgresqlBackend>> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
sqlx::query!(
@ -263,9 +266,10 @@ pub async fn lookup_active_refresh_token(
)]
pub async fn consume_refresh_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
refresh_token: &RefreshToken<PostgresqlBackend>,
) -> Result<(), anyhow::Error> {
let consumed_at = Utc::now();
let consumed_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens

View File

@ -22,7 +22,7 @@ use mas_data_model::{
UserEmailVerificationState,
};
use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::thread_rng;
use rand::{CryptoRng, Rng};
use sqlx::{Acquire, PgExecutor, Postgres, Transaction};
use thiserror::Error;
use tokio::task;
@ -31,6 +31,7 @@ use ulid::Ulid;
use uuid::Uuid;
use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::Clock;
#[derive(Debug, Clone)]
struct UserLookup {
@ -68,7 +69,9 @@ pub enum LoginError {
err,
)]
pub async fn login(
conn: impl Acquire<'_, Database = Postgres>,
conn: impl Acquire<'_, Database = Postgres> + Send,
mut rng: impl Rng + Send,
clock: &Clock,
username: &str,
password: &str,
) -> Result<BrowserSession<PostgresqlBackend>, LoginError> {
@ -86,8 +89,8 @@ pub async fn login(
}
})?;
let mut session = start_session(&mut txn, user).await?;
authenticate_session(&mut txn, &mut session, password)
let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
.await
.map_err(|source| {
if matches!(source, AuthenticationError::Password { .. }) {
@ -230,10 +233,12 @@ pub async fn lookup_active_session(
)]
pub async fn start_session(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: User<PostgresqlBackend>,
) -> Result<BrowserSession<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_session.id", tracing::field::display(id));
sqlx::query!(
@ -301,13 +306,16 @@ pub enum AuthenticationError {
#[tracing::instrument(
skip_all,
fields(
session.id = %session.data,
user.id = %session.user.data
user.id = %session.user.data,
user_session.id = %session.data,
user_session_authentication.id,
),
err,
)]
pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &mut BrowserSession<PostgresqlBackend>,
password: &str,
) -> Result<(), AuthenticationError> {
@ -341,8 +349,13 @@ pub async fn authenticate_session(
.await??;
// That went well, let's insert the auth info
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record(
"user_session_authentication.id",
tracing::field::display(id),
);
sqlx::query!(
r#"
INSERT INTO user_session_authentications
@ -376,12 +389,14 @@ pub async fn authenticate_session(
)]
pub async fn register_user(
txn: &mut Transaction<'_, Postgres>,
phf: impl PasswordHasher,
mut rng: impl CryptoRng + Rng + Send,
clock: &Clock,
phf: impl PasswordHasher + Send,
username: &str,
password: &str,
) -> Result<User<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user.id", tracing::field::display(id));
sqlx::query!(
@ -405,7 +420,7 @@ pub async fn register_user(
primary_email: None,
};
set_password(txn.borrow_mut(), phf, &user, password).await?;
set_password(txn.borrow_mut(), &mut rng, clock, phf, &user, password).await?;
Ok(user)
}
@ -420,15 +435,17 @@ pub async fn register_user(
)]
pub async fn set_password(
executor: impl PgExecutor<'_>,
phf: impl PasswordHasher,
mut rng: impl CryptoRng + Rng + Send,
clock: &Clock,
phf: impl PasswordHasher + Send,
user: &User<PostgresqlBackend>,
password: &str,
) -> Result<(), anyhow::Error> {
let created_at = Utc::now();
let created_at = clock.now();
let id = Ulid::from_datetime(created_at.into());
tracing::Span::current().record("user_password.id", tracing::field::display(id));
let salt = SaltString::generate(thread_rng());
let salt = SaltString::generate(&mut rng);
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
sqlx::query_scalar!(
@ -456,9 +473,10 @@ pub async fn set_password(
)]
pub async fn end_session(
executor: impl PgExecutor<'_>,
clock: &Clock,
session: &BrowserSession<PostgresqlBackend>,
) -> Result<(), anyhow::Error> {
let now = Utc::now();
let now = clock.now();
let res = sqlx::query!(
r#"
UPDATE user_sessions
@ -672,11 +690,13 @@ pub async fn get_user_email(
)]
pub async fn add_user_email(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
user: &User<PostgresqlBackend>,
email: String,
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id));
sqlx::query!(
@ -842,9 +862,10 @@ pub async fn lookup_user_email_by_id(
)]
pub async fn mark_user_email_as_verified(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut email: UserEmail<PostgresqlBackend>,
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
let confirmed_at = Utc::now();
let confirmed_at = clock.now();
sqlx::query!(
r#"
UPDATE user_emails
@ -881,10 +902,11 @@ struct UserEmailConfirmationCodeLookup {
)]
pub async fn lookup_user_email_verification_code(
executor: impl PgExecutor<'_>,
clock: &Clock,
email: UserEmail<PostgresqlBackend>,
code: &str,
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
let now = Utc::now();
let now = clock.now();
let res = sqlx::query_as!(
UserEmailConfirmationCodeLookup,
@ -935,13 +957,14 @@ pub async fn lookup_user_email_verification_code(
)]
pub async fn consume_email_verification(
executor: impl PgExecutor<'_>,
clock: &Clock,
mut verification: UserEmailVerification<PostgresqlBackend>,
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
if !matches!(verification.state, UserEmailVerificationState::Valid) {
bail!("user email verification in wrong state");
}
let consumed_at = Utc::now();
let consumed_at = clock.now();
sqlx::query!(
r#"
@ -974,12 +997,14 @@ pub async fn consume_email_verification(
)]
pub async fn add_user_email_verification_code(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
email: UserEmail<PostgresqlBackend>,
max_age: chrono::Duration,
code: String,
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
let created_at = Utc::now();
let id = Ulid::from_datetime(created_at.into());
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
let expires_at = created_at + max_age;
@ -1013,23 +1038,27 @@ pub async fn add_user_email_verification_code(
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::*;
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_user_registration_and_login(pool: sqlx::PgPool) -> anyhow::Result<()> {
let clock = Clock::default();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let mut txn = pool.begin().await?;
let exists = username_exists(&mut txn, "john").await?;
assert!(!exists);
let hasher = Argon2::default();
let user = register_user(&mut txn, hasher, "john", "hunter2").await?;
let user = register_user(&mut txn, &mut rng, &clock, hasher, "john", "hunter2").await?;
assert_eq!(user.username, "john");
let exists = username_exists(&mut txn, "john").await?;
assert!(exists);
let session = login(&mut txn, "john", "hunter2").await?;
let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
assert_eq!(session.user.data, user.data);
let user2 = lookup_user_by_username(&mut txn, "john").await?;

View File

@ -14,13 +14,14 @@
//! Database-related tasks
use mas_storage::Clock;
use sqlx::{Pool, Postgres};
use tracing::{debug, error, info};
use super::Task;
#[derive(Clone)]
struct CleanupExpired(Pool<Postgres>);
struct CleanupExpired(Pool<Postgres>, Clock);
impl std::fmt::Debug for CleanupExpired {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -31,7 +32,7 @@ impl std::fmt::Debug for CleanupExpired {
#[async_trait::async_trait]
impl Task for CleanupExpired {
async fn run(&self) {
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0).await;
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await;
match res {
Ok(0) => {
debug!("no token to clean up");
@ -49,5 +50,6 @@ impl Task for CleanupExpired {
/// Cleanup expired tokens
#[must_use]
pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone {
CleanupExpired(pool.clone())
// XXX: the clock should come from somewhere else
CleanupExpired(pool.clone(), Clock::default())
}