diff --git a/crates/storage-pg/src/app_session.rs b/crates/storage-pg/src/app_session.rs new file mode 100644 index 00000000..ae10495e --- /dev/null +++ b/crates/storage-pg/src/app_session.rs @@ -0,0 +1,629 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A module containing PostgreSQL implementation of repositories for sessions + +use async_trait::async_trait; +use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState}; +use mas_storage::{ + app_session::{AppSession, AppSessionFilter, AppSessionRepository}, + Page, Pagination, +}; +use oauth2_types::scope::{Scope, ScopeToken}; +use sea_query::{ + Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType, +}; +use sea_query_binder::SqlxBinder; +use sqlx::PgConnection; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{ + errors::DatabaseInconsistencyError, + iden::{CompatSessions, OAuth2Sessions}, + pagination::QueryBuilderExt, + DatabaseError, ExecuteExt, +}; + +/// An implementation of [`AppSessionRepository`] for a PostgreSQL connection +pub struct PgAppSessionRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgAppSessionRepository<'c> { + /// Create a new [`PgAppSessionRepository`] from an active PostgreSQL + /// connection + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + +mod priv_ { + // The enum_def macro generates a public enum, which we don't want, because it + // triggers the missing docs warning + + use std::net::IpAddr; + + use chrono::{DateTime, Utc}; + use sea_query::enum_def; + use uuid::Uuid; + + #[derive(sqlx::FromRow)] + #[enum_def] + pub(super) struct AppSessionLookup { + pub(super) cursor: Uuid, + pub(super) compat_session_id: Option, + pub(super) oauth2_session_id: Option, + pub(super) oauth2_client_id: Option, + pub(super) user_session_id: Option, + pub(super) user_id: Option, + pub(super) scope_list: Option>, + pub(super) device_id: Option, + pub(super) created_at: DateTime, + pub(super) finished_at: Option>, + pub(super) is_synapse_admin: Option, + pub(super) last_active_at: Option>, + pub(super) last_active_ip: Option, + } +} + +use priv_::{AppSessionLookup, AppSessionLookupIden}; + +impl TryFrom for AppSession { + type Error = DatabaseError; + + fn try_from(value: AppSessionLookup) -> Result { + // This is annoying to do, but we have to match on all the fields to determine + // whether it's a compat session or an oauth2 session + let AppSessionLookup { + cursor, + compat_session_id, + oauth2_session_id, + oauth2_client_id, + user_session_id, + user_id, + scope_list, + device_id, + created_at, + finished_at, + is_synapse_admin, + last_active_at, + last_active_ip, + } = value; + + match ( + compat_session_id, + oauth2_session_id, + oauth2_client_id, + user_session_id, + user_id, + scope_list, + device_id, + is_synapse_admin, + ) { + ( + Some(compat_session_id), + None, + None, + None, + Some(user_id), + None, + Some(device_id), + Some(is_synapse_admin), + ) => { + let id = compat_session_id.into(); + let device = Device::try_from(device_id).map_err(|e| { + DatabaseInconsistencyError::on("compat_sessions") + .column("device_id") + .row(id) + .source(e) + })?; + + let state = match finished_at { + None => CompatSessionState::Valid, + Some(finished_at) => CompatSessionState::Finished { finished_at }, + }; + + let session = CompatSession { + id, + state, + user_id: user_id.into(), + device, + created_at, + is_synapse_admin, + last_active_at, + last_active_ip, + }; + + Ok(AppSession::Compat(Box::new(session))) + } + + ( + None, + Some(oauth2_session_id), + Some(oauth2_client_id), + user_session_id, + user_id, + Some(scope_list), + None, + None, + ) => { + let id = oauth2_session_id.into(); + let scope: Result = + scope_list.iter().map(|s| s.parse::()).collect(); + let scope = scope.map_err(|e| { + DatabaseInconsistencyError::on("oauth2_sessions") + .column("scope") + .row(id) + .source(e) + })?; + + let state = match value.finished_at { + None => SessionState::Valid, + Some(finished_at) => SessionState::Finished { finished_at }, + }; + + let session = Session { + id, + state, + created_at, + client_id: oauth2_client_id.into(), + user_id: user_id.map(Ulid::from), + user_session_id: user_session_id.map(Ulid::from), + scope, + last_active_at, + last_active_ip, + }; + + Ok(AppSession::OAuth2(Box::new(session))) + } + + _ => Err(DatabaseInconsistencyError::on("sessions") + .row(cursor.into()) + .into()), + } + } +} + +#[async_trait] +impl<'c> AppSessionRepository for PgAppSessionRepository<'c> { + type Error = DatabaseError; + + #[allow(clippy::too_many_lines)] + async fn list( + &mut self, + filter: AppSessionFilter<'_>, + pagination: Pagination, + ) -> Result, Self::Error> { + let mut oauth2_session_select = Query::select() + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)), + AppSessionLookupIden::Cursor, + ) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::CompatSessionId) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)), + AppSessionLookupIden::Oauth2SessionId, + ) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)), + AppSessionLookupIden::Oauth2ClientId, + ) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)), + AppSessionLookupIden::UserSessionId, + ) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)), + AppSessionLookupIden::UserId, + ) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)), + AppSessionLookupIden::ScopeList, + ) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::DeviceId) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)), + AppSessionLookupIden::CreatedAt, + ) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)), + AppSessionLookupIden::FinishedAt, + ) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)), + AppSessionLookupIden::LastActiveAt, + ) + .expr_as( + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)), + AppSessionLookupIden::LastActiveIp, + ) + .from(OAuth2Sessions::Table) + .and_where_option(filter.user().map(|user| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() + } else { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() + } + })) + .clone(); + + let compat_session_select = Query::select() + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)), + AppSessionLookupIden::Cursor, + ) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)), + AppSessionLookupIden::CompatSessionId, + ) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2SessionId) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2ClientId) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::UserSessionId) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::UserId)), + AppSessionLookupIden::UserId, + ) + .expr_as(Expr::cust("NULL"), AppSessionLookupIden::ScopeList) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::DeviceId)), + AppSessionLookupIden::DeviceId, + ) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)), + AppSessionLookupIden::CreatedAt, + ) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)), + AppSessionLookupIden::FinishedAt, + ) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)), + AppSessionLookupIden::IsSynapseAdmin, + ) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)), + AppSessionLookupIden::LastActiveAt, + ) + .expr_as( + Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)), + AppSessionLookupIden::LastActiveIp, + ) + .from(CompatSessions::Table) + .and_where_option(filter.user().map(|user| { + Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() + } else { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null() + } + })) + .clone(); + + let common_table_expression = CommonTableExpression::new() + .query( + oauth2_session_select + .union(UnionType::All, compat_session_select) + .clone(), + ) + .table_name(Alias::new("sessions")) + .clone(); + + let with_clause = Query::with().cte(common_table_expression).clone(); + + let select = Query::select() + .column(ColumnRef::Asterisk) + .from(Alias::new("sessions")) + .generate_pagination(AppSessionLookupIden::Cursor, pagination) + .clone(); + + let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder); + + let edges: Vec = sqlx::query_as_with(&sql, arguments) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let page = pagination.process(edges).try_map(TryFrom::try_from)?; + + Ok(page) + } + + async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result { + let mut oauth2_session_select = Query::select() + .expr(Expr::cust("1")) + .from(OAuth2Sessions::Table) + .and_where_option(filter.user().map(|user| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() + } else { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() + } + })) + .clone(); + + let compat_session_select = Query::select() + .expr(Expr::cust("1")) + .from(CompatSessions::Table) + .and_where_option(filter.user().map(|user| { + Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() + } else { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null() + } + })) + .clone(); + + let common_table_expression = CommonTableExpression::new() + .query( + oauth2_session_select + .union(UnionType::All, compat_session_select) + .clone(), + ) + .table_name(Alias::new("sessions")) + .clone(); + + let with_clause = Query::with().cte(common_table_expression).clone(); + + let select = Query::select() + .expr(Expr::cust("COUNT(*)")) + .from(Alias::new("sessions")) + .clone(); + + let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder); + + let count: i64 = sqlx::query_scalar_with(&sql, arguments) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + count + .try_into() + .map_err(DatabaseError::to_invalid_operation) + } +} + +#[cfg(test)] +mod tests { + use chrono::Duration; + use mas_data_model::Device; + use mas_storage::{ + app_session::{AppSession, AppSessionFilter}, + clock::MockClock, + oauth2::OAuth2SessionRepository, + Pagination, RepositoryAccess, + }; + use oauth2_types::{ + requests::GrantType, + scope::{Scope, OPENID}, + }; + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use sqlx::PgPool; + + use crate::PgRepository; + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_app_repo(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + + // Create a user + let user = repo + .user() + .add(&mut rng, &clock, "john".to_owned()) + .await + .unwrap(); + + let all = AppSessionFilter::new().for_user(&user); + let active = all.active_only(); + let finished = all.finished_only(); + let pagination = Pagination::first(10); + + assert_eq!(repo.app_session().count(all).await.unwrap(), 0); + assert_eq!(repo.app_session().count(active).await.unwrap(), 0); + assert_eq!(repo.app_session().count(finished).await.unwrap(), 0); + + let full_list = repo.app_session().list(all, pagination).await.unwrap(); + assert!(full_list.edges.is_empty()); + let active_list = repo.app_session().list(active, pagination).await.unwrap(); + assert!(active_list.edges.is_empty()); + let finished_list = repo.app_session().list(finished, pagination).await.unwrap(); + assert!(finished_list.edges.is_empty()); + + // Start a compat session for that user + let device = Device::generate(&mut rng); + let compat_session = repo + .compat_session() + .add(&mut rng, &clock, &user, device, false) + .await + .unwrap(); + + assert_eq!(repo.app_session().count(all).await.unwrap(), 1); + assert_eq!(repo.app_session().count(active).await.unwrap(), 1); + assert_eq!(repo.app_session().count(finished).await.unwrap(), 0); + + let full_list = repo.app_session().list(all, pagination).await.unwrap(); + assert_eq!(full_list.edges.len(), 1); + assert_eq!( + full_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + let active_list = repo.app_session().list(active, pagination).await.unwrap(); + assert_eq!(active_list.edges.len(), 1); + assert_eq!( + active_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + let finished_list = repo.app_session().list(finished, pagination).await.unwrap(); + assert!(finished_list.edges.is_empty()); + + // Finish the session + let compat_session = repo + .compat_session() + .finish(&clock, compat_session) + .await + .unwrap(); + + assert_eq!(repo.app_session().count(all).await.unwrap(), 1); + assert_eq!(repo.app_session().count(active).await.unwrap(), 0); + assert_eq!(repo.app_session().count(finished).await.unwrap(), 1); + + let full_list = repo.app_session().list(all, pagination).await.unwrap(); + assert_eq!(full_list.edges.len(), 1); + assert_eq!( + full_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + let active_list = repo.app_session().list(active, pagination).await.unwrap(); + assert!(active_list.edges.is_empty()); + let finished_list = repo.app_session().list(finished, pagination).await.unwrap(); + assert_eq!(finished_list.edges.len(), 1); + assert_eq!( + finished_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + + // Start an OAuth2 session + let client = repo + .oauth2_client() + .add( + &mut rng, + &clock, + vec!["https://example.com/redirect".parse().unwrap()], + None, + None, + vec![GrantType::AuthorizationCode], + Vec::new(), // TODO: contacts are not yet saved + // vec!["contact@example.com".to_owned()], + Some("First client".to_owned()), + Some("https://example.com/logo.png".parse().unwrap()), + Some("https://example.com/".parse().unwrap()), + Some("https://example.com/policy".parse().unwrap()), + Some("https://example.com/tos".parse().unwrap()), + Some("https://example.com/jwks.json".parse().unwrap()), + None, + None, + None, + None, + None, + Some("https://example.com/login".parse().unwrap()), + ) + .await + .unwrap(); + + let scope = Scope::from_iter([OPENID]); + + // We're moving the clock forward by 1 minute between each session to ensure + // we're getting consistent ordering in lists. + clock.advance(Duration::minutes(1)); + + let oauth_session = repo + .oauth2_session() + .add(&mut rng, &clock, &client, Some(&user), None, scope) + .await + .unwrap(); + + assert_eq!(repo.app_session().count(all).await.unwrap(), 2); + assert_eq!(repo.app_session().count(active).await.unwrap(), 1); + assert_eq!(repo.app_session().count(finished).await.unwrap(), 1); + + let full_list = repo.app_session().list(all, pagination).await.unwrap(); + assert_eq!(full_list.edges.len(), 2); + assert_eq!( + full_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + assert_eq!( + full_list.edges[1], + AppSession::OAuth2(Box::new(oauth_session.clone())) + ); + + let active_list = repo.app_session().list(active, pagination).await.unwrap(); + assert_eq!(active_list.edges.len(), 1); + assert_eq!( + active_list.edges[0], + AppSession::OAuth2(Box::new(oauth_session.clone())) + ); + + let finished_list = repo.app_session().list(finished, pagination).await.unwrap(); + assert_eq!(finished_list.edges.len(), 1); + assert_eq!( + finished_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + + // Finish the session + let oauth_session = repo + .oauth2_session() + .finish(&clock, oauth_session) + .await + .unwrap(); + + assert_eq!(repo.app_session().count(all).await.unwrap(), 2); + assert_eq!(repo.app_session().count(active).await.unwrap(), 0); + assert_eq!(repo.app_session().count(finished).await.unwrap(), 2); + + let full_list = repo.app_session().list(all, pagination).await.unwrap(); + assert_eq!(full_list.edges.len(), 2); + assert_eq!( + full_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + assert_eq!( + full_list.edges[1], + AppSession::OAuth2(Box::new(oauth_session.clone())) + ); + + let active_list = repo.app_session().list(active, pagination).await.unwrap(); + assert!(active_list.edges.is_empty()); + + let finished_list = repo.app_session().list(finished, pagination).await.unwrap(); + assert_eq!(finished_list.edges.len(), 2); + assert_eq!( + finished_list.edges[0], + AppSession::Compat(Box::new(compat_session.clone())) + ); + assert_eq!( + full_list.edges[1], + AppSession::OAuth2(Box::new(oauth_session.clone())) + ); + + // Create a second user + let user2 = repo + .user() + .add(&mut rng, &clock, "alice".to_owned()) + .await + .unwrap(); + + // If we list/count for this user, we should get nothing + let filter = AppSessionFilter::new().for_user(&user2); + assert_eq!(repo.app_session().count(filter).await.unwrap(), 0); + let list = repo.app_session().list(filter, pagination).await.unwrap(); + assert!(list.edges.is_empty()); + } +} diff --git a/crates/storage-pg/src/lib.rs b/crates/storage-pg/src/lib.rs index 151d8d2d..825fb2da 100644 --- a/crates/storage-pg/src/lib.rs +++ b/crates/storage-pg/src/lib.rs @@ -177,6 +177,7 @@ use sqlx::migrate::Migrator; +pub mod app_session; pub mod compat; pub mod job; pub mod oauth2; diff --git a/crates/storage-pg/src/repository.rs b/crates/storage-pg/src/repository.rs index c53723e7..5b4926d5 100644 --- a/crates/storage-pg/src/repository.rs +++ b/crates/storage-pg/src/repository.rs @@ -16,6 +16,7 @@ use std::ops::{Deref, DerefMut}; use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use mas_storage::{ + app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, @@ -36,6 +37,7 @@ use sqlx::{PgConnection, PgPool, Postgres, Transaction}; use tracing::Instrument; use crate::{ + app_session::PgAppSessionRepository, compat::{ PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository, PgCompatSsoLoginRepository, @@ -182,6 +184,10 @@ where Box::new(PgBrowserSessionRepository::new(self.conn.as_mut())) } + fn app_session<'c>(&'c mut self) -> Box + 'c> { + Box::new(PgAppSessionRepository::new(self.conn.as_mut())) + } + fn oauth2_client<'c>( &'c mut self, ) -> Box + 'c> { diff --git a/crates/storage/src/app_session.rs b/crates/storage/src/app_session.rs new file mode 100644 index 00000000..8c2987ac --- /dev/null +++ b/crates/storage/src/app_session.rs @@ -0,0 +1,148 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Repositories to interact with all kinds of sessions + +use async_trait::async_trait; +use mas_data_model::{CompatSession, Session, User}; + +use crate::{repository_impl, Page, Pagination}; + +/// The state of a session +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AppSessionState { + /// The session is active + Active, + /// The session is finished + Finished, +} + +impl AppSessionState { + /// Returns [`true`] if we're looking for active sessions + #[must_use] + pub fn is_active(self) -> bool { + matches!(self, Self::Active) + } + + /// Returns [`true`] if we're looking for finished sessions + #[must_use] + pub fn is_finished(self) -> bool { + matches!(self, Self::Finished) + } +} + +/// An [`AppSession`] is either a [`CompatSession`] or an OAuth 2.0 [`Session`] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AppSession { + /// A compatibility layer session + Compat(Box), + + /// An OAuth 2.0 session + OAuth2(Box), +} + +/// Filtering parameters for application sessions +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub struct AppSessionFilter<'a> { + user: Option<&'a User>, + state: Option, +} + +impl<'a> AppSessionFilter<'a> { + /// Create a new [`AppSessionFilter`] with default values + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Set the user who owns the compatibility sessions + #[must_use] + pub fn for_user(mut self, user: &'a User) -> Self { + self.user = Some(user); + self + } + + /// Get the user filter + #[must_use] + pub fn user(&self) -> Option<&User> { + self.user + } + + /// Only return active compatibility sessions + #[must_use] + pub fn active_only(mut self) -> Self { + self.state = Some(AppSessionState::Active); + self + } + + /// Only return finished compatibility sessions + #[must_use] + pub fn finished_only(mut self) -> Self { + self.state = Some(AppSessionState::Finished); + self + } + + /// Get the state filter + #[must_use] + pub fn state(&self) -> Option { + self.state + } +} + +/// A [`AppSessionRepository`] helps interacting with both [`CompatSession`] and +/// OAuth 2.0 [`Session`] at the same time saved in the storage backend +#[async_trait] +pub trait AppSessionRepository: Send + Sync { + /// The error type returned by the repository + type Error; + + /// List [`AppSession`] with the given filter and pagination + /// + /// Returns a page of [`AppSession`] matching the given filter + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list( + &mut self, + filter: AppSessionFilter<'_>, + pagination: Pagination, + ) -> Result, Self::Error>; + + /// Count the number of [`AppSession`] with the given filter + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result; +} + +repository_impl!(AppSessionRepository: + async fn list( + &mut self, + filter: AppSessionFilter<'_>, + pagination: Pagination, + ) -> Result, Self::Error>; + + async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result; +); diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index f859b93a..7627e539 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -132,7 +132,7 @@ impl<'a> CompatSessionFilter<'a> { } /// A [`CompatSessionRepository`] helps interacting with -/// [`CompatSessionRepository`] saved in the storage backend +/// [`CompatSession`] saved in the storage backend #[async_trait] pub trait CompatSessionRepository: Send + Sync { /// The error type returned by the repository diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index a1a1b0e6..53ab42e7 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -149,6 +149,7 @@ pub mod pagination; pub(crate) mod repository; mod utils; +pub mod app_session; pub mod compat; pub mod job; pub mod oauth2; diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 4461df0c..7f71ab96 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -16,6 +16,7 @@ use futures_util::future::BoxFuture; use thiserror::Error; use crate::{ + app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, @@ -150,6 +151,9 @@ pub trait RepositoryAccess: Send { &'c mut self, ) -> Box + 'c>; + /// Get a [`AppSessionRepository`] + fn app_session<'c>(&'c mut self) -> Box + 'c>; + /// Get an [`OAuth2ClientRepository`] fn oauth2_client<'c>(&'c mut self) -> Box + 'c>; @@ -205,6 +209,7 @@ mod impls { use super::RepositoryAccess; use crate::{ + app_session::AppSessionRepository, compat::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, @@ -310,6 +315,12 @@ mod impls { Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper)) } + fn app_session<'c>( + &'c mut self, + ) -> Box + 'c> { + Box::new(MapErr::new(self.inner.app_session(), &mut self.mapper)) + } + fn oauth2_client<'c>( &'c mut self, ) -> Box + 'c> { @@ -425,6 +436,12 @@ mod impls { (**self).browser_session() } + fn app_session<'c>( + &'c mut self, + ) -> Box + 'c> { + (**self).app_session() + } + fn oauth2_client<'c>( &'c mut self, ) -> Box + 'c> {