1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

WIP: use sea-query for dynamic paginated queries

This commit is contained in:
Quentin Gliech
2023-07-19 13:34:39 +02:00
parent 5f8cd98052
commit 7e82ae845c
15 changed files with 360 additions and 86 deletions

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_graphql::{Description, Object, ID};
use async_graphql::{Description, Enum, Object, ID};
use chrono::{DateTime, Utc};
use super::{NodeType, User};
@ -21,6 +21,16 @@ use super::{NodeType, User};
#[derive(Description)]
pub struct BrowserSession(pub mas_data_model::BrowserSession);
/// The state of a browser session.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum BrowserSessionState {
/// The session is active.
Active,
/// The session is no longer active.
Finished,
}
impl From<mas_data_model::BrowserSession> for BrowserSession {
fn from(v: mas_data_model::BrowserSession) -> Self {
Self(v)

View File

@ -21,7 +21,7 @@ use mas_storage::{
compat::CompatSsoLoginRepository,
oauth2::OAuth2SessionRepository,
upstream_oauth2::UpstreamOAuthLinkRepository,
user::{BrowserSessionRepository, UserEmailRepository},
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailRepository},
Pagination, RepositoryAccess,
};
@ -30,7 +30,7 @@ use super::{
UpstreamOAuth2Link,
};
use crate::{
model::{matrix::MatrixUser, CompatSession},
model::{browser_sessions::BrowserSessionState, matrix::MatrixUser, CompatSession},
state::ContextExt,
};
@ -189,6 +189,9 @@ impl User {
&self,
ctx: &Context<'_>,
#[graphql(name = "state", desc = "List only sessions in the given state.")]
state_param: Option<BrowserSessionState>,
#[graphql(desc = "Returns the elements in the list that come after the cursor.")]
after: Option<String>,
#[graphql(desc = "Returns the elements in the list that come before the cursor.")]
@ -213,10 +216,14 @@ impl User {
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let page = repo
.browser_session()
.list_active_paginated(&self.0, pagination)
.await?;
let filter = BrowserSessionFilter::new().for_user(&self.0);
let filter = match state_param {
Some(BrowserSessionState::Active) => filter.active_only(),
Some(BrowserSessionState::Finished) => filter.finished_only(),
None => filter,
};
let page = repo.browser_session().list(filter, pagination).await?;
repo.cancel().await?;

View File

@ -8,6 +8,7 @@ license = "Apache-2.0"
[dependencies]
async-trait = "0.1.71"
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "json", "uuid"] }
sea-query = { version = "0.28.5", features = ["derive", "attr", "with-uuid", "with-chrono"] }
chrono = { version = "0.4.26", features = ["serde"] }
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.103"

View File

@ -217,6 +217,7 @@ pub mod user;
mod errors;
pub(crate) mod pagination;
pub(crate) mod repository;
mod sea_query_sqlx;
pub(crate) mod tracing;
pub(crate) use self::errors::DatabaseInconsistencyError;

View File

@ -21,9 +21,11 @@ use uuid::Uuid;
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
/// to a query
pub trait QueryBuilderExt {
type Iden;
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self;
fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self;
}
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
@ -32,6 +34,8 @@ where
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
type Iden = &'static str;
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
@ -76,3 +80,40 @@ where
self
}
}
impl QueryBuilderExt for sea_query::SelectStatement {
type Iden = sea_query::ColumnRef;
fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self {
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = pagination.after {
self.and_where(sea_query::Expr::col(id_field.clone()).gt(Uuid::from(after)));
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = pagination.before {
self.and_where(sea_query::Expr::col(id_field.clone()).lt(Uuid::from(before)));
}
match pagination.direction {
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
// query
PaginationDirection::Forward => {
self.order_by(id_field, sea_query::Order::Asc)
.limit((pagination.count + 1) as u64);
}
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
// query
PaginationDirection::Backward => {
self.order_by(id_field, sea_query::Order::Desc)
.limit((pagination.count + 1) as u64);
}
};
self
}
}

View File

@ -0,0 +1,49 @@
// Copyright 2021-2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sea_query::Value;
use sqlx::Arguments;
pub(crate) fn map_values(values: sea_query::Values) -> sqlx::postgres::PgArguments {
let mut arguments = sqlx::postgres::PgArguments::default();
for value in values {
match value {
Value::Bool(b) => arguments.add(b),
Value::TinyInt(i) => arguments.add(i),
Value::SmallInt(i) => arguments.add(i),
Value::Int(i) => arguments.add(i),
Value::BigInt(i) => arguments.add(i),
Value::TinyUnsigned(u) => arguments.add(u.map(|u| u as i16)),
Value::SmallUnsigned(u) => arguments.add(u.map(|u| u as i32)),
Value::Unsigned(u) => arguments.add(u.map(|u| u as i64)),
Value::BigUnsigned(u) => arguments.add(u.map(|u| i64::try_from(u).unwrap_or(i64::MAX))),
Value::Float(f) => arguments.add(f),
Value::Double(d) => arguments.add(d),
Value::String(s) => arguments.add(s.as_deref()),
Value::Char(c) => arguments.add(c.map(|c| c.to_string())),
Value::Bytes(b) => arguments.add(b.as_deref()),
Value::ChronoDate(d) => arguments.add(d.as_deref()),
Value::ChronoTime(t) => arguments.add(t.as_deref()),
Value::ChronoDateTime(dt) => arguments.add(dt.as_deref()),
Value::ChronoDateTimeUtc(dt) => arguments.add(dt.as_deref()),
Value::ChronoDateTimeLocal(dt) => arguments.add(dt.as_deref()),
Value::ChronoDateTimeWithTimeZone(dt) => arguments.add(dt.as_deref()),
Value::Uuid(u) => arguments.add(u.as_deref()),
_ => unimplemented!(),
}
}
arguments
}

View File

@ -17,13 +17,14 @@ use chrono::{DateTime, Utc};
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
use rand::RngCore;
use sea_query::{Expr, IntoColumnRef, PostgresQueryBuilder};
use sqlx::{PgConnection, QueryBuilder};
use ulid::Ulid;
use uuid::Uuid;
use crate::{
pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError,
LookupResultExt,
pagination::QueryBuilderExt, sea_query_sqlx::map_values, tracing::ExecuteExt, DatabaseError,
DatabaseInconsistencyError, LookupResultExt,
};
/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
@ -41,6 +42,7 @@ impl<'c> PgBrowserSessionRepository<'c> {
}
#[derive(sqlx::FromRow)]
#[sea_query::enum_def]
struct SessionLookup {
user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
@ -52,6 +54,31 @@ struct SessionLookup {
last_authd_at: Option<DateTime<Utc>>,
}
#[derive(sea_query::Iden)]
enum UserSessions {
Table,
UserSessionId,
CreatedAt,
FinishedAt,
UserId,
}
#[derive(sea_query::Iden)]
enum Users {
Table,
UserId,
Username,
PrimaryUserEmailId,
}
#[derive(sea_query::Iden)]
enum SessionAuthentication {
Table,
UserSessionAuthenticationId,
UserSessionId,
CreatedAt,
}
impl TryFrom<SessionLookup> for BrowserSession {
type Error = DatabaseInconsistencyError;
@ -214,46 +241,78 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
}
#[tracing::instrument(
name = "db.browser_session.list_active_paginated",
name = "db.browser_session.list",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn list_active_paginated(
async fn list(
&mut self,
user: &User,
filter: mas_storage::user::BrowserSessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<BrowserSession>, Self::Error> {
// TODO: ordering of last authentication is wrong
let mut query = QueryBuilder::new(
r#"
SELECT DISTINCT ON (s.user_session_id)
s.user_session_id,
s.created_at AS "user_session_created_at",
s.finished_at AS "user_session_finished_at",
u.user_id,
u.username AS "user_username",
u.primary_user_email_id AS "user_primary_user_email_id",
a.user_session_authentication_id AS "last_authentication_id",
a.created_at AS "last_authd_at"
FROM user_sessions s
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications a
USING (user_session_id)
"#,
);
let (sql, values) = sea_query::Query::select()
.expr_as(
Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
SessionLookupIden::UserSessionId,
)
.expr_as(
Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
SessionLookupIden::UserSessionCreatedAt,
)
.expr_as(
Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
SessionLookupIden::UserSessionFinishedAt,
)
.expr_as(
Expr::col((Users::Table, Users::UserId)),
SessionLookupIden::UserId,
)
.expr_as(
Expr::col((Users::Table, Users::Username)),
SessionLookupIden::UserUsername,
)
.expr_as(
Expr::col((Users::Table, Users::PrimaryUserEmailId)),
SessionLookupIden::UserPrimaryUserEmailId,
)
.expr_as(
Expr::value(None::<Uuid>),
SessionLookupIden::LastAuthenticationId,
)
.expr_as(
Expr::value(None::<DateTime<Utc>>),
SessionLookupIden::LastAuthdAt,
)
.from(UserSessions::Table)
.inner_join(
Users::Table,
Expr::col((UserSessions::Table, UserSessions::UserId))
.equals((Users::Table, Users::UserId)),
)
.and_where_option(
filter
.user()
.map(|user| Expr::col((Users::Table, Users::UserId)).eq(Uuid::from(user.id))),
)
.and_where_option(filter.state().map(|state| {
if state.is_active() {
Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
} else {
Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
}
}))
.generate_pagination(
(UserSessions::Table, UserSessions::UserSessionId).into_column_ref(),
pagination,
)
.build(PostgresQueryBuilder);
query
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("s.user_session_id", pagination);
let arguments = map_values(values);
let edges: Vec<SessionLookup> = query
.build_query_as()
let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
.traced()
.fetch_all(&mut *self.conn)
.await?;
@ -261,34 +320,10 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
let page = pagination
.process(edges)
.try_map(BrowserSession::try_from)?;
Ok(page)
}
#[tracing::instrument(
name = "db.browser_session.count_active",
skip_all,
fields(
db.statement,
%user.id,
),
err,
)]
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM user_sessions s
WHERE s.user_id = $1 AND s.finished_at IS NULL
"#,
Uuid::from(user.id),
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
res.try_into().map_err(DatabaseError::to_invalid_operation)
}
#[tracing::instrument(
name = "db.browser_session.authenticate_with_password",
skip_all,

View File

@ -26,7 +26,9 @@ mod password;
mod session;
pub use self::{
email::UserEmailRepository, password::UserPasswordRepository, session::BrowserSessionRepository,
email::UserEmailRepository,
password::UserPasswordRepository,
session::{BrowserSessionFilter, BrowserSessionRepository},
};
/// A [`UserRepository`] helps interacting with [`User`] saved in the storage

View File

@ -19,6 +19,70 @@ use ulid::Ulid;
use crate::{pagination::Page, repository_impl, Clock, Pagination};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BrowserSessionState {
Active,
Finished,
}
impl BrowserSessionState {
pub fn is_active(self) -> bool {
matches!(self, Self::Active)
}
pub fn is_finished(self) -> bool {
matches!(self, Self::Finished)
}
}
/// Filter parameters for listing browser sessions
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct BrowserSessionFilter<'a> {
user: Option<&'a User>,
state: Option<BrowserSessionState>,
}
impl<'a> BrowserSessionFilter<'a> {
/// Create a new [`BrowserSessionFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Set the user who owns the browser sessions
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// Only return active browser sessions
#[must_use]
pub fn active_only(mut self) -> Self {
self.state = Some(BrowserSessionState::Active);
self
}
/// Only return finished browser sessions
#[must_use]
pub fn finished_only(mut self) -> Self {
self.state = Some(BrowserSessionState::Finished);
self
}
/// Get the state filter
#[must_use]
pub fn state(&self) -> Option<BrowserSessionState> {
self.state
}
}
/// A [`BrowserSessionRepository`] helps interacting with [`BrowserSession`]
/// saved in the storage backend
#[async_trait]
@ -77,33 +141,22 @@ pub trait BrowserSessionRepository: Send + Sync {
user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error>;
/// List active [`BrowserSession`] for a [`User`] with the given pagination
/// List [`BrowserSession`] with the given filter and pagination
///
/// # Parameters
///
/// * `user`: The user to list the sessions for
/// * `filter`: The filter to apply
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list_active_paginated(
async fn list(
&mut self,
user: &User,
filter: BrowserSessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<BrowserSession>, Self::Error>;
/// Count active [`BrowserSession`] for a [`User`]
///
/// # Parameters
///
/// * `user`: The user to count the sessions for
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error>;
/// Authenticate a [`BrowserSession`] with the given [`Password`]
///
/// Returns the updated [`BrowserSession`]
@ -163,12 +216,12 @@ repository_impl!(BrowserSessionRepository:
clock: &dyn Clock,
user_session: BrowserSession,
) -> Result<BrowserSession, Self::Error>;
async fn list_active_paginated(
async fn list(
&mut self,
user: &User,
filter: BrowserSessionFilter<'_>,
pagination: Pagination,
) -> Result<Page<BrowserSession>, Self::Error>;
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error>;
async fn authenticate_with_password(
&mut self,