From 5969b574e25bf3921bcab78c2caffae6de13f884 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 30 Dec 2022 10:16:22 +0100 Subject: [PATCH] WIP: repository pattern for upstream oauth2 links --- Cargo.lock | 4 +- Cargo.toml | 6 + crates/graphql/src/lib.rs | 3 +- crates/graphql/src/model/users.rs | 12 +- .../handlers/src/upstream_oauth2/callback.rs | 14 +- crates/handlers/src/upstream_oauth2/link.rs | 22 +- crates/handlers/src/views/shared.rs | 7 +- crates/storage/Cargo.toml | 1 + crates/storage/sqlx-data.json | 148 +++--- crates/storage/src/lib.rs | 3 + crates/storage/src/pagination.rs | 8 + crates/storage/src/repository.rs | 41 ++ crates/storage/src/upstream_oauth2/link.rs | 432 ++++++++++-------- crates/storage/src/upstream_oauth2/mod.rs | 5 +- 14 files changed, 419 insertions(+), 287 deletions(-) create mode 100644 crates/storage/src/repository.rs diff --git a/Cargo.lock b/Cargo.lock index 059e91c0..bc0b37ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3097,6 +3097,7 @@ dependencies = [ name = "mas-storage" version = "0.1.0" dependencies = [ + "async-trait", "chrono", "mas-data-model", "mas-iana", @@ -5575,8 +5576,7 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "ulid" version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13a3aaa69b04e5b66cc27309710a569ea23593612387d67daaf102e73aa974fd" +source = "git+https://github.com/sandhose/ulid-rs.git?rev=f1ef6fd736c4d3cbc7cf314fad707f0803de46ed#f1ef6fd736c4d3cbc7cf314fad707f0803de46ed" dependencies = [ "rand 0.8.5", "serde", diff --git a/Cargo.toml b/Cargo.toml index f621be0c..9799f34b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,9 @@ opt-level = 3 [profile.dev.package.sqlx-macros] opt-level = 3 + +# Until https://github.com/dylanhart/ulid-rs/pull/56 gets merged and released +[patch.crates-io.ulid] +git = "https://github.com/sandhose/ulid-rs.git" +#branch = "relax-sized-on-rng" +rev = "f1ef6fd736c4d3cbc7cf314fad707f0803de46ed" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 1e691a96..9a86ecbe 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -30,6 +30,7 @@ use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Description, EmptyMutation, EmptySubscription, ID, }; +use mas_storage::{Repository, UpstreamOAuthLinkRepository}; use model::CreationEvent; use sqlx::PgPool; @@ -171,7 +172,7 @@ impl RootQuery { let Some(session) = session else { return Ok(None) }; let current_user = session.user; - let link = mas_storage::upstream_oauth2::lookup_link(&mut conn, id).await?; + let link = conn.upstream_oauth_link().lookup(id).await?; // Ensure that the link belongs to the current user let link = link.filter(|link| link.user_id == Some(current_user.id)); diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index ad8bfa43..01fcfb0e 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -17,6 +17,7 @@ use async_graphql::{ Context, Description, Object, ID, }; use chrono::{DateTime, Utc}; +use mas_storage::{Repository, UpstreamOAuthLinkRepository}; use sqlx::PgPool; use super::{ @@ -285,14 +286,13 @@ impl User { }) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::upstream_oauth2::get_paginated_user_links( - &mut conn, &self.0, before_id, after_id, first, last, - ) + let page = conn + .upstream_oauth_link() + .list_paginated(&self.0, before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|s| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)), UpstreamOAuth2Link::new(s), diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index ab31641c..6158f941 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -25,8 +25,9 @@ use mas_oidc_client::requests::{ authorization_code::AuthorizationValidationData, jose::JwtVerificationData, }; use mas_router::{Route, UrlBuilder}; -use mas_storage::upstream_oauth2::{ - add_link, complete_session, lookup_link_by_subject, lookup_session, +use mas_storage::{ + upstream_oauth2::{complete_session, lookup_session}, + Repository, UpstreamOAuthLinkRepository, }; use oauth2_types::errors::ClientErrorCode; use serde::Deserialize; @@ -231,12 +232,17 @@ pub(crate) async fn get( let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?; // Look for an existing link - let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject).await?; + let maybe_link = txn + .upstream_oauth_link() + .find_by_subject(&provider, &subject) + .await?; let link = if let Some(link) = maybe_link { link } else { - add_link(&mut txn, &mut rng, &clock, &provider, subject).await? + txn.upstream_oauth_link() + .add(&mut rng, &clock, &provider, subject) + .await? }; let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?; diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 15c5ac93..4a109ba6 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -25,10 +25,9 @@ use mas_axum_utils::{ }; use mas_keystore::Encrypter; use mas_storage::{ - upstream_oauth2::{ - associate_link_to_user, consume_session, lookup_link, lookup_session_on_link, - }, + upstream_oauth2::{consume_session, lookup_session_on_link}, user::{add_user, authenticate_session_with_upstream, lookup_user, start_session}, + Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{ EmptyContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, @@ -104,7 +103,9 @@ pub(crate) async fn get( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let link = lookup_link(&mut txn, link_id) + let link = txn + .upstream_oauth_link() + .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; @@ -205,7 +206,9 @@ pub(crate) async fn post( post_auth_action: post_auth_action.cloned(), }; - let link = lookup_link(&mut txn, link_id) + let link = txn + .upstream_oauth_link() + .lookup(link_id) .await? .ok_or(RouteError::LinkNotFound)?; @@ -224,7 +227,10 @@ pub(crate) async fn post( let mut session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { - associate_link_to_user(&mut txn, &link, &session.user).await?; + txn.upstream_oauth_link() + .associate_to_user(&link, &session.user) + .await?; + session } @@ -235,7 +241,9 @@ pub(crate) async fn post( (None, None, FormData::Register { username }) => { let user = add_user(&mut txn, &mut rng, &clock, &username).await?; - associate_link_to_user(&mut txn, &link, &user).await?; + txn.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; start_session(&mut txn, &mut rng, &clock, user).await? } diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index fcdef3b4..d4b19002 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,7 +15,8 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, Repository, + UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; @@ -63,7 +64,9 @@ impl OptionalPostAuthAction { PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword, PostAuthAction::LinkUpstream { id } => { - let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id) + let link = conn + .upstream_oauth_link() + .lookup(id) .await? .context("Failed to load upstream OAuth 2.0 link")?; diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index b0ed4c5e..71240129 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +async-trait = "0.1.60" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] } chrono = { version = "0.4.23", features = ["serde"] } serde = { version = "1.0.152", features = ["derive"] } diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 1ce99d79..63368ec0 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -241,19 +241,6 @@ }, "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " }, - "1e7b1b7e06b5d97d81dc4a8524bb223c3dc7ddbbcce7cc2a142dbfbdd6a2902e": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " - }, "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { "describe": { "columns": [], @@ -882,6 +869,50 @@ }, "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "4187907bfc770b2c76f741671d5e672f5c35eed7c9a9e57ff52888b1768a5ed6": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_link_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "subject", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " + }, "42bfb0de5bbea2d580f1ff2322255731a4a5655ba80fc2dba0b55a0add8c55c0": { "describe": { "columns": [ @@ -1043,50 +1074,6 @@ }, "query": "\n UPDATE oauth2_authorization_grants AS og\n SET\n oauth2_session_id = os.oauth2_session_id,\n fulfilled_at = os.created_at\n FROM oauth2_sessions os\n WHERE\n og.oauth2_authorization_grant_id = $1\n AND os.oauth2_session_id = $2\n RETURNING fulfilled_at AS \"fulfilled_at!: DateTime\"\n " }, - "47d4048365144c7bfc14790dfb8fa7f862d2952075a68cd5e90ac76d9e6d1388": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_link_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "upstream_oauth_provider_id", - "ordinal": 1, - "type_info": "Uuid" - }, - { - "name": "user_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "subject", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 4, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_link_id = $1\n " - }, "4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": { "describe": { "columns": [ @@ -1345,6 +1332,21 @@ }, "query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n\n ORDER BY ue.email ASC\n " }, + "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " + }, "60d039442cfa57e187602c0ff5e386e32fb774b5ad2d2f2c616040819b76873e": { "describe": { "columns": [], @@ -1837,6 +1839,19 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = ANY($1::uuid[])\n " }, + "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " + }, "7cf5ae665b15ba78b01bb1dfa304150a89fd7203f4ee15b0753cb2143049a3dc": { "describe": { "columns": [ @@ -2687,21 +2702,6 @@ }, "query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n " }, - "e1dc9dd2bf26a341050a53151bf51f7638448ccc2bd458bbdfe87cc22f086313": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Timestamptz" - ] - } - }, - "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " - }, "e30562e9637d3a723a91adca6336a8d083657ce6d7fe9551fcd6a9d672453d3c": { "describe": { "columns": [], @@ -2729,7 +2729,7 @@ }, "query": "\n INSERT INTO user_sessions (user_session_id, user_id, created_at)\n VALUES ($1, $2, $3)\n " }, - "f71cb5761bfc15d8bc3ba7ee49b63fb3c3ea9691745688eb5fd91f4f6e1ec018": { + "e6dc63984aced9e19c20e90e9cd75d6f6d7ade64f782697715ac4da077b2e1fc": { "describe": { "columns": [ { @@ -2772,7 +2772,7 @@ ] } }, - "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " + "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, "fb71ac6539039313fd90b29ac943330e54c7b62b2778727726e2f60a554f9c5a": { "describe": { diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f059e376..26865201 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -178,8 +178,11 @@ impl Clock { pub mod compat; pub mod oauth2; pub(crate) mod pagination; +pub(crate) mod repository; pub mod upstream_oauth2; pub mod user; +pub use self::{repository::Repository, upstream_oauth2::UpstreamOAuthLinkRepository}; + /// Embedded migrations, allowing them to run on startup pub static MIGRATOR: Migrator = sqlx::migrate!(); diff --git a/crates/storage/src/pagination.rs b/crates/storage/src/pagination.rs index 95655675..a240c554 100644 --- a/crates/storage/src/pagination.rs +++ b/crates/storage/src/pagination.rs @@ -111,6 +111,14 @@ pub fn process_page( Ok((has_previous_page, has_next_page, page)) } +pub struct Page { + pub has_next_page: bool, + pub has_previous_page: bool, + pub edges: Vec, +} + +impl Page {} + pub trait QueryBuilderExt { fn generate_pagination( &mut self, diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs new file mode 100644 index 00000000..0bfc2521 --- /dev/null +++ b/crates/storage/src/repository.rs @@ -0,0 +1,41 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use sqlx::{PgConnection, Postgres, Transaction}; + +use crate::upstream_oauth2::PgUpstreamOAuthLinkRepository; + +pub trait Repository { + type UpstreamOAuthLinkRepository<'c> + where + Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; +} + +impl Repository for PgConnection { + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { + PgUpstreamOAuthLinkRepository::new(self) + } +} + +impl<'t> Repository for Transaction<'t, Postgres> { + type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { + PgUpstreamOAuthLinkRepository::new(self) + } +} diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 931b2b7d..3849af3c 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -12,19 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, QueryBuilderExt}, + pagination::{process_page, Page, QueryBuilderExt}, Clock, DatabaseError, LookupResultExt, }; +#[async_trait] +pub trait UpstreamOAuthLinkRepository: Send + Sync { + type Error; + + /// Lookup an upstream OAuth link by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Find an upstream OAuth link for a provider by its subject + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error>; + + /// Add a new upstream OAuth link + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result; + + /// Associate an upstream OAuth link to a user + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error>; + + /// Get a paginated list of upstream OAuth links + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; +} + +pub struct PgUpstreamOAuthLinkRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthLinkRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + #[derive(sqlx::FromRow)] struct LinkLookup { upstream_oauth_link_id: Uuid, @@ -46,197 +98,203 @@ impl From for UpstreamOAuthLink { } } -#[tracing::instrument( - skip_all, - fields(upstream_oauth_link.id = %id), - err, -)] -pub async fn lookup_link( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_link_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()? - .map(Into::into); +#[async_trait] +impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { + type Error = DatabaseError; - Ok(res) -} + #[tracing::instrument( + skip_all, + fields(upstream_oauth_link.id = %id), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_link_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, -)] -pub async fn lookup_link_by_subject( - executor: impl PgExecutor<'_>, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: &str, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - LinkLookup, - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - AND subject = $2 - "#, - Uuid::from(upstream_oauth_provider.id), - subject, - ) - .fetch_one(executor) - .await - .to_option()? - .map(Into::into); + Ok(res) + } - Ok(res) -} + #[tracing::instrument( + skip_all, + fields( + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn find_by_subject( + &mut self, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: &str, + ) -> Result, Self::Error> { + let res = sqlx::query_as!( + LinkLookup, + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + AND subject = $2 + "#, + Uuid::from(upstream_oauth_provider.id), + subject, + ) + .fetch_one(&mut *self.conn) + .await + .to_option()? + .map(Into::into); -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_link.id, - upstream_oauth_link.subject = subject, - %upstream_oauth_provider.id, - %upstream_oauth_provider.issuer, - %upstream_oauth_provider.client_id, - ), - err, -)] -pub async fn add_link( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - upstream_oauth_provider: &UpstreamOAuthProvider, - subject: String, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); + Ok(res) + } - sqlx::query!( - r#" - INSERT INTO upstream_oauth_links ( - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - ) VALUES ($1, $2, NULL, $3, $4) - "#, - Uuid::from(id), - Uuid::from(upstream_oauth_provider.id), - &subject, - created_at, - ) - .execute(executor) - .await?; + #[tracing::instrument( + skip_all, + fields( + upstream_oauth_link.id, + upstream_oauth_link.subject = subject, + %upstream_oauth_provider.id, + %upstream_oauth_provider.issuer, + %upstream_oauth_provider.client_id, + ), + err, + )] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + upstream_oauth_provider: &UpstreamOAuthProvider, + subject: String, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id)); - Ok(UpstreamOAuthLink { - id, - provider_id: upstream_oauth_provider.id, - user_id: None, - subject, - created_at, - }) -} - -#[tracing::instrument( - skip_all, - fields( - %upstream_oauth_link.id, - %upstream_oauth_link.subject, - %user.id, - %user.username, - ), - err, -)] -pub async fn associate_link_to_user( - executor: impl PgExecutor<'_>, - upstream_oauth_link: &UpstreamOAuthLink, - user: &User, -) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - UPDATE upstream_oauth_links - SET user_id = $1 - WHERE upstream_oauth_link_id = $2 - "#, - Uuid::from(user.id), - Uuid::from(upstream_oauth_link.id), - ) - .execute(executor) - .await?; - - Ok(()) -} - -#[tracing::instrument( - skip_all, - fields(%user.id, %user.username), - err -)] -pub async fn get_paginated_user_links( - executor: impl PgExecutor<'_>, - user: &User, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - "#, - ); - - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 user links", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) + sqlx::query!( + r#" + INSERT INTO upstream_oauth_links ( + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + ) VALUES ($1, $2, NULL, $3, $4) + "#, + Uuid::from(id), + Uuid::from(upstream_oauth_provider.id), + &subject, + created_at, + ) + .execute(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; + Ok(UpstreamOAuthLink { + id, + provider_id: upstream_oauth_provider.id, + user_id: None, + subject, + created_at, + }) + } - let page: Vec<_> = page.into_iter().map(Into::into).collect(); - Ok((has_previous_page, has_next_page, page)) + #[tracing::instrument( + skip_all, + fields( + %upstream_oauth_link.id, + %upstream_oauth_link.subject, + %user.id, + %user.username, + ), + err, + )] + async fn associate_to_user( + &mut self, + upstream_oauth_link: &UpstreamOAuthLink, + user: &User, + ) -> Result<(), Self::Error> { + sqlx::query!( + r#" + UPDATE upstream_oauth_links + SET user_id = $1 + WHERE upstream_oauth_link_id = $2 + "#, + Uuid::from(user.id), + Uuid::from(upstream_oauth_link.id), + ) + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + + #[tracing::instrument( + skip_all, + fields(%user.id, %user.username), + err + )] + async fn list_paginated( + &mut self, + user: &User, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_link_id, + upstream_oauth_provider_id, + user_id, + subject, + created_at + FROM upstream_oauth_links + "#, + ); + + query + .push(" WHERE user_id = ") + .push_bind(Uuid::from(user.id)) + .generate_pagination("upstream_oauth_link_id", before, after, first, last)?; + + let span = info_span!( + "Fetch paginated upstream OAuth 2.0 user links", + db.statement = query.sql() + ); + let page: Vec = query + .build_query_as() + .fetch_all(&mut *self.conn) + .instrument(span) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Vec<_> = edges.into_iter().map(Into::into).collect(); + Ok(Page { + has_next_page, + has_previous_page, + edges, + }) + } } diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 4b1d517a..4842fb47 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -17,10 +17,7 @@ mod provider; mod session; pub use self::{ - link::{ - add_link, associate_link_to_user, get_paginated_user_links, lookup_link, - lookup_link_by_subject, - }, + link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link,