1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Save the session activity in the database

This commit is contained in:
Quentin Gliech
2023-09-19 19:02:59 +02:00
parent 407c78a7be
commit b85655b944
14 changed files with 352 additions and 5 deletions

11
Cargo.lock generated
View File

@ -2399,6 +2399,15 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "ipnetwork"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e"
dependencies = [
"serde",
]
[[package]]
name = "iri-string"
version = "0.7.0"
@ -5177,6 +5186,7 @@ dependencies = [
"hashlink",
"hex",
"indexmap 2.0.0",
"ipnetwork",
"log",
"memchr",
"once_cell",
@ -5303,6 +5313,7 @@ dependencies = [
"hkdf",
"hmac",
"home",
"ipnetwork",
"itoa",
"log",
"md-5",

View File

@ -141,7 +141,9 @@ impl Options {
compat_token_ttl: config.experimental.compat_token_ttl,
};
let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60 * 5));
// Initialize the activity tracker
// Activity is flushed every minute
let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60));
// Explicitly the config to properly zeroize secret keys
drop(config);

View File

@ -15,7 +15,7 @@
use std::{collections::HashMap, net::IpAddr};
use chrono::{DateTime, Utc};
use mas_storage::Repository;
use mas_storage::{user::BrowserSessionRepository, Repository, RepositoryAccess};
use opentelemetry::{
metrics::{Counter, Histogram},
Key,
@ -38,6 +38,8 @@ const RESULT: Key = Key::from_static_str("result");
#[derive(Clone, Copy, Debug)]
struct ActivityRecord {
// XXX: We don't actually use the start time for now
#[allow(dead_code)]
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
ip: Option<IpAddr>,
@ -195,18 +197,47 @@ impl Worker {
}
/// Fallible part of [`Self::flush`].
#[tracing::instrument(name = "activity_tracker.flush", skip(self))]
async fn try_flush(&mut self) -> Result<(), anyhow::Error> {
let pending_records = &self.pending_records;
let repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
.await?
.boxed();
let mut browser_sessions = Vec::new();
let mut oauth2_sessions = Vec::new();
let mut compat_sessions = Vec::new();
for ((kind, id), record) in pending_records {
match kind {
SessionKind::Browser => {
browser_sessions.push((*id, record.end_time, record.ip));
}
SessionKind::OAuth2 => {
oauth2_sessions.push((*id, record.end_time, record.ip));
}
SessionKind::Compat => {
compat_sessions.push((*id, record.end_time, record.ip));
}
}
}
tracing::info!(
"Flushing {} activity records to the database",
pending_records.len()
);
// TODO: actually save the records
repo.browser_session()
.record_batch_activity(browser_sessions)
.await?;
repo.oauth2_session()
.record_batch_activity(oauth2_sessions)
.await?;
repo.compat_session()
.record_batch_activity(compat_sessions)
.await?;
repo.save().await?;
self.pending_records.clear();

View File

@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE user_sessions\n SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(user_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE user_sessions.user_session_id = t.user_session_id\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"TimestamptzArray",
"InetArray"
]
},
"nullable": []
},
"hash": "8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b"
}

View File

@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth2_sessions\n SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(oauth2_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"TimestamptzArray",
"InetArray"
]
},
"nullable": []
},
"hash": "d0c02576b1550fe2eb877d24f7cdfc819307ee0c47af9fbbf1a3b484290b321d"
}

View File

@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE compat_sessions\n SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(compat_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE compat_sessions.compat_session_id = t.compat_session_id\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"TimestamptzArray",
"InetArray"
]
},
"nullable": []
},
"hash": "d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930"
}

View File

@ -9,7 +9,7 @@ repository.workspace = true
[dependencies]
async-trait = "0.1.73"
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid"] }
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid", "ipnetwork"] }
sea-query = { version = "0.30.1", features = ["derive", "attr", "with-uuid", "with-chrono", "postgres-array"] }
sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-uuid", "with-chrono", "postgres-array"] }
chrono.workspace = true

View File

@ -0,0 +1,39 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- This adds a `last_active_at` timestamp and a `last_active_ip` column
-- to the `oauth2_sessions`, `user_sessions` and `compat_sessions` tables.
-- The timestamp is indexed with the `user_id`, as they are likely to be queried together.
ALTER TABLE "oauth2_sessions"
ADD COLUMN "last_active_at" TIMESTAMP WITH TIME ZONE,
ADD COLUMN "last_active_ip" INET;
CREATE INDEX "oauth2_sessions_user_id_last_active_at"
ON "oauth2_sessions" ("user_id", "last_active_at");
ALTER TABLE "user_sessions"
ADD COLUMN "last_active_at" TIMESTAMP WITH TIME ZONE,
ADD COLUMN "last_active_ip" INET;
CREATE INDEX "user_sessions_user_id_last_active_at"
ON "user_sessions" ("user_id", "last_active_at");
ALTER TABLE "compat_sessions"
ADD COLUMN "last_active_at" TIMESTAMP WITH TIME ZONE,
ADD COLUMN "last_active_ip" INET;
CREATE INDEX "compat_sessions_user_id_last_active_at"
ON "compat_sessions" ("user_id", "last_active_at");

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
@ -505,4 +507,51 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.compat_session.record_batch_activity",
skip_all,
fields(
db.statement,
),
err,
)]
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error> {
let mut ids = Vec::with_capacity(activity.len());
let mut last_activities = Vec::with_capacity(activity.len());
let mut ips = Vec::with_capacity(activity.len());
for (id, last_activity, ip) in activity {
ids.push(Uuid::from(id));
last_activities.push(last_activity);
ips.push(ip);
}
let res = sqlx::query!(
r#"
UPDATE compat_sessions
SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
, last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
FROM (
SELECT *
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
AS t(compat_session_id, last_active_at, last_active_ip)
) AS t
WHERE compat_sessions.compat_session_id = t.compat_session_id
"#,
&ids,
&last_activities,
&ips as &[Option<IpAddr>],
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
Ok(())
}
}

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
@ -362,4 +364,51 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.oauth2_session.record_batch_activity",
skip_all,
fields(
db.statement,
),
err,
)]
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error> {
let mut ids = Vec::with_capacity(activity.len());
let mut last_activities = Vec::with_capacity(activity.len());
let mut ips = Vec::with_capacity(activity.len());
for (id, last_activity, ip) in activity {
ids.push(Uuid::from(id));
last_activities.push(last_activity);
ips.push(ip);
}
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions
SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
, last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
FROM (
SELECT *
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
AS t(oauth2_session_id, last_active_at, last_active_ip)
) AS t
WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
"#,
&ids,
&last_activities,
&ips as &[Option<IpAddr>],
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
Ok(())
}
}

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
@ -504,4 +506,51 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
let authentication = Authentication::try_from(authentication)?;
Ok(Some(authentication))
}
#[tracing::instrument(
name = "db.browser_session.record_batch_activity",
skip_all,
fields(
db.statement,
),
err,
)]
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error> {
let mut ids = Vec::with_capacity(activity.len());
let mut last_activities = Vec::with_capacity(activity.len());
let mut ips = Vec::with_capacity(activity.len());
for (id, last_activity, ip) in activity {
ids.push(Uuid::from(id));
last_activities.push(last_activity);
ips.push(ip);
}
let res = sqlx::query!(
r#"
UPDATE user_sessions
SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
, last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
FROM (
SELECT *
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
AS t(user_session_id, last_active_at, last_active_ip)
) AS t
WHERE user_sessions.user_session_id = t.user_session_id
"#,
&ids,
&last_activities,
&ips as &[Option<IpAddr>],
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
Ok(())
}
}

View File

@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{CompatSession, CompatSsoLogin, Device, User};
use rand_core::RngCore;
use ulid::Ulid;
@ -236,6 +239,21 @@ pub trait CompatSessionRepository: Send + Sync {
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error>;
/// Record a batch of [`Session`] activity
///
/// # Parameters
///
/// * `activity`: A list of tuples containing the session ID, the last
/// activity timestamp and the IP address of the client
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>;
}
repository_impl!(CompatSessionRepository:
@ -269,4 +287,9 @@ repository_impl!(CompatSessionRepository:
) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error>;
async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error>;
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>;
);

View File

@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, Client, Session, User};
use oauth2_types::scope::Scope;
use rand_core::RngCore;
@ -268,6 +271,21 @@ pub trait OAuth2SessionRepository: Send + Sync {
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
/// Record a batch of [`Session`] activity
///
/// # Parameters
///
/// * `activity`: A list of tuples containing the session ID, the last
/// activity timestamp and the IP address of the client
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>;
}
repository_impl!(OAuth2SessionRepository:
@ -310,4 +328,9 @@ repository_impl!(OAuth2SessionRepository:
) -> Result<Page<Session>, Self::Error>;
async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>;
);

View File

@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
Authentication, BrowserSession, Password, UpstreamOAuthAuthorizationSession, User,
};
@ -227,6 +230,21 @@ pub trait BrowserSessionRepository: Send + Sync {
&mut self,
user_session: &BrowserSession,
) -> Result<Option<Authentication>, Self::Error>;
/// Record a batch of [`Session`] activity
///
/// # Parameters
///
/// * `activity`: A list of tuples containing the session ID, the last
/// activity timestamp and the IP address of the client
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>;
}
repository_impl!(BrowserSessionRepository:
@ -272,4 +290,9 @@ repository_impl!(BrowserSessionRepository:
&mut self,
user_session: &BrowserSession,
) -> Result<Option<Authentication>, Self::Error>;
async fn record_batch_activity(
&mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>;
);