You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Rate-limit password-based login attempts
This commit is contained in:
@ -22,7 +22,7 @@ use ipnetwork::IpNetwork;
|
||||
use mas_data_model::SiteConfig;
|
||||
use mas_handlers::{
|
||||
passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper,
|
||||
GraphQLSchema, HttpClientFactory, MetadataCache,
|
||||
GraphQLSchema, HttpClientFactory, Limiter, MetadataCache, RequesterFingerprint,
|
||||
};
|
||||
use mas_i18n::Translator;
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
@ -57,6 +57,7 @@ pub struct AppState {
|
||||
pub site_config: SiteConfig,
|
||||
pub activity_tracker: ActivityTracker,
|
||||
pub trusted_proxies: Vec<IpNetwork>,
|
||||
pub limiter: Limiter,
|
||||
pub conn_acquisition_histogram: Option<Histogram<u64>>,
|
||||
}
|
||||
|
||||
@ -210,6 +211,12 @@ impl FromRef<AppState> for SiteConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for Limiter {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.limiter.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for BoxHomeserverConnection {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
Box::new(input.homeserver_connection.clone())
|
||||
@ -326,12 +333,35 @@ impl FromRequestParts<AppState> for BoundActivityTracker {
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
// TODO: we may infer the IP twice, for the activity tracker and the limiter
|
||||
let ip = infer_client_ip(parts, &state.trusted_proxies);
|
||||
tracing::debug!(ip = ?ip, "Inferred client IP address");
|
||||
Ok(state.activity_tracker.clone().bind(ip))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for RequesterFingerprint {
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
// TODO: we may infer the IP twice, for the activity tracker and the limiter
|
||||
let ip = infer_client_ip(parts, &state.trusted_proxies);
|
||||
|
||||
if let Some(ip) = ip {
|
||||
Ok(RequesterFingerprint::new(ip))
|
||||
} else {
|
||||
// If we can't infer the IP address, we'll just use an empty fingerprint and
|
||||
// warn about it
|
||||
tracing::warn!("Could not infer client IP address for an operation which rate-limits based on IP addresses");
|
||||
Ok(RequesterFingerprint::EMPTY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<AppState> for BoxRepository {
|
||||
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;
|
||||
|
@ -19,7 +19,7 @@ use clap::Parser;
|
||||
use figment::Figment;
|
||||
use itertools::Itertools;
|
||||
use mas_config::{AppConfig, ClientsConfig, ConfigurationSection, UpstreamOAuth2Config};
|
||||
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCache};
|
||||
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache};
|
||||
use mas_listener::{server::Server, shutdown::ShutdownStream};
|
||||
use mas_matrix_synapse::SynapseConnection;
|
||||
use mas_router::UrlBuilder;
|
||||
@ -200,6 +200,8 @@ impl Options {
|
||||
// Listen for SIGHUP
|
||||
register_sighup(&templates, &activity_tracker)?;
|
||||
|
||||
let limiter = Limiter::default();
|
||||
|
||||
let graphql_schema = mas_handlers::graphql_schema(
|
||||
&pool,
|
||||
&policy_factory,
|
||||
@ -213,7 +215,6 @@ impl Options {
|
||||
pool,
|
||||
templates,
|
||||
key_store,
|
||||
metadata_cache,
|
||||
cookie_manager,
|
||||
encrypter,
|
||||
url_builder,
|
||||
@ -222,9 +223,11 @@ impl Options {
|
||||
graphql_schema,
|
||||
http_client_factory,
|
||||
password_manager,
|
||||
metadata_cache,
|
||||
site_config,
|
||||
activity_tracker,
|
||||
trusted_proxies,
|
||||
limiter,
|
||||
conn_acquisition_histogram: None,
|
||||
};
|
||||
s.init_metrics()?;
|
||||
|
@ -67,12 +67,14 @@ zeroize = "1.8.1"
|
||||
base64ct = "1.6.0"
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
governor.workspace = true
|
||||
indexmap = "2.2.6"
|
||||
psl = "2.1.55"
|
||||
time = "0.3.36"
|
||||
url.workspace = true
|
||||
mime = "0.3.17"
|
||||
minijinja.workspace = true
|
||||
nonzero_ext.workspace = true
|
||||
rand.workspace = true
|
||||
rand_chacha = "0.3.1"
|
||||
headers.workspace = true
|
||||
|
@ -36,7 +36,10 @@ use thiserror::Error;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::MatrixError;
|
||||
use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker};
|
||||
use crate::{
|
||||
impl_from_error_for_route, passwords::PasswordManager, rate_limit::PasswordCheckLimitedError,
|
||||
BoundActivityTracker, Limiter, RequesterFingerprint,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
@ -162,6 +165,9 @@ pub enum RouteError {
|
||||
#[error("password verification failed")]
|
||||
PasswordVerificationFailed(#[source] anyhow::Error),
|
||||
|
||||
#[error("request rate limited")]
|
||||
RateLimited(#[from] PasswordCheckLimitedError),
|
||||
|
||||
#[error("login took too long")]
|
||||
LoginTookTooLong,
|
||||
|
||||
@ -185,6 +191,11 @@ impl IntoResponse for RouteError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
Self::RateLimited(_) => MatrixError {
|
||||
errcode: "M_LIMIT_EXCEEDED",
|
||||
error: "Too many login attempts",
|
||||
status: StatusCode::TOO_MANY_REQUESTS,
|
||||
},
|
||||
Self::Unsupported => MatrixError {
|
||||
errcode: "M_UNRECOGNIZED",
|
||||
error: "Invalid login type",
|
||||
@ -192,18 +203,18 @@ impl IntoResponse for RouteError {
|
||||
},
|
||||
Self::UserNotFound | Self::NoPassword | Self::PasswordVerificationFailed(_) => {
|
||||
MatrixError {
|
||||
errcode: "M_UNAUTHORIZED",
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "Invalid username/password",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
}
|
||||
}
|
||||
Self::LoginTookTooLong => MatrixError {
|
||||
errcode: "M_UNAUTHORIZED",
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "Login token expired",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
Self::InvalidLoginToken => MatrixError {
|
||||
errcode: "M_UNAUTHORIZED",
|
||||
errcode: "M_FORBIDDEN",
|
||||
error: "Invalid login token",
|
||||
status: StatusCode::FORBIDDEN,
|
||||
},
|
||||
@ -222,6 +233,8 @@ pub(crate) async fn post(
|
||||
activity_tracker: BoundActivityTracker,
|
||||
State(homeserver): State<BoxHomeserverConnection>,
|
||||
State(site_config): State<SiteConfig>,
|
||||
State(limiter): State<Limiter>,
|
||||
requester: RequesterFingerprint,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
Json(input): Json<RequestBody>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
@ -238,6 +251,8 @@ pub(crate) async fn post(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&password_manager,
|
||||
&limiter,
|
||||
requester,
|
||||
&mut repo,
|
||||
&homeserver,
|
||||
user,
|
||||
@ -372,6 +387,8 @@ async fn user_password_login(
|
||||
mut rng: &mut (impl RngCore + CryptoRng + Send),
|
||||
clock: &impl Clock,
|
||||
password_manager: &PasswordManager,
|
||||
limiter: &Limiter,
|
||||
requester: RequesterFingerprint,
|
||||
repo: &mut BoxRepository,
|
||||
homeserver: &BoxHomeserverConnection,
|
||||
username: String,
|
||||
@ -385,6 +402,9 @@ async fn user_password_login(
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(RouteError::UserNotFound)?;
|
||||
|
||||
// Check the rate limit
|
||||
limiter.check_password(requester, &user)?;
|
||||
|
||||
// Lookup its password
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
@ -628,7 +648,7 @@ mod tests {
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
|
||||
assert_eq!(body["errcode"], "M_FORBIDDEN");
|
||||
|
||||
// Try to login with a wrong username.
|
||||
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
|
||||
@ -650,6 +670,57 @@ mod tests {
|
||||
assert_eq!(body, old_body);
|
||||
}
|
||||
|
||||
/// Test that password logins are rate limited.
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_password_login_rate_limit(pool: PgPool) {
|
||||
setup();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
// Let's provision a user without a password. This should be enough to trigger
|
||||
// the rate limit.
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.add(&mut state.rng(), &state.clock, "alice".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mxid = state.homeserver_connection.mxid(&user.username);
|
||||
state
|
||||
.homeserver_connection
|
||||
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
repo.save().await.unwrap();
|
||||
|
||||
// Now let's try to login with the password, without asking for a refresh token.
|
||||
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
|
||||
"type": "m.login.password",
|
||||
"identifier": {
|
||||
"type": "m.id.user",
|
||||
"user": "alice",
|
||||
},
|
||||
"password": "password",
|
||||
}));
|
||||
|
||||
// First three attempts should just tell about the invalid credentials
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
|
||||
// The fourth attempt should be rate limited
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::TOO_MANY_REQUESTS);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_eq!(body["errcode"], "M_LIMIT_EXCEEDED");
|
||||
assert_eq!(body["error"], "Too many login attempts");
|
||||
}
|
||||
|
||||
/// Test the response of an unsupported login flow.
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_unsupported_login(pool: PgPool) {
|
||||
@ -699,7 +770,7 @@ mod tests {
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
|
||||
assert_eq!(body["errcode"], "M_FORBIDDEN");
|
||||
|
||||
let (device, token) = get_login_token(&state, &user).await;
|
||||
|
||||
@ -726,7 +797,7 @@ mod tests {
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
|
||||
assert_eq!(body["errcode"], "M_FORBIDDEN");
|
||||
|
||||
// Try to login, but wait too long before sending the request.
|
||||
let (_device, token) = get_login_token(&state, &user).await;
|
||||
@ -743,7 +814,7 @@ mod tests {
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::FORBIDDEN);
|
||||
let body: serde_json::Value = response.json();
|
||||
assert_eq!(body["errcode"], "M_UNAUTHORIZED");
|
||||
assert_eq!(body["errcode"], "M_FORBIDDEN");
|
||||
}
|
||||
|
||||
/// Get a login token for a user.
|
||||
|
@ -65,6 +65,7 @@ mod views;
|
||||
mod activity_tracker;
|
||||
mod captcha;
|
||||
mod preferred_language;
|
||||
mod rate_limit;
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
|
||||
@ -95,6 +96,7 @@ pub use self::{
|
||||
schema as graphql_schema, schema_builder as graphql_schema_builder, Schema as GraphQLSchema,
|
||||
},
|
||||
preferred_language::PreferredLanguage,
|
||||
rate_limit::{Limiter, RequesterFingerprint},
|
||||
upstream_oauth2::cache::MetadataCache,
|
||||
};
|
||||
|
||||
@ -246,7 +248,9 @@ where
|
||||
SiteConfig: FromRef<S>,
|
||||
BoxHomeserverConnection: FromRef<S>,
|
||||
PasswordManager: FromRef<S>,
|
||||
Limiter: FromRef<S>,
|
||||
BoundActivityTracker: FromRequestParts<S>,
|
||||
RequesterFingerprint: FromRequestParts<S>,
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
@ -301,6 +305,7 @@ where
|
||||
BoxRepository: FromRequestParts<S>,
|
||||
CookieJar: FromRequestParts<S>,
|
||||
BoundActivityTracker: FromRequestParts<S>,
|
||||
RequesterFingerprint: FromRequestParts<S>,
|
||||
Encrypter: FromRef<S>,
|
||||
Templates: FromRef<S>,
|
||||
Keystore: FromRef<S>,
|
||||
@ -308,6 +313,7 @@ where
|
||||
PasswordManager: FromRef<S>,
|
||||
MetadataCache: FromRef<S>,
|
||||
SiteConfig: FromRef<S>,
|
||||
Limiter: FromRef<S>,
|
||||
BoxHomeserverConnection: FromRef<S>,
|
||||
BoxClock: FromRequestParts<S>,
|
||||
BoxRng: FromRequestParts<S>,
|
||||
|
183
crates/handlers/src/rate_limit.rs
Normal file
183
crates/handlers/src/rate_limit.rs
Normal file
@ -0,0 +1,183 @@
|
||||
// Copyright 2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
|
||||
use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
|
||||
use mas_data_model::User;
|
||||
use nonzero_ext::nonzero;
|
||||
use ulid::Ulid;
|
||||
|
||||
const PASSWORD_CHECK_FOR_REQUESTER_QUOTA: Quota = Quota::per_minute(nonzero!(3u32));
|
||||
const PASSWORD_CHECK_FOR_USER_QUOTA: Quota = Quota::per_hour(nonzero!(1800u32));
|
||||
|
||||
#[derive(Debug, Clone, Copy, thiserror::Error)]
|
||||
pub enum PasswordCheckLimitedError {
|
||||
#[error("Too many password checks for requester {0}")]
|
||||
Requester(RequesterFingerprint),
|
||||
|
||||
#[error("Too many password checks for user {0}")]
|
||||
User(Ulid),
|
||||
}
|
||||
|
||||
/// Key used to rate limit requests per requester
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct RequesterFingerprint {
|
||||
ip: Option<IpAddr>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RequesterFingerprint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if let Some(ip) = self.ip {
|
||||
write!(f, "{ip}")
|
||||
} else {
|
||||
write!(f, "(NO CLIENT IP)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RequesterFingerprint {
|
||||
/// An anonymous key with no IP address set. This should not be used in
|
||||
/// production, and we should warn users if we can't find their client IPs.
|
||||
pub const EMPTY: Self = Self { ip: None };
|
||||
|
||||
/// Create a new anonymous key with the given IP address
|
||||
#[must_use]
|
||||
pub const fn new(ip: IpAddr) -> Self {
|
||||
Self { ip: Some(ip) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiters for the different operations
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Limiter {
|
||||
inner: Arc<LimiterInner>,
|
||||
}
|
||||
|
||||
type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LimiterInner {
|
||||
password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
|
||||
password_check_for_user: KeyedRateLimiter<Ulid>,
|
||||
}
|
||||
|
||||
impl Default for LimiterInner {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
password_check_for_requester: RateLimiter::keyed(PASSWORD_CHECK_FOR_REQUESTER_QUOTA),
|
||||
password_check_for_user: RateLimiter::keyed(PASSWORD_CHECK_FOR_USER_QUOTA),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Limiter {
|
||||
/// Check if a password check can be performed
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the operation is rate limited
|
||||
pub fn check_password(
|
||||
&self,
|
||||
key: RequesterFingerprint,
|
||||
user: &User,
|
||||
) -> Result<(), PasswordCheckLimitedError> {
|
||||
self.inner
|
||||
.password_check_for_requester
|
||||
.check_key(&key)
|
||||
.map_err(|_| PasswordCheckLimitedError::Requester(key))?;
|
||||
|
||||
self.inner
|
||||
.password_check_for_user
|
||||
.check_key(&user.id)
|
||||
.map_err(|_| PasswordCheckLimitedError::User(user.id))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mas_data_model::User;
|
||||
use mas_storage::{clock::MockClock, Clock};
|
||||
use rand::SeedableRng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_password_check_limiter() {
|
||||
let now = MockClock::default().now();
|
||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
||||
|
||||
let limiter = Limiter::default();
|
||||
|
||||
// Let's create a lot of requesters to test account-level rate limiting
|
||||
let requesters: [_; 768] = (0..=255)
|
||||
.flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let alice = User {
|
||||
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
|
||||
username: "alice".to_owned(),
|
||||
sub: "123-456".to_owned(),
|
||||
primary_user_email_id: None,
|
||||
created_at: now,
|
||||
locked_at: None,
|
||||
can_request_admin: false,
|
||||
};
|
||||
|
||||
let bob = User {
|
||||
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
|
||||
username: "bob".to_owned(),
|
||||
sub: "123-456".to_owned(),
|
||||
primary_user_email_id: None,
|
||||
created_at: now,
|
||||
locked_at: None,
|
||||
can_request_admin: false,
|
||||
};
|
||||
|
||||
// Three times the same IP address should be allowed
|
||||
assert!(limiter.check_password(requesters[0], &alice).is_ok());
|
||||
assert!(limiter.check_password(requesters[0], &alice).is_ok());
|
||||
assert!(limiter.check_password(requesters[0], &alice).is_ok());
|
||||
|
||||
// But the fourth time should be rejected
|
||||
assert!(limiter.check_password(requesters[0], &alice).is_err());
|
||||
// Using another user should also be rejected
|
||||
assert!(limiter.check_password(requesters[0], &bob).is_err());
|
||||
|
||||
// Using a different IP address should be allowed, the account isn't locked yet
|
||||
assert!(limiter.check_password(requesters[1], &alice).is_ok());
|
||||
|
||||
// At this point, we consumed 4 cells out of 1800 on alice, let's distribute the
|
||||
// requests with other IPs so that we get rate-limited on the account-level
|
||||
for requester in requesters.iter().skip(2).take(598) {
|
||||
assert!(limiter.check_password(*requester, &alice).is_ok());
|
||||
assert!(limiter.check_password(*requester, &alice).is_ok());
|
||||
assert!(limiter.check_password(*requester, &alice).is_ok());
|
||||
assert!(limiter.check_password(*requester, &alice).is_err());
|
||||
}
|
||||
|
||||
// We now have consumed 4+598*3 = 1798 cells on the account, so we should be
|
||||
// rejected soon
|
||||
assert!(limiter.check_password(requesters[600], &alice).is_ok());
|
||||
assert!(limiter.check_password(requesters[601], &alice).is_ok());
|
||||
assert!(limiter.check_password(requesters[602], &alice).is_err());
|
||||
|
||||
// The other account isn't rate-limited
|
||||
assert!(limiter.check_password(requesters[603], &bob).is_ok());
|
||||
}
|
||||
}
|
@ -57,7 +57,7 @@ use crate::{
|
||||
graphql,
|
||||
passwords::{Hasher, PasswordManager},
|
||||
upstream_oauth2::cache::MetadataCache,
|
||||
ActivityTracker, BoundActivityTracker,
|
||||
ActivityTracker, BoundActivityTracker, Limiter, RequesterFingerprint,
|
||||
};
|
||||
|
||||
/// Setup rustcrypto and tracing for tests.
|
||||
@ -108,6 +108,7 @@ pub(crate) struct TestState {
|
||||
pub password_manager: PasswordManager,
|
||||
pub site_config: SiteConfig,
|
||||
pub activity_tracker: ActivityTracker,
|
||||
pub limiter: Limiter,
|
||||
pub clock: Arc<MockClock>,
|
||||
pub rng: Arc<Mutex<ChaChaRng>>,
|
||||
}
|
||||
@ -212,6 +213,8 @@ impl TestState {
|
||||
let activity_tracker =
|
||||
ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1));
|
||||
|
||||
let limiter = Limiter::default();
|
||||
|
||||
Ok(Self {
|
||||
pool,
|
||||
templates,
|
||||
@ -227,6 +230,7 @@ impl TestState {
|
||||
password_manager,
|
||||
site_config,
|
||||
activity_tracker,
|
||||
limiter,
|
||||
clock,
|
||||
rng,
|
||||
})
|
||||
@ -436,6 +440,12 @@ impl FromRef<TestState> for BoxHomeserverConnection {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<TestState> for Limiter {
|
||||
fn from_ref(input: &TestState) -> Self {
|
||||
input.limiter.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<TestState> for ActivityTracker {
|
||||
type Rejection = Infallible;
|
||||
@ -461,6 +471,18 @@ impl FromRequestParts<TestState> for BoundActivityTracker {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<TestState> for RequesterFingerprint {
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
_parts: &mut axum::http::request::Parts,
|
||||
_state: &TestState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
Ok(RequesterFingerprint::EMPTY)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequestParts<TestState> for BoxClock {
|
||||
type Rejection = Infallible;
|
||||
|
@ -39,7 +39,10 @@ use serde::{Deserialize, Serialize};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::shared::OptionalPostAuthAction;
|
||||
use crate::{passwords::PasswordManager, BoundActivityTracker, PreferredLanguage, SiteConfig};
|
||||
use crate::{
|
||||
passwords::PasswordManager, BoundActivityTracker, Limiter, PreferredLanguage,
|
||||
RequesterFingerprint, SiteConfig,
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct LoginForm {
|
||||
@ -116,8 +119,10 @@ pub(crate) async fn post(
|
||||
State(site_config): State<SiteConfig>,
|
||||
State(templates): State<Templates>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(limiter): State<Limiter>,
|
||||
mut repo: BoxRepository,
|
||||
activity_tracker: BoundActivityTracker,
|
||||
requester: RequesterFingerprint,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
cookie_jar: CookieJar,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
@ -170,6 +175,8 @@ pub(crate) async fn post(
|
||||
&mut repo,
|
||||
rng,
|
||||
&clock,
|
||||
limiter,
|
||||
requester,
|
||||
&form.username,
|
||||
&form.password,
|
||||
user_agent,
|
||||
@ -211,6 +218,8 @@ async fn login(
|
||||
repo: &mut impl RepositoryAccess,
|
||||
mut rng: impl Rng + CryptoRng + Send,
|
||||
clock: &impl Clock,
|
||||
limiter: Limiter,
|
||||
requester: RequesterFingerprint,
|
||||
username: &str,
|
||||
password: &str,
|
||||
user_agent: Option<UserAgent>,
|
||||
@ -225,6 +234,12 @@ async fn login(
|
||||
.filter(mas_data_model::User::is_valid)
|
||||
.ok_or(FormError::InvalidCredentials)?;
|
||||
|
||||
// Check the rate limit
|
||||
limiter.check_password(requester, &user).map_err(|e| {
|
||||
tracing::warn!(error = &e as &dyn std::error::Error);
|
||||
FormError::RateLimitExceeded
|
||||
})?;
|
||||
|
||||
// And its password
|
||||
let user_password = repo
|
||||
.user_password()
|
||||
@ -491,4 +506,73 @@ mod test {
|
||||
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
|
||||
assert!(response.body().contains("john"));
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_password_login_rate_limit(pool: PgPool) {
|
||||
setup();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
let mut rng = state.rng();
|
||||
let cookies = CookieHelper::new();
|
||||
|
||||
// Provision a user without a password.
|
||||
// We don't give that user a password, so that we skip hashing it in this test.
|
||||
// It will still be rate-limited
|
||||
let mut repo = state.repository().await.unwrap();
|
||||
repo.user()
|
||||
.add(&mut rng, &state.clock, "john".to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
repo.save().await.unwrap();
|
||||
|
||||
// Render the login page to get a CSRF token
|
||||
let request = Request::get("/login").empty();
|
||||
let request = cookies.with_cookies(request);
|
||||
let response = state.request(request).await;
|
||||
cookies.save_cookies(&response);
|
||||
response.assert_status(StatusCode::OK);
|
||||
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
|
||||
// Extract the CSRF token from the response body
|
||||
let csrf_token = response
|
||||
.body()
|
||||
.split("name=\"csrf\" value=\"")
|
||||
.nth(1)
|
||||
.unwrap()
|
||||
.split('\"')
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
// Submit the login form
|
||||
let request = Request::post("/login").form(serde_json::json!({
|
||||
"csrf": csrf_token,
|
||||
"username": "john",
|
||||
"password": "hunter2",
|
||||
}));
|
||||
let request = cookies.with_cookies(request);
|
||||
|
||||
// First three attempts should just tell about the invalid credentials
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let body = response.body();
|
||||
assert!(body.contains("Invalid credentials"));
|
||||
assert!(!body.contains("too many requests"));
|
||||
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let body = response.body();
|
||||
assert!(body.contains("Invalid credentials"));
|
||||
assert!(!body.contains("too many requests"));
|
||||
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let body = response.body();
|
||||
assert!(body.contains("Invalid credentials"));
|
||||
assert!(!body.contains("too many requests"));
|
||||
|
||||
// The fourth attempt should be rate-limited
|
||||
let response = state.request(request.clone()).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let body = response.body();
|
||||
assert!(!body.contains("Invalid credentials"));
|
||||
assert!(body.contains("too many requests"));
|
||||
}
|
||||
}
|
||||
|
@ -62,6 +62,9 @@ pub enum FormError {
|
||||
/// There was an internal error
|
||||
Internal,
|
||||
|
||||
/// Rate limit exceeded
|
||||
RateLimitExceeded,
|
||||
|
||||
/// Denied by the policy
|
||||
Policy {
|
||||
/// Message for this policy violation
|
||||
|
Reference in New Issue
Block a user