diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 5472b78c..a92f35af 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -19,10 +19,11 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use mas_router::UrlBuilder; use mas_storage::{ oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, + upstream_oauth2::UpstreamOAuthProviderRepository, user::{ add_user_password, lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, }, - Clock, + Clock, Repository, }; use oauth2_types::scope::Scope; use rand::SeedableRng; @@ -329,18 +330,19 @@ impl Options { .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - let provider = mas_storage::upstream_oauth2::add_provider( - &mut conn, - &mut rng, - &clock, - issuer.clone(), - scope.clone(), - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id.clone(), - encrypted_client_secret, - ) - .await?; + let provider = conn + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + issuer.clone(), + scope.clone(), + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id.clone(), + encrypted_client_secret, + ) + .await?; let redirect_uri = url_builder.upstream_oauth_callback(provider.id); let auth_uri = url_builder.upstream_oauth_authorize(provider.id); diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index 9a86ecbe..8f3ef321 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -30,7 +30,9 @@ use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Description, EmptyMutation, EmptySubscription, ID, }; -use mas_storage::{Repository, UpstreamOAuthLinkRepository}; +use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, +}; use model::CreationEvent; use sqlx::PgPool; @@ -190,7 +192,7 @@ impl RootQuery { let database = ctx.data::()?; let mut conn = database.acquire().await?; - let provider = mas_storage::upstream_oauth2::lookup_provider(&mut conn, id).await?; + let provider = conn.upstream_oauth_provider().lookup(id).await?; Ok(provider.map(UpstreamOAuth2Provider::new)) } @@ -227,14 +229,13 @@ impl RootQuery { }) .transpose()?; - let (has_previous_page, has_next_page, edges) = - mas_storage::upstream_oauth2::get_paginated_providers( - &mut conn, before_id, after_id, first, last, - ) + let page = conn + .upstream_oauth_provider() + .list_paginated(before_id, after_id, first, last) .await?; - let mut connection = Connection::new(has_previous_page, has_next_page); - connection.edges.extend(edges.into_iter().map(|p| { + let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + connection.edges.extend(page.edges.into_iter().map(|p| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), UpstreamOAuth2Provider::new(p), diff --git a/crates/graphql/src/model/upstream_oauth.rs b/crates/graphql/src/model/upstream_oauth.rs index 87164dd4..2de6f2f7 100644 --- a/crates/graphql/src/model/upstream_oauth.rs +++ b/crates/graphql/src/model/upstream_oauth.rs @@ -15,6 +15,7 @@ use anyhow::Context as _; use async_graphql::{Context, Object, ID}; use chrono::{DateTime, Utc}; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; use sqlx::PgPool; use super::{NodeType, User}; @@ -101,7 +102,8 @@ impl UpstreamOAuth2Link { // Fetch on-the-fly let database = ctx.data::()?; let mut conn = database.acquire().await?; - mas_storage::upstream_oauth2::lookup_provider(&mut conn, self.link.provider_id) + conn.upstream_oauth_provider() + .lookup(self.link.provider_id) .await? .context("Upstream OAuth 2.0 provider not found")? }; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 78712451..5e69f416 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -22,7 +22,7 @@ use mas_axum_utils::http_client_factory::HttpClientFactory; use mas_keystore::Encrypter; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; -use mas_storage::upstream_oauth2::lookup_provider; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository}; use sqlx::PgPool; use thiserror::Error; use ulid::Ulid; @@ -66,7 +66,9 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; - let provider = lookup_provider(&mut txn, provider_id) + let provider = txn + .upstream_oauth_provider() + .lookup(provider_id) .await? .ok_or(RouteError::ProviderNotFound)?; diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 24fc17b7..fd54175d 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -24,11 +24,12 @@ use mas_axum_utils::{ use mas_data_model::BrowserSession; use mas_keystore::Encrypter; use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, user::{ add_user_password, authenticate_session_with_password, lookup_user_by_username, lookup_user_password, start_session, }, - Clock, + Clock, Repository, }; use mas_templates::{ FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, @@ -69,7 +70,7 @@ pub(crate) async fn get( let reply = query.go_next(); Ok((cookie_jar, reply).into_response()) } else { - let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; + let providers = conn.upstream_oauth_provider().all().await?; let content = render( LoginContext::default().with_upstrem_providers(providers), query, @@ -114,7 +115,7 @@ pub(crate) async fn post( }; if !state.is_valid() { - let providers = mas_storage::upstream_oauth2::get_providers(&mut conn).await?; + let providers = conn.upstream_oauth_provider().all().await?; let content = render( LoginContext::default() .with_form_state(state) diff --git a/crates/handlers/src/views/shared.rs b/crates/handlers/src/views/shared.rs index d4b19002..6035c74d 100644 --- a/crates/handlers/src/views/shared.rs +++ b/crates/handlers/src/views/shared.rs @@ -15,8 +15,8 @@ use anyhow::Context; use mas_router::{PostAuthAction, Route}; use mas_storage::{ - compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, Repository, - UpstreamOAuthLinkRepository, + compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id, + upstream_oauth2::UpstreamOAuthProviderRepository, Repository, UpstreamOAuthLinkRepository, }; use mas_templates::{PostAuthContext, PostAuthContextInner}; use serde::{Deserialize, Serialize}; @@ -70,10 +70,11 @@ impl OptionalPostAuthAction { .await? .context("Failed to load upstream OAuth 2.0 link")?; - let provider = - mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id) - .await? - .context("Failed to load upstream OAuth 2.0 provider")?; + let provider = conn + .upstream_oauth_provider() + .lookup(link.provider_id) + .await? + .context("Failed to load upstream OAuth 2.0 provider")?; let provider = Box::new(provider); let link = Box::new(link); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 63368ec0..52b0118f 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -116,68 +116,6 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " }, - "0af182315b36766eca8e232280986bade0202d1b1d64160a99cd14eadcbfc25b": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_provider_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "issuer", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scope", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "client_id", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "encrypted_client_secret", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "token_endpoint_signing_alg", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " - }, "0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": { "describe": { "columns": [ @@ -241,6 +179,66 @@ }, "query": "\n UPDATE user_emails\n SET confirmed_at = $2\n WHERE user_email_id = $1\n " }, + "154e2e4488ff87e09163698750b56a43127cee4e1392785416a586d40a4d9b21": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " + }, "1eb6d13e75d8f526c2785749a020731c18012f03e07995213acd38ab560ce497": { "describe": { "columns": [], @@ -2089,6 +2087,68 @@ }, "query": "\n SELECT COUNT(*)\n FROM user_emails ue\n WHERE ue.user_id = $1\n " }, + "8f7a9fb1f24c24f8dbc3c193df2a742c9ac730ab958587b67297de2d4b843863": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " + }, "99f5f9eb0adc5ec120ed8194cbf6a8545155bef09e6d94d92fb67fd1b14d4f28": { "describe": { "columns": [], @@ -2586,66 +2646,6 @@ }, "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " }, - "cf00e0ad529bcb5c0640adcfe0880a3560d9739f355b90ca3ba88dd1eaf26565": { - "describe": { - "columns": [ - { - "name": "upstream_oauth_provider_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "issuer", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "scope", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "client_id", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "encrypted_client_secret", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "token_endpoint_signing_alg", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "created_at", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - true, - true, - false, - false - ], - "parameters": { - "Left": [] - } - }, - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n " - }, "d1738c27339b81f0844da4bd9b040b9b07a91aa4d9b199b98f24c9cee5709b2b": { "describe": { "columns": [], diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index 0bfc2521..c1d259fc 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -14,28 +14,43 @@ use sqlx::{PgConnection, Postgres, Transaction}; -use crate::upstream_oauth2::PgUpstreamOAuthLinkRepository; +use crate::upstream_oauth2::{PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository}; pub trait Repository { type UpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; } impl Repository for PgConnection { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) } + + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { + PgUpstreamOAuthProviderRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { type UpstreamOAuthLinkRepository<'c> = PgUpstreamOAuthLinkRepository<'c> where Self: 'c; + type UpstreamOAuthProviderRepository<'c> = PgUpstreamOAuthProviderRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) } + + fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_> { + PgUpstreamOAuthProviderRepository::new(self) + } } diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 3849af3c..100e9833 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -56,7 +56,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { user: &User, ) -> Result<(), Self::Error>; - /// Get a paginated list of upstream OAuth links + /// Get a paginated list of upstream OAuth links on a user async fn list_paginated( &mut self, user: &User, diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 4842fb47..d29b5e71 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -18,7 +18,7 @@ mod session; pub use self::{ link::{PgUpstreamOAuthLinkRepository, UpstreamOAuthLinkRepository}, - provider::{add_provider, get_paginated_providers, get_providers, lookup_provider}, + provider::{PgUpstreamOAuthProviderRepository, UpstreamOAuthProviderRepository}, session::{ add_session, complete_session, consume_session, lookup_session, lookup_session_on_link, }, diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 360b9a4a..3d8ba141 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -12,21 +12,66 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::UpstreamOAuthProvider; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use oauth2_types::scope::Scope; -use rand::Rng; -use sqlx::{PgExecutor, QueryBuilder}; +use rand::RngCore; +use sqlx::{PgConnection, QueryBuilder}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::{process_page, QueryBuilderExt}, + pagination::{process_page, Page, QueryBuilderExt}, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, }; +#[async_trait] +pub trait UpstreamOAuthProviderRepository: Send + Sync { + type Error; + + /// Lookup an upstream OAuth provider by its ID + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + /// Add a new upstream OAuth provider + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result; + + /// Get a paginated list of upstream OAuth providers + async fn list_paginated( + &mut self, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error>; + + /// Get all upstream OAuth providers + async fn all(&mut self) -> Result, Self::Error>; +} + +pub struct PgUpstreamOAuthProviderRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgUpstreamOAuthProviderRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} + #[derive(sqlx::FromRow)] struct ProviderLookup { upstream_oauth_provider_id: Uuid, @@ -79,71 +124,72 @@ impl TryFrom for UpstreamOAuthProvider { } } -#[tracing::instrument( - skip_all, - fields(upstream_oauth_provider.id = %id), - err, -)] -pub async fn lookup_provider( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; +#[async_trait] +impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'c> { + type Error = DatabaseError; - let res = res - .map(UpstreamOAuthProvider::try_from) - .transpose() - .map_err(DatabaseError::from)?; + #[tracing::instrument( + skip_all, + fields(upstream_oauth_provider.id = %id), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(&mut *self.conn) + .await + .to_option()?; - Ok(res) -} + let res = res + .map(UpstreamOAuthProvider::try_from) + .transpose() + .map_err(DatabaseError::from)?; -#[tracing::instrument( - skip_all, - fields( - upstream_oauth_provider.id, - upstream_oauth_provider.issuer = %issuer, - upstream_oauth_provider.client_id = %client_id, - ), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn add_provider( - executor: impl PgExecutor<'_>, - mut rng: impl Rng + Send, - clock: &Clock, - issuer: String, - scope: Scope, - token_endpoint_auth_method: OAuthClientAuthenticationMethod, - token_endpoint_signing_alg: Option, - client_id: String, - encrypted_client_secret: Option, -) -> Result { - let created_at = clock.now(); - let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); - tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + Ok(res) + } - sqlx::query!( - r#" + #[tracing::instrument( + skip_all, + fields( + upstream_oauth_provider.id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, + )] + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + ) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), rng); + tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + + sqlx::query!( + r#" INSERT INTO upstream_oauth_providers ( upstream_oauth_provider_id, issuer, @@ -155,94 +201,95 @@ pub async fn add_provider( created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) "#, - Uuid::from(id), - &issuer, - scope.to_string(), - token_endpoint_auth_method.to_string(), - token_endpoint_signing_alg.as_ref().map(ToString::to_string), - &client_id, - encrypted_client_secret.as_deref(), - created_at, - ) - .execute(executor) - .await?; - - Ok(UpstreamOAuthProvider { - id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at, - }) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_paginated_providers( - executor: impl PgExecutor<'_>, - before: Option, - after: Option, - first: Option, - last: Option, -) -> Result<(bool, bool, Vec), DatabaseError> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - WHERE 1 = 1 - "#, - ); - - query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; - - let span = info_span!( - "Fetch paginated upstream OAuth 2.0 providers", - db.statement = query.sql() - ); - let page: Vec = query - .build_query_as() - .fetch_all(executor) - .instrument(span) + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + ) + .execute(&mut *self.conn) .await?; - let (has_previous_page, has_next_page, page) = process_page(page, first, last)?; + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + }) + } - let page: Result, _> = page.into_iter().map(TryInto::try_into).collect(); - Ok((has_previous_page, has_next_page, page?)) -} - -#[tracing::instrument(skip_all, err)] -pub async fn get_providers( - executor: impl PgExecutor<'_>, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - ProviderLookup, - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at - FROM upstream_oauth_providers - "#, - ) - .fetch_all(executor) - .await?; - - let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); - Ok(res?) + async fn list_paginated( + &mut self, + before: Option, + after: Option, + first: Option, + last: Option, + ) -> Result, Self::Error> { + let mut query = QueryBuilder::new( + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE 1 = 1 + "#, + ); + + query.generate_pagination("upstream_oauth_provider_id", before, after, first, last)?; + + let span = info_span!( + "Fetch paginated upstream OAuth 2.0 providers", + db.statement = query.sql() + ); + let page: Vec = query + .build_query_as() + .fetch_all(&mut *self.conn) + .instrument(span) + .await?; + + let (has_previous_page, has_next_page, edges) = process_page(page, first, last)?; + + let edges: Result, _> = edges.into_iter().map(TryInto::try_into).collect(); + Ok(Page { + has_next_page, + has_previous_page, + edges: edges?, + }) + } + #[tracing::instrument(skip_all, err)] + async fn all(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + "#, + ) + .fetch_all(&mut *self.conn) + .await?; + + let res: Result, _> = res.into_iter().map(TryInto::try_into).collect(); + Ok(res?) + } }