diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index 735577ef..08fbf6c0 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -33,6 +33,8 @@ pub struct UpstreamOAuthProvider { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UpstreamOAuthLink { pub id: Ulid, + pub provider_id: Ulid, + pub user_id: Option, pub subject: String, pub created_at: DateTime, } @@ -40,6 +42,8 @@ pub struct UpstreamOAuthLink { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UpstreamOAuthAuthorizationSession { pub id: Ulid, + pub provider_id: Ulid, + pub link_id: Option, pub state: String, pub code_challenge_verifier: Option, pub nonce: String, diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 7b34f4a1..4f1c77a0 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -22,12 +22,18 @@ #![warn(clippy::pedantic)] #![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)] -use async_graphql::{Context, Description, EmptyMutation, EmptySubscription, ID}; +use async_graphql::{ + connection::{query, Connection, Edge, OpaqueCursor}, + Context, Description, EmptyMutation, EmptySubscription, ID, +}; use mas_axum_utils::SessionInfo; use mas_storage::LookupResultExt; use sqlx::PgPool; -use self::model::{BrowserSession, Node, NodeType, OAuth2Client, User, UserEmail}; +use self::model::{ + BrowserSession, Cursor, Node, NodeCursor, NodeType, OAuth2Client, UpstreamOAuth2Link, + UpstreamOAuth2Provider, User, UserEmail, +}; mod model; @@ -167,6 +173,100 @@ impl RootQuery { Ok(user_email.map(UserEmail)) } + /// Fetch an upstream OAuth 2.0 link by its ID. + async fn upstream_oauth2_link( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; + let database = ctx.data::()?; + let session_info = ctx.data::()?; + let mut conn = database.acquire().await?; + let session = session_info.load_session(&mut conn).await?; + + let Some(session) = session else { return Ok(None) }; + let current_user = session.user; + + let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id) + .await + .to_option()?; + + // Ensure that the link belongs to the current user + let link = link.filter(|link| link.user_id == Some(current_user.data)); + + Ok(link.map(UpstreamOAuth2Link::new)) + } + + /// Fetch an upstream OAuth 2.0 provider by its ID. + async fn upstream_oauth2_provider( + &self, + ctx: &Context<'_>, + id: ID, + ) -> Result, async_graphql::Error> { + let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?; + let database = ctx.data::()?; + let mut conn = database.acquire().await?; + + let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id) + .await + .to_option()?; + + Ok(provider.map(UpstreamOAuth2Provider::new)) + } + + /// Get a list of upstream OAuth 2.0 providers. + async fn upstream_oauth2_providers( + &self, + ctx: &Context<'_>, + + #[graphql(desc = "Returns the elements in the list that come after the cursor.")] + after: Option, + #[graphql(desc = "Returns the elements in the list that come before the cursor.")] + before: Option, + #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, + #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, + ) -> Result, async_graphql::Error> { + let database = ctx.data::()?; + + query( + after, + before, + first, + last, + |after, before, first, last| async move { + let mut conn = database.acquire().await?; + let after_id = after + .map(|x: OpaqueCursor| { + x.extract_for_type(NodeType::UpstreamOAuth2Provider) + }) + .transpose()?; + let before_id = before + .map(|x: OpaqueCursor| { + x.extract_for_type(NodeType::UpstreamOAuth2Provider) + }) + .transpose()?; + + let (has_previous_page, has_next_page, edges) = + mas_storage::upstream_oauth2::get_paginated_providers( + &mut conn, before_id, after_id, first, last, + ) + .await?; + + let mut connection = Connection::new(has_previous_page, has_next_page); + connection.edges.extend(edges.into_iter().map(|p| { + Edge::new( + OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), + UpstreamOAuth2Provider::new(p), + ) + })); + + Ok::<_, async_graphql::Error>(connection) + }, + ) + .await + } + /// Fetches an object given its ID. async fn node(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { let (node_type, _id) = NodeType::from_id(&id)?; @@ -178,6 +278,16 @@ impl RootQuery { | NodeType::CompatSsoLogin | NodeType::OAuth2Session => None, + NodeType::UpstreamOAuth2Provider => self + .upstream_oauth2_provider(ctx, id) + .await? + .map(|c| Node::UpstreamOAuth2Provider(Box::new(c))), + + NodeType::UpstreamOAuth2Link => self + .upstream_oauth2_link(ctx, id) + .await? + .map(|c| Node::UpstreamOAuth2Link(Box::new(c))), + NodeType::OAuth2Client => self .oauth2_client(ctx, id) .await? diff --git a/crates/graphql/src/model/mod.rs b/crates/graphql/src/model/mod.rs index f17519aa..7b923b3c 100644 --- a/crates/graphql/src/model/mod.rs +++ b/crates/graphql/src/model/mod.rs @@ -20,6 +20,7 @@ mod compat_sessions; mod cursor; mod node; mod oauth; +mod upstream_oauth; mod users; pub use self::{ @@ -28,6 +29,7 @@ pub use self::{ cursor::{Cursor, NodeCursor}, node::{Node, NodeType}, oauth::{OAuth2Client, OAuth2Consent, OAuth2Session}, + upstream_oauth::{UpstreamOAuth2Link, UpstreamOAuth2Provider}, users::{User, UserEmail}, }; @@ -42,4 +44,6 @@ pub enum CreationEvent { CompatSession(Box), BrowserSession(Box), UserEmail(Box), + UpstreamOAuth2Provider(Box), + UpstreamOAuth2Link(Box), } diff --git a/crates/graphql/src/model/node.rs b/crates/graphql/src/model/node.rs index 124cfeca..92375b87 100644 --- a/crates/graphql/src/model/node.rs +++ b/crates/graphql/src/model/node.rs @@ -19,7 +19,7 @@ use ulid::Ulid; use super::{ Authentication, BrowserSession, CompatSession, CompatSsoLogin, OAuth2Client, OAuth2Session, - User, UserEmail, + UpstreamOAuth2Link, UpstreamOAuth2Provider, User, UserEmail, }; #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -30,6 +30,8 @@ pub enum NodeType { CompatSsoLogin, OAuth2Client, OAuth2Session, + UpstreamOAuth2Provider, + UpstreamOAuth2Link, User, UserEmail, } @@ -52,6 +54,8 @@ impl NodeType { NodeType::CompatSsoLogin => "compat_sso_login", NodeType::OAuth2Client => "oauth2_client", NodeType::OAuth2Session => "oauth2_session", + NodeType::UpstreamOAuth2Provider => "upstream_oauth2_provider", + NodeType::UpstreamOAuth2Link => "upstream_oauth2_link", NodeType::User => "user", NodeType::UserEmail => "user_email", } @@ -65,6 +69,8 @@ impl NodeType { "compat_sso_login" => Some(NodeType::CompatSsoLogin), "oauth2_client" => Some(NodeType::OAuth2Client), "oauth2_session" => Some(NodeType::OAuth2Session), + "upstream_oauth2_provider" => Some(NodeType::UpstreamOAuth2Provider), + "upstream_oauth2_link" => Some(NodeType::UpstreamOAuth2Link), "user" => Some(NodeType::User), "user_email" => Some(NodeType::UserEmail), _ => None, @@ -116,6 +122,8 @@ pub enum Node { CompatSsoLogin(Box), OAuth2Client(Box), OAuth2Session(Box), + UpstreamOAuth2Provider(Box), + UpstreamOAuth2Link(Box), User(Box), UserEmail(Box), } diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs new file mode 100644 index 00000000..5d3489ec --- /dev/null +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -0,0 +1,121 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_graphql::{Context, Object, ID}; +use chrono::{DateTime, Utc}; +use mas_storage::PostgresqlBackend; +use sqlx::PgPool; + +use super::{NodeType, User}; + +#[derive(Debug, Clone)] +pub struct UpstreamOAuth2Provider { + provider: mas_data_model::UpstreamOAuthProvider, +} + +impl UpstreamOAuth2Provider { + #[must_use] + pub const fn new(provider: mas_data_model::UpstreamOAuthProvider) -> Self { + Self { provider } + } +} + +#[Object] +impl UpstreamOAuth2Provider { + /// ID of the object. + pub async fn id(&self) -> ID { + NodeType::UpstreamOAuth2Provider.id(self.provider.id) + } + + /// When the object was created. + pub async fn created_at(&self) -> DateTime { + self.provider.created_at + } + + /// OpenID Connect issuer URL. + pub async fn issuer(&self) -> &str { + &self.provider.issuer + } + + /// Client ID used for this provider. + pub async fn client_id(&self) -> &str { + &self.provider.client_id + } +} + +impl UpstreamOAuth2Link { + #[must_use] + pub const fn new(link: mas_data_model::UpstreamOAuthLink) -> Self { + Self { + link, + provider: None, + user: None, + } + } +} + +#[derive(Debug, Clone)] +pub struct UpstreamOAuth2Link { + link: mas_data_model::UpstreamOAuthLink, + provider: Option, + user: Option>, +} + +#[Object] +impl UpstreamOAuth2Link { + /// ID of the object. + pub async fn id(&self) -> ID { + NodeType::UpstreamOAuth2Link.id(self.link.id) + } + + /// When the object was created. + pub async fn created_at(&self) -> DateTime { + self.link.created_at + } + + /// The provider for which this link is. + pub async fn provider( + &self, + ctx: &Context<'_>, + ) -> Result { + let provider = if let Some(provider) = &self.provider { + // Cached + provider.clone() + } else { + // Fetch on-the-fly + let database = ctx.data::()?; + let mut conn = database.acquire().await?; + mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id).await? + }; + + Ok(UpstreamOAuth2Provider::new(provider)) + } + + /// The user to which this link is associated. + pub async fn user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { + let user = if let Some(user) = &self.user { + // Cached + user.clone() + } else if let Some(user_id) = &self.link.user_id { + // Fetch on-the-fly + let database = ctx.data::()?; + let mut conn = database.acquire().await?; + mas_storage::user::lookup_user(&mut conn, *user_id).await? + } else { + return Ok(None); + }; + + Ok(Some(User(user))) + } +} diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 43d90a26..cc01a2e3 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -250,7 +250,7 @@ pub(crate) async fn get( .await .to_option()?; - let link = if let Some((link, _maybe_user_id)) = maybe_link { + let link = if let Some(link) = maybe_link { link } else { add_link(&mut txn, &mut rng, &clock, &provider, subject).await? diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 39238a24..ab360ecf 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -114,7 +114,7 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; let (clock, mut rng) = crate::rng_and_clock()?; - let (link, _provider_id, maybe_user_id) = lookup_link(&mut txn, link_id) + let link = lookup_link(&mut txn, link_id) .await .to_option()? .ok_or(RouteError::LinkNotFound)?; @@ -141,7 +141,7 @@ pub(crate) async fn get( let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(clock.now(), &mut rng); let maybe_user_session = user_session_info.load_session(&mut txn).await?; - let render = match (maybe_user_session, maybe_user_id) { + let render = match (maybe_user_session, link.user_id) { (Some(mut session), Some(user_id)) if session.user.data == user_id => { // Session already linked, and link matches the currently logged // user. Mark the session as consumed and renew the authentication. @@ -215,7 +215,7 @@ pub(crate) async fn post( let (clock, mut rng) = crate::rng_and_clock()?; let form = cookie_jar.verify_form(clock.now(), form)?; - let (link, _provider_id, maybe_user_id) = lookup_link(&mut txn, link_id) + let link = lookup_link(&mut txn, link_id) .await .to_option()? .ok_or(RouteError::LinkNotFound)?; @@ -241,7 +241,7 @@ pub(crate) async fn post( let (user_session_info, cookie_jar) = cookie_jar.session_info(); let maybe_user_session = user_session_info.load_session(&mut txn).await?; - let mut session = match (maybe_user_session, maybe_user_id, form) { + let mut session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { associate_link_to_user(&mut txn, &link, &session.user).await?; session diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index c4054c2f..3e2e1130 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -57,11 +57,11 @@ impl OptionalPostAuthAction { Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)), Some(PostAuthAction::LinkUpstream { id }) => { - let (link, provider_id, _user_id) = - mas_storage::upstream_oauth2::lookup_link(&mut *conn, *id).await?; + let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, *id).await?; let provider = - mas_storage::upstream_oauth2::lookup_provider(&mut *conn, provider_id).await?; + mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) + .await?; let provider = Box::new(provider); let link = Box::new(link); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index dc8a4a67..e980858f 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -634,6 +634,81 @@ }, "query": "\n INSERT INTO users (user_id, username, created_at)\n VALUES ($1, $2, $3)\n " }, + "2ca7b990c11e84db62fb7887a2bc3410ec1eee2f6a0ec124db36575111970ca9": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_authorization_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_link_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "state", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "code_challenge_verifier", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "id_token", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + }, + { + "name": "completed_at", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 9, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + true, + false, + true, + false, + true, + true + ], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n " + }, "2e756fe7be50128c0acc5f79df3a084230e9ca13cd45bd0858f97e59da20006e": { "describe": { "columns": [], @@ -1284,116 +1359,6 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n\n ORDER BY ue.email ASC\n " }, - "605e9370d233169760dafd0ac5dea4d161b4ad1903c79ad35499732533a1b641": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_authorization_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "state", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "code_challenge_verifier", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "nonce", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "id_token", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "completed_at", - "ordinal": 7, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 8, - "type_info": "Timestamptz" - }, - { - "name": "provider_issuer", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "provider_scope", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "provider_client_id", - "ordinal": 11, - "type_info": "Text" - }, - { - "name": "provider_encrypted_client_secret", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_auth_method", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "provider_token_endpoint_signing_alg", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "provider_created_at", - "ordinal": 15, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - true, - false, - true, - false, - true, - true, - false, - false, - false, - true, - false, - true, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " - }, "60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": { "describe": { "columns": [], @@ -1457,6 +1422,122 @@ }, "query": "\n UPDATE user_sessions\n SET finished_at = $1\n WHERE user_session_id = $2\n " }, + "65c7600f1af07cb6ea49d89ae6fbca5374a57c5a866c8aadd7b75ed1d2d1d0cd": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_authorization_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_link_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "state", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "code_challenge_verifier", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "id_token", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + }, + { + "name": "completed_at", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "consumed_at", + "ordinal": 9, + "type_info": "Timestamptz" + }, + { + "name": "provider_issuer", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "provider_scope", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "provider_client_id", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "provider_encrypted_client_secret", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "provider_token_endpoint_auth_method", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "provider_token_endpoint_signing_alg", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "provider_created_at", + "ordinal": 16, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + true, + false, + true, + false, + true, + true, + false, + false, + false, + true, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.upstream_oauth_link_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.id_token,\n ua.created_at,\n ua.completed_at,\n ua.consumed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " + }, "6bf0da5ba3dd07b499193a2e0ddeea6e712f9df8f7f28874ff56a952a9f10e54": { "describe": { "columns": [], @@ -1994,69 +2075,6 @@ }, "query": "\n UPDATE users\n SET primary_user_email_id = user_emails.user_email_id\n FROM user_emails\n WHERE user_emails.user_email_id = $1\n AND users.user_id = user_emails.user_id\n " }, - "83ae2f24b4e5029a2e28b5404b8f3ae635ad9ec19f4e92d8e1823156fd836b4c": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_authorization_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "state", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "code_challenge_verifier", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "nonce", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "id_token", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "completed_at", - "ordinal": 6, - "type_info": "Timestamptz" - }, - { - "name": "consumed_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - true, - false, - true, - true - ], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n AND upstream_oauth_link_id = $2\n " - }, "874e677f82c221c5bb621c12f293bcef4e70c68c87ec003fcd475bcb994b5a4c": { "describe": { "columns": [], diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 53aaeacc..47790916 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -37,7 +37,7 @@ struct LinkLookup { pub async fn lookup_link( executor: impl PgExecutor<'_>, id: Ulid, -) -> Result<(UpstreamOAuthLink, Ulid, Option), GenericLookupError> { +) -> Result { let res = sqlx::query_as!( LinkLookup, r#" @@ -56,15 +56,13 @@ pub async fn lookup_link( .await .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; - Ok(( - UpstreamOAuthLink { - id: Ulid::from(res.upstream_oauth_link_id), - subject: res.subject, - created_at: res.created_at, - }, - Ulid::from(res.upstream_oauth_provider_id), - res.user_id.map(Ulid::from), - )) + Ok(UpstreamOAuthLink { + id: Ulid::from(res.upstream_oauth_link_id), + provider_id: Ulid::from(res.upstream_oauth_provider_id), + user_id: res.user_id.map(Ulid::from), + subject: res.subject, + created_at: res.created_at, + }) } #[tracing::instrument( @@ -81,7 +79,7 @@ pub async fn lookup_link_by_subject( executor: impl PgExecutor<'_>, upstream_oauth_provider: &UpstreamOAuthProvider, subject: &str, -) -> Result<(UpstreamOAuthLink, Option), GenericLookupError> { +) -> Result { let res = sqlx::query_as!( LinkLookup, r#" @@ -102,14 +100,13 @@ pub async fn lookup_link_by_subject( .await .map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?; - Ok(( - UpstreamOAuthLink { - id: Ulid::from(res.upstream_oauth_link_id), - subject: res.subject, - created_at: res.created_at, - }, - res.user_id.map(Ulid::from), - )) + Ok(UpstreamOAuthLink { + id: Ulid::from(res.upstream_oauth_link_id), + provider_id: Ulid::from(res.upstream_oauth_provider_id), + user_id: res.user_id.map(Ulid::from), + subject: res.subject, + created_at: res.created_at, + }) } #[tracing::instrument( @@ -154,6 +151,8 @@ pub async fn add_link( Ok(UpstreamOAuthLink { id, + provider_id: upstream_oauth_provider.id, + user_id: None, subject, created_at, }) diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index d9e764f1..6376a16e 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -18,7 +18,7 @@ mod session; pub use self::{ link::{add_link, associate_link_to_user, lookup_link, lookup_link_by_subject}, - provider::{add_provider, lookup_provider, ProviderLookupError}, + provider::{add_provider, get_paginated_providers, lookup_provider, ProviderLookupError}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, SessionLookupError, diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index f6da34fa..931f7870 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -17,12 +17,16 @@ use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; use rand::Rng; -use sqlx::PgExecutor; +use sqlx::{PgExecutor, QueryBuilder}; use thiserror::Error; +use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; -use crate::{Clock, DatabaseInconsistencyError, LookupError}; +use crate::{ + pagination::{process_page, QueryBuilderExt}, + Clock, DatabaseInconsistencyError, LookupError, +}; #[derive(Debug, Error)] #[error("Failed to lookup upstream OAuth 2.0 provider")] @@ -37,6 +41,7 @@ impl LookupError for ProviderLookupError { } } +#[derive(sqlx::FromRow)] struct ProviderLookup { upstream_oauth_provider_id: Uuid, issuer: String, @@ -48,6 +53,37 @@ struct ProviderLookup { created_at: DateTime, } +impl TryFrom for UpstreamOAuthProvider { + type Error = DatabaseInconsistencyError; + fn try_from(value: ProviderLookup) -> Result { + let id = value.upstream_oauth_provider_id.into(); + let scope = value + .scope + .parse() + .map_err(|_| DatabaseInconsistencyError)?; + let token_endpoint_auth_method = value + .token_endpoint_auth_method + .parse() + .map_err(|_| DatabaseInconsistencyError)?; + let token_endpoint_signing_alg = value + .token_endpoint_signing_alg + .map(|x| x.parse()) + .transpose() + .map_err(|_| DatabaseInconsistencyError)?; + + Ok(UpstreamOAuthProvider { + id, + issuer: value.issuer, + scope, + client_id: value.client_id, + encrypted_client_secret: value.encrypted_client_secret, + token_endpoint_auth_method, + token_endpoint_signing_alg, + created_at: value.created_at, + }) + } +} + #[tracing::instrument( skip_all, fields(upstream_oauth_provider.id = %id), @@ -77,23 +113,7 @@ pub async fn lookup_provider( .fetch_one(executor) .await?; - Ok(UpstreamOAuthProvider { - id: res.upstream_oauth_provider_id.into(), - issuer: res.issuer, - scope: res.scope.parse().map_err(|_| DatabaseInconsistencyError)?, - client_id: res.client_id, - encrypted_client_secret: res.encrypted_client_secret, - token_endpoint_auth_method: res - .token_endpoint_auth_method - .parse() - .map_err(|_| DatabaseInconsistencyError)?, - token_endpoint_signing_alg: res - .token_endpoint_signing_alg - .map(|x| x.parse()) - .transpose() - .map_err(|_| DatabaseInconsistencyError)?, - created_at: res.created_at, - }) + Ok(res.try_into()?) } #[tracing::instrument( @@ -157,3 +177,45 @@ pub async fn add_provider( created_at, }) } + +#[tracing::instrument(skip_all, err(Display))] +pub async fn get_paginated_providers( + executor: impl PgExecutor<'_>, + before: Option, + after: Option, + first: Option, + last: Option, +) -> Result<(bool, bool, Vec), anyhow::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE 1 = 1 + "#, + ); + + query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; + + let span = info_span!( + "Fetch paginated upstream OAuth 2.0 providers", + db.statement = query.sql() + ); + let page: Vec = query + .build_query_as() + .fetch_all(executor) + .instrument(span) + .await?; + + let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; + + let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); + Ok((has_previous_page, has_next_page, page?)) +} diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 07ced535..651f5a3a 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -38,6 +38,7 @@ impl LookupError for SessionLookupError { struct SessionAndProviderLookup { upstream_oauth_authorization_session_id: Uuid, upstream_oauth_provider_id: Uuid, + upstream_oauth_link_id: Option, state: String, code_challenge_verifier: Option, nonce: String, @@ -70,6 +71,7 @@ pub async fn lookup_session( SELECT ua.upstream_oauth_authorization_session_id, ua.upstream_oauth_provider_id, + ua.upstream_oauth_link_id, ua.state, ua.code_challenge_verifier, ua.nonce, @@ -120,6 +122,8 @@ pub async fn lookup_session( let session = UpstreamOAuthAuthorizationSession { id: res.upstream_oauth_authorization_session_id.into(), + provider_id: provider.id, + link_id: res.upstream_oauth_link_id.map(Ulid::from), state: res.state, code_challenge_verifier: res.code_challenge_verifier, nonce: res.nonce, @@ -185,6 +189,8 @@ pub async fn add_session( Ok(UpstreamOAuthAuthorizationSession { id, + provider_id: upstream_oauth_provider.id, + link_id: None, state, code_challenge_verifier, nonce, @@ -267,6 +273,8 @@ pub async fn consume_session( struct SessionLookup { upstream_oauth_authorization_session_id: Uuid, + upstream_oauth_provider_id: Uuid, + upstream_oauth_link_id: Option, state: String, code_challenge_verifier: Option, nonce: String, @@ -295,6 +303,8 @@ pub async fn lookup_session_on_link( r#" SELECT upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + upstream_oauth_link_id, state, code_challenge_verifier, nonce, @@ -317,6 +327,8 @@ pub async fn lookup_session_on_link( Ok(UpstreamOAuthAuthorizationSession { id: res.upstream_oauth_authorization_session_id.into(), + provider_id: res.upstream_oauth_provider_id.into(), + link_id: res.upstream_oauth_link_id.map(Ulid::from), state: res.state, code_challenge_verifier: res.code_challenge_verifier, nonce: res.nonce, diff --git a/frontend/schema.graphql b/frontend/schema.graphql index 4ef30da8..8d21fc09 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -310,11 +310,95 @@ type RootQuery { """ userEmail(id: ID!): UserEmail """ + Fetch an upstream OAuth 2.0 link by its ID. + """ + upstreamOauth2Link(id: ID!): UpstreamOAuth2Link + """ + Fetch an upstream OAuth 2.0 provider by its ID. + """ + upstreamOauth2Provider(id: ID!): UpstreamOAuth2Provider + """ + Get a list of upstream OAuth 2.0 providers. + """ + upstreamOauth2Providers( + after: String + before: String + first: Int + last: Int + ): UpstreamOAuth2ProviderConnection! + """ Fetches an object given its ID. """ node(id: ID!): Node } +type UpstreamOAuth2Link implements Node { + """ + ID of the object. + """ + id: ID! + """ + When the object was created. + """ + createdAt: DateTime! + """ + The provider for which this link is. + """ + provider: UpstreamOAuth2Provider! + """ + The user to which this link is associated. + """ + user: User +} + +type UpstreamOAuth2Provider implements Node { + """ + ID of the object. + """ + id: ID! + """ + When the object was created. + """ + createdAt: DateTime! + """ + OpenID Connect issuer URL. + """ + issuer: String! + """ + Client ID used for this provider. + """ + clientId: String! +} + +type UpstreamOAuth2ProviderConnection { + """ + Information to aid in pagination. + """ + pageInfo: PageInfo! + """ + A list of edges. + """ + edges: [UpstreamOAuth2ProviderEdge!]! + """ + A list of nodes. + """ + nodes: [UpstreamOAuth2Provider!]! +} + +""" +An edge in a connection. +""" +type UpstreamOAuth2ProviderEdge { + """ + A cursor for use in pagination + """ + cursor: String! + """ + The item at the end of the edge + """ + node: UpstreamOAuth2Provider! +} + """ URL is a String implementing the [URL Standard](http://url.spec.whatwg.org/) """