1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Record user agents on OAuth 2.0 and compat sessions (#2386)

* Record user agents on OAuth 2.0 and compat sessions

* Add tests for recording user agent in sessions
This commit is contained in:
Quentin Gliech
2024-02-22 10:01:32 +01:00
committed by GitHub
parent 7de4be219b
commit f171d76dc5
17 changed files with 303 additions and 13 deletions

View File

@@ -83,6 +83,7 @@ pub struct CompatSession {
pub user_session_id: Option<Ulid>, pub user_session_id: Option<Ulid>,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub is_synapse_admin: bool, pub is_synapse_admin: bool,
pub user_agent: Option<String>,
pub last_active_at: Option<DateTime<Utc>>, pub last_active_at: Option<DateTime<Utc>>,
pub last_active_ip: Option<IpAddr>, pub last_active_ip: Option<IpAddr>,
} }

View File

@@ -75,6 +75,7 @@ pub struct Session {
pub user_session_id: Option<Ulid>, pub user_session_id: Option<Ulid>,
pub client_id: Ulid, pub client_id: Ulid,
pub scope: Scope, pub scope: Scope,
pub user_agent: Option<String>,
pub last_active_at: Option<DateTime<Utc>>, pub last_active_at: Option<DateTime<Utc>>,
pub last_active_ip: Option<IpAddr>, pub last_active_ip: Option<IpAddr>,
} }

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
use chrono::Duration; use chrono::Duration;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::sentry::SentryEventID; use mas_axum_utils::sentry::SentryEventID;
@@ -217,9 +217,11 @@ pub(crate) async fn post(
activity_tracker: BoundActivityTracker, activity_tracker: BoundActivityTracker,
State(homeserver): State<MatrixHomeserver>, State(homeserver): State<MatrixHomeserver>,
State(site_config): State<SiteConfig>, State(site_config): State<SiteConfig>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Json(input): Json<RequestBody>, Json(input): Json<RequestBody>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let (session, user) = match (password_manager.is_enabled(), input.credentials) { let user_agent = user_agent.map(|ua| ua.to_string());
let (mut session, user) = match (password_manager.is_enabled(), input.credentials) {
( (
true, true,
Credentials::Password { Credentials::Password {
@@ -245,6 +247,13 @@ pub(crate) async fn post(
} }
}; };
if let Some(user_agent) = user_agent {
session = repo
.compat_session()
.record_user_agent(session, user_agent)
.await?;
}
let user_id = format!("@{username}:{homeserver}", username = user.username); let user_id = format!("@{username}:{homeserver}", username = user.username);
// If the client asked for a refreshable token, make it expire // If the client asked for a refreshable token, make it expire

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
use hyper::StatusCode; use hyper::StatusCode;
@@ -230,8 +230,10 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>, State(site_config): State<SiteConfig>,
State(encrypter): State<Encrypter>, State(encrypter): State<Encrypter>,
policy: Policy, policy: Policy,
user_agent: Option<TypedHeader<headers::UserAgent>>,
client_authorization: ClientAuthorization<AccessTokenRequest>, client_authorization: ClientAuthorization<AccessTokenRequest>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
let user_agent = user_agent.map(|ua| ua.to_string());
let client = client_authorization let client = client_authorization
.credentials .credentials
.fetch(&mut repo) .fetch(&mut repo)
@@ -262,6 +264,7 @@ pub(crate) async fn post(
&url_builder, &url_builder,
&site_config, &site_config,
repo, repo,
user_agent,
) )
.await? .await?
} }
@@ -274,6 +277,7 @@ pub(crate) async fn post(
&client, &client,
&site_config, &site_config,
repo, repo,
user_agent,
) )
.await? .await?
} }
@@ -287,6 +291,7 @@ pub(crate) async fn post(
&site_config, &site_config,
repo, repo,
policy, policy,
user_agent,
) )
.await? .await?
} }
@@ -301,6 +306,7 @@ pub(crate) async fn post(
&url_builder, &url_builder,
&site_config, &site_config,
repo, repo,
user_agent,
) )
.await? .await?
} }
@@ -329,6 +335,7 @@ async fn authorization_code_grant(
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
site_config: &SiteConfig, site_config: &SiteConfig,
mut repo: BoxRepository, mut repo: BoxRepository,
user_agent: Option<String>,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type // Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::AuthorizationCode) { if !client.grant_types.contains(&GrantType::AuthorizationCode) {
@@ -386,12 +393,19 @@ async fn authorization_code_grant(
} }
}; };
let session = repo let mut session = repo
.oauth2_session() .oauth2_session()
.lookup(session_id) .lookup(session_id)
.await? .await?
.ok_or(RouteError::NoSuchOAuthSession)?; .ok_or(RouteError::NoSuchOAuthSession)?;
if let Some(user_agent) = user_agent {
session = repo
.oauth2_session()
.record_user_agent(session, user_agent)
.await?;
}
// This should never happen, since we looked up in the database using the code // This should never happen, since we looked up in the database using the code
let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?; let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
@@ -490,6 +504,7 @@ async fn refresh_token_grant(
client: &Client, client: &Client,
site_config: &SiteConfig, site_config: &SiteConfig,
mut repo: BoxRepository, mut repo: BoxRepository,
user_agent: Option<String>,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type // Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::RefreshToken) { if !client.grant_types.contains(&GrantType::RefreshToken) {
@@ -502,12 +517,21 @@ async fn refresh_token_grant(
.await? .await?
.ok_or(RouteError::RefreshTokenNotFound)?; .ok_or(RouteError::RefreshTokenNotFound)?;
let session = repo let mut session = repo
.oauth2_session() .oauth2_session()
.lookup(refresh_token.session_id) .lookup(refresh_token.session_id)
.await? .await?
.ok_or(RouteError::NoSuchOAuthSession)?; .ok_or(RouteError::NoSuchOAuthSession)?;
// Let's for now record the user agent on each refresh, that should be
// responsive enough and not too much of a burden on the database.
if let Some(user_agent) = user_agent {
session = repo
.oauth2_session()
.record_user_agent(session, user_agent)
.await?;
}
if !refresh_token.is_valid() { if !refresh_token.is_valid() {
return Err(RouteError::RefreshTokenInvalid(refresh_token.id)); return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
} }
@@ -563,6 +587,7 @@ async fn client_credentials_grant(
site_config: &SiteConfig, site_config: &SiteConfig,
mut repo: BoxRepository, mut repo: BoxRepository,
mut policy: Policy, mut policy: Policy,
user_agent: Option<String>,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type // Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::ClientCredentials) { if !client.grant_types.contains(&GrantType::ClientCredentials) {
@@ -584,11 +609,18 @@ async fn client_credentials_grant(
} }
// Start the session // Start the session
let session = repo let mut session = repo
.oauth2_session() .oauth2_session()
.add_from_client_credentials(rng, clock, client, scope) .add_from_client_credentials(rng, clock, client, scope)
.await?; .await?;
if let Some(user_agent) = user_agent {
session = repo
.oauth2_session()
.record_user_agent(session, user_agent)
.await?;
}
let ttl = site_config.access_token_ttl; let ttl = site_config.access_token_ttl;
let access_token_str = TokenType::AccessToken.generate(rng); let access_token_str = TokenType::AccessToken.generate(rng);
@@ -624,6 +656,7 @@ async fn device_code_grant(
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
site_config: &SiteConfig, site_config: &SiteConfig,
mut repo: BoxRepository, mut repo: BoxRepository,
user_agent: Option<String>,
) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
// Check that the client is allowed to use this grant type // Check that the client is allowed to use this grant type
if !client.grant_types.contains(&GrantType::DeviceCode) { if !client.grant_types.contains(&GrantType::DeviceCode) {
@@ -670,11 +703,19 @@ async fn device_code_grant(
.ok_or(RouteError::NoSuchBrowserSession)?; .ok_or(RouteError::NoSuchBrowserSession)?;
// Start the session // Start the session
let session = repo let mut session = repo
.oauth2_session() .oauth2_session()
.add_from_browser_session(rng, clock, client, &browser_session, grant.scope) .add_from_browser_session(rng, clock, client, &browser_session, grant.scope)
.await?; .await?;
// XXX: should we get the user agent from the device code grant instead?
if let Some(user_agent) = user_agent {
session = repo
.oauth2_session()
.record_user_agent(session, user_agent)
.await?;
}
let ttl = site_config.access_token_ttl; let ttl = site_config.access_token_ttl;
let access_token_str = TokenType::AccessToken.generate(rng); let access_token_str = TokenType::AccessToken.generate(rng);

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth2_sessions\n SET user_agent = $2\n WHERE oauth2_session_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": []
},
"hash": "1919d402fd6f148d14417f633be3353004f458c85f7b4f361802f86651900fbc"
}

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE compat_sessions\n SET user_agent = $2\n WHERE compat_session_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": []
},
"hash": "29148548d592046f7d711676911e3847e376e443ccd841f76b17a81f53fafc3a"
}

View File

@@ -1,6 +1,6 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n SELECT oauth2_session_id\n , user_id\n , user_session_id\n , oauth2_client_id\n , scope_list\n , created_at\n , finished_at\n , last_active_at\n , last_active_ip as \"last_active_ip: IpAddr\"\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n ", "query": "\n SELECT oauth2_session_id\n , user_id\n , user_session_id\n , oauth2_client_id\n , scope_list\n , created_at\n , finished_at\n , user_agent\n , last_active_at\n , last_active_ip as \"last_active_ip: IpAddr\"\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
@@ -40,11 +40,16 @@
}, },
{ {
"ordinal": 7, "ordinal": 7,
"name": "user_agent",
"type_info": "Text"
},
{
"ordinal": 8,
"name": "last_active_at", "name": "last_active_at",
"type_info": "Timestamptz" "type_info": "Timestamptz"
}, },
{ {
"ordinal": 8, "ordinal": 9,
"name": "last_active_ip: IpAddr", "name": "last_active_ip: IpAddr",
"type_info": "Inet" "type_info": "Inet"
} }
@@ -63,8 +68,9 @@
false, false,
true, true,
true, true,
true,
true true
] ]
}, },
"hash": "31aace373b20b5dbf65fa51d8663da7571d85b6a7d2d544d69e7d04260cdffc9" "hash": "5a2e9b5002c1927c0035c22e393172b36ab46a4377b46618205151ea041886d5"
} }

View File

@@ -1,6 +1,6 @@
{ {
"db_name": "PostgreSQL", "db_name": "PostgreSQL",
"query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , user_session_id\n , created_at\n , finished_at\n , is_synapse_admin\n , last_active_at\n , last_active_ip as \"last_active_ip: IpAddr\"\n FROM compat_sessions\n WHERE compat_session_id = $1\n ", "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , user_session_id\n , created_at\n , finished_at\n , is_synapse_admin\n , user_agent\n , last_active_at\n , last_active_ip as \"last_active_ip: IpAddr\"\n FROM compat_sessions\n WHERE compat_session_id = $1\n ",
"describe": { "describe": {
"columns": [ "columns": [
{ {
@@ -40,11 +40,16 @@
}, },
{ {
"ordinal": 7, "ordinal": 7,
"name": "user_agent",
"type_info": "Text"
},
{
"ordinal": 8,
"name": "last_active_at", "name": "last_active_at",
"type_info": "Timestamptz" "type_info": "Timestamptz"
}, },
{ {
"ordinal": 8, "ordinal": 9,
"name": "last_active_ip: IpAddr", "name": "last_active_ip: IpAddr",
"type_info": "Inet" "type_info": "Inet"
} }
@@ -63,8 +68,9 @@
true, true,
false, false,
true, true,
true,
true true
] ]
}, },
"hash": "04e25c9267bf2eb143a6445345229081e7b386743a93b3833ef8ad9d09972f3b" "hash": "bb6f55a4cc10bec8ec0fc138485f6b4d308302bb1fa3accb12932d1e5ce457e9"
} }

View File

@@ -0,0 +1,17 @@
-- Copyright 2024 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- Adds user agent columns to oauth and compat sessions tables
ALTER TABLE oauth2_sessions ADD COLUMN user_agent TEXT;
ALTER TABLE compat_sessions ADD COLUMN user_agent TEXT;

View File

@@ -73,6 +73,7 @@ mod priv_ {
pub(super) created_at: DateTime<Utc>, pub(super) created_at: DateTime<Utc>,
pub(super) finished_at: Option<DateTime<Utc>>, pub(super) finished_at: Option<DateTime<Utc>>,
pub(super) is_synapse_admin: Option<bool>, pub(super) is_synapse_admin: Option<bool>,
pub(super) user_agent: Option<String>,
pub(super) last_active_at: Option<DateTime<Utc>>, pub(super) last_active_at: Option<DateTime<Utc>>,
pub(super) last_active_ip: Option<IpAddr>, pub(super) last_active_ip: Option<IpAddr>,
} }
@@ -98,6 +99,7 @@ impl TryFrom<AppSessionLookup> for AppSession {
created_at, created_at,
finished_at, finished_at,
is_synapse_admin, is_synapse_admin,
user_agent,
last_active_at, last_active_at,
last_active_ip, last_active_ip,
} = value; } = value;
@@ -143,6 +145,7 @@ impl TryFrom<AppSessionLookup> for AppSession {
user_session_id, user_session_id,
created_at, created_at,
is_synapse_admin, is_synapse_admin,
user_agent,
last_active_at, last_active_at,
last_active_ip, last_active_ip,
}; };
@@ -182,6 +185,7 @@ impl TryFrom<AppSessionLookup> for AppSession {
user_id: user_id.map(Ulid::from), user_id: user_id.map(Ulid::from),
user_session_id, user_session_id,
scope, scope,
user_agent,
last_active_at, last_active_at,
last_active_ip, last_active_ip,
}; };
@@ -250,6 +254,10 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
AppSessionLookupIden::FinishedAt, AppSessionLookupIden::FinishedAt,
) )
.expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin) .expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin)
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
AppSessionLookupIden::UserAgent,
)
.expr_as( .expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)), Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
AppSessionLookupIden::LastActiveAt, AppSessionLookupIden::LastActiveAt,
@@ -317,6 +325,10 @@ impl<'c> AppSessionRepository for PgAppSessionRepository<'c> {
Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)), Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
AppSessionLookupIden::IsSynapseAdmin, AppSessionLookupIden::IsSynapseAdmin,
) )
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
AppSessionLookupIden::UserAgent,
)
.expr_as( .expr_as(
Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)), Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
AppSessionLookupIden::LastActiveAt, AppSessionLookupIden::LastActiveAt,

View File

@@ -129,6 +129,24 @@ mod tests {
assert!(session_lookup.is_valid()); assert!(session_lookup.is_valid());
assert!(!session_lookup.is_finished()); assert!(!session_lookup.is_finished());
// Record a user-agent for the session
assert!(session_lookup.user_agent.is_none());
let session = repo
.compat_session()
.record_user_agent(session_lookup, "Mozilla/5.0".to_owned())
.await
.unwrap();
assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
// Reload the session and check again
let session_lookup = repo
.compat_session()
.lookup(session.id)
.await
.unwrap()
.expect("compat session not found");
assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
// Look up the session by device // Look up the session by device
let list = repo let list = repo
.compat_session() .compat_session()

View File

@@ -60,6 +60,7 @@ struct CompatSessionLookup {
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>, finished_at: Option<DateTime<Utc>>,
is_synapse_admin: bool, is_synapse_admin: bool,
user_agent: Option<String>,
last_active_at: Option<DateTime<Utc>>, last_active_at: Option<DateTime<Utc>>,
last_active_ip: Option<IpAddr>, last_active_ip: Option<IpAddr>,
} }
@@ -89,6 +90,7 @@ impl TryFrom<CompatSessionLookup> for CompatSession {
device, device,
created_at: value.created_at, created_at: value.created_at,
is_synapse_admin: value.is_synapse_admin, is_synapse_admin: value.is_synapse_admin,
user_agent: value.user_agent,
last_active_at: value.last_active_at, last_active_at: value.last_active_at,
last_active_ip: value.last_active_ip, last_active_ip: value.last_active_ip,
}; };
@@ -107,6 +109,7 @@ struct CompatSessionAndSsoLoginLookup {
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>, finished_at: Option<DateTime<Utc>>,
is_synapse_admin: bool, is_synapse_admin: bool,
user_agent: Option<String>,
last_active_at: Option<DateTime<Utc>>, last_active_at: Option<DateTime<Utc>>,
last_active_ip: Option<IpAddr>, last_active_ip: Option<IpAddr>,
compat_sso_login_id: Option<Uuid>, compat_sso_login_id: Option<Uuid>,
@@ -142,6 +145,7 @@ impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSs
user_session_id: value.user_session_id.map(Ulid::from), user_session_id: value.user_session_id.map(Ulid::from),
created_at: value.created_at, created_at: value.created_at,
is_synapse_admin: value.is_synapse_admin, is_synapse_admin: value.is_synapse_admin,
user_agent: value.user_agent,
last_active_at: value.last_active_at, last_active_at: value.last_active_at,
last_active_ip: value.last_active_ip, last_active_ip: value.last_active_ip,
}; };
@@ -223,6 +227,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
, created_at , created_at
, finished_at , finished_at
, is_synapse_admin , is_synapse_admin
, user_agent
, last_active_at , last_active_at
, last_active_ip as "last_active_ip: IpAddr" , last_active_ip as "last_active_ip: IpAddr"
FROM compat_sessions FROM compat_sessions
@@ -290,6 +295,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
user_session_id: browser_session.map(|s| s.id), user_session_id: browser_session.map(|s| s.id),
created_at, created_at,
is_synapse_admin, is_synapse_admin,
user_agent: None,
last_active_at: None, last_active_at: None,
last_active_ip: None, last_active_ip: None,
}) })
@@ -377,6 +383,10 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)), Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin, CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
) )
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
CompatSessionAndSsoLoginLookupIden::UserAgent,
)
.expr_as( .expr_as(
Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)), Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
CompatSessionAndSsoLoginLookupIden::LastActiveAt, CompatSessionAndSsoLoginLookupIden::LastActiveAt,
@@ -552,4 +562,38 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
Ok(()) Ok(())
} }
#[tracing::instrument(
name = "db.compat_session.record_user_agent",
skip_all,
fields(
db.statement,
%compat_session.id,
),
err,
)]
async fn record_user_agent(
&mut self,
mut compat_session: CompatSession,
user_agent: String,
) -> Result<CompatSession, Self::Error> {
let res = sqlx::query!(
r#"
UPDATE compat_sessions
SET user_agent = $2
WHERE compat_session_id = $1
"#,
Uuid::from(compat_session.id),
user_agent,
)
.traced()
.execute(&mut *self.conn)
.await?;
compat_session.user_agent = Some(user_agent);
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(compat_session)
}
} }

View File

@@ -57,6 +57,7 @@ pub enum CompatSessions {
CreatedAt, CreatedAt,
FinishedAt, FinishedAt,
IsSynapseAdmin, IsSynapseAdmin,
UserAgent,
LastActiveAt, LastActiveAt,
LastActiveIp, LastActiveIp,
} }
@@ -86,6 +87,7 @@ pub enum OAuth2Sessions {
ScopeList, ScopeList,
CreatedAt, CreatedAt,
FinishedAt, FinishedAt,
UserAgent,
LastActiveAt, LastActiveAt,
LastActiveIp, LastActiveIp,
} }

View File

@@ -367,6 +367,24 @@ mod tests {
.unwrap(); .unwrap();
assert!(!refresh_token.is_valid()); assert!(!refresh_token.is_valid());
// Record the user-agent on the session
assert!(session.user_agent.is_none());
let session = repo
.oauth2_session()
.record_user_agent(session, "Mozilla/5.0".to_owned())
.await
.unwrap();
assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
// Reload the session and check the user-agent
let session = repo
.oauth2_session()
.lookup(session.id)
.await
.unwrap()
.expect("session not found");
assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
// Mark the session as finished // Mark the session as finished
assert!(session.is_valid()); assert!(session.is_valid());
let session = repo.oauth2_session().finish(&clock, session).await.unwrap(); let session = repo.oauth2_session().finish(&clock, session).await.unwrap();

View File

@@ -59,6 +59,7 @@ struct OAuthSessionLookup {
scope_list: Vec<String>, scope_list: Vec<String>,
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
finished_at: Option<DateTime<Utc>>, finished_at: Option<DateTime<Utc>>,
user_agent: Option<String>,
last_active_at: Option<DateTime<Utc>>, last_active_at: Option<DateTime<Utc>>,
last_active_ip: Option<IpAddr>, last_active_ip: Option<IpAddr>,
} }
@@ -93,6 +94,7 @@ impl TryFrom<OAuthSessionLookup> for Session {
user_id: value.user_id.map(Ulid::from), user_id: value.user_id.map(Ulid::from),
user_session_id: value.user_session_id.map(Ulid::from), user_session_id: value.user_session_id.map(Ulid::from),
scope, scope,
user_agent: value.user_agent,
last_active_at: value.last_active_at, last_active_at: value.last_active_at,
last_active_ip: value.last_active_ip, last_active_ip: value.last_active_ip,
}) })
@@ -123,6 +125,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
, scope_list , scope_list
, created_at , created_at
, finished_at , finished_at
, user_agent
, last_active_at , last_active_at
, last_active_ip as "last_active_ip: IpAddr" , last_active_ip as "last_active_ip: IpAddr"
FROM oauth2_sessions FROM oauth2_sessions
@@ -197,6 +200,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
user_session_id: user_session.map(|s| s.id), user_session_id: user_session.map(|s| s.id),
client_id: client.id, client_id: client.id,
scope, scope,
user_agent: None,
last_active_at: None, last_active_at: None,
last_active_ip: None, last_active_ip: None,
}) })
@@ -281,6 +285,10 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)), Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
OAuthSessionLookupIden::FinishedAt, OAuthSessionLookupIden::FinishedAt,
) )
.expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
OAuthSessionLookupIden::UserAgent,
)
.expr_as( .expr_as(
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)), Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
OAuthSessionLookupIden::LastActiveAt, OAuthSessionLookupIden::LastActiveAt,
@@ -427,4 +435,41 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
Ok(()) Ok(())
} }
#[tracing::instrument(
name = "db.oauth2_session.record_user_agent",
skip_all,
fields(
db.statement,
%session.id,
%session.scope,
client.id = %session.client_id,
session.user_agent = %user_agent,
),
err,
)]
async fn record_user_agent(
&mut self,
mut session: Session,
user_agent: String,
) -> Result<Session, Self::Error> {
let res = sqlx::query!(
r#"
UPDATE oauth2_sessions
SET user_agent = $2
WHERE oauth2_session_id = $1
"#,
Uuid::from(session.id),
user_agent,
)
.traced()
.execute(&mut *self.conn)
.await?;
session.user_agent = Some(user_agent);
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(session)
}
} }

View File

@@ -252,6 +252,22 @@ pub trait CompatSessionRepository: Send + Sync {
&mut self, &mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>, activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>; ) -> Result<(), Self::Error>;
/// Record the user agent of a compat session
///
/// # Parameters
///
/// * `compat_session`: The compat session to record the user agent for
/// * `user_agent`: The user agent to record
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn record_user_agent(
&mut self,
compat_session: CompatSession,
user_agent: String,
) -> Result<CompatSession, Self::Error>;
} }
repository_impl!(CompatSessionRepository: repository_impl!(CompatSessionRepository:
@@ -285,4 +301,10 @@ repository_impl!(CompatSessionRepository:
&mut self, &mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>, activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>; ) -> Result<(), Self::Error>;
async fn record_user_agent(
&mut self,
compat_session: CompatSession,
user_agent: String,
) -> Result<CompatSession, Self::Error>;
); );

View File

@@ -286,6 +286,18 @@ pub trait OAuth2SessionRepository: Send + Sync {
&mut self, &mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>, activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>; ) -> Result<(), Self::Error>;
/// Record the user agent of a [`Session`]
///
/// # Parameters
///
/// * `session`: The [`Session`] to record the user agent for
/// * `user_agent`: The user agent to record
async fn record_user_agent(
&mut self,
session: Session,
user_agent: String,
) -> Result<Session, Self::Error>;
} }
repository_impl!(OAuth2SessionRepository: repository_impl!(OAuth2SessionRepository:
@@ -333,4 +345,10 @@ repository_impl!(OAuth2SessionRepository:
&mut self, &mut self,
activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>, activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
) -> Result<(), Self::Error>; ) -> Result<(), Self::Error>;
async fn record_user_agent(
&mut self,
session: Session,
user_agent: String,
) -> Result<Session, Self::Error>;
); );