diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index 5f9a5da4..10efbcd7 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -20,7 +20,10 @@ use mas_data_model::{ Authentication, AuthenticationMethod, BrowserSession, Password, UpstreamOAuthAuthorizationSession, User, UserAgent, }; -use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; +use mas_storage::{ + user::{BrowserSessionFilter, BrowserSessionRepository}, + Clock, Page, Pagination, +}; use rand::RngCore; use sea_query::{Expr, PostgresQueryBuilder}; use sea_query_binder::SqlxBinder; @@ -29,6 +32,7 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ + filter::StatementExt, iden::{UserSessions, Users}, pagination::QueryBuilderExt, tracing::ExecuteExt, @@ -130,6 +134,22 @@ impl TryFrom for Authentication { } } +impl crate::filter::Filter for BrowserSessionFilter<'_> { + fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition { + sea_query::Condition::all() + .add_option(self.user().map(|user| { + Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id)) + })) + .add_option(self.state().map(|state| { + if state.is_active() { + Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null() + } else { + Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null() + } + })) + } +} + #[async_trait] impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { type Error = DatabaseError; @@ -270,22 +290,13 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { async fn finish_bulk( &mut self, clock: &dyn Clock, - filter: mas_storage::user::BrowserSessionFilter<'_>, + filter: BrowserSessionFilter<'_>, ) -> Result { let finished_at = clock.now(); let (sql, arguments) = sea_query::Query::update() .table(UserSessions::Table) .value(UserSessions::FinishedAt, finished_at) - .and_where_option(filter.user().map(|user| { - Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null() - } else { - Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null() - } - })) + .apply_filter(filter) .build_sqlx(PostgresQueryBuilder); let res = sqlx::query_with(&sql, arguments) @@ -306,7 +317,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { )] async fn list( &mut self, - filter: mas_storage::user::BrowserSessionFilter<'_>, + filter: BrowserSessionFilter<'_>, pagination: Pagination, ) -> Result, Self::Error> { let (sql, arguments) = sea_query::Query::select() @@ -364,18 +375,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { Expr::col((UserSessions::Table, UserSessions::UserId)) .equals((Users::Table, Users::UserId)), ) - .and_where_option( - filter - .user() - .map(|user| Expr::col((Users::Table, Users::UserId)).eq(Uuid::from(user.id))), - ) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null() - } else { - Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null() - } - })) + .apply_filter(filter) .generate_pagination( (UserSessions::Table, UserSessions::UserSessionId), pagination, @@ -402,23 +402,11 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { ), err, )] - async fn count( - &mut self, - filter: mas_storage::user::BrowserSessionFilter<'_>, - ) -> Result { + async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result { let (sql, arguments) = sea_query::Query::select() .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count()) .from(UserSessions::Table) - .and_where_option(filter.user().map(|user| { - Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null() - } else { - Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null() - } - })) + .apply_filter(filter) .build_sqlx(PostgresQueryBuilder); let count: i64 = sqlx::query_scalar_with(&sql, arguments)