1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Better compatibility sessions pagination and filtering

This commit is contained in:
Quentin Gliech
2023-07-20 17:17:05 +02:00
parent b60121346f
commit 24b29498a7
13 changed files with 657 additions and 79 deletions

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use anyhow::Context as _; use anyhow::Context as _;
use async_graphql::{Context, Description, Object, ID}; use async_graphql::{Context, Description, Enum, Object, ID};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{compat::CompatSessionRepository, user::UserRepository}; use mas_storage::{compat::CompatSessionRepository, user::UserRepository};
use url::Url; use url::Url;
@ -29,6 +29,26 @@ pub struct CompatSession(
pub Option<mas_data_model::CompatSsoLogin>, pub Option<mas_data_model::CompatSsoLogin>,
); );
/// The state of a compatibility session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum CompatSessionState {
/// The session is active.
Active,
/// The session is no longer active.
Finished,
}
/// The type of a compatibility session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum CompatSessionType {
/// The session was created by a SSO login.
SsoLogin,
/// The session was created by an unknown method.
Unknown,
}
#[Object(use_type_description)] #[Object(use_type_description)]
impl CompatSession { impl CompatSession {
/// ID of the object. /// ID of the object.

View File

@ -18,7 +18,7 @@ use async_graphql::{
}; };
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_storage::{ use mas_storage::{
compat::CompatSsoLoginRepository, compat::{CompatSessionFilter, CompatSsoLoginRepository},
oauth2::OAuth2SessionRepository, oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository, upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository}, user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository},
@ -30,7 +30,12 @@ use super::{
UpstreamOAuth2Link, UpstreamOAuth2Link,
}; };
use crate::{ use crate::{
model::{browser_sessions::BrowserSessionState, matrix::MatrixUser, CompatSession}, model::{
browser_sessions::BrowserSessionState,
compat_sessions::{CompatSessionState, CompatSessionType},
matrix::MatrixUser,
CompatSession,
},
state::ContextExt, state::ContextExt,
}; };
@ -133,17 +138,24 @@ impl User {
} }
/// Get the list of compatibility sessions, chronologically sorted /// Get the list of compatibility sessions, chronologically sorted
#[allow(clippy::too_many_arguments)]
async fn compat_sessions( async fn compat_sessions(
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions with the given state.")]
state_param: Option<CompatSessionState>,
#[graphql(name = "type", desc = "List only sessions with the given type.")]
type_param: Option<CompatSessionType>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")] #[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>, after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")] #[graphql(desc = "Returns the elements in the list that come before the cursor.")]
before: Option<String>, before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, CompatSession>, async_graphql::Error> { ) -> Result<Connection<Cursor, CompatSession, PreloadedTotalCount>, async_graphql::Error> {
let state = ctx.state(); let state = ctx.state();
let mut repo = state.repository().await?; let mut repo = state.repository().await?;
@ -161,14 +173,35 @@ impl User {
.transpose()?; .transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?; let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let page = repo // Build the query filter
.compat_session() let filter = CompatSessionFilter::new().for_user(&self.0);
.list_paginated(&self.0, pagination) let filter = match state_param {
.await?; Some(CompatSessionState::Active) => filter.active_only(),
Some(CompatSessionState::Finished) => filter.finished_only(),
None => filter,
};
let filter = match type_param {
Some(CompatSessionType::SsoLogin) => filter.sso_login_only(),
Some(CompatSessionType::Unknown) => filter.unknown_only(),
None => filter,
};
let page = repo.compat_session().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.compat_session().count(filter).await?)
} else {
None
};
repo.cancel().await?; repo.cancel().await?;
let mut connection = Connection::new(page.has_previous_page, page.has_next_page); let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection connection
.edges .edges
.extend(page.edges.into_iter().map(|(session, sso_login)| { .extend(page.edges.into_iter().map(|(session, sso_login)| {

View File

@ -32,7 +32,8 @@ mod tests {
use mas_storage::{ use mas_storage::{
clock::MockClock, clock::MockClock,
compat::{ compat::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
CompatSessionRepository,
}, },
user::UserRepository, user::UserRepository,
Clock, Pagination, Repository, RepositoryAccess, Clock, Pagination, Repository, RepositoryAccess,
@ -57,6 +58,30 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let all = CompatSessionFilter::new().for_user(&user);
let active = all.active_only();
let finished = all.finished_only();
let pagination = Pagination::first(10);
assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert!(full_list.edges.is_empty());
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert!(finished_list.edges.is_empty());
// Start a compat session for that user // Start a compat session for that user
let device = Device::generate(&mut rng); let device = Device::generate(&mut rng);
let device_str = device.as_str().to_owned(); let device_str = device.as_str().to_owned();
@ -70,6 +95,27 @@ mod tests {
assert!(session.is_valid()); assert!(session.is_valid());
assert!(!session.is_finished()); assert!(!session.is_finished());
assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(full_list.edges[0].0.id, session.id);
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert_eq!(active_list.edges.len(), 1);
assert_eq!(active_list.edges[0].0.id, session.id);
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert!(finished_list.edges.is_empty());
// Lookup the session and check it didn't change // Lookup the session and check it didn't change
let session_lookup = repo let session_lookup = repo
.compat_session() .compat_session()
@ -88,6 +134,27 @@ mod tests {
assert!(!session.is_valid()); assert!(!session.is_valid());
assert!(session.is_finished()); assert!(session.is_finished());
assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
let full_list = repo.compat_session().list(all, pagination).await.unwrap();
assert_eq!(full_list.edges.len(), 1);
assert_eq!(full_list.edges[0].0.id, session.id);
let active_list = repo
.compat_session()
.list(active, pagination)
.await
.unwrap();
assert!(active_list.edges.is_empty());
let finished_list = repo
.compat_session()
.list(finished, pagination)
.await
.unwrap();
assert_eq!(finished_list.edges.len(), 1);
assert_eq!(finished_list.edges[0].0.id, session.id);
// Reload the session and check again // Reload the session and check again
let session_lookup = repo let session_lookup = repo
.compat_session() .compat_session()
@ -97,6 +164,93 @@ mod tests {
.expect("compat session not found"); .expect("compat session not found");
assert!(!session_lookup.is_valid()); assert!(!session_lookup.is_valid());
assert!(session_lookup.is_finished()); assert!(session_lookup.is_finished());
// Now add another session, with an SSO login this time
let unknown_session = session;
// Start a new SSO login
let login = repo
.compat_sso_login()
.add(
&mut rng,
&clock,
"login-token".to_owned(),
"https://example.com/callback".parse().unwrap(),
)
.await
.unwrap();
assert!(login.is_pending());
// Start a compat session for that user
let device = Device::generate(&mut rng);
let sso_login_session = repo
.compat_session()
.add(&mut rng, &clock, &user, device, false)
.await
.unwrap();
// Associate the login with the session
let login = repo
.compat_sso_login()
.fulfill(&clock, login, &sso_login_session)
.await
.unwrap();
assert!(login.is_fulfilled());
// Now query the session list with both the unknown and SSO login session type
// filter
let all = CompatSessionFilter::new().for_user(&user);
let sso_login = all.sso_login_only();
let unknown = all.unknown_only();
assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
let list = repo
.compat_session()
.list(sso_login, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].0.id, sso_login_session.id);
let list = repo
.compat_session()
.list(unknown, pagination)
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].0.id, unknown_session.id);
// Check that combining the two filters works
// At this point, there is one active SSO login session and one finished unknown
// session
assert_eq!(
repo.compat_session()
.count(all.sso_login_only().active_only())
.await
.unwrap(),
1
);
assert_eq!(
repo.compat_session()
.count(all.sso_login_only().finished_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.compat_session()
.count(all.unknown_only().active_only())
.await
.unwrap(),
0
);
assert_eq!(
repo.compat_session()
.count(all.unknown_only().finished_only())
.await
.unwrap(),
1
);
} }
#[sqlx::test(migrator = "crate::MIGRATOR")] #[sqlx::test(migrator = "crate::MIGRATOR")]

View File

@ -17,16 +17,23 @@ use chrono::{DateTime, Utc};
use mas_data_model::{ use mas_data_model::{
CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, User,
}; };
use mas_storage::{compat::CompatSessionRepository, Clock, Page, Pagination}; use mas_storage::{
compat::{CompatSessionFilter, CompatSessionRepository},
Clock, Page, Pagination,
};
use rand::RngCore; use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder}; use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, iden::{CompatSessions, CompatSsoLogins},
LookupResultExt, pagination::QueryBuilderExt,
sea_query_sqlx::map_values,
tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError,
}; };
/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection /// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
@ -82,6 +89,7 @@ impl TryFrom<CompatSessionLookup> for CompatSession {
} }
#[derive(sqlx::FromRow)] #[derive(sqlx::FromRow)]
#[enum_def]
struct CompatSessionAndSsoLoginLookup { struct CompatSessionAndSsoLoginLookup {
compat_session_id: Uuid, compat_session_id: Uuid,
device_id: String, device_id: String,
@ -303,51 +311,162 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
} }
#[tracing::instrument( #[tracing::instrument(
name = "db.compat_session.list_paginated", name = "db.compat_session.list",
skip_all, skip_all,
fields( fields(
db.statement, db.statement,
%user.id,
), ),
err, err,
)] )]
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: CompatSessionFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> { ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
let mut query = QueryBuilder::new( let (sql, values) = sea_query::Query::select()
r#" .expr_as(
SELECT cs.compat_session_id Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
, cs.device_id CompatSessionAndSsoLoginLookupIden::CompatSessionId,
, cs.user_id )
, cs.created_at .expr_as(
, cs.finished_at Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
, cs.is_synapse_admin CompatSessionAndSsoLoginLookupIden::DeviceId,
, cl.compat_sso_login_id )
, cl.login_token as compat_sso_login_token .expr_as(
, cl.redirect_uri as compat_sso_login_redirect_uri Expr::col((CompatSessions::Table, CompatSessions::UserId)),
, cl.created_at as compat_sso_login_created_at CompatSessionAndSsoLoginLookupIden::UserId,
, cl.fulfilled_at as compat_sso_login_fulfilled_at )
, cl.exchanged_at as compat_sso_login_exchanged_at .expr_as(
Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
CompatSessionAndSsoLoginLookupIden::CreatedAt,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
CompatSessionAndSsoLoginLookupIden::FinishedAt,
)
.expr_as(
Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
)
.expr_as(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
)
.from(CompatSessions::Table)
.left_join(
CompatSsoLogins::Table,
Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
.equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.auth_type().map(|auth_type| {
if auth_type.is_sso_login() {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
.is_not_null()
} else {
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).is_null()
}
}))
.generate_pagination(
(CompatSessions::Table, CompatSessions::CompatSessionId).into_column_ref(),
pagination,
)
.build(PostgresQueryBuilder);
FROM compat_sessions cs let arguments = map_values(values);
LEFT JOIN compat_sso_logins cl USING (compat_session_id)
"#,
);
query let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
.push(" WHERE cs.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("cs.compat_session_id", pagination);
let edges: Vec<CompatSessionAndSsoLoginLookup> = query
.build_query_as()
.traced() .traced()
.fetch_all(&mut *self.conn) .fetch_all(&mut *self.conn)
.await?; .await?;
let page = pagination.process(edges).try_map(TryFrom::try_from)?; let page = pagination.process(edges).try_map(TryFrom::try_from)?;
Ok(page) Ok(page)
} }
#[tracing::instrument(
name = "db.compat_session.count",
skip_all,
fields(
db.statement,
),
err,
)]
async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
let (sql, values) = sea_query::Query::select()
.expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
.from(CompatSessions::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
} else {
Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.auth_type().map(|auth_type| {
// Check if it is an SSO login by checking if there is a SSO login for the
// session.
let exists = Expr::exists(
Query::select()
.expr(Expr::val(1))
.from(CompatSsoLogins::Table)
.and_where(
Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))
.equals((CompatSessions::Table, CompatSessions::CompatSessionId)),
)
.take(),
);
if auth_type.is_sso_login() {
exists
} else {
exists.not()
}
}))
.build(PostgresQueryBuilder);
let arguments = map_values(values);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
} }

View File

@ -0,0 +1,55 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Table and column identifiers used by [`sea_query`]
#[derive(sea_query::Iden)]
pub enum UserSessions {
Table,
UserSessionId,
CreatedAt,
FinishedAt,
UserId,
}
#[derive(sea_query::Iden)]
pub enum Users {
Table,
UserId,
Username,
PrimaryUserEmailId,
}
#[derive(sea_query::Iden)]
pub enum CompatSessions {
Table,
CompatSessionId,
UserId,
DeviceId,
CreatedAt,
FinishedAt,
IsSynapseAdmin,
}
#[derive(sea_query::Iden)]
pub enum CompatSsoLogins {
Table,
CompatSsoLoginId,
RedirectUri,
LoginToken,
CompatSessionId,
CreatedAt,
FulfilledAt,
ExchangedAt,
}

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! A [`sea_query::Values`] to [`sqlx::Arguments`] mapper
use sea_query::Value; use sea_query::Value;
use sqlx::Arguments; use sqlx::Arguments;

View File

@ -23,8 +23,11 @@ use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
pagination::QueryBuilderExt, sea_query_sqlx::map_values, tracing::ExecuteExt, DatabaseError, iden::{UserSessions, Users},
DatabaseInconsistencyError, LookupResultExt, pagination::QueryBuilderExt,
sea_query_sqlx::map_values,
tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError,
}; };
/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL /// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
@ -52,23 +55,6 @@ struct SessionLookup {
user_primary_user_email_id: Option<Uuid>, user_primary_user_email_id: Option<Uuid>,
} }
#[derive(sea_query::Iden)]
enum UserSessions {
Table,
UserSessionId,
CreatedAt,
FinishedAt,
UserId,
}
#[derive(sea_query::Iden)]
enum Users {
Table,
UserId,
Username,
PrimaryUserEmailId,
}
impl TryFrom<SessionLookup> for BrowserSession { impl TryFrom<SessionLookup> for BrowserSession {
type Error = DatabaseInconsistencyError; type Error = DatabaseInconsistencyError;

View File

@ -363,11 +363,13 @@ async fn test_user_session(pool: PgPool) {
.await .await
.unwrap(); .unwrap();
let filter = BrowserSessionFilter::default() let all = BrowserSessionFilter::default().for_user(&user);
.for_user(&user) let active = all.active_only();
.active_only(); let finished = all.finished_only();
assert_eq!(repo.browser_session().count(filter).await.unwrap(), 0); assert_eq!(repo.browser_session().count(all).await.unwrap(), 0);
assert_eq!(repo.browser_session().count(active).await.unwrap(), 0);
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 0);
let session = repo let session = repo
.browser_session() .browser_session()
@ -377,12 +379,14 @@ async fn test_user_session(pool: PgPool) {
assert_eq!(session.user.id, user.id); assert_eq!(session.user.id, user.id);
assert!(session.finished_at.is_none()); assert!(session.finished_at.is_none());
assert_eq!(repo.browser_session().count(filter).await.unwrap(), 1); assert_eq!(repo.browser_session().count(all).await.unwrap(), 1);
assert_eq!(repo.browser_session().count(active).await.unwrap(), 1);
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 0);
// The session should be in the list of active sessions // The session should be in the list of active sessions
let session_list = repo let session_list = repo
.browser_session() .browser_session()
.list(filter, Pagination::first(10)) .list(active, Pagination::first(10))
.await .await
.unwrap(); .unwrap();
assert!(!session_list.has_next_page); assert!(!session_list.has_next_page);
@ -406,13 +410,15 @@ async fn test_user_session(pool: PgPool) {
.await .await
.unwrap(); .unwrap();
// The active session counter is back to 0 // The active session counter should be 0, and the finished one should be 1
assert_eq!(repo.browser_session().count(filter).await.unwrap(), 0); assert_eq!(repo.browser_session().count(all).await.unwrap(), 1);
assert_eq!(repo.browser_session().count(active).await.unwrap(), 0);
assert_eq!(repo.browser_session().count(finished).await.unwrap(), 1);
// The session should not be in the list of active sessions anymore // The session should not be in the list of active sessions anymore
let session_list = repo let session_list = repo
.browser_session() .browser_session()
.list(filter, Pagination::first(10)) .list(active, Pagination::first(10))
.await .await
.unwrap(); .unwrap();
assert!(!session_list.has_next_page); assert!(!session_list.has_next_page);

View File

@ -20,6 +20,8 @@ mod session;
mod sso_login; mod sso_login;
pub use self::{ pub use self::{
access_token::CompatAccessTokenRepository, refresh_token::CompatRefreshTokenRepository, access_token::CompatAccessTokenRepository,
session::CompatSessionRepository, sso_login::CompatSsoLoginRepository, refresh_token::CompatRefreshTokenRepository,
session::{CompatSessionFilter, CompatSessionRepository},
sso_login::CompatSsoLoginRepository,
}; };

View File

@ -19,6 +19,115 @@ use ulid::Ulid;
use crate::{repository_impl, Clock, Page, Pagination}; use crate::{repository_impl, Clock, Page, Pagination};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CompatSessionState {
Active,
Finished,
}
impl CompatSessionState {
/// Returns [`true`] if we're looking for active sessions
#[must_use]
pub fn is_active(self) -> bool {
matches!(self, Self::Active)
}
/// Returns [`true`] if we're looking for finished sessions
#[must_use]
pub fn is_finished(self) -> bool {
matches!(self, Self::Finished)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CompatSessionType {
SsoLogin,
Unknown,
}
impl CompatSessionType {
/// Returns [`true`] if we're looking for SSO logins
#[must_use]
pub fn is_sso_login(self) -> bool {
matches!(self, Self::SsoLogin)
}
/// Returns [`true`] if we're looking for unknown sessions
#[must_use]
pub fn is_unknown(self) -> bool {
matches!(self, Self::Unknown)
}
}
/// Filter parameters for listing browser sessions
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct CompatSessionFilter<'a> {
user: Option<&'a User>,
state: Option<CompatSessionState>,
auth_type: Option<CompatSessionType>,
}
impl<'a> CompatSessionFilter<'a> {
/// Create a new [`CompatSessionFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Set the user who owns the compatibility sessions
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// Only return active compatibility sessions
#[must_use]
pub fn active_only(mut self) -> Self {
self.state = Some(CompatSessionState::Active);
self
}
/// Only return finished compatibility sessions
#[must_use]
pub fn finished_only(mut self) -> Self {
self.state = Some(CompatSessionState::Finished);
self
}
/// Get the state filter
#[must_use]
pub fn state(&self) -> Option<CompatSessionState> {
self.state
}
/// Only return SSO login compatibility sessions
#[must_use]
pub fn sso_login_only(mut self) -> Self {
self.auth_type = Some(CompatSessionType::SsoLogin);
self
}
/// Only return unknown compatibility sessions
#[must_use]
pub fn unknown_only(mut self) -> Self {
self.auth_type = Some(CompatSessionType::Unknown);
self
}
/// Get the auth type filter
#[must_use]
pub fn auth_type(&self) -> Option<CompatSessionType> {
self.auth_type
}
}
/// A [`CompatSessionRepository`] helps interacting with /// A [`CompatSessionRepository`] helps interacting with
/// [`CompatSessionRepository`] saved in the storage backend /// [`CompatSessionRepository`] saved in the storage backend
#[async_trait] #[async_trait]
@ -81,23 +190,34 @@ pub trait CompatSessionRepository: Send + Sync {
compat_session: CompatSession, compat_session: CompatSession,
) -> Result<CompatSession, Self::Error>; ) -> Result<CompatSession, Self::Error>;
/// Get a paginated list of compat sessions for a user /// List [`CompatSession`] with the given filter and pagination
/// ///
/// Returns a page of compat sessions, with the associated SSO logins if any /// Returns a page of compat sessions, with the associated SSO logins if any
/// ///
/// # Parameters /// # Parameters
/// ///
/// * `user`: The user to get the compat sessions for /// * `filter`: The filter to apply
/// * `pagination`: The pagination parameters /// * `pagination`: The pagination parameters
/// ///
/// # Errors /// # Errors
/// ///
/// Returns [`Self::Error`] if the underlying repository fails /// Returns [`Self::Error`] if the underlying repository fails
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: CompatSessionFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error>; ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error>;
/// Count the number of [`CompatSession`] with the given filter
///
/// # Parameters
///
/// * `filter`: The filter to apply
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error>;
} }
repository_impl!(CompatSessionRepository: repository_impl!(CompatSessionRepository:
@ -118,9 +238,11 @@ repository_impl!(CompatSessionRepository:
compat_session: CompatSession, compat_session: CompatSession,
) -> Result<CompatSession, Self::Error>; ) -> Result<CompatSession, Self::Error>;
async fn list_paginated( async fn list(
&mut self, &mut self,
user: &User, filter: CompatSessionFilter<'_>,
pagination: Pagination, pagination: Pagination,
) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error>; ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error>;
async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error>;
); );

View File

@ -180,6 +180,10 @@ type CompatSessionConnection {
A list of nodes. A list of nodes.
""" """
nodes: [CompatSession!]! nodes: [CompatSession!]!
"""
Identifies the total count of items in the connection.
"""
totalCount: Int!
} }
""" """
@ -196,6 +200,34 @@ type CompatSessionEdge {
cursor: String! cursor: String!
} }
"""
The state of a compatibility session.
"""
enum CompatSessionState {
"""
The session is active.
"""
ACTIVE
"""
The session is no longer active.
"""
FINISHED
}
"""
The type of a compatibility session.
"""
enum CompatSessionType {
"""
The session was created by a SSO login.
"""
SSO_LOGIN
"""
The session was created by an unknown method.
"""
UNKNOWN
}
""" """
A compat SSO login represents a login done through the legacy Matrix login A compat SSO login represents a login done through the legacy Matrix login
API, via the `m.login.sso` login method. API, via the `m.login.sso` login method.
@ -880,6 +912,8 @@ type User implements Node {
Get the list of compatibility sessions, chronologically sorted Get the list of compatibility sessions, chronologically sorted
""" """
compatSessions( compatSessions(
state: CompatSessionState
type: CompatSessionType
after: String after: String
before: String before: String
first: Int first: Int

View File

@ -156,6 +156,8 @@ export type CompatSessionConnection = {
nodes: Array<CompatSession>; nodes: Array<CompatSession>;
/** Information to aid in pagination. */ /** Information to aid in pagination. */
pageInfo: PageInfo; pageInfo: PageInfo;
/** Identifies the total count of items in the connection. */
totalCount: Scalars["Int"]["output"];
}; };
/** An edge in a connection. */ /** An edge in a connection. */
@ -167,6 +169,22 @@ export type CompatSessionEdge = {
node: CompatSession; node: CompatSession;
}; };
/** The state of a compatibility session. */
export enum CompatSessionState {
/** The session is active. */
Active = "ACTIVE",
/** The session is no longer active. */
Finished = "FINISHED",
}
/** The type of a compatibility session. */
export enum CompatSessionType {
/** The session was created by a SSO login. */
SsoLogin = "SSO_LOGIN",
/** The session was created by an unknown method. */
Unknown = "UNKNOWN",
}
/** /**
* A compat SSO login represents a login done through the legacy Matrix login * A compat SSO login represents a login done through the legacy Matrix login
* API, via the `m.login.sso` login method. * API, via the `m.login.sso` login method.
@ -689,6 +707,8 @@ export type UserCompatSessionsArgs = {
before?: InputMaybe<Scalars["String"]["input"]>; before?: InputMaybe<Scalars["String"]["input"]>;
first?: InputMaybe<Scalars["Int"]["input"]>; first?: InputMaybe<Scalars["Int"]["input"]>;
last?: InputMaybe<Scalars["Int"]["input"]>; last?: InputMaybe<Scalars["Int"]["input"]>;
state?: InputMaybe<CompatSessionState>;
type?: InputMaybe<CompatSessionType>;
}; };
/** A user is an individual's account. */ /** A user is an individual's account. */

View File

@ -391,6 +391,17 @@ export default {
}, },
args: [], args: [],
}, },
{
name: "totalCount",
type: {
kind: "NON_NULL",
ofType: {
kind: "SCALAR",
name: "Any",
},
},
args: [],
},
], ],
interfaces: [], interfaces: [],
}, },
@ -2029,6 +2040,20 @@ export default {
name: "Any", name: "Any",
}, },
}, },
{
name: "state",
type: {
kind: "SCALAR",
name: "Any",
},
},
{
name: "type",
type: {
kind: "SCALAR",
name: "Any",
},
},
], ],
}, },
{ {