From be765fe04f55b9bdb9983210f1daf8d8559449d9 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 20 Apr 2023 18:03:08 +0200 Subject: [PATCH] Setup GraphQL mutations to add and verify email addresses This refactors a bit how the connection to the repository is done in the graphql handler, so that we can properly commit transactions. --- Cargo.lock | 2 +- crates/cli/src/commands/server.rs | 2 +- crates/graphql/Cargo.toml | 4 +- crates/graphql/src/lib.rs | 271 +---------------- crates/graphql/src/model/compat_sessions.rs | 13 +- crates/graphql/src/model/oauth.rs | 21 +- crates/graphql/src/model/upstream_oauth.rs | 20 +- crates/graphql/src/model/users.rs | 41 ++- crates/graphql/src/mutations.rs | 169 +++++++++-- crates/graphql/src/query.rs | 281 ++++++++++++++++++ crates/graphql/src/state.rs | 41 +++ crates/handlers/src/graphql.rs | 50 +++- .../src/views/account/emails/verify.rs | 3 + frontend/schema.graphql | 19 ++ frontend/src/gql/graphql.ts | 31 ++ frontend/src/gql/schema.ts | 100 ++++++- 16 files changed, 746 insertions(+), 322 deletions(-) create mode 100644 crates/graphql/src/query.rs create mode 100644 crates/graphql/src/state.rs diff --git a/Cargo.lock b/Cargo.lock index 0aa63eab..df649085 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3252,11 +3252,11 @@ version = "0.1.0" dependencies = [ "anyhow", "async-graphql", + "async-trait", "chrono", "mas-data-model", "mas-storage", "oauth2-types", - "rand_chacha 0.3.1", "serde", "thiserror", "tokio", diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 90c08407..48db22c7 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -122,7 +122,7 @@ impl Options { watch_templates(&templates).await?; } - let graphql_schema = mas_handlers::graphql_schema(); + let graphql_schema = mas_handlers::graphql_schema(&pool); // Maximum 50 outgoing HTTP requests at a time let http_client_factory = HttpClientFactory::new(50); diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index 6823e104..cea731f9 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -8,14 +8,14 @@ license = "Apache-2.0" [dependencies] anyhow = "1.0.70" async-graphql = { version = "5.0.7", features = ["chrono", "url"] } +async-trait = "0.1.51" chrono = "0.4.24" serde = { version = "1.0.160", features = ["derive"] } -tokio = { version = "1.27.0", features = ["sync"] } thiserror = "1.0.40" +tokio = { version = "1.27.0", features = ["sync"] } tracing = "0.1.37" ulid = "1.0.0" url = "2.3.1" -rand_chacha = "0.3.1" oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 90ce84e9..6699b95f 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,269 +26,26 @@ clippy::unused_async )] -use async_graphql::{ - connection::{query, Connection, Edge, OpaqueCursor}, - Context, Description, EmptyMutation, EmptySubscription, ID, -}; -use mas_storage::{ - oauth2::OAuth2ClientRepository, - upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository}, - user::{BrowserSessionRepository, UserEmailRepository}, - BoxRepository, Pagination, -}; -use model::CreationEvent; -use tokio::sync::Mutex; - -use self::model::{ - BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, - UpstreamOAuth2Provider, User, UserEmail, -}; +use async_graphql::EmptySubscription; mod model; mod mutations; +mod query; +mod state; -pub type Schema = async_graphql::Schema; -pub type SchemaBuilder = async_graphql::SchemaBuilder; +pub use self::{ + model::{CreationEvent, Node}, + mutations::RootMutations, + query::RootQuery, + state::{BoxState, State}, +}; + +pub type Schema = async_graphql::Schema; +pub type SchemaBuilder = async_graphql::SchemaBuilder; #[must_use] pub fn schema_builder() -> SchemaBuilder { - async_graphql::Schema::build(RootQuery::new(), EmptyMutation, EmptySubscription) + async_graphql::Schema::build(RootQuery::new(), RootMutations::new(), EmptySubscription) .register_output_type::() .register_output_type::() } - -/// The query root of the GraphQL interface. -#[derive(Default, Description)] -pub struct RootQuery { - _private: (), -} - -impl RootQuery { - #[must_use] - pub fn new() -> Self { - Self::default() - } -} - -#[async_graphql::Object(use_type_description)] -impl RootQuery { - /// Get the current logged in browser session - async fn current_browser_session( - &self, - ctx: &Context<'_>, - ) -> Result, async_graphql::Error> { - let session = ctx.data_opt::().cloned(); - Ok(session.map(BrowserSession::from)) - } - - /// Get the current logged in user - async fn current_user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { - let session = ctx.data_opt::().cloned(); - Ok(session.map(User::from)) - } - - /// Fetch an OAuth 2.0 client by its ID. - async fn oauth2_client( - &self, - ctx: &Context<'_>, - id: ID, - ) -> Result, async_graphql::Error> { - let id = NodeType::OAuth2Client.extract_ulid(&id)?; - let mut repo = ctx.data::>()?.lock().await; - - let client = repo.oauth2_client().lookup(id).await?; - - Ok(client.map(OAuth2Client)) - } - - /// Fetch a user by its ID. - async fn user(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { - let id = NodeType::User.extract_ulid(&id)?; - let session = ctx.data_opt::().cloned(); - - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - - if current_user.id == id { - Ok(Some(User(current_user))) - } else { - Ok(None) - } - } - - /// Fetch a browser session by its ID. - async fn browser_session( - &self, - ctx: &Context<'_>, - id: ID, - ) -> Result, async_graphql::Error> { - let id = NodeType::BrowserSession.extract_ulid(&id)?; - let session = ctx.data_opt::().cloned(); - let mut repo = ctx.data::>()?.lock().await; - - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - - let browser_session = repo.browser_session().lookup(id).await?; - - let ret = browser_session.and_then(|browser_session| { - if browser_session.user.id == current_user.id { - Some(BrowserSession(browser_session)) - } else { - None - } - }); - - Ok(ret) - } - - /// Fetch a user email by its ID. - async fn user_email( - &self, - ctx: &Context<'_>, - id: ID, - ) -> Result, async_graphql::Error> { - let id = NodeType::UserEmail.extract_ulid(&id)?; - let session = ctx.data_opt::().cloned(); - let mut repo = ctx.data::>()?.lock().await; - - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - - let user_email = repo - .user_email() - .lookup(id) - .await? - .filter(|e| e.user_id == current_user.id); - - Ok(user_email.map(UserEmail)) - } - - /// Fetch an upstream OAuth 2.0 link by its ID. - async fn upstream_oauth2_link( - &self, - ctx: &Context<'_>, - id: ID, - ) -> Result, async_graphql::Error> { - let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; - let session = ctx.data_opt::().cloned(); - let mut repo = ctx.data::>()?.lock().await; - - let Some(session) = session else { return Ok(None) }; - let current_user = session.user; - - let link = repo.upstream_oauth_link().lookup(id).await?; - - // Ensure that the link belongs to the current user - let link = link.filter(|link| link.user_id == Some(current_user.id)); - - Ok(link.map(UpstreamOAuth2Link::new)) - } - - /// Fetch an upstream OAuth 2.0 provider by its ID. - async fn upstream_oauth2_provider( - &self, - ctx: &Context<'_>, - id: ID, - ) -> Result, async_graphql::Error> { - let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; - let mut repo = ctx.data::>()?.lock().await; - - let provider = repo.upstream_oauth_provider().lookup(id).await?; - - Ok(provider.map(UpstreamOAuth2Provider::new)) - } - - /// Get a list of upstream OAuth 2.0 providers. - async fn upstream_oauth2_providers( - &self, - ctx: &Context<'_>, - - #[graphql(desc = "Returns the elements in the list that come after the cursor.")] - after: Option, - #[graphql(desc = "Returns the elements in the list that come before the cursor.")] - before: Option, - #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, - #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, - ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; - - query( - after, - before, - first, - last, - |after, before, first, last| async move { - let after_id = after - .map(|x: OpaqueCursor| { - x.extract_for_type(NodeType::UpstreamOAuth2Provider) - }) - .transpose()?; - let before_id = before - .map(|x: OpaqueCursor| { - x.extract_for_type(NodeType::UpstreamOAuth2Provider) - }) - .transpose()?; - let pagination = Pagination::try_new(before_id, after_id, first, last)?; - - let page = repo - .upstream_oauth_provider() - .list_paginated(pagination) - .await?; - - let mut connection = Connection::new(page.has_previous_page, page.has_next_page); - connection.edges.extend(page.edges.into_iter().map(|p| { - Edge::new( - OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), - UpstreamOAuth2Provider::new(p), - ) - })); - - Ok::<_, async_graphql::Error>(connection) - }, - ) - .await - } - - /// Fetches an object given its ID. - async fn node(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { - let (node_type, _id) = NodeType::from_id(&id)?; - - let ret = match node_type { - // TODO - NodeType::Authentication - | NodeType::CompatSession - | NodeType::CompatSsoLogin - | NodeType::OAuth2Session => None, - - NodeType::UpstreamOAuth2Provider => self - .upstream_oauth2_provider(ctx, id) - .await? - .map(|c| Node::UpstreamOAuth2Provider(Box::new(c))), - - NodeType::UpstreamOAuth2Link => self - .upstream_oauth2_link(ctx, id) - .await? - .map(|c| Node::UpstreamOAuth2Link(Box::new(c))), - - NodeType::OAuth2Client => self - .oauth2_client(ctx, id) - .await? - .map(|c| Node::OAuth2Client(Box::new(c))), - - NodeType::UserEmail => self - .user_email(ctx, id) - .await? - .map(|e| Node::UserEmail(Box::new(e))), - - NodeType::BrowserSession => self - .browser_session(ctx, id) - .await? - .map(|s| Node::BrowserSession(Box::new(s))), - - NodeType::User => self.user(ctx, id).await?.map(|u| Node::User(Box::new(u))), - }; - - Ok(ret) - } -} diff --git a/crates/graphql/src/model/compat_sessions.rs b/crates/graphql/src/model/compat_sessions.rs index 38fdd4ba..f021cf1f 100644 --- a/crates/graphql/src/model/compat_sessions.rs +++ b/crates/graphql/src/model/compat_sessions.rs @@ -15,11 +15,11 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{compat::CompatSessionRepository, user::UserRepository, BoxRepository}; -use tokio::sync::Mutex; +use mas_storage::{compat::CompatSessionRepository, user::UserRepository}; use url::Url; use super::{NodeType, User}; +use crate::state::ContextExt; /// A compat session represents a client session which used the legacy Matrix /// login API. @@ -35,12 +35,15 @@ impl CompatSession { /// The user authorized for this session. async fn user(&self, ctx: &Context<'_>) -> Result { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let user = repo .user() .lookup(self.0.user_id) .await? .context("Could not load user")?; + repo.cancel().await?; + Ok(User(user)) } @@ -100,12 +103,14 @@ impl CompatSsoLogin { ) -> Result, async_graphql::Error> { let Some(session_id) = self.0.session_id() else { return Ok(None) }; - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let session = repo .compat_session() .lookup(session_id) .await? .context("Could not load compat session")?; + repo.cancel().await?; Ok(Some(CompatSession(session))) } diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 19612f6d..b297263b 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,13 +14,13 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository, BoxRepository}; +use mas_storage::{oauth2::OAuth2ClientRepository, user::BrowserSessionRepository}; use oauth2_types::scope::Scope; -use tokio::sync::Mutex; use ulid::Ulid; use url::Url; use super::{BrowserSession, NodeType, User}; +use crate::state::ContextExt; /// An OAuth 2.0 session represents a client session which used the OAuth APIs /// to login. @@ -36,12 +36,14 @@ impl OAuth2Session { /// OAuth 2.0 client used by this session. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let client = repo .oauth2_client() .lookup(self.0.client_id) .await? .context("Could not load client")?; + repo.cancel().await?; Ok(OAuth2Client(client)) } @@ -56,24 +58,28 @@ impl OAuth2Session { &self, ctx: &Context<'_>, ) -> Result { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let browser_session = repo .browser_session() .lookup(self.0.user_session_id) .await? .context("Could not load browser session")?; + repo.cancel().await?; Ok(BrowserSession(browser_session)) } /// User authorized for this session. pub async fn user(&self, ctx: &Context<'_>) -> Result { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let browser_session = repo .browser_session() .lookup(self.0.user_session_id) .await? .context("Could not load browser session")?; + repo.cancel().await?; Ok(User(browser_session.user)) } @@ -138,12 +144,15 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let client = repo .oauth2_client() .lookup(self.client_id) .await? .context("Could not load client")?; + repo.cancel().await?; + Ok(OAuth2Client(client)) } } diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index c2daaf9f..cc6771d7 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -15,12 +15,10 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; -use mas_storage::{ - upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository, BoxRepository, -}; -use tokio::sync::Mutex; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, user::UserRepository}; use super::{NodeType, User}; +use crate::state::ContextExt; #[derive(Debug, Clone)] pub struct UpstreamOAuth2Provider { @@ -97,20 +95,21 @@ impl UpstreamOAuth2Link { &self, ctx: &Context<'_>, ) -> Result { + let state = ctx.state(); let provider = if let Some(provider) = &self.provider { // Cached provider.clone() } else { // Fetch on-the-fly - let mut repo = ctx.data::>()?.lock().await; + let mut repo = state.repository().await?; - // This is a false positive, since it would have a lifetime error - #[allow(clippy::let_and_return)] let provider = repo .upstream_oauth_provider() .lookup(self.link.provider_id) .await? .context("Upstream OAuth 2.0 provider not found")?; + repo.cancel().await?; + provider }; @@ -119,20 +118,21 @@ impl UpstreamOAuth2Link { /// The user to which this link is associated. pub async fn user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { + let state = ctx.state(); let user = if let Some(user) = &self.user { // Cached user.clone() } else if let Some(user_id) = &self.link.user_id { // Fetch on-the-fly - let mut repo = ctx.data::>()?.lock().await; + let mut repo = state.repository().await?; - // This is a false positive, since it would have a lifetime error - #[allow(clippy::let_and_return)] let user = repo .user() .lookup(*user_id) .await? .context("User not found")?; + repo.cancel().await?; + user } else { return Ok(None); diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 35c2cae4..b8b96d88 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -22,14 +22,14 @@ use mas_storage::{ oauth2::OAuth2SessionRepository, upstream_oauth2::UpstreamOAuthLinkRepository, user::{BrowserSessionRepository, UserEmailRepository}, - BoxRepository, Pagination, + Pagination, }; -use tokio::sync::Mutex; use super::{ compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, UpstreamOAuth2Link, }; +use crate::state::ContextExt; #[derive(Description)] /// A user is an individual's account. @@ -64,10 +64,12 @@ impl User { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; - let mut user_email_repo = repo.user_email(); - Ok(user_email_repo.get_primary(&self.0).await?.map(UserEmail)) + let user_email = repo.user_email().get_primary(&self.0).await?.map(UserEmail); + repo.cancel().await?; + Ok(user_email) } /// Get the list of compatibility SSO logins, chronologically sorted @@ -82,7 +84,8 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; query( after, @@ -103,6 +106,8 @@ impl User { .list_paginated(&self.0, pagination) .await?; + repo.cancel().await?; + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( @@ -129,7 +134,8 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; query( after, @@ -150,6 +156,8 @@ impl User { .list_active_paginated(&self.0, pagination) .await?; + repo.cancel().await?; + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); connection.edges.extend(page.edges.into_iter().map(|u| { Edge::new( @@ -176,7 +184,8 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; query( after, @@ -197,6 +206,8 @@ impl User { .list_paginated(&self.0, pagination) .await?; + repo.cancel().await?; + let mut connection = Connection::with_additional_fields( page.has_previous_page, page.has_next_page, @@ -227,7 +238,8 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; query( after, @@ -248,6 +260,8 @@ impl User { .list_paginated(&self.0, pagination) .await?; + repo.cancel().await?; + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( @@ -274,7 +288,8 @@ impl User { #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, ) -> Result, async_graphql::Error> { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; query( after, @@ -299,6 +314,8 @@ impl User { .list_paginated(&self.0, pagination) .await?; + repo.cancel().await?; + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( @@ -348,8 +365,10 @@ pub struct UserEmailsPagination(mas_data_model::User); impl UserEmailsPagination { /// Identifies the total count of items in the connection. async fn total_count(&self, ctx: &Context<'_>) -> Result { - let mut repo = ctx.data::>()?.lock().await; + let state = ctx.state(); + let mut repo = state.repository().await?; let count = repo.user_email().count(&self.0).await?; + repo.cancel().await?; Ok(count) } } diff --git a/crates/graphql/src/mutations.rs b/crates/graphql/src/mutations.rs index b84e3a48..24041502 100644 --- a/crates/graphql/src/mutations.rs +++ b/crates/graphql/src/mutations.rs @@ -12,38 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. -use async_graphql::{Context, Object, ID}; +use anyhow::Context as _; +use async_graphql::{Context, Description, Object, ID}; use mas_storage::{ - job::{JobRepositoryExt, VerifyEmailJob}, + job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, user::UserEmailRepository, - BoxClock, BoxRepository, BoxRng, RepositoryAccess, SystemClock, + RepositoryAccess, }; -use rand_chacha::{rand_core::SeedableRng, ChaChaRng}; -use tokio::sync::Mutex; -use crate::model::{NodeType, UserEmail}; +use crate::{ + model::{NodeType, UserEmail}, + state::ContextExt, +}; -struct RootMutations; - -fn clock_and_rng() -> (BoxClock, BoxRng) { - // XXX: this should be moved somewhere else - let clock = SystemClock::default(); - let rng = ChaChaRng::from_entropy(); - (Box::new(clock), Box::new(rng)) +/// The mutations root of the GraphQL interface. +#[derive(Default, Description)] +pub struct RootMutations { + _private: (), } -#[Object] impl RootMutations { + #[must_use] + pub fn new() -> Self { + Self::default() + } +} + +#[Object(use_type_description)] +impl RootMutations { + /// Add an email address to the specified user async fn add_email( &self, ctx: &Context<'_>, - email: String, - user_id: ID, + + #[graphql(desc = "The email address to add")] email: String, + #[graphql(desc = "The ID of the user to add the email address to")] user_id: ID, ) -> Result { + let state = ctx.state(); let id = NodeType::User.extract_ulid(&user_id)?; - let session = ctx.data_opt::().cloned(); - let (clock, mut rng) = clock_and_rng(); - let mut repo = ctx.data::>()?.lock().await; + let session = ctx.session(); let Some(session) = session else { return Err(async_graphql::Error::new("Unauthorized")); @@ -53,15 +60,133 @@ impl RootMutations { return Err(async_graphql::Error::new("Unauthorized")); } + let mut repo = state.repository().await?; + + // XXX: this logic should be extracted somewhere else, since most of it is + // duplicated in mas_handlers + // Find an existing email address + let existing_user_email = repo.user_email().find(&session.user, &email).await?; + let user_email = if let Some(user_email) = existing_user_email { + user_email + } else { + let clock = state.clock(); + let mut rng = state.rng(); + + repo.user_email() + .add(&mut rng, &clock, &session.user, email) + .await? + }; + + // Schedule a job to verify the email address if needed + if user_email.confirmed_at.is_none() { + repo.job() + .schedule_job(VerifyEmailJob::new(&user_email)) + .await?; + } + + repo.save().await?; + + Ok(UserEmail(user_email)) + } + + /// Send a verification code for an email address + async fn send_verification_email( + &self, + ctx: &Context<'_>, + + #[graphql(desc = "The ID of the email address to verify")] user_email_id: ID, + ) -> Result { + let state = ctx.state(); + let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?; + let session = ctx.session(); + + let Some(session) = session else { + return Err(async_graphql::Error::new("Unauthorized")); + }; + + let mut repo = state.repository().await?; + let user_email = repo .user_email() - .add(&mut rng, &clock, &session.user, email) + .lookup(user_email_id) + .await? + .context("User email not found")?; + + if user_email.user_id != session.user.id { + return Err(async_graphql::Error::new("Unauthorized")); + } + + // Schedule a job to verify the email address if needed + if user_email.confirmed_at.is_none() { + repo.job() + .schedule_job(VerifyEmailJob::new(&user_email)) + .await?; + } + + repo.save().await?; + + Ok(UserEmail(user_email)) + } + + /// Submit a verification code for an email address + async fn verify_email( + &self, + ctx: &Context<'_>, + + #[graphql(desc = "The ID of the email address to verify")] user_email_id: ID, + #[graphql(desc = "The verification code to submit")] code: String, + ) -> Result { + let state = ctx.state(); + let user_email_id = NodeType::UserEmail.extract_ulid(&user_email_id)?; + let session = ctx.session(); + + let Some(session) = session else { + return Err(async_graphql::Error::new("Unauthorized")); + }; + + let clock = state.clock(); + let mut repo = state.repository().await?; + + let user_email = repo + .user_email() + .lookup(user_email_id) + .await? + .context("User email not found")?; + + if user_email.user_id != session.user.id { + return Err(async_graphql::Error::new("Unauthorized")); + } + + // XXX: this logic should be extracted somewhere else, since most of it is + // duplicated in mas_handlers + + // Find the verification code + let verification = repo + .user_email() + .find_verification_code(&clock, &user_email, &code) + .await? + .context("Invalid verification code")?; + + // TODO: display nice errors if the code was already consumed or expired + repo.user_email() + .consume_verification_code(&clock, verification) + .await?; + + // XXX: is this the right place to do this? + if session.user.primary_user_email_id.is_none() { + repo.user_email().set_as_primary(&user_email).await?; + } + + let user_email = repo + .user_email() + .mark_as_verified(&clock, user_email) .await?; repo.job() - .schedule_job(VerifyEmailJob::new(&user_email)) + .schedule_job(ProvisionUserJob::new(&session.user)) .await?; - // TODO: how do we save the transaction here? + + repo.save().await?; Ok(UserEmail(user_email)) } diff --git a/crates/graphql/src/query.rs b/crates/graphql/src/query.rs new file mode 100644 index 00000000..5b225920 --- /dev/null +++ b/crates/graphql/src/query.rs @@ -0,0 +1,281 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_graphql::{ + connection::{query, Connection, Edge, OpaqueCursor}, + Context, Description, Object, ID, +}; +use mas_storage::Pagination; + +use crate::{ + model::{ + BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, + UpstreamOAuth2Provider, User, UserEmail, + }, + state::ContextExt, +}; + +/// The query root of the GraphQL interface. +#[derive(Default, Description)] +pub struct RootQuery { + _private: (), +} + +impl RootQuery { + #[must_use] + pub fn new() -> Self { + Self::default() + } +} + +#[Object(use_type_description)] +impl RootQuery { + /// Get the current logged in browser session + async fn current_browser_session( + &self, + ctx: &Context<'_>, + ) -> Result, async_graphql::Error> { + let session = ctx.session().cloned(); + Ok(session.map(BrowserSession::from)) + } + + /// Get the current logged in user + async fn current_user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { + let session = ctx.session().cloned(); + Ok(session.map(User::from)) + } + + /// Fetch an OAuth 2.0 client by its ID. + async fn oauth2_client( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let state = ctx.state(); + let id = NodeType::OAuth2Client.extract_ulid(&id)?; + + let mut repo = state.repository().await?; + let client = repo.oauth2_client().lookup(id).await?; + repo.cancel().await?; + + Ok(client.map(OAuth2Client)) + } + + /// Fetch a user by its ID. + async fn user(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { + let id = NodeType::User.extract_ulid(&id)?; + + let session = ctx.session().cloned(); + + let Some(session) = session else { return Ok(None) }; + let current_user = session.user; + + if current_user.id == id { + Ok(Some(User(current_user))) + } else { + Ok(None) + } + } + + /// Fetch a browser session by its ID. + async fn browser_session( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let state = ctx.state(); + let id = NodeType::BrowserSession.extract_ulid(&id)?; + + let session = ctx.session().cloned(); + let mut repo = state.repository().await?; + + let Some(session) = session else { return Ok(None) }; + let current_user = session.user; + + let browser_session = repo.browser_session().lookup(id).await?; + + repo.cancel().await?; + + let ret = browser_session.and_then(|browser_session| { + if browser_session.user.id == current_user.id { + Some(BrowserSession(browser_session)) + } else { + None + } + }); + + Ok(ret) + } + + /// Fetch a user email by its ID. + async fn user_email( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let state = ctx.state(); + let id = NodeType::UserEmail.extract_ulid(&id)?; + + let session = ctx.session().cloned(); + let mut repo = state.repository().await?; + + let Some(session) = session else { return Ok(None) }; + let current_user = session.user; + + let user_email = repo + .user_email() + .lookup(id) + .await? + .filter(|e| e.user_id == current_user.id); + + repo.cancel().await?; + + Ok(user_email.map(UserEmail)) + } + + /// Fetch an upstream OAuth 2.0 link by its ID. + async fn upstream_oauth2_link( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let state = ctx.state(); + let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; + + let session = ctx.session().cloned(); + let mut repo = state.repository().await?; + + let Some(session) = session else { return Ok(None) }; + let current_user = session.user; + + let link = repo.upstream_oauth_link().lookup(id).await?; + + // Ensure that the link belongs to the current user + let link = link.filter(|link| link.user_id == Some(current_user.id)); + + Ok(link.map(UpstreamOAuth2Link::new)) + } + + /// Fetch an upstream OAuth 2.0 provider by its ID. + async fn upstream_oauth2_provider( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let state = ctx.state(); + let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; + + let mut repo = state.repository().await?; + let provider = repo.upstream_oauth_provider().lookup(id).await?; + repo.cancel().await?; + + Ok(provider.map(UpstreamOAuth2Provider::new)) + } + + /// Get a list of upstream OAuth 2.0 providers. + async fn upstream_oauth2_providers( + &self, + ctx: &Context<'_>, + + #[graphql(desc = "Returns the elements in the list that come after the cursor.")] + after: Option, + #[graphql(desc = "Returns the elements in the list that come before the cursor.")] + before: Option, + #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, + #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, + ) -> Result, async_graphql::Error> { + let state = ctx.state(); + let mut repo = state.repository().await?; + + query( + after, + before, + first, + last, + |after, before, first, last| async move { + let after_id = after + .map(|x: OpaqueCursor| { + x.extract_for_type(NodeType::UpstreamOAuth2Provider) + }) + .transpose()?; + let before_id = before + .map(|x: OpaqueCursor| { + x.extract_for_type(NodeType::UpstreamOAuth2Provider) + }) + .transpose()?; + let pagination = Pagination::try_new(before_id, after_id, first, last)?; + + let page = repo + .upstream_oauth_provider() + .list_paginated(pagination) + .await?; + + repo.cancel().await?; + + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|p| { + Edge::new( + OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), + UpstreamOAuth2Provider::new(p), + ) + })); + + Ok::<_, async_graphql::Error>(connection) + }, + ) + .await + } + + /// Fetches an object given its ID. + async fn node(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { + let (node_type, _id) = NodeType::from_id(&id)?; + + let ret = match node_type { + // TODO + NodeType::Authentication + | NodeType::CompatSession + | NodeType::CompatSsoLogin + | NodeType::OAuth2Session => None, + + NodeType::UpstreamOAuth2Provider => self + .upstream_oauth2_provider(ctx, id) + .await? + .map(|c| Node::UpstreamOAuth2Provider(Box::new(c))), + + NodeType::UpstreamOAuth2Link => self + .upstream_oauth2_link(ctx, id) + .await? + .map(|c| Node::UpstreamOAuth2Link(Box::new(c))), + + NodeType::OAuth2Client => self + .oauth2_client(ctx, id) + .await? + .map(|c| Node::OAuth2Client(Box::new(c))), + + NodeType::UserEmail => self + .user_email(ctx, id) + .await? + .map(|e| Node::UserEmail(Box::new(e))), + + NodeType::BrowserSession => self + .browser_session(ctx, id) + .await? + .map(|s| Node::BrowserSession(Box::new(s))), + + NodeType::User => self.user(ctx, id).await?.map(|u| Node::User(Box::new(u))), + }; + + Ok(ret) + } +} diff --git a/crates/graphql/src/state.rs b/crates/graphql/src/state.rs new file mode 100644 index 00000000..e9775fbb --- /dev/null +++ b/crates/graphql/src/state.rs @@ -0,0 +1,41 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use mas_data_model::BrowserSession; +use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; + +#[async_trait::async_trait] +pub trait State { + async fn repository(&self) -> Result; + fn clock(&self) -> BoxClock; + fn rng(&self) -> BoxRng; +} + +pub type BoxState = Box; + +pub trait ContextExt { + fn state(&self) -> &BoxState; + + fn session(&self) -> Option<&BrowserSession>; +} + +impl ContextExt for async_graphql::Context<'_> { + fn state(&self) -> &BoxState { + self.data_unchecked::() + } + + fn session(&self) -> Option<&BrowserSession> { + self.data_opt() + } +} diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index 233c4690..1f76f157 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -17,6 +17,7 @@ use async_graphql::{ http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions}, }; use axum::{ + async_trait, extract::{BodyStream, RawQuery, State}, response::{Html, IntoResponse}, Json, TypedHeader, @@ -28,15 +29,50 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; -use mas_storage::BoxRepository; -use tokio::sync::Mutex; +use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, RepositoryError, SystemClock}; +use mas_storage_pg::PgRepository; +use rand::{thread_rng, SeedableRng}; +use rand_chacha::ChaChaRng; +use sqlx::PgPool; use tracing::{info_span, Instrument}; +struct GraphQLState { + pool: PgPool, +} + +#[async_trait] +impl mas_graphql::State for GraphQLState { + async fn repository(&self) -> Result { + let repo = PgRepository::from_pool(&self.pool) + .await + .map_err(RepositoryError::from_error)?; + + Ok(repo.map_err(RepositoryError::from_error).boxed()) + } + + fn clock(&self) -> BoxClock { + let clock = SystemClock::default(); + Box::new(clock) + } + + fn rng(&self) -> BoxRng { + #[allow(clippy::disallowed_methods)] + let rng = thread_rng(); + + let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng"); + Box::new(rng) + } +} + #[must_use] -pub fn schema() -> Schema { +pub fn schema(pool: &PgPool) -> Schema { + let state = GraphQLState { pool: pool.clone() }; + let state: mas_graphql::BoxState = Box::new(state); + mas_graphql::schema_builder() .extension(Tracing) .extension(ApolloTracing) + .data(state) .finish() } @@ -68,6 +104,7 @@ pub async fn post( let (session_info, _cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; + repo.cancel().await?; let mut request = async_graphql::http::receive_body( content_type, @@ -75,8 +112,7 @@ pub async fn post( .into_async_read(), MultipartOptions::default(), ) - .await? // XXX: this should probably return another error response? - .data(Mutex::new(repo)); + .await?; // XXX: this should probably return another error response? if let Some(session) = maybe_session { request = request.data(session); @@ -104,9 +140,9 @@ pub async fn get( ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); let maybe_session = session_info.load_session(&mut repo).await?; + repo.cancel().await?; - let mut request = - async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(Mutex::new(repo)); + let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; if let Some(session) = maybe_session { request = request.data(session); diff --git a/crates/handlers/src/views/account/emails/verify.rs b/crates/handlers/src/views/account/emails/verify.rs index 51765e92..401d6c8f 100644 --- a/crates/handlers/src/views/account/emails/verify.rs +++ b/crates/handlers/src/views/account/emails/verify.rs @@ -118,6 +118,9 @@ pub(crate) async fn post( .filter(|u| u.user_id == session.user.id) .context("Could not find user email")?; + // XXX: this logic should be extracted somewhere else, since most of it is + // duplicated in mas_graphql + let verification = repo .user_email() .find_verification_code(&clock, &user_email, &form.code) diff --git a/frontend/schema.graphql b/frontend/schema.graphql index 8f98b93d..0eb16954 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -291,6 +291,24 @@ type PageInfo { endCursor: String } +""" +The mutations root of the GraphQL interface. +""" +type RootMutations { + """ + Add an email address to the specified user + """ + addEmail(email: String!, userId: ID!): UserEmail! + """ + Send a verification code for an email address + """ + sendVerificationEmail(userEmailId: ID!): UserEmail! + """ + Submit a verification code for an email address + """ + verifyEmail(userEmailId: ID!, code: String!): UserEmail! +} + """ The query root of the GraphQL interface. """ @@ -568,4 +586,5 @@ type UserEmailEdge { schema { query: RootQuery + mutation: RootMutations } diff --git a/frontend/src/gql/graphql.ts b/frontend/src/gql/graphql.ts index 4a01808e..e0202267 100644 --- a/frontend/src/gql/graphql.ts +++ b/frontend/src/gql/graphql.ts @@ -207,6 +207,37 @@ export type PageInfo = { startCursor?: Maybe; }; +/** The mutations root of the GraphQL interface. */ +export type RootMutations = { + __typename?: 'RootMutations'; + /** Add an email address to the specified user */ + addEmail: UserEmail; + /** Send a verification code for an email address */ + sendVerificationEmail: UserEmail; + /** Submit a verification code for an email address */ + verifyEmail: UserEmail; +}; + + +/** The mutations root of the GraphQL interface. */ +export type RootMutationsAddEmailArgs = { + email: Scalars['String']; + userId: Scalars['ID']; +}; + + +/** The mutations root of the GraphQL interface. */ +export type RootMutationsSendVerificationEmailArgs = { + userEmailId: Scalars['ID']; +}; + + +/** The mutations root of the GraphQL interface. */ +export type RootMutationsVerifyEmailArgs = { + code: Scalars['String']; + userEmailId: Scalars['ID']; +}; + /** The query root of the GraphQL interface. */ export type RootQuery = { __typename?: 'RootQuery'; diff --git a/frontend/src/gql/schema.ts b/frontend/src/gql/schema.ts index 40dfafee..1012d501 100644 --- a/frontend/src/gql/schema.ts +++ b/frontend/src/gql/schema.ts @@ -4,7 +4,9 @@ export default { queryType: { name: "RootQuery", }, - mutationType: null, + mutationType: { + name: "RootMutations", + }, subscriptionType: null, types: [ { @@ -800,6 +802,102 @@ export default { ], interfaces: [], }, + { + kind: "OBJECT", + name: "RootMutations", + fields: [ + { + name: "addEmail", + type: { + kind: "NON_NULL", + ofType: { + kind: "OBJECT", + name: "UserEmail", + ofType: null, + }, + }, + args: [ + { + name: "email", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + { + name: "userId", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + ], + }, + { + name: "sendVerificationEmail", + type: { + kind: "NON_NULL", + ofType: { + kind: "OBJECT", + name: "UserEmail", + ofType: null, + }, + }, + args: [ + { + name: "userEmailId", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + ], + }, + { + name: "verifyEmail", + type: { + kind: "NON_NULL", + ofType: { + kind: "OBJECT", + name: "UserEmail", + ofType: null, + }, + }, + args: [ + { + name: "code", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + { + name: "userEmailId", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + ], + }, + ], + interfaces: [], + }, { kind: "OBJECT", name: "RootQuery",