1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Add rate-limiting for account recovery and registration (#3093)

* Add rate-limiting for account recovery and registration

* Rename login ratelimiter `per_address` to `per_ip` for consistency

Co-authored-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
reivilibre
2024-08-07 18:57:36 +01:00
committed by GitHub
parent 244f8f5e5e
commit 5d4a4a6fb8
10 changed files with 320 additions and 35 deletions

View File

@ -23,21 +23,28 @@ use crate::ConfigurationSection;
/// Configuration related to sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimitingConfig {
/// Account Recovery-specific rate limits
#[serde(default)]
pub account_recovery: AccountRecoveryRateLimitingConfig,
/// Login-specific rate limits
#[serde(default)]
pub login: LoginRateLimitingConfig,
/// Controls how many registrations attempts are permitted
/// based on source address.
#[serde(default = "default_registration")]
pub registration: RateLimiterConfiguration,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct LoginRateLimitingConfig {
/// Controls how many login attempts are permitted
/// based on source address.
/// based on source IP address.
/// This can protect against brute force login attempts.
///
/// Note: this limit also applies to password checks when a user attempts to
/// change their own password.
#[serde(default = "default_login_per_address")]
pub per_address: RateLimiterConfiguration,
#[serde(default = "default_login_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many login attempts are permitted
/// based on the account that is being attempted to be logged into.
/// This can protect against a distributed brute force attack
@ -50,6 +57,24 @@ pub struct LoginRateLimitingConfig {
pub per_account: RateLimiterConfiguration,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AccountRecoveryRateLimitingConfig {
/// Controls how many account recovery attempts are permitted
/// based on source IP address.
/// This can protect against causing e-mail spam to many targets.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many account recovery attempts are permitted
/// based on the e-mail address entered into the recovery form.
/// This can protect against causing e-mail spam to one target.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_address")]
pub per_address: RateLimiterConfiguration,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimiterConfiguration {
/// A one-off burst of actions that the user can perform
@ -66,6 +91,13 @@ impl ConfigurationSection for RateLimitingConfig {
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
error
};
let error_on_nested_field =
|mut error: figment::error::Error, container: &'static str, field: &'static str| {
error.metadata = metadata.cloned();
@ -92,8 +124,23 @@ impl ConfigurationSection for RateLimitingConfig {
None
};
if let Some(error) = error_on_limiter(&self.login.per_address) {
return Err(error_on_nested_field(error, "login", "per_address"));
if let Some(error) = error_on_limiter(&self.account_recovery.per_ip) {
return Err(error_on_nested_field(error, "account_recovery", "per_ip"));
}
if let Some(error) = error_on_limiter(&self.account_recovery.per_address) {
return Err(error_on_nested_field(
error,
"account_recovery",
"per_address",
));
}
if let Some(error) = error_on_limiter(&self.registration) {
return Err(error_on_field(error, "registration"));
}
if let Some(error) = error_on_limiter(&self.login.per_ip) {
return Err(error_on_nested_field(error, "login", "per_ip"));
}
if let Some(error) = error_on_limiter(&self.login.per_account) {
return Err(error_on_nested_field(error, "login", "per_account"));
@ -119,7 +166,7 @@ impl RateLimiterConfiguration {
}
}
fn default_login_per_address() -> RateLimiterConfiguration {
fn default_login_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 60.0,
@ -133,11 +180,33 @@ fn default_login_per_account() -> RateLimiterConfiguration {
}
}
#[allow(clippy::derivable_impls)] // when we add some top-level ratelimiters this will not be derivable anymore
fn default_registration() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}
fn default_account_recovery_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}
fn default_account_recovery_per_address() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 1.0 / 3600.0,
}
}
impl Default for RateLimitingConfig {
fn default() -> Self {
RateLimitingConfig {
login: LoginRateLimitingConfig::default(),
registration: default_registration(),
account_recovery: AccountRecoveryRateLimitingConfig::default(),
}
}
}
@ -145,8 +214,17 @@ impl Default for RateLimitingConfig {
impl Default for LoginRateLimitingConfig {
fn default() -> Self {
LoginRateLimitingConfig {
per_address: default_login_per_address(),
per_ip: default_login_per_ip(),
per_account: default_login_per_account(),
}
}
}
impl Default for AccountRecoveryRateLimitingConfig {
fn default() -> Self {
AccountRecoveryRateLimitingConfig {
per_ip: default_account_recovery_per_ip(),
per_address: default_account_recovery_per_address(),
}
}
}

View File

@ -19,6 +19,15 @@ use mas_config::RateLimitingConfig;
use mas_data_model::User;
use ulid::Ulid;
#[derive(Debug, Clone, thiserror::Error)]
pub enum AccountRecoveryLimitedError {
#[error("Too many account recovery requests for requester {0}")]
Requester(RequesterFingerprint),
#[error("Too many account recovery requests for e-mail {0}")]
Email(String),
}
#[derive(Debug, Clone, Copy, thiserror::Error)]
pub enum PasswordCheckLimitedError {
#[error("Too many password checks for requester {0}")]
@ -28,6 +37,12 @@ pub enum PasswordCheckLimitedError {
User(Ulid),
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum RegistrationLimitedError {
#[error("Too many account registration requests for requester {0}")]
Requester(RequesterFingerprint),
}
/// Key used to rate limit requests per requester
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequesterFingerprint {
@ -66,15 +81,25 @@ type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
#[derive(Debug)]
struct LimiterInner {
account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
account_recovery_per_email: KeyedRateLimiter<String>,
password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
password_check_for_user: KeyedRateLimiter<Ulid>,
registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
}
impl LimiterInner {
fn new(config: &RateLimitingConfig) -> Option<Self> {
Some(Self {
password_check_for_requester: RateLimiter::keyed(config.login.per_address.to_quota()?),
account_recovery_per_requester: RateLimiter::keyed(
config.account_recovery.per_ip.to_quota()?,
),
account_recovery_per_email: RateLimiter::keyed(
config.account_recovery.per_address.to_quota()?,
),
password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
})
}
}
@ -105,14 +130,44 @@ impl Limiter {
loop {
// Call the retain_recent method on each rate limiter
this.inner.account_recovery_per_email.retain_recent();
this.inner.account_recovery_per_requester.retain_recent();
this.inner.password_check_for_requester.retain_recent();
this.inner.password_check_for_user.retain_recent();
this.inner.registration_per_requester.retain_recent();
interval.tick().await;
}
});
}
/// Check if an account recovery can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited.
pub fn check_account_recovery(
&self,
requester: RequesterFingerprint,
email_address: &str,
) -> Result<(), AccountRecoveryLimitedError> {
self.inner
.account_recovery_per_requester
.check_key(&requester)
.map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
// Convert to lowercase to prevent bypassing the limit by enumerating different
// case variations.
// A case-folding transformation may be more proper.
let canonical_email = email_address.to_lowercase();
self.inner
.account_recovery_per_email
.check_key(&canonical_email)
.map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
Ok(())
}
/// Check if a password check can be performed
///
/// # Errors
@ -135,6 +190,23 @@ impl Limiter {
Ok(())
}
/// Check if an account registration can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited.
pub fn check_registration(
&self,
requester: RequesterFingerprint,
) -> Result<(), RegistrationLimitedError> {
self.inner
.registration_per_requester
.check_key(&requester)
.map_err(|_| RegistrationLimitedError::Requester(requester))?;
Ok(())
}
}
#[cfg(test)]

View File

@ -17,6 +17,7 @@ use axum::{
response::{Html, IntoResponse, Response},
Form,
};
use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
@ -31,7 +32,7 @@ use mas_storage::{
use mas_templates::{EmptyContext, RecoveryProgressContext, TemplateContext, Templates};
use ulid::Ulid;
use crate::PreferredLanguage;
use crate::{Limiter, PreferredLanguage, RequesterFingerprint};
pub(crate) async fn get(
mut rng: BoxRng,
@ -74,7 +75,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, Html(rendered)).into_response());
}
let context = RecoveryProgressContext::new(recovery_session)
let context = RecoveryProgressContext::new(recovery_session, false)
.with_csrf(csrf_token.form_value())
.with_language(locale);
@ -92,6 +93,7 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
@ -130,6 +132,17 @@ pub(crate) async fn post(
// Verify the CSRF token
let () = cookie_jar.verify_form(&clock, form)?;
// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_account_recovery(requester, &recovery_session.email) {
tracing::warn!(error = &e as &dyn std::error::Error);
let context = RecoveryProgressContext::new(recovery_session, true)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_progress(&context)?;
return Ok((StatusCode::TOO_MANY_REQUESTS, (cookie_jar, Html(rendered))).into_response());
}
// Schedule a new batch of emails
repo.job()
.schedule_job(SendAccountRecoveryEmailsJob::new(&recovery_session))
@ -137,7 +150,7 @@ pub(crate) async fn post(
repo.save().await?;
let context = RecoveryProgressContext::new(recovery_session)
let context = RecoveryProgressContext::new(recovery_session, false)
.with_csrf(csrf_token.form_value())
.with_language(locale);

View File

@ -33,12 +33,12 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{
EmptyContext, FieldError, FormState, RecoveryStartContext, RecoveryStartFormField,
EmptyContext, FieldError, FormError, FormState, RecoveryStartContext, RecoveryStartFormField,
TemplateContext, Templates,
};
use serde::{Deserialize, Serialize};
use crate::{BoundActivityTracker, PreferredLanguage};
use crate::{BoundActivityTracker, Limiter, PreferredLanguage, RequesterFingerprint};
#[derive(Deserialize, Serialize)]
pub(crate) struct StartRecoveryForm {
@ -90,6 +90,7 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Form(form): Form<ProtectedForm<StartRecoveryForm>>,
@ -120,6 +121,14 @@ pub(crate) async fn post(
form_state.with_error_on_field(RecoveryStartFormField::Email, FieldError::Invalid);
}
if form_state.is_valid() {
// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_account_recovery(requester, &form.email) {
tracing::warn!(error = &e as &dyn std::error::Error);
form_state.add_error_on_form(FormError::RateLimitExceeded);
}
}
if !form_state.is_valid() {
repo.save().await?;
let context = RecoveryStartContext::new()

View File

@ -46,8 +46,8 @@ use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
use crate::{
captcha::Form as CaptchaForm, passwords::PasswordManager, BoundActivityTracker,
PreferredLanguage, SiteConfig,
captcha::Form as CaptchaForm, passwords::PasswordManager, BoundActivityTracker, Limiter,
PreferredLanguage, RequesterFingerprint, SiteConfig,
};
#[derive(Debug, Deserialize, Serialize)]
@ -122,12 +122,15 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(homeserver): State<BoxHomeserverConnection>,
State(http_client_factory): State<HttpClientFactory>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
mut policy: Policy,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
(user_agent, activity_tracker): (
Option<TypedHeader<headers::UserAgent>>,
BoundActivityTracker,
),
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> {
let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
@ -243,6 +246,14 @@ pub(crate) async fn post(
}
}
if state.is_valid() {
// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_registration(requester) {
tracing::warn!(error = &e as &dyn std::error::Error);
state.add_error_on_form(FormError::RateLimitExceeded);
}
}
state
};

View File

@ -1056,13 +1056,18 @@ impl TemplateContext for RecoveryStartContext {
#[derive(Serialize)]
pub struct RecoveryProgressContext {
session: UserRecoverySession,
/// Whether resending the e-mail was denied because of rate limits
resend_failed_due_to_rate_limit: bool,
}
impl RecoveryProgressContext {
/// Constructs a context for the recovery progress page
#[must_use]
pub fn new(session: UserRecoverySession) -> Self {
Self { session }
pub fn new(session: UserRecoverySession, resend_failed_due_to_rate_limit: bool) -> Self {
Self {
session,
resend_failed_due_to_rate_limit,
}
}
}
@ -1081,7 +1086,16 @@ impl TemplateContext for RecoveryProgressContext {
consumed_at: None,
};
vec![Self { session }]
vec![
Self {
session: session.clone(),
resend_failed_due_to_rate_limit: false,
},
Self {
session,
resend_failed_due_to_rate_limit: true,
},
]
}
}