You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
storage: simplify the paginated queries
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,74 +12,166 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utilities to manage paginated queries.
|
||||
|
||||
use sqlx::{Database, QueryBuilder};
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// An error returned when invalid pagination parameters are provided
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Either 'first' or 'last' must be specified")]
|
||||
pub struct InvalidPagination;
|
||||
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
pub fn generate_pagination<'a, DB>(
|
||||
query: &mut QueryBuilder<'a, DB>,
|
||||
id_field: &'static str,
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Pagination {
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<(), InvalidPagination>
|
||||
where
|
||||
DB: Database,
|
||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
{
|
||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||
// 1. Start from the greedy query: SELECT * FROM table
|
||||
count: usize,
|
||||
direction: PaginationDirection,
|
||||
}
|
||||
|
||||
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
||||
// clause
|
||||
if let Some(after) = after {
|
||||
query
|
||||
.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" > ")
|
||||
.push_bind(Uuid::from(after));
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum PaginationDirection {
|
||||
Forward,
|
||||
Backward,
|
||||
}
|
||||
|
||||
impl Pagination {
|
||||
/// Creates a new [`Pagination`] from user-provided parameters.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Either `first` or `last` must be provided, else this function will
|
||||
/// return an [`InvalidPagination`] error.
|
||||
pub const fn try_new(
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<Self, InvalidPagination> {
|
||||
let (direction, count) = match (first, last) {
|
||||
(Some(first), _) => (PaginationDirection::Forward, first),
|
||||
(_, Some(last)) => (PaginationDirection::Backward, last),
|
||||
(None, None) => return Err(InvalidPagination),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
before,
|
||||
after,
|
||||
count,
|
||||
direction,
|
||||
})
|
||||
}
|
||||
|
||||
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
||||
// `WHERE` clause
|
||||
if let Some(before) = before {
|
||||
query
|
||||
.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" < ")
|
||||
.push_bind(Uuid::from(before));
|
||||
/// Creates a [`Pagination`] which gets the first N items
|
||||
pub const fn first(first: usize) -> Self {
|
||||
Self {
|
||||
before: None,
|
||||
after: None,
|
||||
count: first,
|
||||
direction: PaginationDirection::Forward,
|
||||
}
|
||||
}
|
||||
|
||||
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to
|
||||
// the query
|
||||
if let Some(count) = first {
|
||||
query
|
||||
.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" ASC LIMIT ")
|
||||
.push_bind((count + 1) as i64);
|
||||
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1`
|
||||
// to the query
|
||||
} else if let Some(count) = last {
|
||||
query
|
||||
.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" DESC LIMIT ")
|
||||
.push_bind((count + 1) as i64);
|
||||
} else {
|
||||
return Err(InvalidPagination);
|
||||
/// Creates a [`Pagination`] which gets the last N items
|
||||
pub const fn last(last: usize) -> Self {
|
||||
Self {
|
||||
before: None,
|
||||
after: None,
|
||||
count: last,
|
||||
direction: PaginationDirection::Backward,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
/// Get items before the given cursor
|
||||
pub const fn before(mut self, id: Ulid) -> Self {
|
||||
self.before = Some(id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get items after the given cursor
|
||||
pub const fn after(mut self, id: Ulid) -> Self {
|
||||
self.after = Some(id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str)
|
||||
where
|
||||
DB: Database,
|
||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||
{
|
||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||
// 1. Start from the greedy query: SELECT * FROM table
|
||||
|
||||
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
||||
// clause
|
||||
if let Some(after) = self.after {
|
||||
query
|
||||
.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" > ")
|
||||
.push_bind(Uuid::from(after));
|
||||
}
|
||||
|
||||
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
||||
// `WHERE` clause
|
||||
if let Some(before) = self.before {
|
||||
query
|
||||
.push(" AND ")
|
||||
.push(id_field)
|
||||
.push(" < ")
|
||||
.push_bind(Uuid::from(before));
|
||||
}
|
||||
|
||||
match self.direction {
|
||||
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
|
||||
// query
|
||||
PaginationDirection::Forward => {
|
||||
query
|
||||
.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" ASC LIMIT ")
|
||||
.push_bind((self.count + 1) as i64);
|
||||
}
|
||||
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
|
||||
// query
|
||||
PaginationDirection::Backward => {
|
||||
query
|
||||
.push(" ORDER BY ")
|
||||
.push(id_field)
|
||||
.push(" DESC LIMIT ")
|
||||
.push_bind((self.count + 1) as i64);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Process a page returned by a paginated query
|
||||
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {
|
||||
let is_full = edges.len() == (self.count + 1);
|
||||
if is_full {
|
||||
edges.pop();
|
||||
}
|
||||
|
||||
let (has_previous_page, has_next_page) = match self.direction {
|
||||
PaginationDirection::Forward => (false, is_full),
|
||||
PaginationDirection::Backward => {
|
||||
// 6. If the last argument is provided, I reverse the order of the results
|
||||
edges.reverse();
|
||||
(is_full, false)
|
||||
}
|
||||
};
|
||||
|
||||
Page {
|
||||
has_next_page,
|
||||
has_previous_page,
|
||||
edges,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Page<T> {
|
||||
@@ -89,39 +181,6 @@ pub struct Page<T> {
|
||||
}
|
||||
|
||||
impl<T> Page<T> {
|
||||
/// Process a page returned by a paginated query
|
||||
pub fn process(
|
||||
mut edges: Vec<T>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<Self, InvalidPagination> {
|
||||
let limit = match (first, last) {
|
||||
(Some(count), _) | (_, Some(count)) => count,
|
||||
_ => return Err(InvalidPagination),
|
||||
};
|
||||
|
||||
let is_full = edges.len() == (limit + 1);
|
||||
if is_full {
|
||||
edges.pop();
|
||||
}
|
||||
|
||||
let (has_previous_page, has_next_page) = if first.is_some() {
|
||||
(false, is_full)
|
||||
} else if last.is_some() {
|
||||
// 6. If the last argument is provided, I reverse the order of the results
|
||||
edges.reverse();
|
||||
(is_full, false)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
Ok(Page {
|
||||
has_next_page,
|
||||
has_previous_page,
|
||||
edges,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map<F, T2>(self, f: F) -> Page<T2>
|
||||
where
|
||||
F: FnMut(T) -> T2,
|
||||
@@ -147,17 +206,13 @@ impl<T> Page<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Page<T> {}
|
||||
|
||||
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
|
||||
/// to a query
|
||||
pub trait QueryBuilderExt {
|
||||
fn generate_pagination(
|
||||
&mut self,
|
||||
id_field: &'static str,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<&mut Self, InvalidPagination>;
|
||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||
/// connections
|
||||
fn generate_pagination(&mut self, id_field: &'static str, pagination: &Pagination)
|
||||
-> &mut Self;
|
||||
}
|
||||
|
||||
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
||||
@@ -169,12 +224,9 @@ where
|
||||
fn generate_pagination(
|
||||
&mut self,
|
||||
id_field: &'static str,
|
||||
before: Option<Ulid>,
|
||||
after: Option<Ulid>,
|
||||
first: Option<usize>,
|
||||
last: Option<usize>,
|
||||
) -> Result<&mut Self, InvalidPagination> {
|
||||
generate_pagination(self, id_field, before, after, first, last)?;
|
||||
Ok(self)
|
||||
pagination: &Pagination,
|
||||
) -> &mut Self {
|
||||
pagination.generate_pagination(self, id_field);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user