From cf5510a1a28c00c6df1302bd55a65d985f6982e1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 19 Sep 2023 17:10:09 +0200 Subject: [PATCH] Add an ActivityTracker which tracks session activity and regularly flush them to the database --- crates/cli/src/commands/server.rs | 10 +- crates/cli/src/util.rs | 12 +- crates/handlers/src/activity_tracker/bound.rs | 56 +++++ crates/handlers/src/activity_tracker/mod.rs | 197 ++++++++++++++++ .../handlers/src/activity_tracker/worker.rs | 215 ++++++++++++++++++ crates/handlers/src/app_state.rs | 29 ++- crates/handlers/src/lib.rs | 10 +- crates/handlers/src/test_utils.rs | 32 ++- crates/handlers/src/views/app.rs | 14 +- 9 files changed, 563 insertions(+), 12 deletions(-) create mode 100644 crates/handlers/src/activity_tracker/bound.rs create mode 100644 crates/handlers/src/activity_tracker/mod.rs create mode 100644 crates/handlers/src/activity_tracker/worker.rs diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 4766d7e9..b79199b5 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -19,7 +19,8 @@ use clap::Parser; use itertools::Itertools; use mas_config::AppConfig; use mas_handlers::{ - AppState, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache, SiteConfig, + ActivityTracker, AppState, CookieManager, HttpClientFactory, MatrixHomeserver, MetadataCache, + SiteConfig, }; use mas_listener::{server::Server, shutdown::ShutdownStream}; use mas_matrix_synapse::SynapseConnection; @@ -140,11 +141,13 @@ impl Options { compat_token_ttl: config.experimental.compat_token_ttl, }; + let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60 * 5)); + // Explicitly the config to properly zeroize secret keys drop(config); // Listen for SIGHUP - register_sighup(&templates)?; + register_sighup(&templates, &activity_tracker)?; let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn); @@ -163,6 +166,7 @@ impl Options { http_client_factory, password_manager, site_config, + activity_tracker, conn_acquisition_histogram: None, }; s.init_metrics()?; @@ -242,6 +246,8 @@ impl Options { mas_listener::server::run_servers(servers, shutdown).await; + state.activity_tracker.shutdown().await; + Ok(()) } } diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 2ebe2efe..4ca4a126 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -20,7 +20,7 @@ use mas_config::{ PasswordsConfig, PolicyConfig, TemplatesConfig, }; use mas_email::{MailTransport, Mailer}; -use mas_handlers::passwords::PasswordManager; +use mas_handlers::{passwords::PasswordManager, ActivityTracker}; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_templates::{TemplateLoadingError, Templates}; @@ -206,11 +206,16 @@ pub async fn database_connection_from_config( } /// Reload templates on SIGHUP -pub fn register_sighup(templates: &Templates) -> anyhow::Result<()> { +pub fn register_sighup( + templates: &Templates, + activity_tracker: &ActivityTracker, +) -> anyhow::Result<()> { #[cfg(unix)] { let mut signal = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())?; let templates = templates.clone(); + let activity_tracker = activity_tracker.clone(); + tokio::spawn(async move { loop { if signal.recv().await.is_none() { @@ -218,8 +223,9 @@ pub fn register_sighup(templates: &Templates) -> anyhow::Result<()> { break; }; - info!("SIGHUP received, reloading templates"); + info!("SIGHUP received, reloading templates & flushing activity tracker"); + activity_tracker.flush().await; templates.clone().reload().await.unwrap_or_else(|err| { error!(?err, "Error while reloading templates"); }); diff --git a/crates/handlers/src/activity_tracker/bound.rs b/crates/handlers/src/activity_tracker/bound.rs new file mode 100644 index 00000000..419ba3a1 --- /dev/null +++ b/crates/handlers/src/activity_tracker/bound.rs @@ -0,0 +1,56 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::IpAddr; + +use mas_data_model::{BrowserSession, CompatSession, Session}; +use mas_storage::Clock; + +use crate::activity_tracker::ActivityTracker; + +/// An activity tracker with an IP address bound to it. +#[derive(Clone)] +pub struct Bound { + tracker: ActivityTracker, + ip: Option, +} + +impl Bound { + /// Create a new bound activity tracker. + #[must_use] + pub fn new(tracker: ActivityTracker, ip: Option) -> Self { + Self { tracker, ip } + } + + /// Record activity in an OAuth 2.0 session. + pub async fn record_oauth2_session(&self, clock: &dyn Clock, session: &Session) { + self.tracker + .record_oauth2_session(clock, session, self.ip) + .await; + } + + /// Record activity in a compatibility session. + pub async fn record_compat_session(&self, clock: &dyn Clock, session: &CompatSession) { + self.tracker + .record_compat_session(clock, session, self.ip) + .await; + } + + /// Record activity in a browser session. + pub async fn record_browser_session(&self, clock: &dyn Clock, session: &BrowserSession) { + self.tracker + .record_browser_session(clock, session, self.ip) + .await; + } +} diff --git a/crates/handlers/src/activity_tracker/mod.rs b/crates/handlers/src/activity_tracker/mod.rs new file mode 100644 index 00000000..f867720c --- /dev/null +++ b/crates/handlers/src/activity_tracker/mod.rs @@ -0,0 +1,197 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod bound; +mod worker; + +use std::net::IpAddr; + +use chrono::{DateTime, Utc}; +use mas_data_model::{BrowserSession, CompatSession, Session}; +use mas_storage::Clock; +use sqlx::PgPool; +use ulid::Ulid; + +pub use self::bound::Bound; +use self::worker::Worker; + +static MESSAGE_QUEUE_SIZE: usize = 1000; + +#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)] +enum SessionKind { + OAuth2, + Compat, + Browser, +} + +impl SessionKind { + const fn as_str(self) -> &'static str { + match self { + SessionKind::OAuth2 => "oauth2", + SessionKind::Compat => "compat", + SessionKind::Browser => "browser", + } + } +} + +enum Message { + Record { + kind: SessionKind, + id: Ulid, + date_time: DateTime, + ip: Option, + }, + Flush, + Shutdown(tokio::sync::oneshot::Sender<()>), +} + +#[derive(Clone)] +pub struct ActivityTracker { + channel: tokio::sync::mpsc::Sender, +} + +impl ActivityTracker { + /// Create a new activity tracker, spawning the worker. + #[must_use] + pub fn new(pool: PgPool, flush_interval: std::time::Duration) -> Self { + let worker = Worker::new(pool); + let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE); + let tracker = ActivityTracker { channel: sender }; + + // Spawn the flush loop and the worker + tokio::spawn(tracker.clone().flush_loop(flush_interval)); + tokio::spawn(worker.run(receiver)); + + tracker + } + + /// Bind the activity tracker to an IP address. + #[must_use] + pub fn bind(self, ip: Option) -> Bound { + Bound::new(self, ip) + } + + /// Record activity in an OAuth 2.0 session. + pub async fn record_oauth2_session( + &self, + clock: &dyn Clock, + session: &Session, + ip: Option, + ) { + let res = self + .channel + .send(Message::Record { + kind: SessionKind::OAuth2, + id: session.id, + date_time: clock.now(), + ip, + }) + .await; + + if let Err(e) = res { + tracing::error!("Failed to record OAuth2 session: {}", e); + } + } + + /// Record activity in a compat session. + pub async fn record_compat_session( + &self, + clock: &dyn Clock, + compat_session: &CompatSession, + ip: Option, + ) { + let res = self + .channel + .send(Message::Record { + kind: SessionKind::Compat, + id: compat_session.id, + date_time: clock.now(), + ip, + }) + .await; + + if let Err(e) = res { + tracing::error!("Failed to record compat session: {}", e); + } + } + + /// Record activity in a browser session. + pub async fn record_browser_session( + &self, + clock: &dyn Clock, + browser_session: &BrowserSession, + ip: Option, + ) { + let res = self + .channel + .send(Message::Record { + kind: SessionKind::Browser, + id: browser_session.id, + date_time: clock.now(), + ip, + }) + .await; + + if let Err(e) = res { + tracing::error!("Failed to record browser session: {}", e); + } + } + + /// Manually flush the activity tracker. + pub async fn flush(&self) { + let res = self.channel.send(Message::Flush).await; + + if let Err(e) = res { + tracing::error!("Failed to flush activity tracker: {}", e); + } + } + + /// Regularly flush the activity tracker. + async fn flush_loop(self, interval: std::time::Duration) { + loop { + tokio::select! { + biased; + + // First check if the channel is closed, then check if the timer expired + _ = self.channel.closed() => { + // The channel was closed, so we should exit + break; + } + + _ = tokio::time::sleep(interval) => { + self.flush().await; + } + } + } + } + + /// Shutdown the activity tracker. + /// + /// This will wait for all pending messages to be processed. + pub async fn shutdown(&self) { + let (tx, rx) = tokio::sync::oneshot::channel(); + let res = self.channel.send(Message::Shutdown(tx)).await; + + match res { + Ok(_) => { + if let Err(e) = rx.await { + tracing::error!("Failed to shutdown activity tracker: {}", e); + } + } + Err(e) => { + tracing::error!("Failed to shutdown activity tracker: {}", e); + } + } + } +} diff --git a/crates/handlers/src/activity_tracker/worker.rs b/crates/handlers/src/activity_tracker/worker.rs new file mode 100644 index 00000000..ed37636e --- /dev/null +++ b/crates/handlers/src/activity_tracker/worker.rs @@ -0,0 +1,215 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, net::IpAddr}; + +use chrono::{DateTime, Utc}; +use mas_storage::Repository; +use opentelemetry::{ + metrics::{Counter, Histogram}, + Key, +}; +use sqlx::PgPool; +use ulid::Ulid; + +use crate::activity_tracker::{Message, SessionKind}; + +/// The maximum number of pending activity records before we flush them to the +/// database automatically. +/// +/// The [`ActivityRecord`] structure plus the key in the [`HashMap`] takes less +/// than 100 bytes, so this should allocate around a megabyte of memory. +static MAX_PENDING_RECORDS: usize = 10_000; + +const TYPE: Key = Key::from_static_str("type"); +const SESSION_KIND: Key = Key::from_static_str("session_kind"); +const RESULT: Key = Key::from_static_str("result"); + +#[derive(Clone, Copy, Debug)] +struct ActivityRecord { + start_time: DateTime, + end_time: DateTime, + ip: Option, +} + +/// Handles writing activity records to the database. +pub struct Worker { + pool: PgPool, + pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>, + message_counter: Counter, + flush_time_histogram: Histogram, +} + +impl Worker { + pub(crate) fn new(pool: PgPool) -> Self { + let meter = opentelemetry::global::meter_with_version( + env!("CARGO_PKG_NAME"), + Some(env!("CARGO_PKG_VERSION")), + Some(opentelemetry_semantic_conventions::SCHEMA_URL), + None, + ); + + let message_counter = meter + .u64_counter("mas.activity_tracker.messages") + .with_description("The number of messages received by the activity tracker") + .with_unit(opentelemetry::metrics::Unit::new("{messages}")) + .init(); + + // Record stuff on the counter so that the metrics are initialized + for kind in &[ + SessionKind::OAuth2, + SessionKind::Compat, + SessionKind::Browser, + ] { + message_counter.add( + 0, + &[TYPE.string("record"), SESSION_KIND.string(kind.as_str())], + ); + } + message_counter.add(0, &[TYPE.string("flush")]); + message_counter.add(0, &[TYPE.string("shutdown")]); + + let flush_time_histogram = meter + .u64_histogram("mas.activity_tracker.flush_time") + .with_description("The time it took to flush the activity tracker") + .with_unit(opentelemetry::metrics::Unit::new("ms")) + .init(); + + Self { + pool, + pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS), + message_counter, + flush_time_histogram, + } + } + + pub(super) async fn run(mut self, mut receiver: tokio::sync::mpsc::Receiver) { + let mut shutdown_notifier = None; + while let Some(message) = receiver.recv().await { + match message { + Message::Record { + kind, + id, + date_time, + ip, + } => { + if self.pending_records.len() >= MAX_PENDING_RECORDS { + tracing::warn!("Too many pending activity records, flushing"); + self.flush().await; + } + + if self.pending_records.len() >= MAX_PENDING_RECORDS { + tracing::error!( + kind = kind.as_str(), + %id, + %date_time, + "Still too many pending activity records, dropping" + ); + continue; + } + + self.message_counter.add( + 1, + &[TYPE.string("record"), SESSION_KIND.string(kind.as_str())], + ); + + let record = + self.pending_records + .entry((kind, id)) + .or_insert_with(|| ActivityRecord { + start_time: date_time, + end_time: date_time, + ip, + }); + + record.end_time = date_time.max(record.end_time); + } + Message::Flush => { + self.message_counter.add(1, &[TYPE.string("flush")]); + + self.flush().await; + } + Message::Shutdown(tx) => { + self.message_counter.add(1, &[TYPE.string("shutdown")]); + + let old_tx = shutdown_notifier.replace(tx); + if let Some(old_tx) = old_tx { + tracing::warn!("Activity tracker shutdown requested while another shutdown was already in progress"); + // Still send the shutdown signal to the previous notifier. This means we + // send the shutdown signal before we flush the activity tracker, but that + // should be fine, since there should not be multiple shutdown requests. + let _ = old_tx.send(()); + } + receiver.close(); + } + } + } + + self.flush().await; + + if let Some(shutdown_notifier) = shutdown_notifier { + let _ = shutdown_notifier.send(()); + } else { + // This should never happen, since we set the shutdown notifier when we receive + // the first shutdown message + tracing::warn!("Activity tracker shutdown requested but no shutdown notifier was set"); + } + } + + /// Flush the activity tracker. + async fn flush(&mut self) { + // Short path: if there are no pending records, we don't need to flush + if self.pending_records.is_empty() { + return; + } + + let start = std::time::Instant::now(); + let res = self.try_flush().await; + + // Measure the time it took to flush the activity tracker + let duration = start.elapsed(); + let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX); + + match res { + Ok(_) => { + self.flush_time_histogram + .record(duration_ms, &[RESULT.string("success")]); + } + Err(e) => { + self.flush_time_histogram + .record(duration_ms, &[RESULT.string("failure")]); + tracing::error!("Failed to flush activity tracker: {}", e); + } + } + } + + /// Fallible part of [`Self::flush`]. + async fn try_flush(&mut self) -> Result<(), anyhow::Error> { + let pending_records = &self.pending_records; + + let repo = mas_storage_pg::PgRepository::from_pool(&self.pool) + .await? + .boxed(); + + tracing::info!( + "Flushing {} activity records to the database", + pending_records.len() + ); + // TODO: actually save the records + repo.save().await?; + self.pending_records.clear(); + + Ok(()) + } +} diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index d53e3188..2df24db8 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -36,7 +36,7 @@ use sqlx::PgPool; use crate::{ passwords::PasswordManager, site_config::SiteConfig, upstream_oauth2::cache::MetadataCache, - MatrixHomeserver, + ActivityTracker, BoundActivityTracker, MatrixHomeserver, }; #[derive(Clone)] @@ -54,6 +54,7 @@ pub struct AppState { pub password_manager: PasswordManager, pub metadata_cache: MetadataCache, pub site_config: SiteConfig, + pub activity_tracker: ActivityTracker, pub conn_acquisition_histogram: Option>, } @@ -269,6 +270,32 @@ impl FromRequestParts for Policy { } } +#[async_trait] +impl FromRequestParts for ActivityTracker { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + Ok(state.activity_tracker.clone()) + } +} + +#[async_trait] +impl FromRequestParts for BoundActivityTracker { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + // TODO: grab the IP address from the request + let ip = None; + Ok(state.activity_tracker.clone().bind(ip)) + } +} + #[async_trait] impl FromRequestParts for BoxRepository { type Rejection = ErrorWrapper; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index b22cc971..c4c68467 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -68,6 +68,7 @@ pub mod passwords; pub mod upstream_oauth2; mod views; +mod activity_tracker; mod site_config; #[cfg(test)] mod test_utils; @@ -91,8 +92,12 @@ macro_rules! impl_from_error_for_route { pub use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; pub use self::{ - app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema, - site_config::SiteConfig, upstream_oauth2::cache::MetadataCache, + activity_tracker::{ActivityTracker, Bound as BoundActivityTracker}, + app_state::AppState, + compat::MatrixHomeserver, + graphql::schema as graphql_schema, + site_config::SiteConfig, + upstream_oauth2::cache::MetadataCache, }; pub fn healthcheck_router() -> Router @@ -288,6 +293,7 @@ where UrlBuilder: FromRef, BoxRepository: FromRequestParts, CookieJar: FromRequestParts, + BoundActivityTracker: FromRequestParts, Encrypter: FromRef, Templates: FromRef, Keystore: FromRef, diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index ec75e290..78aee642 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -50,7 +50,7 @@ use crate::{ passwords::{Hasher, PasswordManager}, site_config::SiteConfig, upstream_oauth2::cache::MetadataCache, - MatrixHomeserver, + ActivityTracker, BoundActivityTracker, MatrixHomeserver, }; // This might fail if it's not the first time it's being called, which is fine, @@ -100,6 +100,7 @@ pub(crate) struct TestState { pub http_client_factory: HttpClientFactory, pub password_manager: PasswordManager, pub site_config: SiteConfig, + pub activity_tracker: ActivityTracker, pub clock: Arc, pub rng: Arc>, } @@ -160,6 +161,9 @@ impl TestState { let graphql_schema = mas_graphql::schema_builder().data(state).finish(); + let activity_tracker = + ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1)); + Ok(Self { pool, templates, @@ -174,6 +178,7 @@ impl TestState { http_client_factory, password_manager, site_config, + activity_tracker, clock, rng, }) @@ -366,6 +371,31 @@ impl FromRef for SiteConfig { } } +#[async_trait] +impl FromRequestParts for ActivityTracker { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + Ok(state.activity_tracker.clone()) + } +} + +#[async_trait] +impl FromRequestParts for BoundActivityTracker { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + let ip = None; + Ok(state.activity_tracker.clone().bind(ip)) + } +} + #[async_trait] impl FromRequestParts for BoxClock { type Rejection = Infallible; diff --git a/crates/handlers/src/views/app.rs b/crates/handlers/src/views/app.rs index 2c6af220..f171c204 100644 --- a/crates/handlers/src/views/app.rs +++ b/crates/handlers/src/views/app.rs @@ -18,14 +18,18 @@ use axum::{ }; use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfoExt}; use mas_router::{PostAuthAction, Route}; -use mas_storage::BoxRepository; +use mas_storage::{BoxClock, BoxRepository}; use mas_templates::{AppContext, Templates}; +use crate::BoundActivityTracker; + #[tracing::instrument(name = "handlers.views.app.get", skip_all, err)] pub async fn get( State(templates): State, + activity_tracker: BoundActivityTracker, action: Option>, mut repo: BoxRepository, + clock: BoxClock, cookie_jar: CookieJar, ) -> Result { let (session_info, cookie_jar) = cookie_jar.session_info(); @@ -33,13 +37,17 @@ pub async fn get( let action = action.map(|Query(a)| a); // TODO: keep the full path, not just the action - if session.is_none() { + let Some(session) = session else { return Ok(( cookie_jar, mas_router::Login::and_then(PostAuthAction::manage_account(action)).go(), ) .into_response()); - } + }; + + activity_tracker + .record_browser_session(&clock, &session) + .await; let ctx = AppContext::default(); let content = templates.render_app(&ctx).await?;