1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Add an ActivityTracker which tracks session activity and regularly flush them to the database

This commit is contained in:
Quentin Gliech
2023-09-19 17:10:09 +02:00
parent 16962b451b
commit cf5510a1a2
9 changed files with 563 additions and 12 deletions

View File

@ -0,0 +1,56 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use mas_data_model::{BrowserSession, CompatSession, Session};
use mas_storage::Clock;
use crate::activity_tracker::ActivityTracker;
/// An activity tracker with an IP address bound to it.
#[derive(Clone)]
pub struct Bound {
tracker: ActivityTracker,
ip: Option<IpAddr>,
}
impl Bound {
/// Create a new bound activity tracker.
#[must_use]
pub fn new(tracker: ActivityTracker, ip: Option<IpAddr>) -> Self {
Self { tracker, ip }
}
/// Record activity in an OAuth 2.0 session.
pub async fn record_oauth2_session(&self, clock: &dyn Clock, session: &Session) {
self.tracker
.record_oauth2_session(clock, session, self.ip)
.await;
}
/// Record activity in a compatibility session.
pub async fn record_compat_session(&self, clock: &dyn Clock, session: &CompatSession) {
self.tracker
.record_compat_session(clock, session, self.ip)
.await;
}
/// Record activity in a browser session.
pub async fn record_browser_session(&self, clock: &dyn Clock, session: &BrowserSession) {
self.tracker
.record_browser_session(clock, session, self.ip)
.await;
}
}

View File

@ -0,0 +1,197 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod bound;
mod worker;
use std::net::IpAddr;
use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, CompatSession, Session};
use mas_storage::Clock;
use sqlx::PgPool;
use ulid::Ulid;
pub use self::bound::Bound;
use self::worker::Worker;
static MESSAGE_QUEUE_SIZE: usize = 1000;
#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)]
enum SessionKind {
OAuth2,
Compat,
Browser,
}
impl SessionKind {
const fn as_str(self) -> &'static str {
match self {
SessionKind::OAuth2 => "oauth2",
SessionKind::Compat => "compat",
SessionKind::Browser => "browser",
}
}
}
enum Message {
Record {
kind: SessionKind,
id: Ulid,
date_time: DateTime<Utc>,
ip: Option<IpAddr>,
},
Flush,
Shutdown(tokio::sync::oneshot::Sender<()>),
}
#[derive(Clone)]
pub struct ActivityTracker {
channel: tokio::sync::mpsc::Sender<Message>,
}
impl ActivityTracker {
/// Create a new activity tracker, spawning the worker.
#[must_use]
pub fn new(pool: PgPool, flush_interval: std::time::Duration) -> Self {
let worker = Worker::new(pool);
let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
let tracker = ActivityTracker { channel: sender };
// Spawn the flush loop and the worker
tokio::spawn(tracker.clone().flush_loop(flush_interval));
tokio::spawn(worker.run(receiver));
tracker
}
/// Bind the activity tracker to an IP address.
#[must_use]
pub fn bind(self, ip: Option<IpAddr>) -> Bound {
Bound::new(self, ip)
}
/// Record activity in an OAuth 2.0 session.
pub async fn record_oauth2_session(
&self,
clock: &dyn Clock,
session: &Session,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::OAuth2,
id: session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record OAuth2 session: {}", e);
}
}
/// Record activity in a compat session.
pub async fn record_compat_session(
&self,
clock: &dyn Clock,
compat_session: &CompatSession,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::Compat,
id: compat_session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record compat session: {}", e);
}
}
/// Record activity in a browser session.
pub async fn record_browser_session(
&self,
clock: &dyn Clock,
browser_session: &BrowserSession,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::Browser,
id: browser_session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record browser session: {}", e);
}
}
/// Manually flush the activity tracker.
pub async fn flush(&self) {
let res = self.channel.send(Message::Flush).await;
if let Err(e) = res {
tracing::error!("Failed to flush activity tracker: {}", e);
}
}
/// Regularly flush the activity tracker.
async fn flush_loop(self, interval: std::time::Duration) {
loop {
tokio::select! {
biased;
// First check if the channel is closed, then check if the timer expired
_ = self.channel.closed() => {
// The channel was closed, so we should exit
break;
}
_ = tokio::time::sleep(interval) => {
self.flush().await;
}
}
}
}
/// Shutdown the activity tracker.
///
/// This will wait for all pending messages to be processed.
pub async fn shutdown(&self) {
let (tx, rx) = tokio::sync::oneshot::channel();
let res = self.channel.send(Message::Shutdown(tx)).await;
match res {
Ok(_) => {
if let Err(e) = rx.await {
tracing::error!("Failed to shutdown activity tracker: {}", e);
}
}
Err(e) => {
tracing::error!("Failed to shutdown activity tracker: {}", e);
}
}
}
}

View File

@ -0,0 +1,215 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, net::IpAddr};
use chrono::{DateTime, Utc};
use mas_storage::Repository;
use opentelemetry::{
metrics::{Counter, Histogram},
Key,
};
use sqlx::PgPool;
use ulid::Ulid;
use crate::activity_tracker::{Message, SessionKind};
/// The maximum number of pending activity records before we flush them to the
/// database automatically.
///
/// The [`ActivityRecord`] structure plus the key in the [`HashMap`] takes less
/// than 100 bytes, so this should allocate around a megabyte of memory.
static MAX_PENDING_RECORDS: usize = 10_000;
const TYPE: Key = Key::from_static_str("type");
const SESSION_KIND: Key = Key::from_static_str("session_kind");
const RESULT: Key = Key::from_static_str("result");
#[derive(Clone, Copy, Debug)]
struct ActivityRecord {
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
ip: Option<IpAddr>,
}
/// Handles writing activity records to the database.
pub struct Worker {
pool: PgPool,
pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
message_counter: Counter<u64>,
flush_time_histogram: Histogram<u64>,
}
impl Worker {
pub(crate) fn new(pool: PgPool) -> Self {
let meter = opentelemetry::global::meter_with_version(
env!("CARGO_PKG_NAME"),
Some(env!("CARGO_PKG_VERSION")),
Some(opentelemetry_semantic_conventions::SCHEMA_URL),
None,
);
let message_counter = meter
.u64_counter("mas.activity_tracker.messages")
.with_description("The number of messages received by the activity tracker")
.with_unit(opentelemetry::metrics::Unit::new("{messages}"))
.init();
// Record stuff on the counter so that the metrics are initialized
for kind in &[
SessionKind::OAuth2,
SessionKind::Compat,
SessionKind::Browser,
] {
message_counter.add(
0,
&[TYPE.string("record"), SESSION_KIND.string(kind.as_str())],
);
}
message_counter.add(0, &[TYPE.string("flush")]);
message_counter.add(0, &[TYPE.string("shutdown")]);
let flush_time_histogram = meter
.u64_histogram("mas.activity_tracker.flush_time")
.with_description("The time it took to flush the activity tracker")
.with_unit(opentelemetry::metrics::Unit::new("ms"))
.init();
Self {
pool,
pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
message_counter,
flush_time_histogram,
}
}
pub(super) async fn run(mut self, mut receiver: tokio::sync::mpsc::Receiver<Message>) {
let mut shutdown_notifier = None;
while let Some(message) = receiver.recv().await {
match message {
Message::Record {
kind,
id,
date_time,
ip,
} => {
if self.pending_records.len() >= MAX_PENDING_RECORDS {
tracing::warn!("Too many pending activity records, flushing");
self.flush().await;
}
if self.pending_records.len() >= MAX_PENDING_RECORDS {
tracing::error!(
kind = kind.as_str(),
%id,
%date_time,
"Still too many pending activity records, dropping"
);
continue;
}
self.message_counter.add(
1,
&[TYPE.string("record"), SESSION_KIND.string(kind.as_str())],
);
let record =
self.pending_records
.entry((kind, id))
.or_insert_with(|| ActivityRecord {
start_time: date_time,
end_time: date_time,
ip,
});
record.end_time = date_time.max(record.end_time);
}
Message::Flush => {
self.message_counter.add(1, &[TYPE.string("flush")]);
self.flush().await;
}
Message::Shutdown(tx) => {
self.message_counter.add(1, &[TYPE.string("shutdown")]);
let old_tx = shutdown_notifier.replace(tx);
if let Some(old_tx) = old_tx {
tracing::warn!("Activity tracker shutdown requested while another shutdown was already in progress");
// Still send the shutdown signal to the previous notifier. This means we
// send the shutdown signal before we flush the activity tracker, but that
// should be fine, since there should not be multiple shutdown requests.
let _ = old_tx.send(());
}
receiver.close();
}
}
}
self.flush().await;
if let Some(shutdown_notifier) = shutdown_notifier {
let _ = shutdown_notifier.send(());
} else {
// This should never happen, since we set the shutdown notifier when we receive
// the first shutdown message
tracing::warn!("Activity tracker shutdown requested but no shutdown notifier was set");
}
}
/// Flush the activity tracker.
async fn flush(&mut self) {
// Short path: if there are no pending records, we don't need to flush
if self.pending_records.is_empty() {
return;
}
let start = std::time::Instant::now();
let res = self.try_flush().await;
// Measure the time it took to flush the activity tracker
let duration = start.elapsed();
let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
match res {
Ok(_) => {
self.flush_time_histogram
.record(duration_ms, &[RESULT.string("success")]);
}
Err(e) => {
self.flush_time_histogram
.record(duration_ms, &[RESULT.string("failure")]);
tracing::error!("Failed to flush activity tracker: {}", e);
}
}
}
/// Fallible part of [`Self::flush`].
async fn try_flush(&mut self) -> Result<(), anyhow::Error> {
let pending_records = &self.pending_records;
let repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
.await?
.boxed();
tracing::info!(
"Flushing {} activity records to the database",
pending_records.len()
);
// TODO: actually save the records
repo.save().await?;
self.pending_records.clear();
Ok(())
}
}

View File

@ -36,7 +36,7 @@ use sqlx::PgPool;
use crate::{
passwords::PasswordManager, site_config::SiteConfig, upstream_oauth2::cache::MetadataCache,
MatrixHomeserver,
ActivityTracker, BoundActivityTracker, MatrixHomeserver,
};
#[derive(Clone)]
@ -54,6 +54,7 @@ pub struct AppState {
pub password_manager: PasswordManager,
pub metadata_cache: MetadataCache,
pub site_config: SiteConfig,
pub activity_tracker: ActivityTracker,
pub conn_acquisition_histogram: Option<Histogram<u64>>,
}
@ -269,6 +270,32 @@ impl FromRequestParts<AppState> for Policy {
}
}
#[async_trait]
impl FromRequestParts<AppState> for ActivityTracker {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
Ok(state.activity_tracker.clone())
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoundActivityTracker {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
// TODO: grab the IP address from the request
let ip = None;
Ok(state.activity_tracker.clone().bind(ip))
}
}
#[async_trait]
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;

View File

@ -68,6 +68,7 @@ pub mod passwords;
pub mod upstream_oauth2;
mod views;
mod activity_tracker;
mod site_config;
#[cfg(test)]
mod test_utils;
@ -91,8 +92,12 @@ macro_rules! impl_from_error_for_route {
pub use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
pub use self::{
app_state::AppState, compat::MatrixHomeserver, graphql::schema as graphql_schema,
site_config::SiteConfig, upstream_oauth2::cache::MetadataCache,
activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
app_state::AppState,
compat::MatrixHomeserver,
graphql::schema as graphql_schema,
site_config::SiteConfig,
upstream_oauth2::cache::MetadataCache,
};
pub fn healthcheck_router<S, B>() -> Router<S, B>
@ -288,6 +293,7 @@ where
UrlBuilder: FromRef<S>,
BoxRepository: FromRequestParts<S>,
CookieJar: FromRequestParts<S>,
BoundActivityTracker: FromRequestParts<S>,
Encrypter: FromRef<S>,
Templates: FromRef<S>,
Keystore: FromRef<S>,

View File

@ -50,7 +50,7 @@ use crate::{
passwords::{Hasher, PasswordManager},
site_config::SiteConfig,
upstream_oauth2::cache::MetadataCache,
MatrixHomeserver,
ActivityTracker, BoundActivityTracker, MatrixHomeserver,
};
// This might fail if it's not the first time it's being called, which is fine,
@ -100,6 +100,7 @@ pub(crate) struct TestState {
pub http_client_factory: HttpClientFactory,
pub password_manager: PasswordManager,
pub site_config: SiteConfig,
pub activity_tracker: ActivityTracker,
pub clock: Arc<MockClock>,
pub rng: Arc<Mutex<ChaChaRng>>,
}
@ -160,6 +161,9 @@ impl TestState {
let graphql_schema = mas_graphql::schema_builder().data(state).finish();
let activity_tracker =
ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1));
Ok(Self {
pool,
templates,
@ -174,6 +178,7 @@ impl TestState {
http_client_factory,
password_manager,
site_config,
activity_tracker,
clock,
rng,
})
@ -366,6 +371,31 @@ impl FromRef<TestState> for SiteConfig {
}
}
#[async_trait]
impl FromRequestParts<TestState> for ActivityTracker {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &TestState,
) -> Result<Self, Self::Rejection> {
Ok(state.activity_tracker.clone())
}
}
#[async_trait]
impl FromRequestParts<TestState> for BoundActivityTracker {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &TestState,
) -> Result<Self, Self::Rejection> {
let ip = None;
Ok(state.activity_tracker.clone().bind(ip))
}
}
#[async_trait]
impl FromRequestParts<TestState> for BoxClock {
type Rejection = Infallible;

View File

@ -18,14 +18,18 @@ use axum::{
};
use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfoExt};
use mas_router::{PostAuthAction, Route};
use mas_storage::BoxRepository;
use mas_storage::{BoxClock, BoxRepository};
use mas_templates::{AppContext, Templates};
use crate::BoundActivityTracker;
#[tracing::instrument(name = "handlers.views.app.get", skip_all, err)]
pub async fn get(
State(templates): State<Templates>,
activity_tracker: BoundActivityTracker,
action: Option<Query<mas_router::AccountAction>>,
mut repo: BoxRepository,
clock: BoxClock,
cookie_jar: CookieJar,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
@ -33,13 +37,17 @@ pub async fn get(
let action = action.map(|Query(a)| a);
// TODO: keep the full path, not just the action
if session.is_none() {
let Some(session) = session else {
return Ok((
cookie_jar,
mas_router::Login::and_then(PostAuthAction::manage_account(action)).go(),
)
.into_response());
}
};
activity_tracker
.record_browser_session(&clock, &session)
.await;
let ctx = AppContext::default();
let content = templates.render_app(&ctx).await?;