You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Use dynamic filters on OAuth 2.0 sessions
This commit is contained in:
@@ -30,7 +30,8 @@ use ulid::Ulid;
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
iden::{OAuth2Sessions, UserSessions},
|
filter::{Filter, StatementExt},
|
||||||
|
iden::OAuth2Sessions,
|
||||||
pagination::QueryBuilderExt,
|
pagination::QueryBuilderExt,
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
DatabaseError, DatabaseInconsistencyError,
|
DatabaseError, DatabaseInconsistencyError,
|
||||||
@@ -101,6 +102,30 @@ impl TryFrom<OAuthSessionLookup> for Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Filter for OAuth2SessionFilter<'_> {
|
||||||
|
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
|
||||||
|
sea_query::Condition::all()
|
||||||
|
.add_option(self.user().map(|user| {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
||||||
|
}))
|
||||||
|
.add_option(self.client().map(|client| {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
||||||
|
.eq(Uuid::from(client.id))
|
||||||
|
}))
|
||||||
|
.add_option(self.state().map(|state| {
|
||||||
|
if state.is_active() {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
||||||
|
} else {
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.add_option(self.scope().map(|scope| {
|
||||||
|
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
||||||
|
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
||||||
type Error = DatabaseError;
|
type Error = DatabaseError;
|
||||||
@@ -223,24 +248,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
let (sql, arguments) = Query::update()
|
let (sql, arguments) = Query::update()
|
||||||
.table(OAuth2Sessions::Table)
|
.table(OAuth2Sessions::Table)
|
||||||
.value(OAuth2Sessions::FinishedAt, finished_at)
|
.value(OAuth2Sessions::FinishedAt, finished_at)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(filter)
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.client().map(|client| {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
|
||||||
.eq(Uuid::from(client.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.scope().map(|scope| {
|
|
||||||
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
|
||||||
}))
|
|
||||||
.build_sqlx(PostgresQueryBuilder);
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
let res = sqlx::query_with(&sql, arguments)
|
let res = sqlx::query_with(&sql, arguments)
|
||||||
@@ -343,24 +351,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
OAuthSessionLookupIden::LastActiveIp,
|
OAuthSessionLookupIden::LastActiveIp,
|
||||||
)
|
)
|
||||||
.from(OAuth2Sessions::Table)
|
.from(OAuth2Sessions::Table)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(filter)
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.client().map(|client| {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
|
||||||
.eq(Uuid::from(client.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.scope().map(|scope| {
|
|
||||||
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
|
||||||
}))
|
|
||||||
.generate_pagination(
|
.generate_pagination(
|
||||||
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
|
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
|
||||||
pagination,
|
pagination,
|
||||||
@@ -389,39 +380,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
let (sql, arguments) = Query::select()
|
let (sql, arguments) = Query::select()
|
||||||
.expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
|
.expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
|
||||||
.from(OAuth2Sessions::Table)
|
.from(OAuth2Sessions::Table)
|
||||||
.and_where_option(filter.user().map(|user| {
|
.apply_filter(filter)
|
||||||
// Check for user ownership by querying the user_sessions table
|
|
||||||
// The query plan is the same as if we were joining the tables instead
|
|
||||||
Expr::exists(
|
|
||||||
Query::select()
|
|
||||||
.expr(Expr::cust("1"))
|
|
||||||
.from(UserSessions::Table)
|
|
||||||
.and_where(
|
|
||||||
Expr::col((UserSessions::Table, UserSessions::UserId))
|
|
||||||
.eq(Uuid::from(user.id)),
|
|
||||||
)
|
|
||||||
.and_where(
|
|
||||||
Expr::col((UserSessions::Table, UserSessions::UserSessionId))
|
|
||||||
.equals((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
|
|
||||||
)
|
|
||||||
.take(),
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.client().map(|client| {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
|
|
||||||
.eq(Uuid::from(client.id))
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.state().map(|state| {
|
|
||||||
if state.is_active() {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
|
|
||||||
} else {
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.and_where_option(filter.scope().map(|scope| {
|
|
||||||
let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
|
|
||||||
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
|
|
||||||
}))
|
|
||||||
.build_sqlx(PostgresQueryBuilder);
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
|
||||||
@@ -463,7 +422,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
, last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
|
, last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
|
||||||
FROM (
|
FROM (
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
|
FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
|
||||||
AS t(oauth2_session_id, last_active_at, last_active_ip)
|
AS t(oauth2_session_id, last_active_at, last_active_ip)
|
||||||
) AS t
|
) AS t
|
||||||
WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
|
WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
|
||||||
|
Reference in New Issue
Block a user