You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-21 23:00:50 +03:00
WIP: use sea-query for dynamic paginated queries
This commit is contained in:
@@ -217,6 +217,7 @@ pub mod user;
|
||||
mod errors;
|
||||
pub(crate) mod pagination;
|
||||
pub(crate) mod repository;
|
||||
mod sea_query_sqlx;
|
||||
pub(crate) mod tracing;
|
||||
|
||||
pub(crate) use self::errors::DatabaseInconsistencyError;
|
||||
|
||||
@@ -21,9 +21,11 @@ use uuid::Uuid;
|
||||
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
|
||||
/// to a query
|
||||
pub trait QueryBuilderExt {
|
||||
type Iden;
|
||||
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
|
||||
fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self;
|
||||
}
|
||||
|
||||
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
||||
@@ -32,6 +34,8 @@ where
|
||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
{
|
||||
type Iden = &'static str;
|
||||
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
|
||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||
// 1. Start from the greedy query: SELECT * FROM table
|
||||
@@ -76,3 +80,40 @@ where
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryBuilderExt for sea_query::SelectStatement {
|
||||
type Iden = sea_query::ColumnRef;
|
||||
fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self {
|
||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||
// 1. Start from the greedy query: SELECT * FROM table
|
||||
|
||||
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
||||
// clause
|
||||
if let Some(after) = pagination.after {
|
||||
self.and_where(sea_query::Expr::col(id_field.clone()).gt(Uuid::from(after)));
|
||||
}
|
||||
|
||||
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
||||
// `WHERE` clause
|
||||
if let Some(before) = pagination.before {
|
||||
self.and_where(sea_query::Expr::col(id_field.clone()).lt(Uuid::from(before)));
|
||||
}
|
||||
|
||||
match pagination.direction {
|
||||
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
|
||||
// query
|
||||
PaginationDirection::Forward => {
|
||||
self.order_by(id_field, sea_query::Order::Asc)
|
||||
.limit((pagination.count + 1) as u64);
|
||||
}
|
||||
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
|
||||
// query
|
||||
PaginationDirection::Backward => {
|
||||
self.order_by(id_field, sea_query::Order::Desc)
|
||||
.limit((pagination.count + 1) as u64);
|
||||
}
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
49
crates/storage-pg/src/sea_query_sqlx.rs
Normal file
49
crates/storage-pg/src/sea_query_sqlx.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sea_query::Value;
|
||||
use sqlx::Arguments;
|
||||
|
||||
pub(crate) fn map_values(values: sea_query::Values) -> sqlx::postgres::PgArguments {
|
||||
let mut arguments = sqlx::postgres::PgArguments::default();
|
||||
|
||||
for value in values {
|
||||
match value {
|
||||
Value::Bool(b) => arguments.add(b),
|
||||
Value::TinyInt(i) => arguments.add(i),
|
||||
Value::SmallInt(i) => arguments.add(i),
|
||||
Value::Int(i) => arguments.add(i),
|
||||
Value::BigInt(i) => arguments.add(i),
|
||||
Value::TinyUnsigned(u) => arguments.add(u.map(|u| u as i16)),
|
||||
Value::SmallUnsigned(u) => arguments.add(u.map(|u| u as i32)),
|
||||
Value::Unsigned(u) => arguments.add(u.map(|u| u as i64)),
|
||||
Value::BigUnsigned(u) => arguments.add(u.map(|u| i64::try_from(u).unwrap_or(i64::MAX))),
|
||||
Value::Float(f) => arguments.add(f),
|
||||
Value::Double(d) => arguments.add(d),
|
||||
Value::String(s) => arguments.add(s.as_deref()),
|
||||
Value::Char(c) => arguments.add(c.map(|c| c.to_string())),
|
||||
Value::Bytes(b) => arguments.add(b.as_deref()),
|
||||
Value::ChronoDate(d) => arguments.add(d.as_deref()),
|
||||
Value::ChronoTime(t) => arguments.add(t.as_deref()),
|
||||
Value::ChronoDateTime(dt) => arguments.add(dt.as_deref()),
|
||||
Value::ChronoDateTimeUtc(dt) => arguments.add(dt.as_deref()),
|
||||
Value::ChronoDateTimeLocal(dt) => arguments.add(dt.as_deref()),
|
||||
Value::ChronoDateTimeWithTimeZone(dt) => arguments.add(dt.as_deref()),
|
||||
Value::Uuid(u) => arguments.add(u.as_deref()),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
arguments
|
||||
}
|
||||
@@ -17,13 +17,14 @@ use chrono::{DateTime, Utc};
|
||||
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
|
||||
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
|
||||
use rand::RngCore;
|
||||
use sea_query::{Expr, IntoColumnRef, PostgresQueryBuilder};
|
||||
use sqlx::{PgConnection, QueryBuilder};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
|
||||
LookupResultExt,
|
||||
pagination::QueryBuilderExt, sea_query_sqlx::map_values, tracing::ExecuteExt, DatabaseError,
|
||||
DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
|
||||
@@ -41,6 +42,7 @@ impl<'c> PgBrowserSessionRepository<'c> {
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
#[sea_query::enum_def]
|
||||
struct SessionLookup {
|
||||
user_session_id: Uuid,
|
||||
user_session_created_at: DateTime<Utc>,
|
||||
@@ -52,6 +54,31 @@ struct SessionLookup {
|
||||
last_authd_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
enum UserSessions {
|
||||
Table,
|
||||
UserSessionId,
|
||||
CreatedAt,
|
||||
FinishedAt,
|
||||
UserId,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
enum Users {
|
||||
Table,
|
||||
UserId,
|
||||
Username,
|
||||
PrimaryUserEmailId,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
enum SessionAuthentication {
|
||||
Table,
|
||||
UserSessionAuthenticationId,
|
||||
UserSessionId,
|
||||
CreatedAt,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for BrowserSession {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
@@ -214,46 +241,78 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.list_active_paginated",
|
||||
name = "db.browser_session.list",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn list_active_paginated(
|
||||
async fn list(
|
||||
&mut self,
|
||||
user: &User,
|
||||
filter: mas_storage::user::BrowserSessionFilter<'_>,
|
||||
pagination: Pagination,
|
||||
) -> Result<Page<BrowserSession>, Self::Error> {
|
||||
// TODO: ordering of last authentication is wrong
|
||||
let mut query = QueryBuilder::new(
|
||||
r#"
|
||||
SELECT DISTINCT ON (s.user_session_id)
|
||||
s.user_session_id,
|
||||
s.created_at AS "user_session_created_at",
|
||||
s.finished_at AS "user_session_finished_at",
|
||||
u.user_id,
|
||||
u.username AS "user_username",
|
||||
u.primary_user_email_id AS "user_primary_user_email_id",
|
||||
a.user_session_authentication_id AS "last_authentication_id",
|
||||
a.created_at AS "last_authd_at"
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
USING (user_id)
|
||||
LEFT JOIN user_session_authentications a
|
||||
USING (user_session_id)
|
||||
"#,
|
||||
);
|
||||
let (sql, values) = sea_query::Query::select()
|
||||
.expr_as(
|
||||
Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
|
||||
SessionLookupIden::UserSessionId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
|
||||
SessionLookupIden::UserSessionCreatedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
|
||||
SessionLookupIden::UserSessionFinishedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::UserId)),
|
||||
SessionLookupIden::UserId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::Username)),
|
||||
SessionLookupIden::UserUsername,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
|
||||
SessionLookupIden::UserPrimaryUserEmailId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::value(None::<Uuid>),
|
||||
SessionLookupIden::LastAuthenticationId,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::value(None::<DateTime<Utc>>),
|
||||
SessionLookupIden::LastAuthdAt,
|
||||
)
|
||||
.from(UserSessions::Table)
|
||||
.inner_join(
|
||||
Users::Table,
|
||||
Expr::col((UserSessions::Table, UserSessions::UserId))
|
||||
.equals((Users::Table, Users::UserId)),
|
||||
)
|
||||
.and_where_option(
|
||||
filter
|
||||
.user()
|
||||
.map(|user| Expr::col((Users::Table, Users::UserId)).eq(Uuid::from(user.id))),
|
||||
)
|
||||
.and_where_option(filter.state().map(|state| {
|
||||
if state.is_active() {
|
||||
Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
|
||||
} else {
|
||||
Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
|
||||
}
|
||||
}))
|
||||
.generate_pagination(
|
||||
(UserSessions::Table, UserSessions::UserSessionId).into_column_ref(),
|
||||
pagination,
|
||||
)
|
||||
.build(PostgresQueryBuilder);
|
||||
|
||||
query
|
||||
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
|
||||
.push_bind(Uuid::from(user.id))
|
||||
.generate_pagination("s.user_session_id", pagination);
|
||||
let arguments = map_values(values);
|
||||
|
||||
let edges: Vec<SessionLookup> = query
|
||||
.build_query_as()
|
||||
let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
@@ -261,34 +320,10 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
||||
let page = pagination
|
||||
.process(edges)
|
||||
.try_map(BrowserSession::try_from)?;
|
||||
|
||||
Ok(page)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.count_active",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
%user.id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error> {
|
||||
let res = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT COUNT(*) as "count!"
|
||||
FROM user_sessions s
|
||||
WHERE s.user_id = $1 AND s.finished_at IS NULL
|
||||
"#,
|
||||
Uuid::from(user.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
res.try_into().map_err(DatabaseError::to_invalid_operation)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.browser_session.authenticate_with_password",
|
||||
skip_all,
|
||||
|
||||
Reference in New Issue
Block a user