You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Pass the rng and clock around
This commit is contained in:
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -2468,6 +2468,8 @@ dependencies = [
|
||||
"opentelemetry-semantic-conventions",
|
||||
"opentelemetry-zipkin",
|
||||
"prometheus",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"rustls",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
@ -2497,6 +2499,7 @@ dependencies = [
|
||||
"mas-keystore",
|
||||
"pem-rfc7468",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"rustls-pemfile",
|
||||
"schemars",
|
||||
"serde",
|
||||
@ -2567,6 +2570,7 @@ dependencies = [
|
||||
"mime",
|
||||
"oauth2-types",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
@ -2779,6 +2783,7 @@ dependencies = [
|
||||
"oauth2-types",
|
||||
"password-hash",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
@ -3068,7 +3073,7 @@ checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
|
||||
[[package]]
|
||||
name = "opa-wasm"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/matrix-org/rust-opa-wasm.git#325071ee8a2a7d18cc611365edd3235945a8cdf4"
|
||||
source = "git+https://github.com/matrix-org/rust-opa-wasm.git#f838595670747b0644b6bfd9829fca5d63bbee66"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64",
|
||||
@ -3079,6 +3084,7 @@ dependencies = [
|
||||
"md-5",
|
||||
"parse-size",
|
||||
"rand",
|
||||
"rayon-core",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -1,2 +1,11 @@
|
||||
msrv = "1.61.0"
|
||||
doc-valid-idents = ["OpenID", "OAuth", ".."]
|
||||
|
||||
disallowed-methods = [
|
||||
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },
|
||||
{ path = "chrono::Utc::now", reason = "source the current time from the clock instead" },
|
||||
]
|
||||
|
||||
disallowed-types = [
|
||||
"rand::OsRng",
|
||||
]
|
||||
|
@ -15,6 +15,7 @@
|
||||
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use thiserror::Error;
|
||||
@ -56,20 +57,20 @@ pub struct CsrfToken {
|
||||
|
||||
impl CsrfToken {
|
||||
/// Create a new token from a defined value valid for a specified duration
|
||||
fn new(token: [u8; 32], ttl: Duration) -> Self {
|
||||
let expiration = Utc::now() + ttl;
|
||||
fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
|
||||
let expiration = now + ttl;
|
||||
Self { expiration, token }
|
||||
}
|
||||
|
||||
/// Generate a new random token valid for a specified duration
|
||||
fn generate(ttl: Duration) -> Self {
|
||||
let token = rand::random();
|
||||
Self::new(token, ttl)
|
||||
fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
|
||||
let token = rng.gen();
|
||||
Self::new(token, now, ttl)
|
||||
}
|
||||
|
||||
/// Generate a new token with the same value but an up to date expiration
|
||||
fn refresh(self, ttl: Duration) -> Self {
|
||||
Self::new(self.token, ttl)
|
||||
fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
|
||||
Self::new(self.token, now, ttl)
|
||||
}
|
||||
|
||||
/// Get the value to include in HTML forms
|
||||
@ -88,8 +89,8 @@ impl CsrfToken {
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_expiration(self) -> Result<Self, CsrfError> {
|
||||
if Utc::now() < self.expiration {
|
||||
fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
|
||||
if now < self.expiration {
|
||||
Ok(self)
|
||||
} else {
|
||||
Err(CsrfError::Expired)
|
||||
@ -118,12 +119,18 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
|
||||
// XXX: the rng source and clock should come from somewhere else
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let now = Utc::now();
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let rng = thread_rng();
|
||||
|
||||
let new_token = cookie
|
||||
.decode()
|
||||
.ok()
|
||||
.and_then(|token: CsrfToken| token.verify_expiration().ok())
|
||||
.unwrap_or_else(|| CsrfToken::generate(Duration::hours(1)))
|
||||
.refresh(Duration::hours(1));
|
||||
.and_then(|token: CsrfToken| token.verify_expiration(now).ok())
|
||||
.unwrap_or_else(|| CsrfToken::generate(now, rng, Duration::hours(1)))
|
||||
.refresh(now, Duration::hours(1));
|
||||
|
||||
let cookie = cookie.encode(&new_token);
|
||||
let jar = jar.add(cookie);
|
||||
@ -131,9 +138,13 @@ impl<K> CsrfExt for PrivateCookieJar<K> {
|
||||
}
|
||||
|
||||
fn verify_form<T>(&self, form: ProtectedForm<T>) -> Result<T, CsrfError> {
|
||||
// XXX: the clock should come from somewhere else
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let now = Utc::now();
|
||||
|
||||
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
|
||||
let token: CsrfToken = cookie.decode()?;
|
||||
let token = token.verify_expiration()?;
|
||||
let token = token.verify_expiration(now)?;
|
||||
token.verify_form_value(&form.csrf)?;
|
||||
Ok(form.inner)
|
||||
}
|
||||
|
@ -6,23 +6,25 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6.0-rc.2"
|
||||
tokio = { version = "1.21.2", features = ["full"] }
|
||||
futures-util = "0.3.25"
|
||||
anyhow = "1.0.66"
|
||||
argon2 = { version = "0.4.1", features = ["password-hash"] }
|
||||
atty = "0.2.14"
|
||||
axum = "0.6.0-rc.2"
|
||||
clap = { version = "4.0.18", features = ["derive"] }
|
||||
dotenv = "0.15.0"
|
||||
tower = { version = "0.4.13", features = ["full"] }
|
||||
futures-util = "0.3.25"
|
||||
hyper = { version = "0.14.22", features = ["full"] }
|
||||
serde_yaml = "0.9.14"
|
||||
serde_json = "1.0.87"
|
||||
url = "2.3.1"
|
||||
argon2 = { version = "0.4.1", features = ["password-hash"] }
|
||||
watchman_client = "0.8.0"
|
||||
atty = "0.2.14"
|
||||
listenfd = "1.0.0"
|
||||
rustls = "0.20.7"
|
||||
itertools = "0.10.5"
|
||||
listenfd = "1.0.0"
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
rustls = "0.20.7"
|
||||
serde_json = "1.0.87"
|
||||
serde_yaml = "0.9.14"
|
||||
tokio = { version = "1.21.2", features = ["full"] }
|
||||
tower = { version = "0.4.13", features = ["full"] }
|
||||
url = "2.3.1"
|
||||
watchman_client = "0.8.0"
|
||||
|
||||
tracing = "0.1.37"
|
||||
tracing-appender = "0.2.2"
|
||||
|
@ -20,7 +20,9 @@ use mas_storage::{
|
||||
user::{
|
||||
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
|
||||
},
|
||||
Clock,
|
||||
};
|
||||
use rand::SeedableRng;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -51,14 +53,17 @@ enum Subcommand {
|
||||
impl Options {
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
let clock = Clock::default();
|
||||
|
||||
match &self.subcommand {
|
||||
SC::Register { username, password } => {
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = config.connect().await?;
|
||||
let mut txn = pool.begin().await?;
|
||||
let hasher = Argon2::default();
|
||||
let rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
|
||||
let user = register_user(&mut txn, hasher, username, password).await?;
|
||||
let user = register_user(&mut txn, rng, &clock, hasher, username, password).await?;
|
||||
txn.commit().await?;
|
||||
info!(?user, "User registered");
|
||||
|
||||
@ -76,7 +81,7 @@ impl Options {
|
||||
|
||||
let user = lookup_user_by_username(&mut txn, username).await?;
|
||||
let email = lookup_user_email(&mut txn, &user, email).await?;
|
||||
let email = mark_user_email_as_verified(&mut txn, email).await?;
|
||||
let email = mark_user_email_as_verified(&mut txn, &clock, email).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
info!(?email, "Email marked as verified");
|
||||
|
@ -28,6 +28,7 @@ lettre = { version = "0.10.1", default-features = false, features = ["serde", "b
|
||||
pem-rfc7468 = "0.6.0"
|
||||
rustls-pemfile = "1.0.1"
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
|
||||
indoc = "1.0.7"
|
||||
|
||||
|
@ -20,7 +20,7 @@ use mas_jose::jwk::{JsonWebKey, JsonWebKeySet};
|
||||
use mas_keystore::{Encrypter, Keystore, PrivateKey};
|
||||
use rand::{
|
||||
distributions::{Alphanumeric, DistString},
|
||||
thread_rng,
|
||||
thread_rng, SeedableRng,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -139,64 +139,72 @@ impl ConfigurationSection<'_> for SecretsConfig {
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
// XXX: that RNG should come from somewhere else
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let mut rng = rand_chacha::ChaChaRng::from_rng(thread_rng())?;
|
||||
|
||||
info!("Generating keys...");
|
||||
|
||||
let span = tracing::info_span!("rsa");
|
||||
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
|
||||
let rsa_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let ret = PrivateKey::generate_rsa(thread_rng()).unwrap();
|
||||
let ret = PrivateKey::generate_rsa(key_rng).unwrap();
|
||||
info!("Done generating RSA key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")?;
|
||||
let rsa_key = KeyConfig {
|
||||
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
|
||||
kid: Alphanumeric.sample_string(&mut rng, 10),
|
||||
password: None,
|
||||
key: KeyOrFile::Key(rsa_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
|
||||
};
|
||||
|
||||
let span = tracing::info_span!("ec_p256");
|
||||
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
|
||||
let ec_p256_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let ret = PrivateKey::generate_ec_p256(thread_rng());
|
||||
let ret = PrivateKey::generate_ec_p256(key_rng);
|
||||
info!("Done generating EC P-256 key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")?;
|
||||
let ec_p256_key = KeyConfig {
|
||||
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
|
||||
kid: Alphanumeric.sample_string(&mut rng, 10),
|
||||
password: None,
|
||||
key: KeyOrFile::Key(ec_p256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
|
||||
};
|
||||
|
||||
let span = tracing::info_span!("ec_p384");
|
||||
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
|
||||
let ec_p384_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let ret = PrivateKey::generate_ec_p384(thread_rng());
|
||||
let ret = PrivateKey::generate_ec_p384(key_rng);
|
||||
info!("Done generating EC P-256 key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")?;
|
||||
let ec_p384_key = KeyConfig {
|
||||
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
|
||||
kid: Alphanumeric.sample_string(&mut rng, 10),
|
||||
password: None,
|
||||
key: KeyOrFile::Key(ec_p384_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
|
||||
};
|
||||
|
||||
let span = tracing::info_span!("ec_k256");
|
||||
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
|
||||
let ec_k256_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let ret = PrivateKey::generate_ec_k256(thread_rng());
|
||||
let ret = PrivateKey::generate_ec_k256(key_rng);
|
||||
info!("Done generating EC secp256k1 key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")?;
|
||||
let ec_k256_key = KeyConfig {
|
||||
kid: Alphanumeric.sample_string(&mut thread_rng(), 10),
|
||||
kid: Alphanumeric.sample_string(&mut rng, 10),
|
||||
password: None,
|
||||
key: KeyOrFile::Key(ec_k256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
|
||||
};
|
||||
|
@ -263,6 +263,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_generate_and_check() {
|
||||
const COUNT: usize = 500; // Generate 500 of each token type
|
||||
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let mut rng = thread_rng();
|
||||
|
||||
for t in [
|
||||
|
@ -44,6 +44,7 @@ chrono = { version = "0.4.22", features = ["serde"] }
|
||||
url = { version = "2.3.1", features = ["serde"] }
|
||||
mime = "0.3.16"
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
headers = "0.3.8"
|
||||
ulid = "1.0.0"
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use chrono::{Duration, Utc};
|
||||
use chrono::Duration;
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
|
||||
use mas_storage::{
|
||||
@ -22,9 +22,8 @@ use mas_storage::{
|
||||
get_compat_sso_login_by_token, mark_compat_sso_login_as_exchanged,
|
||||
CompatSsoLoginLookupError,
|
||||
},
|
||||
PostgresqlBackend,
|
||||
Clock, PostgresqlBackend,
|
||||
};
|
||||
use rand::thread_rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, skip_serializing_none, DurationMilliSeconds};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
@ -201,6 +200,7 @@ pub(crate) async fn post(
|
||||
State(homeserver): State<MatrixHomeserver>,
|
||||
Json(input): Json<RequestBody>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
let session = match input.credentials {
|
||||
Credentials::Password {
|
||||
@ -208,7 +208,7 @@ pub(crate) async fn post(
|
||||
password,
|
||||
} => user_password_login(&mut txn, user, password).await?,
|
||||
|
||||
Credentials::Token { token } => token_login(&mut txn, &token).await?,
|
||||
Credentials::Token { token } => token_login(&mut txn, &clock, &token).await?,
|
||||
|
||||
_ => {
|
||||
return Err(RouteError::Unsupported);
|
||||
@ -225,14 +225,28 @@ pub(crate) async fn post(
|
||||
None
|
||||
};
|
||||
|
||||
let access_token = TokenType::CompatAccessToken.generate(&mut thread_rng());
|
||||
let access_token =
|
||||
add_compat_access_token(&mut txn, &session, access_token, expires_in).await?;
|
||||
let access_token = TokenType::CompatAccessToken.generate(&mut rng);
|
||||
let access_token = add_compat_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
access_token,
|
||||
expires_in,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let refresh_token = if input.refresh_token {
|
||||
let refresh_token = TokenType::CompatRefreshToken.generate(&mut thread_rng());
|
||||
let refresh_token =
|
||||
add_compat_refresh_token(&mut txn, &session, &access_token, refresh_token).await?;
|
||||
let refresh_token = TokenType::CompatRefreshToken.generate(&mut rng);
|
||||
let refresh_token = add_compat_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&access_token,
|
||||
refresh_token,
|
||||
)
|
||||
.await?;
|
||||
Some(refresh_token.token)
|
||||
} else {
|
||||
None
|
||||
@ -251,11 +265,12 @@ pub(crate) async fn post(
|
||||
|
||||
async fn token_login(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<CompatSession<PostgresqlBackend>, RouteError> {
|
||||
let login = get_compat_sso_login_by_token(&mut *txn, token).await?;
|
||||
|
||||
let now = Utc::now();
|
||||
let now = clock.now();
|
||||
match login.state {
|
||||
CompatSsoLoginState::Pending => {
|
||||
tracing::error!(
|
||||
@ -285,7 +300,7 @@ async fn token_login(
|
||||
}
|
||||
}
|
||||
|
||||
let login = mark_compat_sso_login_as_exchanged(&mut *txn, login).await?;
|
||||
let login = mark_compat_sso_login_as_exchanged(&mut *txn, clock, login).await?;
|
||||
|
||||
match login.state {
|
||||
CompatSsoLoginState::Exchanged { session, .. } => Ok(session),
|
||||
@ -298,8 +313,10 @@ async fn user_password_login(
|
||||
username: String,
|
||||
password: String,
|
||||
) -> Result<CompatSession<PostgresqlBackend>, RouteError> {
|
||||
let device = Device::generate(&mut thread_rng());
|
||||
let session = compat_login(txn, &username, &password, device)
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
let device = Device::generate(&mut rng);
|
||||
let session = compat_login(txn, &mut rng, &clock, &username, &password, device)
|
||||
.await
|
||||
.map_err(|_| RouteError::LoginFailed)?;
|
||||
|
||||
|
@ -20,7 +20,7 @@ use axum::{
|
||||
response::{Html, IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use chrono::{Duration, Utc};
|
||||
use chrono::Duration;
|
||||
use mas_axum_utils::{
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
FancyError, SessionInfoExt,
|
||||
@ -28,9 +28,11 @@ use mas_axum_utils::{
|
||||
use mas_data_model::Device;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::{CompatLoginSsoAction, PostAuthAction, Route};
|
||||
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
|
||||
use mas_storage::{
|
||||
compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id},
|
||||
Clock,
|
||||
};
|
||||
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
|
||||
use rand::thread_rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use ulid::Ulid;
|
||||
@ -56,6 +58,7 @@ pub async fn get(
|
||||
Path(id): Path<Ulid>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let clock = Clock::default();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
@ -95,7 +98,7 @@ pub async fn get(
|
||||
let login = get_compat_sso_login_by_id(&mut conn, id).await?;
|
||||
|
||||
// Bail out if that login session is more than 30min old
|
||||
if Utc::now() > login.created_at + Duration::minutes(30) {
|
||||
if clock.now() > login.created_at + Duration::minutes(30) {
|
||||
let ctx = ErrorContext::new()
|
||||
.with_code("compat_sso_login_expired")
|
||||
.with_description("This login session expired.".to_owned());
|
||||
@ -121,6 +124,7 @@ pub async fn post(
|
||||
Query(params): Query<Params>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
@ -160,7 +164,7 @@ pub async fn post(
|
||||
let login = get_compat_sso_login_by_id(&mut txn, id).await?;
|
||||
|
||||
// Bail out if that login session is more than 30min old
|
||||
if Utc::now() > login.created_at + Duration::minutes(30) {
|
||||
if clock.now() > login.created_at + Duration::minutes(30) {
|
||||
let ctx = ErrorContext::new()
|
||||
.with_code("compat_sso_login_expired")
|
||||
.with_description("This login session expired.".to_owned());
|
||||
@ -186,8 +190,9 @@ pub async fn post(
|
||||
redirect_uri
|
||||
};
|
||||
|
||||
let device = Device::generate(&mut thread_rng());
|
||||
let _login = fullfill_compat_sso_login(&mut txn, session.user, login, device).await?;
|
||||
let device = Device::generate(&mut rng);
|
||||
let _login =
|
||||
fullfill_compat_sso_login(&mut txn, &mut rng, &clock, session.user, login, device).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -20,10 +20,7 @@ use axum::{
|
||||
use hyper::StatusCode;
|
||||
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
|
||||
use mas_storage::compat::insert_compat_sso_login;
|
||||
use rand::{
|
||||
distributions::{Alphanumeric, DistString},
|
||||
thread_rng,
|
||||
};
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use serde::Deserialize;
|
||||
use serde_with::serde;
|
||||
use sqlx::PgPool;
|
||||
@ -70,6 +67,8 @@ pub async fn get(
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
// Check the redirectUrl parameter
|
||||
let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?;
|
||||
let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?;
|
||||
@ -84,9 +83,9 @@ pub async fn get(
|
||||
return Err(RouteError::InvalidRedirectUrl);
|
||||
}
|
||||
|
||||
let token = Alphanumeric.sample_string(&mut thread_rng(), 32);
|
||||
let token = Alphanumeric.sample_string(&mut rng, 32);
|
||||
let mut conn = pool.acquire().await?;
|
||||
let login = insert_compat_sso_login(&mut conn, token, redirect_url).await?;
|
||||
let login = insert_compat_sso_login(&mut conn, &mut rng, &clock, token, redirect_url).await?;
|
||||
|
||||
Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.data, params.action)))
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
|
||||
use headers::{authorization::Bearer, Authorization};
|
||||
use hyper::StatusCode;
|
||||
use mas_data_model::{TokenFormatError, TokenType};
|
||||
use mas_storage::compat::compat_logout;
|
||||
use mas_storage::{compat::compat_logout, Clock};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use super::MatrixError;
|
||||
@ -67,6 +67,7 @@ pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let clock = Clock::default();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
|
||||
@ -78,7 +79,7 @@ pub(crate) async fn post(
|
||||
return Err(RouteError::InvalidAuthorization);
|
||||
}
|
||||
|
||||
compat_logout(&mut conn, token)
|
||||
compat_logout(&mut conn, &clock, token)
|
||||
.await
|
||||
.map_err(|_| RouteError::LogoutFailed)?;
|
||||
|
||||
|
@ -20,7 +20,6 @@ use mas_storage::compat::{
|
||||
add_compat_access_token, add_compat_refresh_token, consume_compat_refresh_token,
|
||||
expire_compat_access_token, lookup_active_compat_refresh_token, CompatRefreshTokenLookupError,
|
||||
};
|
||||
use rand::thread_rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DurationMilliSeconds};
|
||||
use sqlx::PgPool;
|
||||
@ -98,6 +97,7 @@ pub(crate) async fn post(
|
||||
State(pool): State<PgPool>,
|
||||
Json(input): Json<RequestBody>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let token_type = TokenType::check(&input.refresh_token)?;
|
||||
@ -109,23 +109,31 @@ pub(crate) async fn post(
|
||||
let (refresh_token, access_token, session) =
|
||||
lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?;
|
||||
|
||||
let (new_refresh_token_str, new_access_token_str) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
TokenType::CompatRefreshToken.generate(&mut rng),
|
||||
TokenType::CompatAccessToken.generate(&mut rng),
|
||||
)
|
||||
};
|
||||
let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
|
||||
let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
|
||||
|
||||
let expires_in = Duration::minutes(5);
|
||||
let new_access_token =
|
||||
add_compat_access_token(&mut txn, &session, new_access_token_str, Some(expires_in)).await?;
|
||||
let new_refresh_token =
|
||||
add_compat_refresh_token(&mut txn, &session, &new_access_token, new_refresh_token_str)
|
||||
let new_access_token = add_compat_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
new_access_token_str,
|
||||
Some(expires_in),
|
||||
)
|
||||
.await?;
|
||||
let new_refresh_token = add_compat_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
&new_access_token,
|
||||
new_refresh_token_str,
|
||||
)
|
||||
.await?;
|
||||
|
||||
consume_compat_refresh_token(&mut txn, refresh_token).await?;
|
||||
expire_compat_access_token(&mut txn, access_token).await?;
|
||||
consume_compat_refresh_token(&mut txn, &clock, refresh_token).await?;
|
||||
expire_compat_access_token(&mut txn, &clock, access_token).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -21,6 +21,7 @@
|
||||
|
||||
use std::{convert::Infallible, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
body::HttpBody,
|
||||
extract::FromRef,
|
||||
@ -36,6 +37,7 @@ use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_policy::PolicyFactory;
|
||||
use mas_router::{Route, UrlBuilder};
|
||||
use mas_templates::{ErrorContext, Templates};
|
||||
use rand::SeedableRng;
|
||||
use sqlx::PgPool;
|
||||
use tower::util::AndThenLayer;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
@ -356,3 +358,15 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
|
||||
policy_factory,
|
||||
}))
|
||||
}
|
||||
|
||||
// XXX: that should be moved somewhere else
|
||||
fn rng_and_clock() -> Result<(mas_storage::Clock, rand_chacha::ChaChaRng), anyhow::Error> {
|
||||
let clock = mas_storage::Clock::default();
|
||||
|
||||
// This rng is used to source the local rng
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let rng = rand::thread_rng();
|
||||
|
||||
let rng = rand_chacha::ChaChaRng::from_rng(rng).context("Failed to seed RNG")?;
|
||||
Ok((clock, rng))
|
||||
}
|
||||
|
@ -190,6 +190,8 @@ pub(crate) async fn complete(
|
||||
policy_factory: &PolicyFactory,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<AuthorizationResponse<Option<AccessTokenResponse>>, GrantCompletionError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
// Verify that the grant is in a pending stage
|
||||
if !grant.stage.is_pending() {
|
||||
return Err(GrantCompletionError::NotPending);
|
||||
@ -226,7 +228,7 @@ pub(crate) async fn complete(
|
||||
}
|
||||
|
||||
// All good, let's start the session
|
||||
let session = derive_session(&mut txn, &grant, browser_session).await?;
|
||||
let session = derive_session(&mut txn, &mut rng, &clock, &grant, browser_session).await?;
|
||||
|
||||
let grant = fulfill_grant(&mut txn, grant, session.clone()).await?;
|
||||
|
||||
|
@ -37,7 +37,7 @@ use oauth2_types::{
|
||||
requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
|
||||
response_type::ResponseType,
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
@ -159,6 +159,7 @@ pub(crate) async fn get(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(params): Form<Params>,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
// First, figure out what client it is
|
||||
@ -265,7 +266,7 @@ pub(crate) async fn get(
|
||||
}
|
||||
|
||||
// 32 random alphanumeric characters, about 190bit of entropy
|
||||
let code: String = thread_rng()
|
||||
let code: String = (&mut rng)
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
@ -296,6 +297,8 @@ pub(crate) async fn get(
|
||||
|
||||
let grant = new_authorization_grant(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
client,
|
||||
redirect_uri.clone(),
|
||||
params.auth.scope,
|
||||
|
@ -119,6 +119,7 @@ pub(crate) async fn post(
|
||||
Path(grant_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<()>>,
|
||||
) -> Result<Response, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool
|
||||
.begin()
|
||||
.await
|
||||
@ -163,6 +164,8 @@ pub(crate) async fn post(
|
||||
.collect();
|
||||
insert_client_consent(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
&grant.client,
|
||||
&scope_without_device,
|
||||
|
@ -28,6 +28,7 @@ use mas_storage::{
|
||||
client::ClientFetchError,
|
||||
refresh_token::{lookup_active_refresh_token, RefreshTokenLookupError},
|
||||
},
|
||||
Clock,
|
||||
};
|
||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
|
||||
use sqlx::PgPool;
|
||||
@ -158,6 +159,7 @@ pub(crate) async fn post(
|
||||
State(encrypter): State<Encrypter>,
|
||||
client_authorization: ClientAuthorization<IntrospectionRequest>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let clock = Clock::default();
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let client = client_authorization.credentials.fetch(&mut conn).await?;
|
||||
@ -227,7 +229,8 @@ pub(crate) async fn post(
|
||||
}
|
||||
}
|
||||
TokenType::CompatAccessToken => {
|
||||
let (token, session) = lookup_active_compat_access_token(&mut conn, token).await?;
|
||||
let (token, session) =
|
||||
lookup_active_compat_access_token(&mut conn, &clock, token).await?;
|
||||
|
||||
let device_scope = session.device.to_scope_token();
|
||||
let scope = [device_scope].into_iter().collect();
|
||||
|
@ -50,7 +50,6 @@ use oauth2_types::{
|
||||
},
|
||||
scope,
|
||||
};
|
||||
use rand::thread_rng;
|
||||
use serde::Serialize;
|
||||
use serde_with::{serde_as, skip_serializing_none};
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
@ -235,12 +234,13 @@ async fn authorization_code_grant(
|
||||
url_builder: &UrlBuilder,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<AccessTokenResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
// TODO: there is a bunch of unnecessary cloning here
|
||||
// TODO: handle "not found" cases
|
||||
let authz_grant = lookup_grant_by_code(&mut txn, &grant.code).await?;
|
||||
|
||||
// TODO: that's not a timestamp from the DB. Let's assume they are in sync
|
||||
let now = Utc::now();
|
||||
let now = clock.now();
|
||||
|
||||
let session = match authz_grant.stage {
|
||||
AuthorizationGrantStage::Cancelled { cancelled_at } => {
|
||||
@ -257,7 +257,7 @@ async fn authorization_code_grant(
|
||||
// Ending the session if the token was already exchanged more than 20s ago
|
||||
if now - exchanged_at > Duration::seconds(20) {
|
||||
debug!("Ending potentially compromised session");
|
||||
end_oauth_session(&mut txn, session).await?;
|
||||
end_oauth_session(&mut txn, &clock, session).await?;
|
||||
txn.commit().await?;
|
||||
}
|
||||
|
||||
@ -303,22 +303,32 @@ async fn authorization_code_grant(
|
||||
let browser_session = &session.browser_session;
|
||||
|
||||
let ttl = Duration::minutes(5);
|
||||
let (access_token_str, refresh_token_str) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
TokenType::AccessToken.generate(&mut rng),
|
||||
TokenType::RefreshToken.generate(&mut rng),
|
||||
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
||||
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
||||
|
||||
let access_token = add_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
session,
|
||||
access_token_str.clone(),
|
||||
ttl,
|
||||
)
|
||||
};
|
||||
.await?;
|
||||
|
||||
let access_token = add_access_token(&mut txn, session, access_token_str.clone(), ttl).await?;
|
||||
|
||||
let _refresh_token =
|
||||
add_refresh_token(&mut txn, session, access_token, refresh_token_str.clone()).await?;
|
||||
let _refresh_token = add_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
session,
|
||||
access_token,
|
||||
refresh_token_str.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let id_token = if session.scope.contains(&scope::OPENID) {
|
||||
let mut claims = HashMap::new();
|
||||
let now = Utc::now();
|
||||
let now = clock.now();
|
||||
claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?;
|
||||
claims::SUB.insert(&mut claims, &browser_session.user.sub)?;
|
||||
claims::AUD.insert(&mut claims, client.client_id.clone())?;
|
||||
@ -346,7 +356,7 @@ async fn authorization_code_grant(
|
||||
let signer = key.params().signing_key_for_alg(&alg)?;
|
||||
let header = JsonWebSignatureHeader::new(alg)
|
||||
.with_kid(key.kid().context("key has no `kid` for some reason")?);
|
||||
let id_token = Jwt::sign(header, claims, &signer)?;
|
||||
let id_token = Jwt::sign_with_rng(&mut rng, header, claims, &signer)?;
|
||||
|
||||
Some(id_token.as_str().to_owned())
|
||||
} else {
|
||||
@ -362,7 +372,7 @@ async fn authorization_code_grant(
|
||||
params = params.with_id_token(id_token);
|
||||
}
|
||||
|
||||
exchange_grant(&mut txn, authz_grant).await?;
|
||||
exchange_grant(&mut txn, &clock, authz_grant).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
@ -374,6 +384,8 @@ async fn refresh_token_grant(
|
||||
client: &Client<PostgresqlBackend>,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<AccessTokenResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
let (refresh_token, session) =
|
||||
lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?;
|
||||
|
||||
@ -383,24 +395,33 @@ async fn refresh_token_grant(
|
||||
}
|
||||
|
||||
let ttl = Duration::minutes(5);
|
||||
let (access_token_str, refresh_token_str) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
TokenType::AccessToken.generate(&mut rng),
|
||||
TokenType::RefreshToken.generate(&mut rng),
|
||||
let access_token_str = TokenType::AccessToken.generate(&mut rng);
|
||||
let refresh_token_str = TokenType::RefreshToken.generate(&mut rng);
|
||||
|
||||
let new_access_token = add_access_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
access_token_str.clone(),
|
||||
ttl,
|
||||
)
|
||||
};
|
||||
.await?;
|
||||
|
||||
let new_access_token =
|
||||
add_access_token(&mut txn, &session, access_token_str.clone(), ttl).await?;
|
||||
let new_refresh_token = add_refresh_token(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session,
|
||||
new_access_token,
|
||||
refresh_token_str,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let new_refresh_token =
|
||||
add_refresh_token(&mut txn, &session, new_access_token, refresh_token_str).await?;
|
||||
|
||||
consume_refresh_token(&mut txn, &refresh_token).await?;
|
||||
consume_refresh_token(&mut txn, &clock, &refresh_token).await?;
|
||||
|
||||
if let Some(access_token) = refresh_token.access_token {
|
||||
revoke_access_token(&mut txn, access_token).await?;
|
||||
revoke_access_token(&mut txn, &clock, access_token).await?;
|
||||
}
|
||||
|
||||
let params = AccessTokenResponse::new(access_token_str)
|
||||
|
@ -54,6 +54,7 @@ pub async fn get(
|
||||
user_authorization: UserAuthorization,
|
||||
) -> Result<Response, FancyError> {
|
||||
// TODO: error handling
|
||||
let (_clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let session = user_authorization.protected(&mut conn).await?;
|
||||
@ -88,7 +89,7 @@ pub async fn get(
|
||||
user_info,
|
||||
};
|
||||
|
||||
let token = Jwt::sign(header, user_info, &signer)?;
|
||||
let token = Jwt::sign_with_rng(&mut rng, header, user_info, &signer)?;
|
||||
Ok(JwtResponse(token).into_response())
|
||||
} else {
|
||||
Ok(Json(user_info).into_response())
|
||||
|
@ -72,6 +72,7 @@ pub(crate) async fn post(
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
Form(form): Form<ProtectedForm<EmailForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -86,14 +87,22 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
let user_email = add_user_email(&mut txn, &session.user, form.email).await?;
|
||||
let user_email = add_user_email(&mut txn, &mut rng, &clock, &session.user, form.email).await?;
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.data);
|
||||
let next = if let Some(action) = query.post_auth_action {
|
||||
next.and_then(action)
|
||||
} else {
|
||||
next
|
||||
};
|
||||
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?;
|
||||
start_email_verification(
|
||||
&mailer,
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
user_email,
|
||||
)
|
||||
.await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -32,10 +32,10 @@ use mas_storage::{
|
||||
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
|
||||
remove_user_email, set_user_email_as_primary,
|
||||
},
|
||||
PostgresqlBackend,
|
||||
Clock, PostgresqlBackend,
|
||||
};
|
||||
use mas_templates::{AccountEmailsContext, EmailVerificationContext, TemplateContext, Templates};
|
||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use serde::Deserialize;
|
||||
use sqlx::{PgExecutor, PgPool};
|
||||
use tracing::info;
|
||||
@ -93,17 +93,26 @@ async fn render(
|
||||
async fn start_email_verification(
|
||||
mailer: &Mailer,
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: &User<PostgresqlBackend>,
|
||||
user_email: UserEmail<PostgresqlBackend>,
|
||||
) -> anyhow::Result<()> {
|
||||
// First, generate a code
|
||||
let range = Uniform::<u32>::from(0..1_000_000);
|
||||
let code = thread_rng().sample(range).to_string();
|
||||
let code = rng.sample(range).to_string();
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
|
||||
let verification =
|
||||
add_user_email_verification_code(executor, user_email, Duration::hours(8), code).await?;
|
||||
let verification = add_user_email_verification_code(
|
||||
executor,
|
||||
&mut rng,
|
||||
clock,
|
||||
user_email,
|
||||
Duration::hours(8),
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
@ -126,6 +135,7 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<ManagementForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let (session_info, cookie_jar) = cookie_jar.session_info();
|
||||
@ -143,9 +153,18 @@ pub(crate) async fn post(
|
||||
|
||||
match form {
|
||||
ManagementForm::Add { email } => {
|
||||
let user_email = add_user_email(&mut txn, &session.user, email).await?;
|
||||
let user_email =
|
||||
add_user_email(&mut txn, &mut rng, &clock, &session.user, email).await?;
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.data);
|
||||
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?;
|
||||
start_email_verification(
|
||||
&mailer,
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
user_email,
|
||||
)
|
||||
.await?;
|
||||
txn.commit().await?;
|
||||
return Ok((cookie_jar, next.go()).into_response());
|
||||
}
|
||||
@ -154,7 +173,15 @@ pub(crate) async fn post(
|
||||
|
||||
let user_email = get_user_email(&mut txn, &session.user, id).await?;
|
||||
let next = mas_router::AccountVerifyEmail::new(user_email.data);
|
||||
start_email_verification(&mailer, &mut txn, &session.user, user_email).await?;
|
||||
start_email_verification(
|
||||
&mailer,
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&session.user,
|
||||
user_email,
|
||||
)
|
||||
.await?;
|
||||
txn.commit().await?;
|
||||
return Ok((cookie_jar, next.go()).into_response());
|
||||
}
|
||||
|
@ -23,9 +23,12 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::user::{
|
||||
use mas_storage::{
|
||||
user::{
|
||||
consume_email_verification, lookup_user_email_by_id, lookup_user_email_verification_code,
|
||||
mark_user_email_as_verified, set_user_email_as_primary,
|
||||
},
|
||||
Clock,
|
||||
};
|
||||
use mas_templates::{EmailVerificationPageContext, TemplateContext, Templates};
|
||||
use serde::Deserialize;
|
||||
@ -84,6 +87,7 @@ pub(crate) async fn post(
|
||||
Path(id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<CodeForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let clock = Clock::default();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -105,12 +109,13 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
// TODO: make those 8 hours configurable
|
||||
let verification = lookup_user_email_verification_code(&mut txn, email, &form.code).await?;
|
||||
let verification =
|
||||
lookup_user_email_verification_code(&mut txn, &clock, email, &form.code).await?;
|
||||
|
||||
// TODO: display nice errors if the code was already consumed or expired
|
||||
let verification = consume_email_verification(&mut txn, verification).await?;
|
||||
let verification = consume_email_verification(&mut txn, &clock, verification).await?;
|
||||
|
||||
let _email = mark_user_email_as_verified(&mut txn, verification.email).await?;
|
||||
let _email = mark_user_email_as_verified(&mut txn, &clock, verification.email).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -81,6 +81,7 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<ChangeForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -96,7 +97,14 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, login.go()).into_response());
|
||||
};
|
||||
|
||||
authenticate_session(&mut txn, &mut session, &form.current_password).await?;
|
||||
authenticate_session(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&mut session,
|
||||
&form.current_password,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// TODO: display nice form errors
|
||||
if form.new_password != form.new_password_confirm {
|
||||
@ -104,7 +112,15 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
let phf = Argon2::default();
|
||||
set_password(&mut txn, phf, &session.user, &form.new_password).await?;
|
||||
set_password(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
phf,
|
||||
&session.user,
|
||||
&form.new_password,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let reply = render(templates.clone(), session, cookie_jar).await?;
|
||||
|
||||
|
@ -80,6 +80,7 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<LoginForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -114,7 +115,7 @@ pub(crate) async fn post(
|
||||
return Ok((cookie_jar, Html(content)).into_response());
|
||||
}
|
||||
|
||||
match login(&mut conn, &form.username, &form.password).await {
|
||||
match login(&mut conn, &mut rng, &clock, &form.username, &form.password).await {
|
||||
Ok(session_info) => {
|
||||
let cookie_jar = cookie_jar.set_session(&session_info);
|
||||
let reply = query.go_next();
|
||||
|
@ -23,7 +23,7 @@ use mas_axum_utils::{
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::{PostAuthAction, Route};
|
||||
use mas_storage::user::end_session;
|
||||
use mas_storage::{user::end_session, Clock};
|
||||
use sqlx::PgPool;
|
||||
|
||||
pub(crate) async fn post(
|
||||
@ -31,6 +31,7 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let clock = Clock::default();
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -40,7 +41,7 @@ pub(crate) async fn post(
|
||||
let maybe_session = session_info.load_session(&mut txn).await?;
|
||||
|
||||
if let Some(session) = maybe_session {
|
||||
end_session(&mut txn, &session).await?;
|
||||
end_session(&mut txn, &clock, &session).await?;
|
||||
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
|
||||
}
|
||||
|
||||
|
@ -80,6 +80,7 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<ReauthForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -98,7 +99,7 @@ pub(crate) async fn post(
|
||||
};
|
||||
|
||||
// TODO: recover from errors here
|
||||
authenticate_session(&mut txn, &mut session, &form.password).await?;
|
||||
authenticate_session(&mut txn, &mut rng, &clock, &mut session, &form.password).await?;
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -39,7 +39,7 @@ use mas_templates::{
|
||||
EmailVerificationContext, FieldError, FormError, RegisterContext, RegisterFormField,
|
||||
TemplateContext, Templates, ToFormState,
|
||||
};
|
||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{PgConnection, PgPool};
|
||||
|
||||
@ -87,6 +87,7 @@ pub(crate) async fn get(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn post(
|
||||
State(mailer): State<Mailer>,
|
||||
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||
@ -96,6 +97,7 @@ pub(crate) async fn post(
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Form(form): Form<ProtectedForm<RegisterForm>>,
|
||||
) -> Result<Response, FancyError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let form = cookie_jar.verify_form(form)?;
|
||||
@ -180,18 +182,34 @@ pub(crate) async fn post(
|
||||
}
|
||||
|
||||
let pfh = Argon2::default();
|
||||
let user = register_user(&mut txn, pfh, &form.username, &form.password).await?;
|
||||
let user = register_user(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
pfh,
|
||||
&form.username,
|
||||
&form.password,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_email = add_user_email(&mut txn, &user, form.email).await?;
|
||||
let user_email = add_user_email(&mut txn, &mut rng, &clock, &user, form.email).await?;
|
||||
|
||||
// First, generate a code
|
||||
let range = Uniform::<u32>::from(0..1_000_000);
|
||||
let code = thread_rng().sample(range).to_string();
|
||||
let code = rng.sample(range);
|
||||
let code = format!("{code:06}");
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
|
||||
let verification =
|
||||
add_user_email_verification_code(&mut txn, user_email, Duration::hours(8), code).await?;
|
||||
let verification = add_user_email_verification_code(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
user_email,
|
||||
Duration::hours(8),
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
@ -203,7 +221,7 @@ pub(crate) async fn post(
|
||||
let next = mas_router::AccountVerifyEmail::new(verification.email.data)
|
||||
.and_maybe(query.post_auth_action);
|
||||
|
||||
let session = start_session(&mut txn, user).await?;
|
||||
let session = start_session(&mut txn, &mut rng, &clock, user).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
|
@ -309,6 +309,7 @@ impl<T> Jwt<'static, T> {
|
||||
S: Signature,
|
||||
T: Serialize,
|
||||
{
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
Self::sign_with_rng(thread_rng(), header, payload, key)
|
||||
}
|
||||
|
||||
@ -357,6 +358,7 @@ impl<T> Jwt<'static, T> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::disallowed_methods)]
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use rand::thread_rng;
|
||||
|
||||
|
@ -19,6 +19,7 @@ tracing = "0.1.37"
|
||||
argon2 = { version = "0.4.1", features = ["password-hash"] }
|
||||
password-hash = { version = "0.4.2", features = ["std"] }
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
url = { version = "2.3.1", features = ["serde"] }
|
||||
uuid = "1.2.1"
|
||||
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
|
||||
|
@ -19,6 +19,7 @@ use mas_data_model::{
|
||||
CompatAccessToken, CompatRefreshToken, CompatSession, CompatSsoLogin, CompatSsoLoginState,
|
||||
Device, User, UserEmail,
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{Acquire, PgExecutor, Postgres};
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
@ -27,7 +28,7 @@ use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{user::lookup_user_by_username, DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::{user::lookup_user_by_username, Clock, DatabaseInconsistencyError, PostgresqlBackend};
|
||||
|
||||
struct CompatAccessTokenLookup {
|
||||
compat_access_token_id: Uuid,
|
||||
@ -67,6 +68,7 @@ impl CompatAccessTokenLookupError {
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn lookup_active_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<
|
||||
(
|
||||
@ -112,7 +114,7 @@ pub async fn lookup_active_compat_access_token(
|
||||
|
||||
// Check for token expiration
|
||||
if let Some(expires_at) = res.compat_access_token_expires_at {
|
||||
if expires_at < Utc::now() {
|
||||
if expires_at < clock.now() {
|
||||
return Err(CompatAccessTokenLookupError::Expired { when: expires_at });
|
||||
}
|
||||
}
|
||||
@ -311,7 +313,9 @@ pub async fn lookup_active_compat_refresh_token(
|
||||
err(Display),
|
||||
)]
|
||||
pub async fn compat_login(
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
username: &str,
|
||||
password: &str,
|
||||
device: Device,
|
||||
@ -348,8 +352,8 @@ pub async fn compat_login(
|
||||
.instrument(tracing::info_span!("Verify hashed password"))
|
||||
.await??;
|
||||
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -392,12 +396,14 @@ pub async fn compat_login(
|
||||
)]
|
||||
pub async fn add_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &CompatSession<PostgresqlBackend>,
|
||||
token: String,
|
||||
expires_after: Option<Duration>,
|
||||
) -> Result<CompatAccessToken<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
|
||||
|
||||
let expires_at = expires_after.map(|expires_after| created_at + expires_after);
|
||||
@ -436,9 +442,10 @@ pub async fn add_compat_access_token(
|
||||
)]
|
||||
pub async fn expire_compat_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token: CompatAccessToken<PostgresqlBackend>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let expires_at = Utc::now();
|
||||
let expires_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_access_tokens
|
||||
@ -474,12 +481,14 @@ pub async fn expire_compat_access_token(
|
||||
)]
|
||||
pub async fn add_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &CompatSession<PostgresqlBackend>,
|
||||
access_token: &CompatAccessToken<PostgresqlBackend>,
|
||||
token: String,
|
||||
) -> Result<CompatRefreshToken<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -514,9 +523,10 @@ pub async fn add_compat_refresh_token(
|
||||
)]
|
||||
pub async fn compat_logout(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
token: &str,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let finished_at = Utc::now();
|
||||
let finished_at = clock.now();
|
||||
// TODO: this does not check for token expiration
|
||||
let compat_session_id = sqlx::query_scalar!(
|
||||
r#"
|
||||
@ -552,9 +562,10 @@ pub async fn compat_logout(
|
||||
)]
|
||||
pub async fn consume_compat_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: CompatRefreshToken<PostgresqlBackend>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let consumed_at = Utc::now();
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_refresh_tokens
|
||||
@ -587,11 +598,13 @@ pub async fn consume_compat_refresh_token(
|
||||
)]
|
||||
pub async fn insert_compat_sso_login(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
login_token: String,
|
||||
redirect_uri: Url,
|
||||
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -845,7 +858,9 @@ pub async fn get_compat_sso_login_by_token(
|
||||
err(Display),
|
||||
)]
|
||||
pub async fn fullfill_compat_sso_login(
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: User<PostgresqlBackend>,
|
||||
mut login: CompatSsoLogin<PostgresqlBackend>,
|
||||
device: Device,
|
||||
@ -856,8 +871,8 @@ pub async fn fullfill_compat_sso_login(
|
||||
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user.id", tracing::field::display(user.data));
|
||||
|
||||
sqlx::query!(
|
||||
@ -883,7 +898,7 @@ pub async fn fullfill_compat_sso_login(
|
||||
finished_at: None,
|
||||
};
|
||||
|
||||
let fulfilled_at = Utc::now();
|
||||
let fulfilled_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
@ -924,6 +939,7 @@ pub async fn fullfill_compat_sso_login(
|
||||
)]
|
||||
pub async fn mark_compat_sso_login_as_exchanged(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut login: CompatSsoLogin<PostgresqlBackend>,
|
||||
) -> Result<CompatSsoLogin<PostgresqlBackend>, anyhow::Error> {
|
||||
let (fulfilled_at, session) = match login.state {
|
||||
@ -934,7 +950,7 @@ pub async fn mark_compat_sso_login_as_exchanged(
|
||||
_ => bail!("sso login in wrong state"),
|
||||
};
|
||||
|
||||
let exchanged_at = Utc::now();
|
||||
let exchanged_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE compat_sso_logins
|
||||
|
@ -15,7 +15,12 @@
|
||||
//! Interactions with the database
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(clippy::all, clippy::str_to_string, rustdoc::broken_intra_doc_links)]
|
||||
#![deny(
|
||||
clippy::all,
|
||||
clippy::str_to_string,
|
||||
clippy::future_not_send,
|
||||
rustdoc::broken_intra_doc_links
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::missing_errors_doc,
|
||||
@ -23,12 +28,27 @@
|
||||
clippy::module_name_repetitions
|
||||
)]
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{StorageBackend, StorageBackendMarker};
|
||||
use serde::Serialize;
|
||||
use sqlx::migrate::Migrator;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy)]
|
||||
pub struct Clock {
|
||||
_private: (),
|
||||
}
|
||||
|
||||
impl Clock {
|
||||
#[must_use]
|
||||
pub fn now(&self) -> DateTime<Utc> {
|
||||
// This is the clock used elsewhere, it's fine to call Utc::now here
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
Utc::now()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("database query returned an inconsistent state")]
|
||||
pub struct DatabaseInconsistencyError;
|
||||
|
@ -15,13 +15,14 @@
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
|
||||
use rand::Rng;
|
||||
use sqlx::{Acquire, PgExecutor, Postgres};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::{lookup_client, ClientFetchError};
|
||||
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -35,13 +36,15 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
)]
|
||||
pub async fn add_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &Session<PostgresqlBackend>,
|
||||
access_token: String,
|
||||
expires_after: Duration,
|
||||
) -> Result<AccessToken<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let created_at = clock.now();
|
||||
let expires_at = created_at + expires_after;
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
|
||||
tracing::Span::current().record("access_token.id", tracing::field::display(id));
|
||||
|
||||
@ -243,9 +246,10 @@ where
|
||||
)]
|
||||
pub async fn revoke_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
access_token: AccessToken<PostgresqlBackend>,
|
||||
) -> anyhow::Result<()> {
|
||||
let revoked_at = Utc::now();
|
||||
let revoked_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_access_tokens
|
||||
@ -266,9 +270,9 @@ pub async fn revoke_access_token(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result<u64> {
|
||||
pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result<u64> {
|
||||
// Cleanup token which expired more than 15 minutes ago
|
||||
let threshold = Utc::now() - Duration::minutes(15);
|
||||
let threshold = clock.now() - Duration::minutes(15);
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth2_access_tokens
|
||||
|
@ -24,13 +24,14 @@ use mas_data_model::{
|
||||
};
|
||||
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::lookup_client;
|
||||
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -43,6 +44,8 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new_authorization_grant(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
client: Client<PostgresqlBackend>,
|
||||
redirect_uri: Url,
|
||||
scope: Scope,
|
||||
@ -67,8 +70,8 @@ pub async fn new_authorization_grant(
|
||||
let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
|
||||
let code_str = code.as_ref().map(|c| &c.code);
|
||||
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("grant.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -504,11 +507,13 @@ pub async fn lookup_grant_by_code(
|
||||
)]
|
||||
pub async fn derive_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
grant: &AuthorizationGrant<PostgresqlBackend>,
|
||||
browser_session: BrowserSession<PostgresqlBackend>,
|
||||
) -> Result<Session<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -623,9 +628,10 @@ pub async fn give_consent_to_grant(
|
||||
)]
|
||||
pub async fn exchange_grant(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut grant: AuthorizationGrant<PostgresqlBackend>,
|
||||
) -> Result<AuthorizationGrant<PostgresqlBackend>, anyhow::Error> {
|
||||
let exchanged_at = Utc::now();
|
||||
let exchanged_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_authorization_grants
|
||||
|
@ -14,14 +14,14 @@
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use chrono::Utc;
|
||||
use mas_data_model::{Client, User};
|
||||
use oauth2_types::scope::{Scope, ScopeToken};
|
||||
use rand::Rng;
|
||||
use sqlx::PgExecutor;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::PostgresqlBackend;
|
||||
use crate::{Clock, PostgresqlBackend};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -67,17 +67,19 @@ pub async fn fetch_client_consent(
|
||||
)]
|
||||
pub async fn insert_client_consent(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: &User<PostgresqlBackend>,
|
||||
client: &Client<PostgresqlBackend>,
|
||||
scope: &Scope,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let now = Utc::now();
|
||||
let now = clock.now();
|
||||
let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
|
||||
.iter()
|
||||
.map(|token| {
|
||||
(
|
||||
token.to_string(),
|
||||
Uuid::from(Ulid::from_datetime(now.into())),
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
@ -12,12 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::Utc;
|
||||
use mas_data_model::Session;
|
||||
use sqlx::PgExecutor;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::PostgresqlBackend;
|
||||
use crate::{Clock, PostgresqlBackend};
|
||||
|
||||
pub mod access_token;
|
||||
pub mod authorization_grant;
|
||||
@ -37,9 +36,10 @@ pub mod refresh_token;
|
||||
)]
|
||||
pub async fn end_oauth_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
session: Session<PostgresqlBackend>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let finished_at = Utc::now();
|
||||
let finished_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_sessions
|
||||
|
@ -17,13 +17,14 @@ use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{
|
||||
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
|
||||
};
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::{lookup_client, ClientFetchError};
|
||||
use crate::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -38,12 +39,14 @@ use crate::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
)]
|
||||
pub async fn add_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &Session<PostgresqlBackend>,
|
||||
access_token: AccessToken<PostgresqlBackend>,
|
||||
refresh_token: String,
|
||||
) -> anyhow::Result<RefreshToken<PostgresqlBackend>> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -263,9 +266,10 @@ pub async fn lookup_active_refresh_token(
|
||||
)]
|
||||
pub async fn consume_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
refresh_token: &RefreshToken<PostgresqlBackend>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let consumed_at = Utc::now();
|
||||
let consumed_at = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_refresh_tokens
|
||||
|
@ -22,7 +22,7 @@ use mas_data_model::{
|
||||
UserEmailVerificationState,
|
||||
};
|
||||
use password_hash::{PasswordHash, PasswordHasher, SaltString};
|
||||
use rand::thread_rng;
|
||||
use rand::{CryptoRng, Rng};
|
||||
use sqlx::{Acquire, PgExecutor, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
@ -31,6 +31,7 @@ use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{DatabaseInconsistencyError, PostgresqlBackend};
|
||||
use crate::Clock;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct UserLookup {
|
||||
@ -68,7 +69,9 @@ pub enum LoginError {
|
||||
err,
|
||||
)]
|
||||
pub async fn login(
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
conn: impl Acquire<'_, Database = Postgres> + Send,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, LoginError> {
|
||||
@ -86,8 +89,8 @@ pub async fn login(
|
||||
}
|
||||
})?;
|
||||
|
||||
let mut session = start_session(&mut txn, user).await?;
|
||||
authenticate_session(&mut txn, &mut session, password)
|
||||
let mut session = start_session(&mut txn, &mut rng, clock, user).await?;
|
||||
authenticate_session(&mut txn, &mut rng, clock, &mut session, password)
|
||||
.await
|
||||
.map_err(|source| {
|
||||
if matches!(source, AuthenticationError::Password { .. }) {
|
||||
@ -230,10 +233,12 @@ pub async fn lookup_active_session(
|
||||
)]
|
||||
pub async fn start_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: User<PostgresqlBackend>,
|
||||
) -> Result<BrowserSession<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user_session.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -301,13 +306,16 @@ pub enum AuthenticationError {
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
session.id = %session.data,
|
||||
user.id = %session.user.data
|
||||
user.id = %session.user.data,
|
||||
user_session.id = %session.data,
|
||||
user_session_authentication.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn authenticate_session(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
session: &mut BrowserSession<PostgresqlBackend>,
|
||||
password: &str,
|
||||
) -> Result<(), AuthenticationError> {
|
||||
@ -341,8 +349,13 @@ pub async fn authenticate_session(
|
||||
.await??;
|
||||
|
||||
// That went well, let's insert the auth info
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record(
|
||||
"user_session_authentication.id",
|
||||
tracing::field::display(id),
|
||||
);
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications
|
||||
@ -376,12 +389,14 @@ pub async fn authenticate_session(
|
||||
)]
|
||||
pub async fn register_user(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
phf: impl PasswordHasher,
|
||||
mut rng: impl CryptoRng + Rng + Send,
|
||||
clock: &Clock,
|
||||
phf: impl PasswordHasher + Send,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<User<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -405,7 +420,7 @@ pub async fn register_user(
|
||||
primary_email: None,
|
||||
};
|
||||
|
||||
set_password(txn.borrow_mut(), phf, &user, password).await?;
|
||||
set_password(txn.borrow_mut(), &mut rng, clock, phf, &user, password).await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
@ -420,15 +435,17 @@ pub async fn register_user(
|
||||
)]
|
||||
pub async fn set_password(
|
||||
executor: impl PgExecutor<'_>,
|
||||
phf: impl PasswordHasher,
|
||||
mut rng: impl CryptoRng + Rng + Send,
|
||||
clock: &Clock,
|
||||
phf: impl PasswordHasher + Send,
|
||||
user: &User<PostgresqlBackend>,
|
||||
password: &str,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
tracing::Span::current().record("user_password.id", tracing::field::display(id));
|
||||
|
||||
let salt = SaltString::generate(thread_rng());
|
||||
let salt = SaltString::generate(&mut rng);
|
||||
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
|
||||
|
||||
sqlx::query_scalar!(
|
||||
@ -456,9 +473,10 @@ pub async fn set_password(
|
||||
)]
|
||||
pub async fn end_session(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
session: &BrowserSession<PostgresqlBackend>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let now = Utc::now();
|
||||
let now = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_sessions
|
||||
@ -672,11 +690,13 @@ pub async fn get_user_email(
|
||||
)]
|
||||
pub async fn add_user_email(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
user: &User<PostgresqlBackend>,
|
||||
email: String,
|
||||
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user_email.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
@ -842,9 +862,10 @@ pub async fn lookup_user_email_by_id(
|
||||
)]
|
||||
pub async fn mark_user_email_as_verified(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut email: UserEmail<PostgresqlBackend>,
|
||||
) -> Result<UserEmail<PostgresqlBackend>, anyhow::Error> {
|
||||
let confirmed_at = Utc::now();
|
||||
let confirmed_at = clock.now();
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE user_emails
|
||||
@ -881,10 +902,11 @@ struct UserEmailConfirmationCodeLookup {
|
||||
)]
|
||||
pub async fn lookup_user_email_verification_code(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
email: UserEmail<PostgresqlBackend>,
|
||||
code: &str,
|
||||
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
|
||||
let now = Utc::now();
|
||||
let now = clock.now();
|
||||
|
||||
let res = sqlx::query_as!(
|
||||
UserEmailConfirmationCodeLookup,
|
||||
@ -935,13 +957,14 @@ pub async fn lookup_user_email_verification_code(
|
||||
)]
|
||||
pub async fn consume_email_verification(
|
||||
executor: impl PgExecutor<'_>,
|
||||
clock: &Clock,
|
||||
mut verification: UserEmailVerification<PostgresqlBackend>,
|
||||
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
|
||||
if !matches!(verification.state, UserEmailVerificationState::Valid) {
|
||||
bail!("user email verification in wrong state");
|
||||
}
|
||||
|
||||
let consumed_at = Utc::now();
|
||||
let consumed_at = clock.now();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
@ -974,12 +997,14 @@ pub async fn consume_email_verification(
|
||||
)]
|
||||
pub async fn add_user_email_verification_code(
|
||||
executor: impl PgExecutor<'_>,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
email: UserEmail<PostgresqlBackend>,
|
||||
max_age: chrono::Duration,
|
||||
code: String,
|
||||
) -> Result<UserEmailVerification<PostgresqlBackend>, anyhow::Error> {
|
||||
let created_at = Utc::now();
|
||||
let id = Ulid::from_datetime(created_at.into());
|
||||
let created_at = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||
tracing::Span::current().record("user_email_confirmation.id", tracing::field::display(id));
|
||||
let expires_at = created_at + max_age;
|
||||
|
||||
@ -1013,23 +1038,27 @@ pub async fn add_user_email_verification_code(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::SeedableRng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||
async fn test_user_registration_and_login(pool: sqlx::PgPool) -> anyhow::Result<()> {
|
||||
let clock = Clock::default();
|
||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let exists = username_exists(&mut txn, "john").await?;
|
||||
assert!(!exists);
|
||||
|
||||
let hasher = Argon2::default();
|
||||
let user = register_user(&mut txn, hasher, "john", "hunter2").await?;
|
||||
let user = register_user(&mut txn, &mut rng, &clock, hasher, "john", "hunter2").await?;
|
||||
assert_eq!(user.username, "john");
|
||||
|
||||
let exists = username_exists(&mut txn, "john").await?;
|
||||
assert!(exists);
|
||||
|
||||
let session = login(&mut txn, "john", "hunter2").await?;
|
||||
let session = login(&mut txn, &mut rng, &clock, "john", "hunter2").await?;
|
||||
assert_eq!(session.user.data, user.data);
|
||||
|
||||
let user2 = lookup_user_by_username(&mut txn, "john").await?;
|
||||
|
@ -14,13 +14,14 @@
|
||||
|
||||
//! Database-related tasks
|
||||
|
||||
use mas_storage::Clock;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use super::Task;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CleanupExpired(Pool<Postgres>);
|
||||
struct CleanupExpired(Pool<Postgres>, Clock);
|
||||
|
||||
impl std::fmt::Debug for CleanupExpired {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@ -31,7 +32,7 @@ impl std::fmt::Debug for CleanupExpired {
|
||||
#[async_trait::async_trait]
|
||||
impl Task for CleanupExpired {
|
||||
async fn run(&self) {
|
||||
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0).await;
|
||||
let res = mas_storage::oauth2::access_token::cleanup_expired(&self.0, &self.1).await;
|
||||
match res {
|
||||
Ok(0) => {
|
||||
debug!("no token to clean up");
|
||||
@ -49,5 +50,6 @@ impl Task for CleanupExpired {
|
||||
/// Cleanup expired tokens
|
||||
#[must_use]
|
||||
pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone {
|
||||
CleanupExpired(pool.clone())
|
||||
// XXX: the clock should come from somewhere else
|
||||
CleanupExpired(pool.clone(), Clock::default())
|
||||
}
|
||||
|
Reference in New Issue
Block a user