diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 9eb5f1a4..42ae269e 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -30,7 +30,8 @@ use ulid::Ulid; use uuid::Uuid; use crate::{ - iden::{OAuth2Sessions, UserSessions}, + filter::{Filter, StatementExt}, + iden::OAuth2Sessions, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, @@ -101,6 +102,30 @@ impl TryFrom for Session { } } +impl Filter for OAuth2SessionFilter<'_> { + fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition { + sea_query::Condition::all() + .add_option(self.user().map(|user| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) + })) + .add_option(self.client().map(|client| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)) + .eq(Uuid::from(client.id)) + })) + .add_option(self.state().map(|state| { + if state.is_active() { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() + } else { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() + } + })) + .add_option(self.scope().map(|scope| { + let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) + })) + } +} + #[async_trait] impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { type Error = DatabaseError; @@ -223,24 +248,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { let (sql, arguments) = Query::update() .table(OAuth2Sessions::Table) .value(OAuth2Sessions::FinishedAt, finished_at) - .and_where_option(filter.user().map(|user| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.client().map(|client| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)) - .eq(Uuid::from(client.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() - } else { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.scope().map(|scope| { - let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) - })) + .apply_filter(filter) .build_sqlx(PostgresQueryBuilder); let res = sqlx::query_with(&sql, arguments) @@ -343,24 +351,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { OAuthSessionLookupIden::LastActiveIp, ) .from(OAuth2Sessions::Table) - .and_where_option(filter.user().map(|user| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) - })) - .and_where_option(filter.client().map(|client| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)) - .eq(Uuid::from(client.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() - } else { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.scope().map(|scope| { - let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) - })) + .apply_filter(filter) .generate_pagination( (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId), pagination, @@ -389,39 +380,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { let (sql, arguments) = Query::select() .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count()) .from(OAuth2Sessions::Table) - .and_where_option(filter.user().map(|user| { - // Check for user ownership by querying the user_sessions table - // The query plan is the same as if we were joining the tables instead - Expr::exists( - Query::select() - .expr(Expr::cust("1")) - .from(UserSessions::Table) - .and_where( - Expr::col((UserSessions::Table, UserSessions::UserId)) - .eq(Uuid::from(user.id)), - ) - .and_where( - Expr::col((UserSessions::Table, UserSessions::UserSessionId)) - .equals((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)), - ) - .take(), - ) - })) - .and_where_option(filter.client().map(|client| { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)) - .eq(Uuid::from(client.id)) - })) - .and_where_option(filter.state().map(|state| { - if state.is_active() { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() - } else { - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() - } - })) - .and_where_option(filter.scope().map(|scope| { - let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); - Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) - })) + .apply_filter(filter) .build_sqlx(PostgresQueryBuilder); let count: i64 = sqlx::query_scalar_with(&sql, arguments) @@ -463,7 +422,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip) FROM ( SELECT * - FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) AS t(oauth2_session_id, last_active_at, last_active_ip) ) AS t WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id