diff --git a/Cargo.lock b/Cargo.lock index a46d94a5..64f1c8bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2399,6 +2399,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "iri-string" version = "0.7.0" @@ -5177,6 +5186,7 @@ dependencies = [ "hashlink", "hex", "indexmap 2.0.0", + "ipnetwork", "log", "memchr", "once_cell", @@ -5303,6 +5313,7 @@ dependencies = [ "hkdf", "hmac", "home", + "ipnetwork", "itoa", "log", "md-5", diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index b79199b5..55e3cc24 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -141,7 +141,9 @@ impl Options { compat_token_ttl: config.experimental.compat_token_ttl, }; - let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60 * 5)); + // Initialize the activity tracker + // Activity is flushed every minute + let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60)); // Explicitly the config to properly zeroize secret keys drop(config); diff --git a/crates/handlers/src/activity_tracker/worker.rs b/crates/handlers/src/activity_tracker/worker.rs index ed37636e..7fde6b44 100644 --- a/crates/handlers/src/activity_tracker/worker.rs +++ b/crates/handlers/src/activity_tracker/worker.rs @@ -15,7 +15,7 @@ use std::{collections::HashMap, net::IpAddr}; use chrono::{DateTime, Utc}; -use mas_storage::Repository; +use mas_storage::{user::BrowserSessionRepository, Repository, RepositoryAccess}; use opentelemetry::{ metrics::{Counter, Histogram}, Key, @@ -38,6 +38,8 @@ const RESULT: Key = Key::from_static_str("result"); #[derive(Clone, Copy, Debug)] struct ActivityRecord { + // XXX: We don't actually use the start time for now + #[allow(dead_code)] start_time: DateTime, end_time: DateTime, ip: Option, @@ -195,18 +197,47 @@ impl Worker { } /// Fallible part of [`Self::flush`]. + #[tracing::instrument(name = "activity_tracker.flush", skip(self))] async fn try_flush(&mut self) -> Result<(), anyhow::Error> { let pending_records = &self.pending_records; - let repo = mas_storage_pg::PgRepository::from_pool(&self.pool) + let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool) .await? .boxed(); + let mut browser_sessions = Vec::new(); + let mut oauth2_sessions = Vec::new(); + let mut compat_sessions = Vec::new(); + + for ((kind, id), record) in pending_records { + match kind { + SessionKind::Browser => { + browser_sessions.push((*id, record.end_time, record.ip)); + } + SessionKind::OAuth2 => { + oauth2_sessions.push((*id, record.end_time, record.ip)); + } + SessionKind::Compat => { + compat_sessions.push((*id, record.end_time, record.ip)); + } + } + } + tracing::info!( "Flushing {} activity records to the database", pending_records.len() ); - // TODO: actually save the records + + repo.browser_session() + .record_batch_activity(browser_sessions) + .await?; + repo.oauth2_session() + .record_batch_activity(oauth2_sessions) + .await?; + repo.compat_session() + .record_batch_activity(compat_sessions) + .await?; + repo.save().await?; self.pending_records.clear(); diff --git a/crates/storage-pg/.sqlx/query-8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b.json b/crates/storage-pg/.sqlx/query-8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b.json new file mode 100644 index 00000000..db3f56f3 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE user_sessions\n SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(user_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE user_sessions.user_session_id = t.user_session_id\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "UuidArray", + "TimestamptzArray", + "InetArray" + ] + }, + "nullable": [] + }, + "hash": "8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b" +} diff --git a/crates/storage-pg/.sqlx/query-d0c02576b1550fe2eb877d24f7cdfc819307ee0c47af9fbbf1a3b484290b321d.json b/crates/storage-pg/.sqlx/query-d0c02576b1550fe2eb877d24f7cdfc819307ee0c47af9fbbf1a3b484290b321d.json new file mode 100644 index 00000000..550b6ece --- /dev/null +++ b/crates/storage-pg/.sqlx/query-d0c02576b1550fe2eb877d24f7cdfc819307ee0c47af9fbbf1a3b484290b321d.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth2_sessions\n SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(oauth2_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "UuidArray", + "TimestamptzArray", + "InetArray" + ] + }, + "nullable": [] + }, + "hash": "d0c02576b1550fe2eb877d24f7cdfc819307ee0c47af9fbbf1a3b484290b321d" +} diff --git a/crates/storage-pg/.sqlx/query-d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930.json b/crates/storage-pg/.sqlx/query-d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930.json new file mode 100644 index 00000000..14d44776 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE compat_sessions\n SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(compat_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE compat_sessions.compat_session_id = t.compat_session_id\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "UuidArray", + "TimestamptzArray", + "InetArray" + ] + }, + "nullable": [] + }, + "hash": "d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930" +} diff --git a/crates/storage-pg/Cargo.toml b/crates/storage-pg/Cargo.toml index 05c43dd6..ee8f3e37 100644 --- a/crates/storage-pg/Cargo.toml +++ b/crates/storage-pg/Cargo.toml @@ -9,7 +9,7 @@ repository.workspace = true [dependencies] async-trait = "0.1.73" -sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid"] } +sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid", "ipnetwork"] } sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono", "postgres-array"] } sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono", "postgres-array"] } chrono.workspace = true diff --git a/crates/storage-pg/migrations/20230919155444_record_session_last_activity.sql b/crates/storage-pg/migrations/20230919155444_record_session_last_activity.sql new file mode 100644 index 00000000..0336c9df --- /dev/null +++ b/crates/storage-pg/migrations/20230919155444_record_session_last_activity.sql @@ -0,0 +1,39 @@ +-- Copyright 2023 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- This adds a `last_active_at` timestamp and a `last_active_ip` column +-- to the `oauth2_sessions`, `user_sessions` and `compat_sessions` tables. +-- The timestamp is indexed with the `user_id`, as they are likely to be queried together. +ALTER TABLE "oauth2_sessions" + ADD COLUMN "last_active_at" TIMESTAMP WITH TIME ZONE, + ADD COLUMN "last_active_ip" INET; + +CREATE INDEX "oauth2_sessions_user_id_last_active_at" + ON "oauth2_sessions" ("user_id", "last_active_at"); + + +ALTER TABLE "user_sessions" + ADD COLUMN "last_active_at" TIMESTAMP WITH TIME ZONE, + ADD COLUMN "last_active_ip" INET; + +CREATE INDEX "user_sessions_user_id_last_active_at" + ON "user_sessions" ("user_id", "last_active_at"); + + +ALTER TABLE "compat_sessions" + ADD COLUMN "last_active_at" TIMESTAMP WITH TIME ZONE, + ADD COLUMN "last_active_ip" INET; + +CREATE INDEX "compat_sessions_user_id_last_active_at" + ON "compat_sessions" ("user_id", "last_active_at"); \ No newline at end of file diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index cbfeb710..3695d617 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::IpAddr; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ @@ -505,4 +507,51 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { .try_into() .map_err(DatabaseError::to_invalid_operation) } + + #[tracing::instrument( + name = "db.compat_session.record_batch_activity", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error> { + let mut ids = Vec::with_capacity(activity.len()); + let mut last_activities = Vec::with_capacity(activity.len()); + let mut ips = Vec::with_capacity(activity.len()); + + for (id, last_activity, ip) in activity { + ids.push(Uuid::from(id)); + last_activities.push(last_activity); + ips.push(ip); + } + + let res = sqlx::query!( + r#" + UPDATE compat_sessions + SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at) + , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip) + FROM ( + SELECT * + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + AS t(compat_session_id, last_active_at, last_active_ip) + ) AS t + WHERE compat_sessions.compat_session_id = t.compat_session_id + "#, + &ids, + &last_activities, + &ips as &[Option], + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?; + + Ok(()) + } } diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 98e6563a..83b44b17 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::IpAddr; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{BrowserSession, Client, Session, SessionState, User}; @@ -362,4 +364,51 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { .try_into() .map_err(DatabaseError::to_invalid_operation) } + + #[tracing::instrument( + name = "db.oauth2_session.record_batch_activity", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error> { + let mut ids = Vec::with_capacity(activity.len()); + let mut last_activities = Vec::with_capacity(activity.len()); + let mut ips = Vec::with_capacity(activity.len()); + + for (id, last_activity, ip) in activity { + ids.push(Uuid::from(id)); + last_activities.push(last_activity); + ips.push(ip); + } + + let res = sqlx::query!( + r#" + UPDATE oauth2_sessions + SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at) + , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip) + FROM ( + SELECT * + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + AS t(oauth2_session_id, last_active_at, last_active_ip) + ) AS t + WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id + "#, + &ids, + &last_activities, + &ips as &[Option], + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?; + + Ok(()) + } } diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index 80f09962..4fa7d603 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::IpAddr; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ @@ -504,4 +506,51 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { let authentication = Authentication::try_from(authentication)?; Ok(Some(authentication)) } + + #[tracing::instrument( + name = "db.browser_session.record_batch_activity", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error> { + let mut ids = Vec::with_capacity(activity.len()); + let mut last_activities = Vec::with_capacity(activity.len()); + let mut ips = Vec::with_capacity(activity.len()); + + for (id, last_activity, ip) in activity { + ids.push(Uuid::from(id)); + last_activities.push(last_activity); + ips.push(ip); + } + + let res = sqlx::query!( + r#" + UPDATE user_sessions + SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at) + , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip) + FROM ( + SELECT * + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + AS t(user_session_id, last_active_at, last_active_ip) + ) AS t + WHERE user_sessions.user_session_id = t.user_session_id + "#, + &ids, + &last_activities, + &ips as &[Option], + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?; + + Ok(()) + } } diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index ec28ebcf..aa370c6a 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::IpAddr; + use async_trait::async_trait; +use chrono::{DateTime, Utc}; use mas_data_model::{CompatSession, CompatSsoLogin, Device, User}; use rand_core::RngCore; use ulid::Ulid; @@ -236,6 +239,21 @@ pub trait CompatSessionRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result; + + /// Record a batch of [`Session`] activity + /// + /// # Parameters + /// + /// * `activity`: A list of tuples containing the session ID, the last + /// activity timestamp and the IP address of the client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; } repository_impl!(CompatSessionRepository: @@ -269,4 +287,9 @@ repository_impl!(CompatSessionRepository: ) -> Result)>, Self::Error>; async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result; + + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; ); diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 56fad3f4..4ccc0d08 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::IpAddr; + use async_trait::async_trait; +use chrono::{DateTime, Utc}; use mas_data_model::{BrowserSession, Client, Session, User}; use oauth2_types::scope::Scope; use rand_core::RngCore; @@ -268,6 +271,21 @@ pub trait OAuth2SessionRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result; + + /// Record a batch of [`Session`] activity + /// + /// # Parameters + /// + /// * `activity`: A list of tuples containing the session ID, the last + /// activity timestamp and the IP address of the client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; } repository_impl!(OAuth2SessionRepository: @@ -310,4 +328,9 @@ repository_impl!(OAuth2SessionRepository: ) -> Result, Self::Error>; async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result; + + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; ); diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index e4fc6814..fa1fd8b6 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::IpAddr; + use async_trait::async_trait; +use chrono::{DateTime, Utc}; use mas_data_model::{ Authentication, BrowserSession, Password, UpstreamOAuthAuthorizationSession, User, }; @@ -227,6 +230,21 @@ pub trait BrowserSessionRepository: Send + Sync { &mut self, user_session: &BrowserSession, ) -> Result, Self::Error>; + + /// Record a batch of [`Session`] activity + /// + /// # Parameters + /// + /// * `activity`: A list of tuples containing the session ID, the last + /// activity timestamp and the IP address of the client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; } repository_impl!(BrowserSessionRepository: @@ -272,4 +290,9 @@ repository_impl!(BrowserSessionRepository: &mut self, user_session: &BrowserSession, ) -> Result, Self::Error>; + + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; );