You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
storage: simplify the paginated queries
This commit is contained in:
@ -34,7 +34,7 @@ use mas_storage::{
|
|||||||
oauth2::OAuth2ClientRepository,
|
oauth2::OAuth2ClientRepository,
|
||||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
|
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
|
||||||
user::{BrowserSessionRepository, UserEmailRepository},
|
user::{BrowserSessionRepository, UserEmailRepository},
|
||||||
PgRepository, Repository,
|
Pagination, PgRepository, Repository,
|
||||||
};
|
};
|
||||||
use model::CreationEvent;
|
use model::CreationEvent;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
@ -228,10 +228,11 @@ impl RootQuery {
|
|||||||
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
|
x.extract_for_type(NodeType::UpstreamOAuth2Provider)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_provider()
|
.upstream_oauth_provider()
|
||||||
.list_paginated(before_id, after_id, first, last)
|
.list_paginated(&pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||||
|
@ -22,7 +22,7 @@ use mas_storage::{
|
|||||||
oauth2::OAuth2SessionRepository,
|
oauth2::OAuth2SessionRepository,
|
||||||
upstream_oauth2::UpstreamOAuthLinkRepository,
|
upstream_oauth2::UpstreamOAuthLinkRepository,
|
||||||
user::{BrowserSessionRepository, UserEmailRepository},
|
user::{BrowserSessionRepository, UserEmailRepository},
|
||||||
PgRepository, Repository,
|
Pagination, PgRepository, Repository,
|
||||||
};
|
};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
@ -95,10 +95,11 @@ impl User {
|
|||||||
let before_id = before
|
let before_id = before
|
||||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
|
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::CompatSsoLogin))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.compat_sso_login()
|
.compat_sso_login()
|
||||||
.list_paginated(&self.0, before_id, after_id, first, last)
|
.list_paginated(&self.0, &pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||||
@ -141,10 +142,11 @@ impl User {
|
|||||||
let before_id = before
|
let before_id = before
|
||||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
|
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::BrowserSession))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.browser_session()
|
.browser_session()
|
||||||
.list_active_paginated(&self.0, before_id, after_id, first, last)
|
.list_active_paginated(&self.0, &pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||||
@ -187,10 +189,11 @@ impl User {
|
|||||||
let before_id = before
|
let before_id = before
|
||||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
|
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::UserEmail))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.user_email()
|
.user_email()
|
||||||
.list_paginated(&self.0, before_id, after_id, first, last)
|
.list_paginated(&self.0, &pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut connection = Connection::with_additional_fields(
|
let mut connection = Connection::with_additional_fields(
|
||||||
@ -237,10 +240,11 @@ impl User {
|
|||||||
let before_id = before
|
let before_id = before
|
||||||
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
|
.map(|x: OpaqueCursor<NodeCursor>| x.extract_for_type(NodeType::OAuth2Session))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.oauth2_session()
|
.oauth2_session()
|
||||||
.list_paginated(&self.0, before_id, after_id, first, last)
|
.list_paginated(&self.0, &pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||||
@ -287,10 +291,11 @@ impl User {
|
|||||||
x.extract_for_type(NodeType::UpstreamOAuth2Link)
|
x.extract_for_type(NodeType::UpstreamOAuth2Link)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
|
||||||
|
|
||||||
let page = repo
|
let page = repo
|
||||||
.upstream_oauth_link()
|
.upstream_oauth_link()
|
||||||
.list_paginated(&self.0, before_id, after_id, first, last)
|
.list_paginated(&self.0, &pagination)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
|
||||||
|
@ -24,7 +24,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{Page, QueryBuilderExt},
|
pagination::{Page, QueryBuilderExt},
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -68,10 +68,7 @@ pub trait CompatSsoLoginRepository: Send + Sync {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<CompatSsoLogin>, Self::Error>;
|
) -> Result<Page<CompatSsoLogin>, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,10 +351,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<CompatSsoLogin>, Self::Error> {
|
) -> Result<Page<CompatSsoLogin>, Self::Error> {
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
r#"
|
r#"
|
||||||
@ -377,7 +371,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
|
|||||||
query
|
query
|
||||||
.push(" WHERE user_id = ")
|
.push(" WHERE user_id = ")
|
||||||
.push_bind(Uuid::from(user.id))
|
.push_bind(Uuid::from(user.id))
|
||||||
.generate_pagination("cl.compat_sso_login_id", before, after, first, last)?;
|
.generate_pagination("cl.compat_sso_login_id", &pagination);
|
||||||
|
|
||||||
let edges: Vec<CompatSsoLoginLookup> = query
|
let edges: Vec<CompatSsoLoginLookup> = query
|
||||||
.build_query_as()
|
.build_query_as()
|
||||||
@ -385,7 +379,9 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
|
|||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = Page::process(edges, first, last)?.try_map(CompatSsoLogin::try_from)?;
|
let page = pagination
|
||||||
|
.process(edges)
|
||||||
|
.try_map(CompatSsoLogin::try_from)?;
|
||||||
Ok(page)
|
Ok(page)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,7 +253,10 @@ pub(crate) mod tracing;
|
|||||||
pub mod upstream_oauth2;
|
pub mod upstream_oauth2;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|
||||||
pub use self::repository::{PgRepository, Repository};
|
pub use self::{
|
||||||
|
pagination::Pagination,
|
||||||
|
repository::{PgRepository, Repository},
|
||||||
|
};
|
||||||
|
|
||||||
/// Embedded migrations, allowing them to run on startup
|
/// Embedded migrations, allowing them to run on startup
|
||||||
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
||||||
|
@ -23,7 +23,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{Page, QueryBuilderExt},
|
pagination::{Page, QueryBuilderExt},
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -45,10 +45,7 @@ pub trait OAuth2SessionRepository: Send + Sync {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<Session>, Self::Error>;
|
) -> Result<Page<Session>, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,10 +240,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<Session>, Self::Error> {
|
) -> Result<Page<Session>, Self::Error> {
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
r#"
|
r#"
|
||||||
@ -263,7 +257,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
query
|
query
|
||||||
.push(" WHERE us.user_id = ")
|
.push(" WHERE us.user_id = ")
|
||||||
.push_bind(Uuid::from(user.id))
|
.push_bind(Uuid::from(user.id))
|
||||||
.generate_pagination("oauth2_session_id", before, after, first, last)?;
|
.generate_pagination("oauth2_session_id", pagination);
|
||||||
|
|
||||||
let edges: Vec<OAuthSessionLookup> = query
|
let edges: Vec<OAuthSessionLookup> = query
|
||||||
.build_query_as()
|
.build_query_as()
|
||||||
@ -271,7 +265,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
|
|||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = Page::process(edges, first, last)?.try_map(Session::try_from)?;
|
let page = pagination.process(edges).try_map(Session::try_from)?;
|
||||||
Ok(page)
|
Ok(page)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
// Copyright 2022, 2023 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -12,74 +12,166 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! Utilities to manage paginated queries.
|
||||||
|
|
||||||
use sqlx::{Database, QueryBuilder};
|
use sqlx::{Database, QueryBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
/// An error returned when invalid pagination parameters are provided
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
#[error("Either 'first' or 'last' must be specified")]
|
#[error("Either 'first' or 'last' must be specified")]
|
||||||
pub struct InvalidPagination;
|
pub struct InvalidPagination;
|
||||||
|
|
||||||
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
/// connections
|
pub struct Pagination {
|
||||||
pub fn generate_pagination<'a, DB>(
|
|
||||||
query: &mut QueryBuilder<'a, DB>,
|
|
||||||
id_field: &'static str,
|
|
||||||
before: Option<Ulid>,
|
before: Option<Ulid>,
|
||||||
after: Option<Ulid>,
|
after: Option<Ulid>,
|
||||||
first: Option<usize>,
|
count: usize,
|
||||||
last: Option<usize>,
|
direction: PaginationDirection,
|
||||||
) -> Result<(), InvalidPagination>
|
}
|
||||||
where
|
|
||||||
DB: Database,
|
|
||||||
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
|
||||||
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
|
||||||
{
|
|
||||||
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
|
||||||
// 1. Start from the greedy query: SELECT * FROM table
|
|
||||||
|
|
||||||
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
// clause
|
enum PaginationDirection {
|
||||||
if let Some(after) = after {
|
Forward,
|
||||||
query
|
Backward,
|
||||||
.push(" AND ")
|
}
|
||||||
.push(id_field)
|
|
||||||
.push(" > ")
|
impl Pagination {
|
||||||
.push_bind(Uuid::from(after));
|
/// Creates a new [`Pagination`] from user-provided parameters.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Either `first` or `last` must be provided, else this function will
|
||||||
|
/// return an [`InvalidPagination`] error.
|
||||||
|
pub const fn try_new(
|
||||||
|
before: Option<Ulid>,
|
||||||
|
after: Option<Ulid>,
|
||||||
|
first: Option<usize>,
|
||||||
|
last: Option<usize>,
|
||||||
|
) -> Result<Self, InvalidPagination> {
|
||||||
|
let (direction, count) = match (first, last) {
|
||||||
|
(Some(first), _) => (PaginationDirection::Forward, first),
|
||||||
|
(_, Some(last)) => (PaginationDirection::Backward, last),
|
||||||
|
(None, None) => return Err(InvalidPagination),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
before,
|
||||||
|
after,
|
||||||
|
count,
|
||||||
|
direction,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
/// Creates a [`Pagination`] which gets the first N items
|
||||||
// `WHERE` clause
|
pub const fn first(first: usize) -> Self {
|
||||||
if let Some(before) = before {
|
Self {
|
||||||
query
|
before: None,
|
||||||
.push(" AND ")
|
after: None,
|
||||||
.push(id_field)
|
count: first,
|
||||||
.push(" < ")
|
direction: PaginationDirection::Forward,
|
||||||
.push_bind(Uuid::from(before));
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to
|
/// Creates a [`Pagination`] which gets the last N items
|
||||||
// the query
|
pub const fn last(last: usize) -> Self {
|
||||||
if let Some(count) = first {
|
Self {
|
||||||
query
|
before: None,
|
||||||
.push(" ORDER BY ")
|
after: None,
|
||||||
.push(id_field)
|
count: last,
|
||||||
.push(" ASC LIMIT ")
|
direction: PaginationDirection::Backward,
|
||||||
.push_bind((count + 1) as i64);
|
}
|
||||||
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1`
|
|
||||||
// to the query
|
|
||||||
} else if let Some(count) = last {
|
|
||||||
query
|
|
||||||
.push(" ORDER BY ")
|
|
||||||
.push(id_field)
|
|
||||||
.push(" DESC LIMIT ")
|
|
||||||
.push_bind((count + 1) as i64);
|
|
||||||
} else {
|
|
||||||
return Err(InvalidPagination);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
/// Get items before the given cursor
|
||||||
|
pub const fn before(mut self, id: Ulid) -> Self {
|
||||||
|
self.before = Some(id);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get items after the given cursor
|
||||||
|
pub const fn after(mut self, id: Ulid) -> Self {
|
||||||
|
self.after = Some(id);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||||
|
/// connections
|
||||||
|
fn generate_pagination<'a, DB>(&self, query: &mut QueryBuilder<'a, DB>, id_field: &'static str)
|
||||||
|
where
|
||||||
|
DB: Database,
|
||||||
|
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||||
|
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
|
||||||
|
{
|
||||||
|
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
|
||||||
|
// 1. Start from the greedy query: SELECT * FROM table
|
||||||
|
|
||||||
|
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
|
||||||
|
// clause
|
||||||
|
if let Some(after) = self.after {
|
||||||
|
query
|
||||||
|
.push(" AND ")
|
||||||
|
.push(id_field)
|
||||||
|
.push(" > ")
|
||||||
|
.push_bind(Uuid::from(after));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. If the before argument is provided, add `id < parsed_cursor` to the
|
||||||
|
// `WHERE` clause
|
||||||
|
if let Some(before) = self.before {
|
||||||
|
query
|
||||||
|
.push(" AND ")
|
||||||
|
.push(id_field)
|
||||||
|
.push(" < ")
|
||||||
|
.push_bind(Uuid::from(before));
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.direction {
|
||||||
|
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
|
||||||
|
// query
|
||||||
|
PaginationDirection::Forward => {
|
||||||
|
query
|
||||||
|
.push(" ORDER BY ")
|
||||||
|
.push(id_field)
|
||||||
|
.push(" ASC LIMIT ")
|
||||||
|
.push_bind((self.count + 1) as i64);
|
||||||
|
}
|
||||||
|
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
|
||||||
|
// query
|
||||||
|
PaginationDirection::Backward => {
|
||||||
|
query
|
||||||
|
.push(" ORDER BY ")
|
||||||
|
.push(id_field)
|
||||||
|
.push(" DESC LIMIT ")
|
||||||
|
.push_bind((self.count + 1) as i64);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a page returned by a paginated query
|
||||||
|
pub fn process<T>(&self, mut edges: Vec<T>) -> Page<T> {
|
||||||
|
let is_full = edges.len() == (self.count + 1);
|
||||||
|
if is_full {
|
||||||
|
edges.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
let (has_previous_page, has_next_page) = match self.direction {
|
||||||
|
PaginationDirection::Forward => (false, is_full),
|
||||||
|
PaginationDirection::Backward => {
|
||||||
|
// 6. If the last argument is provided, I reverse the order of the results
|
||||||
|
edges.reverse();
|
||||||
|
(is_full, false)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Page {
|
||||||
|
has_next_page,
|
||||||
|
has_previous_page,
|
||||||
|
edges,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Page<T> {
|
pub struct Page<T> {
|
||||||
@ -89,39 +181,6 @@ pub struct Page<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Page<T> {
|
impl<T> Page<T> {
|
||||||
/// Process a page returned by a paginated query
|
|
||||||
pub fn process(
|
|
||||||
mut edges: Vec<T>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Self, InvalidPagination> {
|
|
||||||
let limit = match (first, last) {
|
|
||||||
(Some(count), _) | (_, Some(count)) => count,
|
|
||||||
_ => return Err(InvalidPagination),
|
|
||||||
};
|
|
||||||
|
|
||||||
let is_full = edges.len() == (limit + 1);
|
|
||||||
if is_full {
|
|
||||||
edges.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
let (has_previous_page, has_next_page) = if first.is_some() {
|
|
||||||
(false, is_full)
|
|
||||||
} else if last.is_some() {
|
|
||||||
// 6. If the last argument is provided, I reverse the order of the results
|
|
||||||
edges.reverse();
|
|
||||||
(is_full, false)
|
|
||||||
} else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Page {
|
|
||||||
has_next_page,
|
|
||||||
has_previous_page,
|
|
||||||
edges,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn map<F, T2>(self, f: F) -> Page<T2>
|
pub fn map<F, T2>(self, f: F) -> Page<T2>
|
||||||
where
|
where
|
||||||
F: FnMut(T) -> T2,
|
F: FnMut(T) -> T2,
|
||||||
@ -147,17 +206,13 @@ impl<T> Page<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Page<T> {}
|
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
|
||||||
|
/// to a query
|
||||||
pub trait QueryBuilderExt {
|
pub trait QueryBuilderExt {
|
||||||
fn generate_pagination(
|
/// Add cursor-based pagination to a query, as used in paginated GraphQL
|
||||||
&mut self,
|
/// connections
|
||||||
id_field: &'static str,
|
fn generate_pagination(&mut self, id_field: &'static str, pagination: &Pagination)
|
||||||
before: Option<Ulid>,
|
-> &mut Self;
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<&mut Self, InvalidPagination>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
|
||||||
@ -169,12 +224,9 @@ where
|
|||||||
fn generate_pagination(
|
fn generate_pagination(
|
||||||
&mut self,
|
&mut self,
|
||||||
id_field: &'static str,
|
id_field: &'static str,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
) -> &mut Self {
|
||||||
first: Option<usize>,
|
pagination.generate_pagination(self, id_field);
|
||||||
last: Option<usize>,
|
self
|
||||||
) -> Result<&mut Self, InvalidPagination> {
|
|
||||||
generate_pagination(self, id_field, before, after, first, last)?;
|
|
||||||
Ok(self)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{Page, QueryBuilderExt},
|
pagination::{Page, QueryBuilderExt},
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
Clock, DatabaseError, LookupResultExt,
|
Clock, DatabaseError, LookupResultExt, Pagination,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -60,10 +60,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
|
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,10 +272,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
|
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
r#"
|
r#"
|
||||||
@ -295,7 +289,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
|||||||
query
|
query
|
||||||
.push(" WHERE user_id = ")
|
.push(" WHERE user_id = ")
|
||||||
.push_bind(Uuid::from(user.id))
|
.push_bind(Uuid::from(user.id))
|
||||||
.generate_pagination("upstream_oauth_link_id", before, after, first, last)?;
|
.generate_pagination("upstream_oauth_link_id", pagination);
|
||||||
|
|
||||||
let edges: Vec<LinkLookup> = query
|
let edges: Vec<LinkLookup> = query
|
||||||
.build_query_as()
|
.build_query_as()
|
||||||
@ -303,7 +297,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
|
|||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = Page::process(edges, first, last)?.map(UpstreamOAuthLink::from);
|
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
|
||||||
Ok(page)
|
Ok(page)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ mod tests {
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{user::UserRepository, Clock, PgRepository, Repository};
|
use crate::{user::UserRepository, Clock, Pagination, PgRepository, Repository};
|
||||||
|
|
||||||
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
#[sqlx::test(migrator = "crate::MIGRATOR")]
|
||||||
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
|
async fn test_repository(pool: PgPool) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@ -144,7 +144,7 @@ mod tests {
|
|||||||
|
|
||||||
let links = repo
|
let links = repo
|
||||||
.upstream_oauth_link()
|
.upstream_oauth_link()
|
||||||
.list_paginated(&user, None, None, Some(10), None)
|
.list_paginated(&user, &Pagination::first(10))
|
||||||
.await?;
|
.await?;
|
||||||
assert!(!links.has_previous_page);
|
assert!(!links.has_previous_page);
|
||||||
assert!(!links.has_next_page);
|
assert!(!links.has_next_page);
|
||||||
|
@ -25,7 +25,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{Page, QueryBuilderExt},
|
pagination::{Page, QueryBuilderExt},
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -52,10 +52,7 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
|
|||||||
/// Get a paginated list of upstream OAuth providers
|
/// Get a paginated list of upstream OAuth providers
|
||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
|
) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
|
||||||
|
|
||||||
/// Get all upstream OAuth providers
|
/// Get all upstream OAuth providers
|
||||||
@ -243,10 +240,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
)]
|
)]
|
||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
r#"
|
r#"
|
||||||
@ -264,7 +258,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
"#,
|
"#,
|
||||||
);
|
);
|
||||||
|
|
||||||
query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?;
|
query.generate_pagination("upstream_oauth_provider_id", pagination);
|
||||||
|
|
||||||
let edges: Vec<ProviderLookup> = query
|
let edges: Vec<ProviderLookup> = query
|
||||||
.build_query_as()
|
.build_query_as()
|
||||||
@ -272,7 +266,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
|||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = Page::process(edges, first, last)?.try_map(TryInto::try_into)?;
|
let page = pagination.process(edges).try_map(TryInto::try_into)?;
|
||||||
Ok(page)
|
Ok(page)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{Page, QueryBuilderExt},
|
pagination::{Page, QueryBuilderExt},
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -39,10 +39,7 @@ pub trait UserEmailRepository: Send + Sync {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<UserEmail>, Self::Error>;
|
) -> Result<Page<UserEmail>, Self::Error>;
|
||||||
async fn count(&mut self, user: &User) -> Result<usize, Self::Error>;
|
async fn count(&mut self, user: &User) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
@ -289,10 +286,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
|
|||||||
async fn list_paginated(
|
async fn list_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<UserEmail>, DatabaseError> {
|
) -> Result<Page<UserEmail>, DatabaseError> {
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
r#"
|
r#"
|
||||||
@ -308,7 +302,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
|
|||||||
query
|
query
|
||||||
.push(" WHERE user_id = ")
|
.push(" WHERE user_id = ")
|
||||||
.push_bind(Uuid::from(user.id))
|
.push_bind(Uuid::from(user.id))
|
||||||
.generate_pagination("ue.user_email_id", before, after, first, last)?;
|
.generate_pagination("ue.user_email_id", &pagination);
|
||||||
|
|
||||||
let edges: Vec<UserEmailLookup> = query
|
let edges: Vec<UserEmailLookup> = query
|
||||||
.build_query_as()
|
.build_query_as()
|
||||||
@ -316,7 +310,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
|
|||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = Page::process(edges, first, last)?.map(UserEmail::from);
|
let page = pagination.process(edges).map(UserEmail::from);
|
||||||
Ok(page)
|
Ok(page)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ use uuid::Uuid;
|
|||||||
use crate::{
|
use crate::{
|
||||||
pagination::{Page, QueryBuilderExt},
|
pagination::{Page, QueryBuilderExt},
|
||||||
tracing::ExecuteExt,
|
tracing::ExecuteExt,
|
||||||
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Pagination,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -45,10 +45,7 @@ pub trait BrowserSessionRepository: Send + Sync {
|
|||||||
async fn list_active_paginated(
|
async fn list_active_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<BrowserSession>, Self::Error>;
|
) -> Result<Page<BrowserSession>, Self::Error>;
|
||||||
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error>;
|
async fn count_active(&mut self, user: &User) -> Result<usize, Self::Error>;
|
||||||
|
|
||||||
@ -264,10 +261,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
|||||||
async fn list_active_paginated(
|
async fn list_active_paginated(
|
||||||
&mut self,
|
&mut self,
|
||||||
user: &User,
|
user: &User,
|
||||||
before: Option<Ulid>,
|
pagination: &Pagination,
|
||||||
after: Option<Ulid>,
|
|
||||||
first: Option<usize>,
|
|
||||||
last: Option<usize>,
|
|
||||||
) -> Result<Page<BrowserSession>, Self::Error> {
|
) -> Result<Page<BrowserSession>, Self::Error> {
|
||||||
// TODO: ordering of last authentication is wrong
|
// TODO: ordering of last authentication is wrong
|
||||||
let mut query = QueryBuilder::new(
|
let mut query = QueryBuilder::new(
|
||||||
@ -290,7 +284,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
|||||||
query
|
query
|
||||||
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
|
.push(" WHERE s.finished_at IS NULL AND s.user_id = ")
|
||||||
.push_bind(Uuid::from(user.id))
|
.push_bind(Uuid::from(user.id))
|
||||||
.generate_pagination("s.user_session_id", before, after, first, last)?;
|
.generate_pagination("s.user_session_id", pagination);
|
||||||
|
|
||||||
let edges: Vec<SessionLookup> = query
|
let edges: Vec<SessionLookup> = query
|
||||||
.build_query_as()
|
.build_query_as()
|
||||||
@ -298,7 +292,9 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
|
|||||||
.fetch_all(&mut *self.conn)
|
.fetch_all(&mut *self.conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let page = Page::process(edges, first, last)?.try_map(BrowserSession::try_from)?;
|
let page = pagination
|
||||||
|
.process(edges)
|
||||||
|
.try_map(BrowserSession::try_from)?;
|
||||||
Ok(page)
|
Ok(page)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user