// Copyright 2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::{collections::HashMap, net::IpAddr}; use chrono::{DateTime, Utc}; use mas_storage::{user::BrowserSessionRepository, Repository, RepositoryAccess}; use opentelemetry::{ metrics::{Counter, Histogram}, Key, }; use sqlx::PgPool; use ulid::Ulid; use crate::activity_tracker::{Message, SessionKind}; /// The maximum number of pending activity records before we flush them to the /// database automatically. /// /// The [`ActivityRecord`] structure plus the key in the [`HashMap`] takes less /// than 100 bytes, so this should allocate around a megabyte of memory. static MAX_PENDING_RECORDS: usize = 10_000; const TYPE: Key = Key::from_static_str("type"); const SESSION_KIND: Key = Key::from_static_str("session_kind"); const RESULT: Key = Key::from_static_str("result"); #[derive(Clone, Copy, Debug)] struct ActivityRecord { // XXX: We don't actually use the start time for now #[allow(dead_code)] start_time: DateTime, end_time: DateTime, ip: Option, } /// Handles writing activity records to the database. pub struct Worker { pool: PgPool, pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>, message_counter: Counter, flush_time_histogram: Histogram, } impl Worker { pub(crate) fn new(pool: PgPool) -> Self { let meter = opentelemetry::global::meter_with_version( env!("CARGO_PKG_NAME"), Some(env!("CARGO_PKG_VERSION")), Some(opentelemetry_semantic_conventions::SCHEMA_URL), None, ); let message_counter = meter .u64_counter("mas.activity_tracker.messages") .with_description("The number of messages received by the activity tracker") .with_unit(opentelemetry::metrics::Unit::new("{messages}")) .init(); // Record stuff on the counter so that the metrics are initialized for kind in &[ SessionKind::OAuth2, SessionKind::Compat, SessionKind::Browser, ] { message_counter.add( 0, &[TYPE.string("record"), SESSION_KIND.string(kind.as_str())], ); } message_counter.add(0, &[TYPE.string("flush")]); message_counter.add(0, &[TYPE.string("shutdown")]); let flush_time_histogram = meter .u64_histogram("mas.activity_tracker.flush_time") .with_description("The time it took to flush the activity tracker") .with_unit(opentelemetry::metrics::Unit::new("ms")) .init(); Self { pool, pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS), message_counter, flush_time_histogram, } } pub(super) async fn run(mut self, mut receiver: tokio::sync::mpsc::Receiver) { let mut shutdown_notifier = None; while let Some(message) = receiver.recv().await { match message { Message::Record { kind, id, date_time, ip, } => { if self.pending_records.len() >= MAX_PENDING_RECORDS { tracing::warn!("Too many pending activity records, flushing"); self.flush().await; } if self.pending_records.len() >= MAX_PENDING_RECORDS { tracing::error!( kind = kind.as_str(), %id, %date_time, "Still too many pending activity records, dropping" ); continue; } self.message_counter.add( 1, &[TYPE.string("record"), SESSION_KIND.string(kind.as_str())], ); let record = self.pending_records .entry((kind, id)) .or_insert_with(|| ActivityRecord { start_time: date_time, end_time: date_time, ip, }); record.end_time = date_time.max(record.end_time); } Message::Flush(tx) => { self.message_counter.add(1, &[TYPE.string("flush")]); self.flush().await; let _ = tx.send(()); } Message::Shutdown(tx) => { self.message_counter.add(1, &[TYPE.string("shutdown")]); let old_tx = shutdown_notifier.replace(tx); if let Some(old_tx) = old_tx { tracing::warn!("Activity tracker shutdown requested while another shutdown was already in progress"); // Still send the shutdown signal to the previous notifier. This means we // send the shutdown signal before we flush the activity tracker, but that // should be fine, since there should not be multiple shutdown requests. let _ = old_tx.send(()); } receiver.close(); } } } self.flush().await; if let Some(shutdown_notifier) = shutdown_notifier { let _ = shutdown_notifier.send(()); } else { // This should never happen, since we set the shutdown notifier when we receive // the first shutdown message tracing::warn!("Activity tracker shutdown requested but no shutdown notifier was set"); } } /// Flush the activity tracker. async fn flush(&mut self) { // Short path: if there are no pending records, we don't need to flush if self.pending_records.is_empty() { return; } let start = std::time::Instant::now(); let res = self.try_flush().await; // Measure the time it took to flush the activity tracker let duration = start.elapsed(); let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX); match res { Ok(()) => { self.flush_time_histogram .record(duration_ms, &[RESULT.string("success")]); } Err(e) => { self.flush_time_histogram .record(duration_ms, &[RESULT.string("failure")]); tracing::error!("Failed to flush activity tracker: {}", e); } } } /// Fallible part of [`Self::flush`]. #[tracing::instrument(name = "activity_tracker.flush", skip(self))] async fn try_flush(&mut self) -> Result<(), anyhow::Error> { let pending_records = &self.pending_records; let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool) .await? .boxed(); let mut browser_sessions = Vec::new(); let mut oauth2_sessions = Vec::new(); let mut compat_sessions = Vec::new(); for ((kind, id), record) in pending_records { match kind { SessionKind::Browser => { browser_sessions.push((*id, record.end_time, record.ip)); } SessionKind::OAuth2 => { oauth2_sessions.push((*id, record.end_time, record.ip)); } SessionKind::Compat => { compat_sessions.push((*id, record.end_time, record.ip)); } } } tracing::info!( "Flushing {} activity records to the database", pending_records.len() ); repo.browser_session() .record_batch_activity(browser_sessions) .await?; repo.oauth2_session() .record_batch_activity(oauth2_sessions) .await?; repo.compat_session() .record_batch_activity(compat_sessions) .await?; repo.save().await?; self.pending_records.clear(); Ok(()) } }