1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Use dynamic filters on OAuth 2.0 sessions

This commit is contained in:
Quentin Gliech
2024-07-16 16:48:40 +02:00
parent df7bc53826
commit 7c54c5f2e6

View File

@@ -30,7 +30,8 @@ use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
iden::{OAuth2Sessions, UserSessions}, filter::{Filter, StatementExt},
iden::OAuth2Sessions,
pagination::QueryBuilderExt, pagination::QueryBuilderExt,
tracing::ExecuteExt, tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError, DatabaseError, DatabaseInconsistencyError,
@@ -101,6 +102,30 @@ impl TryFrom<OAuthSessionLookup> for Session {
} }
} }
impl Filter for OAuth2SessionFilter<'_> {
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all()
.add_option(self.user().map(|user| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.add_option(self.client().map(|client| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
.eq(Uuid::from(client.id))
}))
.add_option(self.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.add_option(self.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
}
}
#[async_trait] #[async_trait]
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
type Error = DatabaseError; type Error = DatabaseError;
@@ -223,24 +248,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
let (sql, arguments) = Query::update() let (sql, arguments) = Query::update()
.table(OAuth2Sessions::Table) .table(OAuth2Sessions::Table)
.value(OAuth2Sessions::FinishedAt, finished_at) .value(OAuth2Sessions::FinishedAt, finished_at)
.and_where_option(filter.user().map(|user| { .apply_filter(filter)
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.client().map(|client| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
.eq(Uuid::from(client.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
.build_sqlx(PostgresQueryBuilder); .build_sqlx(PostgresQueryBuilder);
let res = sqlx::query_with(&sql, arguments) let res = sqlx::query_with(&sql, arguments)
@@ -343,24 +351,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
OAuthSessionLookupIden::LastActiveIp, OAuthSessionLookupIden::LastActiveIp,
) )
.from(OAuth2Sessions::Table) .from(OAuth2Sessions::Table)
.and_where_option(filter.user().map(|user| { .apply_filter(filter)
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
}))
.and_where_option(filter.client().map(|client| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
.eq(Uuid::from(client.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
.generate_pagination( .generate_pagination(
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId), (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
pagination, pagination,
@@ -389,39 +380,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
let (sql, arguments) = Query::select() let (sql, arguments) = Query::select()
.expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count()) .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
.from(OAuth2Sessions::Table) .from(OAuth2Sessions::Table)
.and_where_option(filter.user().map(|user| { .apply_filter(filter)
// Check for user ownership by querying the user_sessions table
// The query plan is the same as if we were joining the tables instead
Expr::exists(
Query::select()
.expr(Expr::cust("1"))
.from(UserSessions::Table)
.and_where(
Expr::col((UserSessions::Table, UserSessions::UserId))
.eq(Uuid::from(user.id)),
)
.and_where(
Expr::col((UserSessions::Table, UserSessions::UserSessionId))
.equals((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
)
.take(),
)
}))
.and_where_option(filter.client().map(|client| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
.eq(Uuid::from(client.id))
}))
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
} else {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
}
}))
.and_where_option(filter.scope().map(|scope| {
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
}))
.build_sqlx(PostgresQueryBuilder); .build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments) let count: i64 = sqlx::query_scalar_with(&sql, arguments)