1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Rate-limit password-based login attempts

This commit is contained in:
Quentin Gliech
2024-07-25 16:59:09 +02:00
parent f5b4caf520
commit e25c170403
13 changed files with 525 additions and 15 deletions

92
Cargo.lock generated
View File

@@ -1550,6 +1550,19 @@ dependencies = [
"syn 2.0.68",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
@@ -2077,6 +2090,12 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.30"
@@ -2173,6 +2192,26 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if",
"dashmap",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand",
"smallvec",
"spinning_top",
]
[[package]]
name = "graceful-shutdown"
version = "0.2.0"
@@ -3332,6 +3371,7 @@ dependencies = [
"chrono",
"cookie_store",
"futures-util",
"governor",
"headers",
"hyper",
"indexmap 2.2.6",
@@ -3354,6 +3394,7 @@ dependencies = [
"mas-templates",
"mime",
"minijinja",
"nonzero_ext",
"oauth2-types",
"opentelemetry",
"opentelemetry-semantic-conventions",
@@ -3925,6 +3966,12 @@ dependencies = [
"version_check",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]]
name = "nom"
version = "7.1.3"
@@ -3935,6 +3982,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@@ -4633,6 +4686,12 @@ dependencies = [
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]]
name = "postcard"
version = "1.0.8"
@@ -4784,6 +4843,21 @@ dependencies = [
"psl-types",
]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quinn"
version = "0.11.2"
@@ -4876,6 +4950,15 @@ dependencies = [
"getrandom",
]
[[package]]
name = "raw-cpuid"
version = "11.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d"
dependencies = [
"bitflags 2.6.0",
]
[[package]]
name = "rayon"
version = "1.10.0"
@@ -5787,6 +5870,15 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"

View File

@@ -106,6 +106,10 @@ features = ["derive"]
version = "0.10.19"
features = ["env", "yaml", "test"]
# Rate-limiting
[workspace.dependencies.governor]
version = "0.6.3"
# HTTP headers
[workspace.dependencies.headers]
version = "0.4.0"
@@ -164,6 +168,10 @@ features = [
[workspace.dependencies.minijinja]
version = "2.1.0"
# Utilities to deal with non-zero values
[workspace.dependencies.nonzero_ext]
version = "0.3.0"
# Random values
[workspace.dependencies.rand]
version = "0.8.5"

View File

@@ -22,7 +22,7 @@ use ipnetwork::IpNetwork;
use mas_data_model::SiteConfig;
use mas_handlers::{
passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper,
GraphQLSchema, HttpClientFactory, MetadataCache,
GraphQLSchema, HttpClientFactory, Limiter, MetadataCache, RequesterFingerprint,
};
use mas_i18n::Translator;
use mas_keystore::{Encrypter, Keystore};
@@ -57,6 +57,7 @@ pub struct AppState {
pub site_config: SiteConfig,
pub activity_tracker: ActivityTracker,
pub trusted_proxies: Vec<IpNetwork>,
pub limiter: Limiter,
pub conn_acquisition_histogram: Option<Histogram<u64>>,
}
@@ -210,6 +211,12 @@ impl FromRef<AppState> for SiteConfig {
}
}
impl FromRef<AppState> for Limiter {
fn from_ref(input: &AppState) -> Self {
input.limiter.clone()
}
}
impl FromRef<AppState> for BoxHomeserverConnection {
fn from_ref(input: &AppState) -> Self {
Box::new(input.homeserver_connection.clone())
@@ -326,12 +333,35 @@ impl FromRequestParts<AppState> for BoundActivityTracker {
parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
// TODO: we may infer the IP twice, for the activity tracker and the limiter
let ip = infer_client_ip(parts, &state.trusted_proxies);
tracing::debug!(ip = ?ip, "Inferred client IP address");
Ok(state.activity_tracker.clone().bind(ip))
}
}
#[async_trait]
impl FromRequestParts<AppState> for RequesterFingerprint {
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
// TODO: we may infer the IP twice, for the activity tracker and the limiter
let ip = infer_client_ip(parts, &state.trusted_proxies);
if let Some(ip) = ip {
Ok(RequesterFingerprint::new(ip))
} else {
// If we can't infer the IP address, we'll just use an empty fingerprint and
// warn about it
tracing::warn!("Could not infer client IP address for an operation which rate-limits based on IP addresses");
Ok(RequesterFingerprint::EMPTY)
}
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;

View File

@@ -19,7 +19,7 @@ use clap::Parser;
use figment::Figment;
use itertools::Itertools;
use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config};
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache};
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache};
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
@@ -200,6 +200,8 @@ impl Options {
// Listen for SIGHUP
register_sighup(&templates, &activity_tracker)?;
let limiter = Limiter::default();
let graphql_schema = mas_handlers::graphql_schema(
&pool,
&policy_factory,
@@ -213,7 +215,6 @@ impl Options {
pool,
templates,
key_store,
metadata_cache,
cookie_manager,
encrypter,
url_builder,
@@ -222,9 +223,11 @@ impl Options {
graphql_schema,
http_client_factory,
password_manager,
metadata_cache,
site_config,
activity_tracker,
trusted_proxies,
limiter,
conn_acquisition_histogram: None,
};
s.init_metrics()?;

View File

@@ -67,12 +67,14 @@ zeroize = "1.8.1"
base64ct = "1.6.0"
camino.workspace = true
chrono.workspace = true
governor.workspace = true
indexmap = "2.2.6"
psl = "2.1.55"
time = "0.3.36"
url.workspace = true
mime = "0.3.17"
minijinja.workspace = true
nonzero_ext.workspace = true
rand.workspace = true
rand_chacha = "0.3.1"
headers.workspace = true

View File

@@ -36,7 +36,10 @@ use thiserror::Error;
use zeroize::Zeroizing;
use super::MatrixError;
use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker};
use crate::{
impl_from_error_for_route, passwords::PasswordManager, rate_limit::PasswordCheckLimitedError,
BoundActivityTracker, Limiter, RequesterFingerprint,
};
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
@@ -162,6 +165,9 @@ pub enum RouteError {
#[error("password verification failed")]
PasswordVerificationFailed(#[source] anyhow::Error),
#[error("request rate limited")]
RateLimited(#[from] PasswordCheckLimitedError),
#[error("login took too long")]
LoginTookTooLong,
@@ -185,6 +191,11 @@ impl IntoResponse for RouteError {
status: StatusCode::INTERNAL_SERVER_ERROR,
}
}
Self::RateLimited(_) => MatrixError {
errcode: "M_LIMIT_EXCEEDED",
error: "Too many login attempts",
status: StatusCode::TOO_MANY_REQUESTS,
},
Self::Unsupported => MatrixError {
errcode: "M_UNRECOGNIZED",
error: "Invalid login type",
@@ -192,18 +203,18 @@ impl IntoResponse for RouteError {
},
Self::UserNotFound | Self::NoPassword | Self::PasswordVerificationFailed(_) => {
MatrixError {
errcode: "M_UNAUTHORIZED",
errcode: "M_FORBIDDEN",
error: "Invalid username/password",
status: StatusCode::FORBIDDEN,
}
}
Self::LoginTookTooLong => MatrixError {
errcode: "M_UNAUTHORIZED",
errcode: "M_FORBIDDEN",
error: "Login token expired",
status: StatusCode::FORBIDDEN,
},
Self::InvalidLoginToken => MatrixError {
errcode: "M_UNAUTHORIZED",
errcode: "M_FORBIDDEN",
error: "Invalid login token",
status: StatusCode::FORBIDDEN,
},
@@ -222,6 +233,8 @@ pub(crate) async fn post(
activity_tracker: BoundActivityTracker,
State(homeserver): State<BoxHomeserverConnection>,
State(site_config): State<SiteConfig>,
State(limiter): State<Limiter>,
requester: RequesterFingerprint,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> {
@@ -238,6 +251,8 @@ pub(crate) async fn post(
&mut rng,
&clock,
&password_manager,
&limiter,
requester,
&mut repo,
&homeserver,
user,
@@ -372,6 +387,8 @@ async fn user_password_login(
mut rng: &mut (impl RngCore + CryptoRng + Send),
clock: &impl Clock,
password_manager: &PasswordManager,
limiter: &Limiter,
requester: RequesterFingerprint,
repo: &mut BoxRepository,
homeserver: &BoxHomeserverConnection,
username: String,
@@ -385,6 +402,9 @@ async fn user_password_login(
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
// Check the rate limit
limiter.check_password(requester, &user)?;
// Lookup its password
let user_password = repo
.user_password()
@@ -628,7 +648,7 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
assert_eq!(body["errcode"], "M_FORBIDDEN");
// Try to login with a wrong username.
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
@@ -650,6 +670,57 @@ mod tests {
assert_eq!(body, old_body);
}
/// Test that password logins are rate limited.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
// Let's provision a user without a password. This should be enough to trigger
// the rate limit.
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.await
.unwrap();
let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();
repo.save().await.unwrap();
// Now let's try to login with the password, without asking for a refresh token.
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "alice",
},
"password": "password",
}));
// First three attempts should just tell about the invalid credentials
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::FORBIDDEN);
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::FORBIDDEN);
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::FORBIDDEN);
// The fourth attempt should be rate limited
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::TOO_MANY_REQUESTS);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_LIMIT_EXCEEDED");
assert_eq!(body["error"], "Too many login attempts");
}
/// Test the response of an unsupported login flow.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_unsupported_login(pool: PgPool) {
@@ -699,7 +770,7 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
assert_eq!(body["errcode"], "M_FORBIDDEN");
let (device, token) = get_login_token(&state, &user).await;
@@ -726,7 +797,7 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
assert_eq!(body["errcode"], "M_FORBIDDEN");
// Try to login, but wait too long before sending the request.
let (_device, token) = get_login_token(&state, &user).await;
@@ -743,7 +814,7 @@ mod tests {
let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
assert_eq!(body["errcode"], "M_FORBIDDEN");
}
/// Get a login token for a user.

View File

@@ -65,6 +65,7 @@ mod views;
mod activity_tracker;
mod captcha;
mod preferred_language;
mod rate_limit;
#[cfg(test)]
mod test_utils;
@@ -95,6 +96,7 @@ pub use self::{
schema as graphql_schema, schema_builder as graphql_schema_builder, Schema as GraphQLSchema,
},
preferred_language::PreferredLanguage,
rate_limit::{Limiter, RequesterFingerprint},
upstream_oauth2::cache::MetadataCache,
};
@@ -246,7 +248,9 @@ where
SiteConfig: FromRef<S>,
BoxHomeserverConnection: FromRef<S>,
PasswordManager: FromRef<S>,
Limiter: FromRef<S>,
BoundActivityTracker: FromRequestParts<S>,
RequesterFingerprint: FromRequestParts<S>,
BoxRepository: FromRequestParts<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
@@ -301,6 +305,7 @@ where
BoxRepository: FromRequestParts<S>,
CookieJar: FromRequestParts<S>,
BoundActivityTracker: FromRequestParts<S>,
RequesterFingerprint: FromRequestParts<S>,
Encrypter: FromRef<S>,
Templates: FromRef<S>,
Keystore: FromRef<S>,
@@ -308,6 +313,7 @@ where
PasswordManager: FromRef<S>,
MetadataCache: FromRef<S>,
SiteConfig: FromRef<S>,
Limiter: FromRef<S>,
BoxHomeserverConnection: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,

View File

@@ -0,0 +1,183 @@
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{net::IpAddr, sync::Arc};
use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
use mas_data_model::User;
use nonzero_ext::nonzero;
use ulid::Ulid;
const PASSWORD_CHECK_FOR_REQUESTER_QUOTA: Quota = Quota::per_minute(nonzero!(3u32));
const PASSWORD_CHECK_FOR_USER_QUOTA: Quota = Quota::per_hour(nonzero!(1800u32));
#[derive(Debug, Clone, Copy, thiserror::Error)]
pub enum PasswordCheckLimitedError {
#[error("Too many password checks for requester {0}")]
Requester(RequesterFingerprint),
#[error("Too many password checks for user {0}")]
User(Ulid),
}
/// Key used to rate limit requests per requester
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequesterFingerprint {
ip: Option<IpAddr>,
}
impl std::fmt::Display for RequesterFingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(ip) = self.ip {
write!(f, "{ip}")
} else {
write!(f, "(NO CLIENT IP)")
}
}
}
impl RequesterFingerprint {
/// An anonymous key with no IP address set. This should not be used in
/// production, and we should warn users if we can't find their client IPs.
pub const EMPTY: Self = Self { ip: None };
/// Create a new anonymous key with the given IP address
#[must_use]
pub const fn new(ip: IpAddr) -> Self {
Self { ip: Some(ip) }
}
}
/// Rate limiters for the different operations
#[derive(Debug, Clone, Default)]
pub struct Limiter {
inner: Arc<LimiterInner>,
}
type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
#[derive(Debug)]
struct LimiterInner {
password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
password_check_for_user: KeyedRateLimiter<Ulid>,
}
impl Default for LimiterInner {
fn default() -> Self {
Self {
password_check_for_requester: RateLimiter::keyed(PASSWORD_CHECK_FOR_REQUESTER_QUOTA),
password_check_for_user: RateLimiter::keyed(PASSWORD_CHECK_FOR_USER_QUOTA),
}
}
}
impl Limiter {
/// Check if a password check can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited
pub fn check_password(
&self,
key: RequesterFingerprint,
user: &User,
) -> Result<(), PasswordCheckLimitedError> {
self.inner
.password_check_for_requester
.check_key(&key)
.map_err(|_| PasswordCheckLimitedError::Requester(key))?;
self.inner
.password_check_for_user
.check_key(&user.id)
.map_err(|_| PasswordCheckLimitedError::User(user.id))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use mas_data_model::User;
use mas_storage::{clock::MockClock, Clock};
use rand::SeedableRng;
use super::*;
#[test]
fn test_password_check_limiter() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let limiter = Limiter::default();
// Let's create a lot of requesters to test account-level rate limiting
let requesters: [_; 768] = (0..=255)
.flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
.collect::<Vec<_>>()
.try_into()
.unwrap();
let alice = User {
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
username: "alice".to_owned(),
sub: "123-456".to_owned(),
primary_user_email_id: None,
created_at: now,
locked_at: None,
can_request_admin: false,
};
let bob = User {
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
username: "bob".to_owned(),
sub: "123-456".to_owned(),
primary_user_email_id: None,
created_at: now,
locked_at: None,
can_request_admin: false,
};
// Three times the same IP address should be allowed
assert!(limiter.check_password(requesters[0], &alice).is_ok());
assert!(limiter.check_password(requesters[0], &alice).is_ok());
assert!(limiter.check_password(requesters[0], &alice).is_ok());
// But the fourth time should be rejected
assert!(limiter.check_password(requesters[0], &alice).is_err());
// Using another user should also be rejected
assert!(limiter.check_password(requesters[0], &bob).is_err());
// Using a different IP address should be allowed, the account isn't locked yet
assert!(limiter.check_password(requesters[1], &alice).is_ok());
// At this point, we consumed 4 cells out of 1800 on alice, let's distribute the
// requests with other IPs so that we get rate-limited on the account-level
for requester in requesters.iter().skip(2).take(598) {
assert!(limiter.check_password(*requester, &alice).is_ok());
assert!(limiter.check_password(*requester, &alice).is_ok());
assert!(limiter.check_password(*requester, &alice).is_ok());
assert!(limiter.check_password(*requester, &alice).is_err());
}
// We now have consumed 4+598*3 = 1798 cells on the account, so we should be
// rejected soon
assert!(limiter.check_password(requesters[600], &alice).is_ok());
assert!(limiter.check_password(requesters[601], &alice).is_ok());
assert!(limiter.check_password(requesters[602], &alice).is_err());
// The other account isn't rate-limited
assert!(limiter.check_password(requesters[603], &bob).is_ok());
}
}

View File

@@ -57,7 +57,7 @@ use crate::{
graphql,
passwords::{Hasher, PasswordManager},
upstream_oauth2::cache::MetadataCache,
ActivityTracker, BoundActivityTracker,
ActivityTracker, BoundActivityTracker, Limiter, RequesterFingerprint,
};
/// Setup rustcrypto and tracing for tests.
@@ -108,6 +108,7 @@ pub(crate) struct TestState {
pub password_manager: PasswordManager,
pub site_config: SiteConfig,
pub activity_tracker: ActivityTracker,
pub limiter: Limiter,
pub clock: Arc<MockClock>,
pub rng: Arc<Mutex<ChaChaRng>>,
}
@@ -212,6 +213,8 @@ impl TestState {
let activity_tracker =
ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1));
let limiter = Limiter::default();
Ok(Self {
pool,
templates,
@@ -227,6 +230,7 @@ impl TestState {
password_manager,
site_config,
activity_tracker,
limiter,
clock,
rng,
})
@@ -436,6 +440,12 @@ impl FromRef<TestState> for BoxHomeserverConnection {
}
}
impl FromRef<TestState> for Limiter {
fn from_ref(input: &TestState) -> Self {
input.limiter.clone()
}
}
#[async_trait]
impl FromRequestParts<TestState> for ActivityTracker {
type Rejection = Infallible;
@@ -461,6 +471,18 @@ impl FromRequestParts<TestState> for BoundActivityTracker {
}
}
#[async_trait]
impl FromRequestParts<TestState> for RequesterFingerprint {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &TestState,
) -> Result<Self, Self::Rejection> {
Ok(RequesterFingerprint::EMPTY)
}
}
#[async_trait]
impl FromRequestParts<TestState> for BoxClock {
type Rejection = Infallible;

View File

@@ -39,7 +39,10 @@ use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig};
use crate::{
passwords::PasswordManager, BoundActivityTracker, Limiter, PreferredLanguage,
RequesterFingerprint, SiteConfig,
};
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct LoginForm {
@@ -116,8 +119,10 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(limiter): State<Limiter>,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
requester: RequesterFingerprint,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
@@ -170,6 +175,8 @@ pub(crate) async fn post(
&mut repo,
rng,
&clock,
limiter,
requester,
&form.username,
&form.password,
user_agent,
@@ -211,6 +218,8 @@ async fn login(
repo: &mut impl RepositoryAccess,
mut rng: impl Rng + CryptoRng + Send,
clock: &impl Clock,
limiter: Limiter,
requester: RequesterFingerprint,
username: &str,
password: &str,
user_agent: Option<UserAgent>,
@@ -225,6 +234,12 @@ async fn login(
.filter(mas_data_model::User::is_valid)
.ok_or(FormError::InvalidCredentials)?;
// Check the rate limit
limiter.check_password(requester, &user).map_err(|e| {
tracing::warn!(error = &e as &dyn std::error::Error);
FormError::RateLimitExceeded
})?;
// And its password
let user_password = repo
.user_password()
@@ -491,4 +506,73 @@ mod test {
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(response.body().contains("john"));
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let cookies = CookieHelper::new();
// Provision a user without a password.
// We don't give that user a password, so that we skip hashing it in this test.
// It will still be rate-limited
let mut repo = state.repository().await.unwrap();
repo.user()
.add(&mut rng, &state.clock, "john".to_owned())
.await
.unwrap();
repo.save().await.unwrap();
// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();
// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "john",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
// First three attempts should just tell about the invalid credentials
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::OK);
let body = response.body();
assert!(body.contains("Invalid credentials"));
assert!(!body.contains("too many requests"));
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::OK);
let body = response.body();
assert!(body.contains("Invalid credentials"));
assert!(!body.contains("too many requests"));
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::OK);
let body = response.body();
assert!(body.contains("Invalid credentials"));
assert!(!body.contains("too many requests"));
// The fourth attempt should be rate-limited
let response = state.request(request.clone()).await;
response.assert_status(StatusCode::OK);
let body = response.body();
assert!(!body.contains("Invalid credentials"));
assert!(body.contains("too many requests"));
}
}

View File

@@ -62,6 +62,9 @@ pub enum FormError {
/// There was an internal error
Internal,
/// Rate limit exceeded
RateLimitExceeded,
/// Denied by the policy
Policy {
/// Message for this policy violation

View File

@@ -19,6 +19,8 @@ limitations under the License.
{{ _("mas.errors.invalid_credentials") }}
{% elif error.kind == "password_mismatch" %}
{{ _("mas.errors.password_mismatch") }}
{% elif error.kind == "rate_limit_exceeded" %}
{{ _("mas.errors.rate_limit_exceeded") }}
{% elif error.kind == "policy" %}
{{ _("mas.errors.denied_policy", policy=error.message) }}
{% elif error.kind == "captcha" %}

View File

@@ -276,11 +276,11 @@
"errors": {
"captcha": "CAPTCHA verification failed, please try again",
"@captcha": {
"context": "components/errors.html:25:7-30"
"context": "components/errors.html:27:7-30"
},
"denied_policy": "Denied by policy: %(policy)s",
"@denied_policy": {
"context": "components/errors.html:23:7-58, components/field.html:72:17-68"
"context": "components/errors.html:25:7-58, components/field.html:72:17-68"
},
"field_required": "This field is required",
"@field_required": {
@@ -294,6 +294,10 @@
"@password_mismatch": {
"context": "components/errors.html:21:7-40, components/field.html:74:17-50"
},
"rate_limit_exceeded": "You've made too many requests in a short period. Please wait a few minutes and try again.",
"@rate_limit_exceeded": {
"context": "components/errors.html:23:7-42"
},
"username_taken": "This username is already taken",
"@username_taken": {
"context": "components/field.html:70:17-47"