diff --git a/Cargo.lock b/Cargo.lock index ab05b293..883d1104 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1550,6 +1550,19 @@ dependencies = [ "syn 2.0.68", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -2077,6 +2090,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -2173,6 +2192,26 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand", + "smallvec", + "spinning_top", +] + [[package]] name = "graceful-shutdown" version = "0.2.0" @@ -3332,6 +3371,7 @@ dependencies = [ "chrono", "cookie_store", "futures-util", + "governor", "headers", "hyper", "indexmap 2.2.6", @@ -3354,6 +3394,7 @@ dependencies = [ "mas-templates", "mime", "minijinja", + "nonzero_ext", "oauth2-types", "opentelemetry", "opentelemetry-semantic-conventions", @@ -3925,6 +3966,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -3935,6 +3982,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -4633,6 +4686,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + [[package]] name = "postcard" version = "1.0.8" @@ -4784,6 +4843,21 @@ dependencies = [ "psl-types", ] +[[package]] +name = "quanta" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.2" @@ -4876,6 +4950,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "11.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "rayon" version = "1.10.0" @@ -5787,6 +5870,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" diff --git a/Cargo.toml b/Cargo.toml index f7227724..e0ecdebf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,10 @@ features = ["derive"] version = "0.10.19" features = ["env", "yaml", "test"] +# Rate-limiting +[workspace.dependencies.governor] +version = "0.6.3" + # HTTP headers [workspace.dependencies.headers] version = "0.4.0" @@ -164,6 +168,10 @@ features = [ [workspace.dependencies.minijinja] version = "2.1.0" +# Utilities to deal with non-zero values +[workspace.dependencies.nonzero_ext] +version = "0.3.0" + # Random values [workspace.dependencies.rand] version = "0.8.5" diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index 59aeb0a9..edc5a459 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -22,7 +22,7 @@ use ipnetwork::IpNetwork; use mas_data_model::SiteConfig; use mas_handlers::{ passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, - GraphQLSchema, HttpClientFactory, MetadataCache, + GraphQLSchema, HttpClientFactory, Limiter, MetadataCache, RequesterFingerprint, }; use mas_i18n::Translator; use mas_keystore::{Encrypter, Keystore}; @@ -57,6 +57,7 @@ pub struct AppState { pub site_config: SiteConfig, pub activity_tracker: ActivityTracker, pub trusted_proxies: Vec, + pub limiter: Limiter, pub conn_acquisition_histogram: Option>, } @@ -210,6 +211,12 @@ impl FromRef for SiteConfig { } } +impl FromRef for Limiter { + fn from_ref(input: &AppState) -> Self { + input.limiter.clone() + } +} + impl FromRef for BoxHomeserverConnection { fn from_ref(input: &AppState) -> Self { Box::new(input.homeserver_connection.clone()) @@ -326,12 +333,35 @@ impl FromRequestParts for BoundActivityTracker { parts: &mut axum::http::request::Parts, state: &AppState, ) -> Result { + // TODO: we may infer the IP twice, for the activity tracker and the limiter let ip = infer_client_ip(parts, &state.trusted_proxies); tracing::debug!(ip = ?ip, "Inferred client IP address"); Ok(state.activity_tracker.clone().bind(ip)) } } +#[async_trait] +impl FromRequestParts for RequesterFingerprint { + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + // TODO: we may infer the IP twice, for the activity tracker and the limiter + let ip = infer_client_ip(parts, &state.trusted_proxies); + + if let Some(ip) = ip { + Ok(RequesterFingerprint::new(ip)) + } else { + // If we can't infer the IP address, we'll just use an empty fingerprint and + // warn about it + tracing::warn!("Could not infer client IP address for an operation which rate-limits based on IP addresses"); + Ok(RequesterFingerprint::EMPTY) + } + } +} + #[async_trait] impl FromRequestParts for BoxRepository { type Rejection = ErrorWrapper; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index f3f9f409..c97c55a7 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -19,7 +19,7 @@ use clap::Parser; use figment::Figment; use itertools::Itertools; use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config}; -use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache}; +use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache}; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; @@ -200,6 +200,8 @@ impl Options { // Listen for SIGHUP register_sighup(&templates, &activity_tracker)?; + let limiter = Limiter::default(); + let graphql_schema = mas_handlers::graphql_schema( &pool, &policy_factory, @@ -213,7 +215,6 @@ impl Options { pool, templates, key_store, - metadata_cache, cookie_manager, encrypter, url_builder, @@ -222,9 +223,11 @@ impl Options { graphql_schema, http_client_factory, password_manager, + metadata_cache, site_config, activity_tracker, trusted_proxies, + limiter, conn_acquisition_histogram: None, }; s.init_metrics()?; diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index a0623672..0780047f 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -67,12 +67,14 @@ zeroize = "1.8.1" base64ct = "1.6.0" camino.workspace = true chrono.workspace = true +governor.workspace = true indexmap = "2.2.6" psl = "2.1.55" time = "0.3.36" url.workspace = true mime = "0.3.17" minijinja.workspace = true +nonzero_ext.workspace = true rand.workspace = true rand_chacha = "0.3.1" headers.workspace = true diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 3e83a260..d1e7bfce 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -36,7 +36,10 @@ use thiserror::Error; use zeroize::Zeroizing; use super::MatrixError; -use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker}; +use crate::{ + impl_from_error_for_route, passwords::PasswordManager, rate_limit::PasswordCheckLimitedError, + BoundActivityTracker, Limiter, RequesterFingerprint, +}; #[derive(Debug, Serialize)] #[serde(tag = "type")] @@ -162,6 +165,9 @@ pub enum RouteError { #[error("password verification failed")] PasswordVerificationFailed(#[source] anyhow::Error), + #[error("request rate limited")] + RateLimited(#[from] PasswordCheckLimitedError), + #[error("login took too long")] LoginTookTooLong, @@ -185,6 +191,11 @@ impl IntoResponse for RouteError { status: StatusCode::INTERNAL_SERVER_ERROR, } } + Self::RateLimited(_) => MatrixError { + errcode: "M_LIMIT_EXCEEDED", + error: "Too many login attempts", + status: StatusCode::TOO_MANY_REQUESTS, + }, Self::Unsupported => MatrixError { errcode: "M_UNRECOGNIZED", error: "Invalid login type", @@ -192,18 +203,18 @@ impl IntoResponse for RouteError { }, Self::UserNotFound | Self::NoPassword | Self::PasswordVerificationFailed(_) => { MatrixError { - errcode: "M_UNAUTHORIZED", + errcode: "M_FORBIDDEN", error: "Invalid username/password", status: StatusCode::FORBIDDEN, } } Self::LoginTookTooLong => MatrixError { - errcode: "M_UNAUTHORIZED", + errcode: "M_FORBIDDEN", error: "Login token expired", status: StatusCode::FORBIDDEN, }, Self::InvalidLoginToken => MatrixError { - errcode: "M_UNAUTHORIZED", + errcode: "M_FORBIDDEN", error: "Invalid login token", status: StatusCode::FORBIDDEN, }, @@ -222,6 +233,8 @@ pub(crate) async fn post( activity_tracker: BoundActivityTracker, State(homeserver): State, State(site_config): State, + State(limiter): State, + requester: RequesterFingerprint, user_agent: Option>, Json(input): Json, ) -> Result { @@ -238,6 +251,8 @@ pub(crate) async fn post( &mut rng, &clock, &password_manager, + &limiter, + requester, &mut repo, &homeserver, user, @@ -372,6 +387,8 @@ async fn user_password_login( mut rng: &mut (impl RngCore + CryptoRng + Send), clock: &impl Clock, password_manager: &PasswordManager, + limiter: &Limiter, + requester: RequesterFingerprint, repo: &mut BoxRepository, homeserver: &BoxHomeserverConnection, username: String, @@ -385,6 +402,9 @@ async fn user_password_login( .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; + // Check the rate limit + limiter.check_password(requester, &user)?; + // Lookup its password let user_password = repo .user_password() @@ -628,7 +648,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::FORBIDDEN); let body: serde_json::Value = response.json(); - assert_eq!(body["errcode"], "M_UNAUTHORIZED"); + assert_eq!(body["errcode"], "M_FORBIDDEN"); // Try to login with a wrong username. let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ @@ -650,6 +670,57 @@ mod tests { assert_eq!(body, old_body); } + /// Test that password logins are rate limited. + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login_rate_limit(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Let's provision a user without a password. This should be enough to trigger + // the rate limit. + let mut repo = state.repository().await.unwrap(); + + let user = repo + .user() + .add(&mut state.rng(), &state.clock, "alice".to_owned()) + .await + .unwrap(); + + let mxid = state.homeserver_connection.mxid(&user.username); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(mxid, &user.sub)) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // Now let's try to login with the password, without asking for a refresh token. + let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "alice", + }, + "password": "password", + })); + + // First three attempts should just tell about the invalid credentials + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::FORBIDDEN); + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::FORBIDDEN); + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::FORBIDDEN); + + // The fourth attempt should be rate limited + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::TOO_MANY_REQUESTS); + let body: serde_json::Value = response.json(); + assert_eq!(body["errcode"], "M_LIMIT_EXCEEDED"); + assert_eq!(body["error"], "Too many login attempts"); + } + /// Test the response of an unsupported login flow. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_unsupported_login(pool: PgPool) { @@ -699,7 +770,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::FORBIDDEN); let body: serde_json::Value = response.json(); - assert_eq!(body["errcode"], "M_UNAUTHORIZED"); + assert_eq!(body["errcode"], "M_FORBIDDEN"); let (device, token) = get_login_token(&state, &user).await; @@ -726,7 +797,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::FORBIDDEN); let body: serde_json::Value = response.json(); - assert_eq!(body["errcode"], "M_UNAUTHORIZED"); + assert_eq!(body["errcode"], "M_FORBIDDEN"); // Try to login, but wait too long before sending the request. let (_device, token) = get_login_token(&state, &user).await; @@ -743,7 +814,7 @@ mod tests { let response = state.request(request).await; response.assert_status(StatusCode::FORBIDDEN); let body: serde_json::Value = response.json(); - assert_eq!(body["errcode"], "M_UNAUTHORIZED"); + assert_eq!(body["errcode"], "M_FORBIDDEN"); } /// Get a login token for a user. diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index d87e4f61..28375cbe 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -65,6 +65,7 @@ mod views; mod activity_tracker; mod captcha; mod preferred_language; +mod rate_limit; #[cfg(test)] mod test_utils; @@ -95,6 +96,7 @@ pub use self::{ schema as graphql_schema, schema_builder as graphql_schema_builder, Schema as GraphQLSchema, }, preferred_language::PreferredLanguage, + rate_limit::{Limiter, RequesterFingerprint}, upstream_oauth2::cache::MetadataCache, }; @@ -246,7 +248,9 @@ where SiteConfig: FromRef, BoxHomeserverConnection: FromRef, PasswordManager: FromRef, + Limiter: FromRef, BoundActivityTracker: FromRequestParts, + RequesterFingerprint: FromRequestParts, BoxRepository: FromRequestParts, BoxClock: FromRequestParts, BoxRng: FromRequestParts, @@ -301,6 +305,7 @@ where BoxRepository: FromRequestParts, CookieJar: FromRequestParts, BoundActivityTracker: FromRequestParts, + RequesterFingerprint: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Keystore: FromRef, @@ -308,6 +313,7 @@ where PasswordManager: FromRef, MetadataCache: FromRef, SiteConfig: FromRef, + Limiter: FromRef, BoxHomeserverConnection: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, diff --git a/crates/handlers/src/rate_limit.rs b/crates/handlers/src/rate_limit.rs new file mode 100644 index 00000000..7d33b857 --- /dev/null +++ b/crates/handlers/src/rate_limit.rs @@ -0,0 +1,183 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{net::IpAddr, sync::Arc}; + +use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, Quota, RateLimiter}; +use mas_data_model::User; +use nonzero_ext::nonzero; +use ulid::Ulid; + +const PASSWORD_CHECK_FOR_REQUESTER_QUOTA: Quota = Quota::per_minute(nonzero!(3u32)); +const PASSWORD_CHECK_FOR_USER_QUOTA: Quota = Quota::per_hour(nonzero!(1800u32)); + +#[derive(Debug, Clone, Copy, thiserror::Error)] +pub enum PasswordCheckLimitedError { + #[error("Too many password checks for requester {0}")] + Requester(RequesterFingerprint), + + #[error("Too many password checks for user {0}")] + User(Ulid), +} + +/// Key used to rate limit requests per requester +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RequesterFingerprint { + ip: Option, +} + +impl std::fmt::Display for RequesterFingerprint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(ip) = self.ip { + write!(f, "{ip}") + } else { + write!(f, "(NO CLIENT IP)") + } + } +} + +impl RequesterFingerprint { + /// An anonymous key with no IP address set. This should not be used in + /// production, and we should warn users if we can't find their client IPs. + pub const EMPTY: Self = Self { ip: None }; + + /// Create a new anonymous key with the given IP address + #[must_use] + pub const fn new(ip: IpAddr) -> Self { + Self { ip: Some(ip) } + } +} + +/// Rate limiters for the different operations +#[derive(Debug, Clone, Default)] +pub struct Limiter { + inner: Arc, +} + +type KeyedRateLimiter = RateLimiter, QuantaClock>; + +#[derive(Debug)] +struct LimiterInner { + password_check_for_requester: KeyedRateLimiter, + password_check_for_user: KeyedRateLimiter, +} + +impl Default for LimiterInner { + fn default() -> Self { + Self { + password_check_for_requester: RateLimiter::keyed(PASSWORD_CHECK_FOR_REQUESTER_QUOTA), + password_check_for_user: RateLimiter::keyed(PASSWORD_CHECK_FOR_USER_QUOTA), + } + } +} + +impl Limiter { + /// Check if a password check can be performed + /// + /// # Errors + /// + /// Returns an error if the operation is rate limited + pub fn check_password( + &self, + key: RequesterFingerprint, + user: &User, + ) -> Result<(), PasswordCheckLimitedError> { + self.inner + .password_check_for_requester + .check_key(&key) + .map_err(|_| PasswordCheckLimitedError::Requester(key))?; + + self.inner + .password_check_for_user + .check_key(&user.id) + .map_err(|_| PasswordCheckLimitedError::User(user.id))?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use mas_data_model::User; + use mas_storage::{clock::MockClock, Clock}; + use rand::SeedableRng; + + use super::*; + + #[test] + fn test_password_check_limiter() { + let now = MockClock::default().now(); + let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42); + + let limiter = Limiter::default(); + + // Let's create a lot of requesters to test account-level rate limiting + let requesters: [_; 768] = (0..=255) + .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into()))) + .collect::>() + .try_into() + .unwrap(); + + let alice = User { + id: Ulid::from_datetime_with_source(now.into(), &mut rng), + username: "alice".to_owned(), + sub: "123-456".to_owned(), + primary_user_email_id: None, + created_at: now, + locked_at: None, + can_request_admin: false, + }; + + let bob = User { + id: Ulid::from_datetime_with_source(now.into(), &mut rng), + username: "bob".to_owned(), + sub: "123-456".to_owned(), + primary_user_email_id: None, + created_at: now, + locked_at: None, + can_request_admin: false, + }; + + // Three times the same IP address should be allowed + assert!(limiter.check_password(requesters[0], &alice).is_ok()); + assert!(limiter.check_password(requesters[0], &alice).is_ok()); + assert!(limiter.check_password(requesters[0], &alice).is_ok()); + + // But the fourth time should be rejected + assert!(limiter.check_password(requesters[0], &alice).is_err()); + // Using another user should also be rejected + assert!(limiter.check_password(requesters[0], &bob).is_err()); + + // Using a different IP address should be allowed, the account isn't locked yet + assert!(limiter.check_password(requesters[1], &alice).is_ok()); + + // At this point, we consumed 4 cells out of 1800 on alice, let's distribute the + // requests with other IPs so that we get rate-limited on the account-level + for requester in requesters.iter().skip(2).take(598) { + assert!(limiter.check_password(*requester, &alice).is_ok()); + assert!(limiter.check_password(*requester, &alice).is_ok()); + assert!(limiter.check_password(*requester, &alice).is_ok()); + assert!(limiter.check_password(*requester, &alice).is_err()); + } + + // We now have consumed 4+598*3 = 1798 cells on the account, so we should be + // rejected soon + assert!(limiter.check_password(requesters[600], &alice).is_ok()); + assert!(limiter.check_password(requesters[601], &alice).is_ok()); + assert!(limiter.check_password(requesters[602], &alice).is_err()); + + // The other account isn't rate-limited + assert!(limiter.check_password(requesters[603], &bob).is_ok()); + } +} diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index dd8eaa1a..49e58e29 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -57,7 +57,7 @@ use crate::{ graphql, passwords::{Hasher, PasswordManager}, upstream_oauth2::cache::MetadataCache, - ActivityTracker, BoundActivityTracker, + ActivityTracker, BoundActivityTracker, Limiter, RequesterFingerprint, }; /// Setup rustcrypto and tracing for tests. @@ -108,6 +108,7 @@ pub(crate) struct TestState { pub password_manager: PasswordManager, pub site_config: SiteConfig, pub activity_tracker: ActivityTracker, + pub limiter: Limiter, pub clock: Arc, pub rng: Arc>, } @@ -212,6 +213,8 @@ impl TestState { let activity_tracker = ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1)); + let limiter = Limiter::default(); + Ok(Self { pool, templates, @@ -227,6 +230,7 @@ impl TestState { password_manager, site_config, activity_tracker, + limiter, clock, rng, }) @@ -436,6 +440,12 @@ impl FromRef for BoxHomeserverConnection { } } +impl FromRef for Limiter { + fn from_ref(input: &TestState) -> Self { + input.limiter.clone() + } +} + #[async_trait] impl FromRequestParts for ActivityTracker { type Rejection = Infallible; @@ -461,6 +471,18 @@ impl FromRequestParts for BoundActivityTracker { } } +#[async_trait] +impl FromRequestParts for RequesterFingerprint { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + _state: &TestState, + ) -> Result { + Ok(RequesterFingerprint::EMPTY) + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 177253d7..99203a7f 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -39,7 +39,10 @@ use serde::{Deserialize, Serialize}; use zeroize::Zeroizing; use super::shared::OptionalPostAuthAction; -use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig}; +use crate::{ + passwords::PasswordManager, BoundActivityTracker, Limiter, PreferredLanguage, + RequesterFingerprint, SiteConfig, +}; #[derive(Debug, Deserialize, Serialize)] pub(crate) struct LoginForm { @@ -116,8 +119,10 @@ pub(crate) async fn post( State(site_config): State, State(templates): State, State(url_builder): State, + State(limiter): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, + requester: RequesterFingerprint, Query(query): Query, cookie_jar: CookieJar, user_agent: Option>, @@ -170,6 +175,8 @@ pub(crate) async fn post( &mut repo, rng, &clock, + limiter, + requester, &form.username, &form.password, user_agent, @@ -211,6 +218,8 @@ async fn login( repo: &mut impl RepositoryAccess, mut rng: impl Rng + CryptoRng + Send, clock: &impl Clock, + limiter: Limiter, + requester: RequesterFingerprint, username: &str, password: &str, user_agent: Option, @@ -225,6 +234,12 @@ async fn login( .filter(mas_data_model::User::is_valid) .ok_or(FormError::InvalidCredentials)?; + // Check the rate limit + limiter.check_password(requester, &user).map_err(|e| { + tracing::warn!(error = &e as &dyn std::error::Error); + FormError::RateLimitExceeded + })?; + // And its password let user_password = repo .user_password() @@ -491,4 +506,73 @@ mod test { response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); assert!(response.body().contains("john")); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login_rate_limit(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let mut rng = state.rng(); + let cookies = CookieHelper::new(); + + // Provision a user without a password. + // We don't give that user a password, so that we skip hashing it in this test. + // It will still be rate-limited + let mut repo = state.repository().await.unwrap(); + repo.user() + .add(&mut rng, &state.clock, "john".to_owned()) + .await + .unwrap(); + repo.save().await.unwrap(); + + // Render the login page to get a CSRF token + let request = Request::get("/login").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + // Submit the login form + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "john", + "password": "hunter2", + })); + let request = cookies.with_cookies(request); + + // First three attempts should just tell about the invalid credentials + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::OK); + let body = response.body(); + assert!(body.contains("Invalid credentials")); + assert!(!body.contains("too many requests")); + + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::OK); + let body = response.body(); + assert!(body.contains("Invalid credentials")); + assert!(!body.contains("too many requests")); + + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::OK); + let body = response.body(); + assert!(body.contains("Invalid credentials")); + assert!(!body.contains("too many requests")); + + // The fourth attempt should be rate-limited + let response = state.request(request.clone()).await; + response.assert_status(StatusCode::OK); + let body = response.body(); + assert!(!body.contains("Invalid credentials")); + assert!(body.contains("too many requests")); + } } diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs index cf236dda..66333251 100644 --- a/crates/templates/src/forms.rs +++ b/crates/templates/src/forms.rs @@ -62,6 +62,9 @@ pub enum FormError { /// There was an internal error Internal, + /// Rate limit exceeded + RateLimitExceeded, + /// Denied by the policy Policy { /// Message for this policy violation diff --git a/templates/components/errors.html b/templates/components/errors.html index 9c9619aa..aa23ef5c 100644 --- a/templates/components/errors.html +++ b/templates/components/errors.html @@ -19,6 +19,8 @@ limitations under the License. {{ _("mas.errors.invalid_credentials") }} {% elif error.kind == "password_mismatch" %} {{ _("mas.errors.password_mismatch") }} + {% elif error.kind == "rate_limit_exceeded" %} + {{ _("mas.errors.rate_limit_exceeded") }} {% elif error.kind == "policy" %} {{ _("mas.errors.denied_policy", policy=error.message) }} {% elif error.kind == "captcha" %} diff --git a/translations/en.json b/translations/en.json index 2f57c73e..7385cd4a 100644 --- a/translations/en.json +++ b/translations/en.json @@ -276,11 +276,11 @@ "errors": { "captcha": "CAPTCHA verification failed, please try again", "@captcha": { - "context": "components/errors.html:25:7-30" + "context": "components/errors.html:27:7-30" }, "denied_policy": "Denied by policy: %(policy)s", "@denied_policy": { - "context": "components/errors.html:23:7-58, components/field.html:72:17-68" + "context": "components/errors.html:25:7-58, components/field.html:72:17-68" }, "field_required": "This field is required", "@field_required": { @@ -294,6 +294,10 @@ "@password_mismatch": { "context": "components/errors.html:21:7-40, components/field.html:74:17-50" }, + "rate_limit_exceeded": "You've made too many requests in a short period. Please wait a few minutes and try again.", + "@rate_limit_exceeded": { + "context": "components/errors.html:23:7-42" + }, "username_taken": "This username is already taken", "@username_taken": { "context": "components/field.html:70:17-47"