1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Use dynamic filters on upstream OAuth 2.0 providers

This commit is contained in:
Quentin Gliech
2024-07-16 17:00:47 +02:00
parent 7c2c310cac
commit 112f673e22
2 changed files with 23 additions and 21 deletions

View File

@@ -34,5 +34,5 @@
false false
] ]
}, },
"hash": "94fd87e99088671b6a20bb7b9a3838ecce8df564257b348adf22f2e9356e6dae" "hash": "9aa8fa3a6277f67b2bf5a5ea5429a61e7997ff4f3e8d0dc772448a1f97e1e390"
} }

View File

@@ -31,8 +31,11 @@ use ulid::Ulid;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
iden::UpstreamOAuthProviders, pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, filter::{Filter, StatementExt},
DatabaseInconsistencyError, iden::UpstreamOAuthProviders,
pagination::QueryBuilderExt,
tracing::ExecuteExt,
DatabaseError, DatabaseInconsistencyError,
}; };
/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL /// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
@@ -174,6 +177,19 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
} }
} }
impl Filter for UpstreamOAuthProviderFilter<'_> {
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled)
}))
}
}
#[async_trait] #[async_trait]
impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> {
type Error = DatabaseError; type Error = DatabaseError;
@@ -676,14 +692,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
ProviderLookupIden::AdditionalParameters, ProviderLookupIden::AdditionalParameters,
) )
.from(UpstreamOAuthProviders::Table) .from(UpstreamOAuthProviders::Table)
.and_where_option(filter.enabled().map(|enabled| { .apply_filter(filter)
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled)
}))
.generate_pagination( .generate_pagination(
( (
UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Table,
@@ -726,14 +735,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
.count(), .count(),
) )
.from(UpstreamOAuthProviders::Table) .from(UpstreamOAuthProviders::Table)
.and_where_option(filter.enabled().map(|enabled| { .apply_filter(filter)
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::DisabledAt,
))
.is_null()
.eq(enabled)
}))
.build_sqlx(PostgresQueryBuilder); .build_sqlx(PostgresQueryBuilder);
let count: i64 = sqlx::query_scalar_with(&sql, arguments) let count: i64 = sqlx::query_scalar_with(&sql, arguments)