diff --git a/crates/graphql/src/query/session.rs b/crates/graphql/src/query/session.rs index 4eee094a..1a4d6ed7 100644 --- a/crates/graphql/src/query/session.rs +++ b/crates/graphql/src/query/session.rs @@ -15,7 +15,9 @@ use async_graphql::{Context, Object, Union, ID}; use mas_data_model::Device; use mas_storage::{ - compat::CompatSessionRepository, oauth2::OAuth2SessionFilter, Pagination, RepositoryAccess, + compat::{CompatSessionFilter, CompatSessionRepository}, + oauth2::OAuth2SessionFilter, + Pagination, RepositoryAccess, }; use oauth2_types::scope::Scope; @@ -62,13 +64,27 @@ impl SessionQuery { }; // First, try to find a compat session - let compat_session = repo.compat_session().find_by_device(&user, &device).await?; - if let Some(compat_session) = compat_session { + let filter = CompatSessionFilter::new() + .for_user(&user) + .active_only() + .for_device(&device); + // We only want most recent session + let pagination = Pagination::last(1); + let compat_sessions = repo.compat_session().list(filter, pagination).await?; + + if compat_sessions.has_previous_page { + // XXX: should we bail out? + tracing::warn!( + "Found more than one active session with device {device} for user {user_id}" + ); + } + + if let Some((compat_session, sso_login)) = compat_sessions.edges.into_iter().next() { repo.cancel().await?; - return Ok(Some(Session::CompatSession(Box::new(CompatSession::new( - compat_session, - ))))); + return Ok(Some(Session::CompatSession(Box::new( + CompatSession::new(compat_session).with_loaded_sso_login(sso_login), + )))); } // Then, try to find an OAuth 2.0 session. Because we don't have any dedicated @@ -78,13 +94,11 @@ impl SessionQuery { .for_user(&user) .active_only() .with_scope(&scope); - // We only want most recent session - let pagination = Pagination::last(1); let sessions = repo.oauth2_session().list(filter, pagination).await?; - // It's technically possible to have multiple active OAuth 2.0 sessions. For - // now, we just log it if it is the case - if sessions.has_next_page { + // It's possible to have multiple active OAuth 2.0 sessions. For now, we just + // log it if it is the case + if sessions.has_previous_page { // XXX: should we bail out? tracing::warn!( "Found more than one active session with device {device} for user {user_id}" diff --git a/crates/storage-pg/.sqlx/query-662ff8972c0cbccb9ba45b1d724c7b6e87656beabe702603cfd4b5a92263b5ab.json b/crates/storage-pg/.sqlx/query-662ff8972c0cbccb9ba45b1d724c7b6e87656beabe702603cfd4b5a92263b5ab.json deleted file mode 100644 index e43a8937..00000000 --- a/crates/storage-pg/.sqlx/query-662ff8972c0cbccb9ba45b1d724c7b6e87656beabe702603cfd4b5a92263b5ab.json +++ /dev/null @@ -1,65 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT compat_session_id\n , device_id\n , user_id\n , created_at\n , finished_at\n , is_synapse_admin\n , last_active_at\n , last_active_ip as \"last_active_ip: IpAddr\"\n FROM compat_sessions\n WHERE user_id = $1\n AND device_id = $2\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "compat_session_id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "device_id", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "user_id", - "type_info": "Uuid" - }, - { - "ordinal": 3, - "name": "created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 4, - "name": "finished_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 5, - "name": "is_synapse_admin", - "type_info": "Bool" - }, - { - "ordinal": 6, - "name": "last_active_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 7, - "name": "last_active_ip: IpAddr", - "type_info": "Inet" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - }, - "nullable": [ - false, - false, - false, - false, - true, - false, - true, - true - ] - }, - "hash": "662ff8972c0cbccb9ba45b1d724c7b6e87656beabe702603cfd4b5a92263b5ab" -} diff --git a/crates/storage-pg/migrations/20240220141353_nonunique_compat_device_id.sql b/crates/storage-pg/migrations/20240220141353_nonunique_compat_device_id.sql new file mode 100644 index 00000000..3c65c673 --- /dev/null +++ b/crates/storage-pg/migrations/20240220141353_nonunique_compat_device_id.sql @@ -0,0 +1,17 @@ +-- Copyright 2024 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Drops the unique constraint on the device_id column in the compat_sessions table +ALTER TABLE compat_sessions + DROP CONSTRAINT compat_sessions_device_id_unique; diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 96d97955..2ee6b638 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -87,7 +87,7 @@ mod tests { let device_str = device.as_str().to_owned(); let session = repo .compat_session() - .add(&mut rng, &clock, &user, device, false) + .add(&mut rng, &clock, &user, device.clone(), false) .await .unwrap(); assert_eq!(session.user_id, user.id); @@ -130,12 +130,18 @@ mod tests { assert!(!session_lookup.is_finished()); // Look up the session by device - let session_lookup = repo + let list = repo .compat_session() - .find_by_device(&user, &session.device) + .list( + CompatSessionFilter::new() + .for_user(&user) + .for_device(&device), + pagination, + ) .await - .unwrap() - .expect("compat session not found"); + .unwrap(); + assert_eq!(list.edges.len(), 1); + let session_lookup = &list.edges[0].0; assert_eq!(session_lookup.id, session.id); assert_eq!(session_lookup.user_id, user.id); assert_eq!(session_lookup.device.as_str(), device_str); diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 1e8ebef8..35e2ed55 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -233,48 +233,6 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { Ok(Some(res.try_into()?)) } - #[tracing::instrument( - name = "db.compat_session.find_by_device", - skip_all, - fields( - db.statement, - %user.id, - %user.username, - compat_session.device.id = device.as_str(), - ), - )] - async fn find_by_device( - &mut self, - user: &User, - device: &Device, - ) -> Result, Self::Error> { - let res = sqlx::query_as!( - CompatSessionLookup, - r#" - SELECT compat_session_id - , device_id - , user_id - , created_at - , finished_at - , is_synapse_admin - , last_active_at - , last_active_ip as "last_active_ip: IpAddr" - FROM compat_sessions - WHERE user_id = $1 - AND device_id = $2 - "#, - Uuid::from(user.id), - device.as_str(), - ) - .traced() - .fetch_optional(&mut *self.conn) - .await?; - - let Some(res) = res else { return Ok(None) }; - - Ok(Some(res.try_into()?)) - } - #[tracing::instrument( name = "db.compat_session.add", skip_all, @@ -460,6 +418,9 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).is_null() } })) + .and_where_option(filter.device().map(|device| { + Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str()) + })) .generate_pagination( (CompatSessions::Table, CompatSessions::CompatSessionId), pagination, diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 7627e539..508cdc4e 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -68,6 +68,7 @@ pub struct CompatSessionFilter<'a> { user: Option<&'a User>, state: Option, auth_type: Option, + device: Option<&'a Device>, } impl<'a> CompatSessionFilter<'a> { @@ -90,6 +91,19 @@ impl<'a> CompatSessionFilter<'a> { self.user } + /// Set the device filter + #[must_use] + pub fn for_device(mut self, device: &'a Device) -> Self { + self.device = Some(device); + self + } + + /// Get the device filter + #[must_use] + pub fn device(&self) -> Option<&Device> { + self.device + } + /// Only return active compatibility sessions #[must_use] pub fn active_only(mut self) -> Self { @@ -151,24 +165,6 @@ pub trait CompatSessionRepository: Send + Sync { /// Returns [`Self::Error`] if the underlying repository fails async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - /// Find a compatibility session by its device ID - /// - /// Returns the compat session if it exists, `None` otherwise - /// - /// # Parameters - /// - /// * `user`: The user to lookup the compat session for - /// * `device`: The device ID of the compat session to lookup - /// - /// # Errors - /// - /// Returns [`Self::Error`] if the underlying repository fails - async fn find_by_device( - &mut self, - user: &User, - device: &Device, - ) -> Result, Self::Error>; - /// Start a new compat session /// /// Returns the newly created compat session @@ -259,12 +255,6 @@ pub trait CompatSessionRepository: Send + Sync { repository_impl!(CompatSessionRepository: async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; - async fn find_by_device( - &mut self, - user: &User, - device: &Device, - ) -> Result, Self::Error>; - async fn add( &mut self, rng: &mut (dyn RngCore + Send),