1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

storage: simplify the paginated queries

This commit is contained in:
Quentin Gliech
2023-01-17 15:09:53 +01:00
parent 62be962c4e
commit 0d02864589
11 changed files with 207 additions and 178 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,74 +12,166 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utilities to manage paginated queries.
use sqlx::{Database, QueryBuilder};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
/// An error returned when invalid pagination parameters are provided
#[derive(Debug, Error)]
#[error("Either 'first' or 'last' must be specified")]
pub struct InvalidPagination;
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
pub fn generate_pagination<'a, DB>(
query: &mut QueryBuilder<'a, DB>,
id_field: &'static str,
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Pagination {
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(), InvalidPagination>
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
count: usize,
direction: PaginationDirection,
}
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = after {
query
.push(" AND ")
.push(id_field)
.push(" > ")
.push_bind(Uuid::from(after));
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PaginationDirection {
Forward,
Backward,
}
impl Pagination {
/// Creates a new [`Pagination`] from user-provided parameters.
///
/// # Errors
///
/// Either `first` or `last` must be provided, else this function will
/// return an [`InvalidPagination`] error.
pub const fn try_new(
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Self, InvalidPagination> {
let (direction, count) = match (first, last) {
(Some(first), _) => (PaginationDirection::Forward, first),
(_, Some(last)) => (PaginationDirection::Backward, last),
(None, None) => return Err(InvalidPagination),
};
Ok(Self {
before,
after,
count,
direction,
})
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = before {
query
.push(" AND ")
.push(id_field)
.push(" < ")
.push_bind(Uuid::from(before));
/// Creates a [`Pagination`] which gets the first N items
pub const fn first(first: usize) -> Self {
Self {
before: None,
after: None,
count: first,
direction: PaginationDirection::Forward,
}
}
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to
// the query
if let Some(count) = first {
query
.push(" ORDER BY ")
.push(id_field)
.push(" ASC LIMIT ")
.push_bind((count + 1) as i64);
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1`
// to the query
} else if let Some(count) = last {
query
.push(" ORDER BY ")
.push(id_field)
.push(" DESC LIMIT ")
.push_bind((count + 1) as i64);
} else {
return Err(InvalidPagination);
/// Creates a [`Pagination`] which gets the last N items
pub const fn last(last: usize) -> Self {
Self {
before: None,
after: None,
count: last,
direction: PaginationDirection::Backward,
}
}
Ok(())
/// Get items before the given cursor
pub const fn before(mut self, id: Ulid) -> Self {
self.before = Some(id);
self
}
/// Get items after the given cursor
pub const fn after(mut self, id: Ulid) -> Self {
self.after = Some(id);
self
}
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str)
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = self.after {
query
.push(" AND ")
.push(id_field)
.push(" > ")
.push_bind(Uuid::from(after));
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = self.before {
query
.push(" AND ")
.push(id_field)
.push(" < ")
.push_bind(Uuid::from(before));
}
match self.direction {
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
// query
PaginationDirection::Forward => {
query
.push(" ORDER BY ")
.push(id_field)
.push(" ASC LIMIT ")
.push_bind((self.count + 1) as i64);
}
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
// query
PaginationDirection::Backward => {
query
.push(" ORDER BY ")
.push(id_field)
.push(" DESC LIMIT ")
.push_bind((self.count + 1) as i64);
}
};
}
/// Process a page returned by a paginated query
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {
let is_full = edges.len() == (self.count + 1);
if is_full {
edges.pop();
}
let (has_previous_page, has_next_page) = match self.direction {
PaginationDirection::Forward => (false, is_full),
PaginationDirection::Backward => {
// 6. If the last argument is provided, I reverse the order of the results
edges.reverse();
(is_full, false)
}
};
Page {
has_next_page,
has_previous_page,
edges,
}
}
}
pub struct Page<T> {
@@ -89,39 +181,6 @@ pub struct Page<T> {
}
impl<T> Page<T> {
/// Process a page returned by a paginated query
pub fn process(
mut edges: Vec<T>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Self, InvalidPagination> {
let limit = match (first, last) {
(Some(count), _) | (_, Some(count)) => count,
_ => return Err(InvalidPagination),
};
let is_full = edges.len() == (limit + 1);
if is_full {
edges.pop();
}
let (has_previous_page, has_next_page) = if first.is_some() {
(false, is_full)
} else if last.is_some() {
// 6. If the last argument is provided, I reverse the order of the results
edges.reverse();
(is_full, false)
} else {
unreachable!()
};
Ok(Page {
has_next_page,
has_previous_page,
edges,
})
}
pub fn map<F, T2>(self, f: F) -> Page<T2>
where
F: FnMut(T) -> T2,
@@ -147,17 +206,13 @@ impl<T> Page<T> {
}
}
impl<T> Page<T> {}
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
/// to a query
pub trait QueryBuilderExt {
fn generate_pagination(
&mut self,
id_field: &'static str,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<&mut Self, InvalidPagination>;
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination(&mut self, id_field: &'static str, pagination: &Pagination)
-> &mut Self;
}
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
@@ -169,12 +224,9 @@ where
fn generate_pagination(
&mut self,
id_field: &'static str,
before: Option<Ulid>,
after: Option<Ulid>,
first: Option<usize>,
last: Option<usize>,
) -> Result<&mut Self, InvalidPagination> {
generate_pagination(self, id_field, before, after, first, last)?;
Ok(self)
pagination: &Pagination,
) -> &mut Self {
pagination.generate_pagination(self, id_field);
self
}
}