diff --git a/Cargo.lock b/Cargo.lock index ceef08df..67ab0d4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2597,10 +2597,15 @@ name = "mas-graphql" version = "0.1.0" dependencies = [ "async-graphql", + "chrono", "mas-axum-utils", + "mas-data-model", + "mas-storage", + "serde", "sqlx", "tokio", "tokio-stream", + "ulid", ] [[package]] diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index ddab414b..55b4cf22 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -6,12 +6,17 @@ edition = "2021" license = "Apache-2.0" [dependencies] -async-graphql = "4.0.16" +async-graphql = { version = "4.0.16", features = ["chrono"] } +chrono = "0.4.22" +serde = { version = "1.0.147", features = ["derive"] } sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } tokio = { version = "1.21.2", features = ["time"] } tokio-stream = "0.1.11" +ulid = "1.0.0" mas-axum-utils = { path = "../axum-utils" } +mas-data-model = { path = "../data-model" } +mas-storage = { path = "../storage" } [features] native-roots = ["mas-axum-utils/native-roots"] diff --git a/crates/graphql/schema.graphql b/crates/graphql/schema.graphql index 20acef86..ab9dbe80 100644 --- a/crates/graphql/schema.graphql +++ b/crates/graphql/schema.graphql @@ -1,3 +1,22 @@ +type Authentication { + id: ID! + createdAt: DateTime! +} + + +type BrowserSession { + id: ID! + user: User! + lastAuthentication: Authentication + createdAt: DateTime! +} + +""" +Implement the DateTime scalar + +The input/output is a string in RFC3339 format. +""" +scalar DateTime @@ -9,11 +28,31 @@ type Mutation { hello: Boolean! } +""" +Information about pagination in a connection +""" +type PageInfo { + """ + When paginating backwards, are there more items? + """ + hasPreviousPage: Boolean! + """ + When paginating forwards, are there more items? + """ + hasNextPage: Boolean! + """ + When paginating backwards, the cursor to continue. + """ + startCursor: String + """ + When paginating forwards, the cursor to continue. + """ + endCursor: String +} + type Query { - """ - A simple property which uses the DB pool and the current session - """ - username: String + currentSession: BrowserSession + currentUser: User } @@ -24,6 +63,50 @@ type Subscription { integers(step: Int! = 1): Int! } +type User { + id: ID! + username: String! + primaryEmail: UserEmail + emails(after: String, before: String, first: Int, last: Int): UserEmailConnection! +} + +type UserEmail { + id: ID! + email: String! + createdAt: DateTime! + confirmedAt: DateTime +} + +type UserEmailConnection { + """ + Information to aid in pagination. + """ + pageInfo: PageInfo! + """ + A list of edges. + """ + edges: [UserEmailEdge!]! + """ + A list of nodes. + """ + nodes: [UserEmail!]! + totalCount: Int! +} + +""" +An edge in a connection. +""" +type UserEmailEdge { + """ + A cursor for use in pagination + """ + cursor: String! + """ + The item at the end of the edge + """ + node: UserEmail! +} + schema { query: Query mutation: Mutation diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 5831b891..dab25bfe 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -29,6 +29,10 @@ use mas_axum_utils::SessionInfo; use sqlx::PgPool; use tokio_stream::{Stream, StreamExt}; +use self::model::{BrowserSession, User}; + +mod model; + pub type Schema = async_graphql::Schema; pub type SchemaBuilder = async_graphql::SchemaBuilder; @@ -51,14 +55,25 @@ impl Query { #[async_graphql::Object] impl Query { - /// A simple property which uses the DB pool and the current session - async fn username(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { + async fn current_session( + &self, + ctx: &Context<'_>, + ) -> Result, async_graphql::Error> { let database = ctx.data::()?; let session_info = ctx.data::()?; let mut conn = database.acquire().await?; let session = session_info.load_session(&mut conn).await?; - Ok(session.map(|s| s.user.username)) + Ok(session.map(BrowserSession::from)) + } + + async fn current_user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { + let database = ctx.data::()?; + let session_info = ctx.data::()?; + let mut conn = database.acquire().await?; + let session = session_info.load_session(&mut conn).await?; + + Ok(session.map(User::from)) } } diff --git a/crates/graphql/src/model.rs b/crates/graphql/src/model.rs new file mode 100644 index 00000000..cf464693 --- /dev/null +++ b/crates/graphql/src/model.rs @@ -0,0 +1,195 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_graphql::{ + connection::{query, Connection, Edge, OpaqueCursor}, + Context, Object, ID, +}; +use chrono::{DateTime, Utc}; +use mas_storage::PostgresqlBackend; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; +use ulid::Ulid; + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)] +#[serde(rename = "snake_case")] +enum NodeType { + User, + UserEmail, + BrowserSession, +} + +#[derive(Serialize, Deserialize, PartialEq, Eq)] +struct NodeCursor(NodeType, Ulid); + +impl NodeCursor { + fn extract_for_type(&self, node_type: NodeType) -> Result { + if self.0 == node_type { + Ok(self.1) + } else { + Err(async_graphql::Error::new("invalid cursor")) + } + } +} + +type Cursor = OpaqueCursor; + +pub struct BrowserSession(mas_data_model::BrowserSession); + +impl From> for BrowserSession { + fn from(v: mas_data_model::BrowserSession) -> Self { + Self(v) + } +} + +#[Object] +impl BrowserSession { + async fn id(&self) -> ID { + ID(self.0.data.to_string()) + } + + async fn user(&self) -> User { + User(self.0.user.clone()) + } + + async fn last_authentication(&self) -> Option { + self.0.last_authentication.clone().map(Authentication) + } + + async fn created_at(&self) -> DateTime { + self.0.created_at + } +} + +pub struct User(mas_data_model::User); + +impl From> for User { + fn from(v: mas_data_model::User) -> Self { + Self(v) + } +} + +impl From> for User { + fn from(v: mas_data_model::BrowserSession) -> Self { + Self(v.user) + } +} + +#[Object] +impl User { + async fn id(&self) -> ID { + ID(self.0.data.to_string()) + } + + async fn username(&self) -> &str { + &self.0.username + } + + async fn primary_email(&self) -> Option { + self.0.primary_email.clone().map(UserEmail) + } + + async fn emails( + &self, + ctx: &Context<'_>, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> Result, async_graphql::Error> { + let database = ctx.data::()?; + + query( + after, + before, + first, + last, + |after, before, first, last| async move { + let mut conn = database.acquire().await?; + let after_id = after + .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) + .transpose()?; + let before_id = before + .map(|x: OpaqueCursor| x.extract_for_type(NodeType::UserEmail)) + .transpose()?; + + let (has_previous_page, has_next_page, edges) = + mas_storage::user::get_paginated_user_emails( + &mut conn, &self.0, before_id, after_id, first, last, + ) + .await?; + + let mut connection = Connection::with_additional_fields( + has_previous_page, + has_next_page, + UserEmailsPagination(self.0.clone()), + ); + connection.edges.extend(edges.into_iter().map(|u| { + Edge::new( + OpaqueCursor(NodeCursor(NodeType::UserEmail, u.data)), + UserEmail(u), + ) + })); + + Ok::<_, async_graphql::Error>(connection) + }, + ) + .await + } +} + +pub struct Authentication(mas_data_model::Authentication); + +#[Object] +impl Authentication { + async fn id(&self) -> ID { + ID(self.0.data.to_string()) + } + + async fn created_at(&self) -> DateTime { + self.0.created_at + } +} + +pub struct UserEmail(mas_data_model::UserEmail); + +#[Object] +impl UserEmail { + async fn id(&self) -> ID { + ID(self.0.data.to_string()) + } + + async fn email(&self) -> &str { + &self.0.email + } + + async fn created_at(&self) -> DateTime { + self.0.created_at + } + + async fn confirmed_at(&self) -> Option> { + self.0.confirmed_at + } +} + +pub struct UserEmailsPagination(mas_data_model::User); + +#[Object] +impl UserEmailsPagination { + async fn total_count(&self, ctx: &Context<'_>) -> Result { + let mut conn = ctx.data::()?.acquire().await?; + let count = mas_storage::user::count_user_emails(&mut conn, &self.0).await?; + Ok(count) + } +} diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 335387b5..cd53c73b 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1578,6 +1578,26 @@ }, "query": "\n UPDATE oauth2_refresh_tokens\n SET consumed_at = $2\n WHERE oauth2_refresh_token_id = $1\n " }, + "89e0d338348588831a7a810763a1901073f7a7cb81d51c18bb987a5be10c1202": { + "describe": { + "columns": [ + { + "name": "count", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n " + }, "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { "describe": { "columns": [], diff --git a/crates/storage/src/user.rs b/crates/storage/src/user.rs index 63282ca5..e65a3970 100644 --- a/crates/storage/src/user.rs +++ b/crates/storage/src/user.rs @@ -23,7 +23,7 @@ use mas_data_model::{ }; use password_hash::{PasswordHash, PasswordHasher, SaltString}; use rand::{CryptoRng, Rng}; -use sqlx::{Acquire, PgExecutor, Postgres, Transaction}; +use sqlx::{postgres::PgArguments, Acquire, Arguments, PgExecutor, Postgres, Transaction}; use thiserror::Error; use tokio::task; use tracing::{info_span, Instrument}; @@ -590,7 +590,7 @@ pub async fn username_exists( .await } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, sqlx::FromRow)] struct UserEmailLookup { user_email_id: Uuid, user_email: String, @@ -641,6 +641,126 @@ pub async fn get_user_emails( Ok(res.into_iter().map(Into::into).collect()) } +#[tracing::instrument( + skip_all, + fields(user.id = %user.data, user.username = user.username), + err(Display), +)] +pub async fn count_user_emails( + executor: impl PgExecutor<'_>, + user: &User, +) -> Result { + let res = sqlx::query_scalar!( + r#" + SELECT COUNT(*) + FROM user_emails ue + WHERE ue.user_id = $1 + "#, + Uuid::from(user.data), + ) + .fetch_one(executor) + .instrument(info_span!("Count user emails")) + .await?; + + Ok(res.unwrap_or_default()) +} + +#[tracing::instrument( + skip_all, + fields( + user.id = %user.data, + user.username = user.username, + ), + err(Display), +)] +pub async fn get_paginated_user_emails( + executor: impl PgExecutor<'_>, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, +) -> Result<(bool, bool, Vec>), anyhow::Error> { + // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 + // 1. Start from the greedy query: SELECT * FROM table + let mut query = String::from( + r#" + SELECT + ue.user_email_id, + ue.email AS "user_email", + ue.created_at AS "user_email_created_at", + ue.confirmed_at AS "user_email_confirmed_at" + FROM user_emails ue + "#, + ); + + let mut arguments = PgArguments::default(); + + query += " WHERE ue.user_id = "; + arguments.add(Uuid::from(user.data)); + arguments.format_placeholder(&mut query)?; + + // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` + // clause + if let Some(after) = after { + query += " AND ue.user_email_id > "; + arguments.add(Uuid::from(after)); + arguments.format_placeholder(&mut query)?; + } + + // 3. If the before argument is provided, add `id < parsed_cursor` to the + // `WHERE` clause + if let Some(before) = before { + query += " AND ue.user_email_id < "; + arguments.add(Uuid::from(before)); + arguments.format_placeholder(&mut query)?; + } + + // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to + // the query + let limit = if let Some(count) = first { + query += " ORDER BY ue.user_email_id ASC LIMIT "; + arguments.add((count + 1) as i64); + arguments.format_placeholder(&mut query)?; + count + // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` + // to the query + } else if let Some(count) = last { + query += " ORDER BY ue.user_email_id DESC LIMIT "; + arguments.add((count + 1) as i64); + arguments.format_placeholder(&mut query)?; + count + } else { + bail!("Either 'first' or 'last' must be specified"); + }; + + let mut res: Vec = sqlx::query_as_with(&query, arguments) + .fetch_all(executor) + .instrument(info_span!("Fetch paginated user emails", query = query)) + .await?; + + let is_full = res.len() == (limit + 1); + if is_full { + res.pop(); + } + + let (has_previous_page, has_next_page) = if first.is_some() { + (false, is_full) + } else if last.is_some() { + // 5. If the last argument is provided, I reverse the order of the results + res.reverse(); + (is_full, false) + } else { + unreachable!() + }; + + Ok(( + has_previous_page, + has_next_page, + res.into_iter().map(Into::into).collect(), + )) +} + #[tracing::instrument( skip_all, fields(