diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index 00baf4a5..0e6771b4 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -31,10 +31,10 @@ use mas_http::HttpServiceExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; -use mas_storage::{oauth2::client::lookup_client_by_client_id, DatabaseError}; +use mas_storage::{oauth2::client::OAuth2ClientRepository, DatabaseError, Repository}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; -use sqlx::PgExecutor; +use sqlx::PgConnection; use thiserror::Error; use tower::{Service, ServiceExt}; @@ -73,10 +73,7 @@ pub enum Credentials { } impl Credentials { - pub async fn fetch( - &self, - executor: impl PgExecutor<'_>, - ) -> Result, DatabaseError> { + pub async fn fetch(&self, conn: &mut PgConnection) -> Result, DatabaseError> { let client_id = match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } @@ -84,7 +81,7 @@ impl Credentials { | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, }; - lookup_client_by_client_id(executor, client_id).await + conn.oauth2_client().find_by_client_id(client_id).await } #[tracing::instrument(skip_all, err)] diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index b6c0e465..9e2b3988 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -18,7 +18,7 @@ use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_router::UrlBuilder; use mas_storage::{ - oauth2::client::{insert_client_from_config, lookup_client}, + oauth2::client::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, Repository, @@ -254,7 +254,7 @@ impl Options { for client in config.clients.iter() { let client_id = client.client_id; - let existing = lookup_client(&mut txn, client_id).await?.is_some(); + let existing = txn.oauth2_client().lookup(client_id).await?.is_some(); if !update && existing { warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); continue; @@ -270,25 +270,24 @@ impl Options { let client_auth_method = client.client_auth_method(); let jwks = client.jwks(); let jwks_uri = client.jwks_uri(); - let redirect_uris = &client.redirect_uris; // TODO: should be moved somewhere else let encrypted_client_secret = client_secret .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) .transpose()?; - insert_client_from_config( - &mut txn, - &mut rng, - &clock, - client_id, - client_auth_method, - encrypted_client_secret.as_deref(), - jwks, - jwks_uri, - redirect_uris, - ) - .await?; + txn.oauth2_client() + .add_from_config( + &mut rng, + &clock, + client_id, + client_auth_method, + encrypted_client_secret, + jwks.cloned(), + jwks_uri.cloned(), + client.redirect_uris.clone(), + ) + .await?; } txn.commit().await?; diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index b79b6fe9..ffa63396 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -31,6 +31,7 @@ use async_graphql::{ Context, Description, EmptyMutation, EmptySubscription, ID, }; use mas_storage::{ + oauth2::client::OAuth2ClientRepository, upstream_oauth2::UpstreamOAuthProviderRepository, user::{BrowserSessionRepository, UserEmailRepository}, Repository, UpstreamOAuthLinkRepository, @@ -95,7 +96,7 @@ impl RootQuery { let database = ctx.data::()?; let mut conn = database.acquire().await?; - let client = mas_storage::oauth2::client::lookup_client(&mut conn, id).await?; + let client = conn.oauth2_client().lookup(id).await?; Ok(client.map(OAuth2Client)) } diff --git a/crates/graphql/src/model/oauth.rs b/crates/graphql/src/model/oauth.rs index 89598ffa..5f1236f2 100644 --- a/crates/graphql/src/model/oauth.rs +++ b/crates/graphql/src/model/oauth.rs @@ -14,7 +14,7 @@ use anyhow::Context as _; use async_graphql::{Context, Description, Object, ID}; -use mas_storage::oauth2::client::lookup_client; +use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; use oauth2_types::scope::Scope; use sqlx::PgPool; use ulid::Ulid; @@ -115,7 +115,9 @@ impl OAuth2Consent { /// OAuth 2.0 client for which the user granted access. pub async fn client(&self, ctx: &Context<'_>) -> Result { let mut conn = ctx.data::()?.acquire().await?; - let client = lookup_client(&mut conn, self.client_id) + let client = conn + .oauth2_client() + .lookup(self.client_id) .await? .context("Could not load client")?; Ok(OAuth2Client(client)) diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 1b999ffc..36d15d2b 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -25,8 +25,9 @@ use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::Encrypter; use mas_policy::PolicyFactory; use mas_router::{PostAuthAction, Route}; -use mas_storage::oauth2::{ - authorization_grant::new_authorization_grant, client::lookup_client_by_client_id, +use mas_storage::{ + oauth2::{authorization_grant::new_authorization_grant, client::OAuth2ClientRepository}, + Repository, }; use mas_templates::Templates; use oauth2_types::{ @@ -141,7 +142,9 @@ pub(crate) async fn get( let mut txn = pool.begin().await?; // First, figure out what client it is - let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id) + let client = txn + .oauth2_client() + .find_by_client_id(¶ms.auth.client_id) .await? .ok_or(RouteError::ClientNotFound)?; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 25b734cf..b12194eb 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -19,7 +19,7 @@ use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; use mas_policy::{PolicyFactory, Violation}; -use mas_storage::oauth2::client::insert_client; +use mas_storage::{oauth2::client::OAuth2ClientRepository, Repository}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, registration::{ @@ -30,7 +30,6 @@ use rand::distributions::{Alphanumeric, DistString}; use sqlx::PgPool; use thiserror::Error; use tracing::info; -use ulid::Ulid; use crate::impl_from_error_for_route; @@ -50,6 +49,7 @@ pub(crate) enum RouteError { } impl_from_error_for_route!(sqlx::Error); +impl_from_error_for_route!(mas_storage::DatabaseError); impl_from_error_for_route!(mas_policy::LoadError); impl_from_error_for_route!(mas_policy::InstanciateError); impl_from_error_for_route!(mas_policy::EvaluationError); @@ -124,16 +124,9 @@ pub(crate) async fn post( return Err(RouteError::PolicyDenied(res.violations)); } - // Contacts was checked by the policy - let contacts = metadata.contacts.as_deref().unwrap_or_default(); - // Grab a txn let mut txn = pool.begin().await?; - let now = clock.now(); - // Let's generate a random client ID - let client_id = Ulid::from_datetime_with_source(now.into(), &mut rng); - let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method { Some( OAuthClientAuthenticationMethod::ClientSecretJwt @@ -148,41 +141,42 @@ pub(crate) async fn post( _ => (None, None), }; - insert_client( - &mut txn, - &mut rng, - &clock, - client_id, - metadata.redirect_uris(), - encrypted_client_secret.as_deref(), - //&metadata.response_types(), - metadata.grant_types(), - contacts, - metadata - .client_name - .as_ref() - .map(|l| l.non_localized().as_ref()), - metadata.logo_uri.as_ref().map(Localized::non_localized), - metadata.client_uri.as_ref().map(Localized::non_localized), - metadata.policy_uri.as_ref().map(Localized::non_localized), - metadata.tos_uri.as_ref().map(Localized::non_localized), - metadata.jwks_uri.as_ref(), - metadata.jwks.as_ref(), - // XXX: those might not be right, should be function calls - metadata.id_token_signed_response_alg.as_ref(), - metadata.userinfo_signed_response_alg.as_ref(), - metadata.token_endpoint_auth_method.as_ref(), - metadata.token_endpoint_auth_signing_alg.as_ref(), - metadata.initiate_login_uri.as_ref(), - ) - .await?; + let client = txn + .oauth2_client() + .add( + &mut rng, + &clock, + metadata.redirect_uris().to_vec(), + encrypted_client_secret, + //&metadata.response_types(), + metadata.grant_types().to_vec(), + metadata.contacts.clone().unwrap_or_default(), + metadata + .client_name + .clone() + .map(Localized::to_non_localized), + metadata.logo_uri.clone().map(Localized::to_non_localized), + metadata.client_uri.clone().map(Localized::to_non_localized), + metadata.policy_uri.clone().map(Localized::to_non_localized), + metadata.tos_uri.clone().map(Localized::to_non_localized), + metadata.jwks_uri.clone(), + metadata.jwks.clone(), + // XXX: those might not be right, should be function calls + metadata.id_token_signed_response_alg.clone(), + metadata.userinfo_signed_response_alg.clone(), + metadata.token_endpoint_auth_method.clone(), + metadata.token_endpoint_auth_signing_alg.clone(), + metadata.initiate_login_uri.clone(), + ) + .await?; txn.commit().await?; let response = ClientRegistrationResponse { - client_id: client_id.to_string(), + client_id: client.client_id, client_secret, - client_id_issued_at: Some(now), + // XXX: we should have a `created_at` field on the clients + client_id_issued_at: Some(client.id.datetime().into()), client_secret_expires_at: None, }; diff --git a/crates/oauth2-types/src/registration/mod.rs b/crates/oauth2-types/src/registration/mod.rs index 18aa24fa..0d958996 100644 --- a/crates/oauth2-types/src/registration/mod.rs +++ b/crates/oauth2-types/src/registration/mod.rs @@ -90,6 +90,11 @@ impl Localized { &self.non_localized } + /// Get the non-localized variant. + pub fn to_non_localized(self) -> T { + self.non_localized + } + /// Get the variant corresponding to the given language, if it exists. pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> { match language { diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 01036742..feced0ba 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -98,122 +98,6 @@ }, "query": "\n SELECT at.oauth2_access_token_id\n , at.access_token AS \"oauth2_access_token\"\n , at.created_at AS \"oauth2_access_token_created_at\"\n , at.expires_at AS \"oauth2_access_token_expires_at\"\n , os.oauth2_session_id AS \"oauth2_session_id!\"\n , os.oauth2_client_id AS \"oauth2_client_id!\"\n , os.scope AS \"scope!\"\n , us.user_session_id AS \"user_session_id!\"\n , us.created_at AS \"user_session_created_at!\"\n , u.user_id AS \"user_id!\"\n , u.username AS \"user_username!\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , usa.user_session_authentication_id AS \"user_session_last_authentication_id?\"\n , usa.created_at AS \"user_session_last_authentication_created_at?\"\n\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n USING (oauth2_session_id)\n INNER JOIN user_sessions us\n USING (user_session_id)\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications usa\n USING (user_session_id)\n\n WHERE at.access_token = $1\n AND at.revoked_at IS NULL\n AND os.finished_at IS NULL\n\n ORDER BY usa.created_at DESC\n LIMIT 1\n " }, - "05b50b7ae0109063c50fe70e83635a31920e44a7fbaa2b4f07552ba2f83a28d7": { - "describe": { - "columns": [ - { - "name": "oauth2_client_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "encrypted_client_secret", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "redirect_uris!", - "ordinal": 2, - "type_info": "TextArray" - }, - { - "name": "grant_type_authorization_code", - "ordinal": 3, - "type_info": "Bool" - }, - { - "name": "grant_type_refresh_token", - "ordinal": 4, - "type_info": "Bool" - }, - { - "name": "client_name", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "logo_uri", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "client_uri", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "policy_uri", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "tos_uri", - "ordinal": 9, - "type_info": "Text" - }, - { - "name": "jwks_uri", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "jwks", - "ordinal": 11, - "type_info": "Jsonb" - }, - { - "name": "id_token_signed_response_alg", - "ordinal": 12, - "type_info": "Text" - }, - { - "name": "userinfo_signed_response_alg", - "ordinal": 13, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_method", - "ordinal": 14, - "type_info": "Text" - }, - { - "name": "token_endpoint_auth_signing_alg", - "ordinal": 15, - "type_info": "Text" - }, - { - "name": "initiate_login_uri", - "ordinal": 16, - "type_info": "Text" - } - ], - "nullable": [ - false, - true, - null, - false, - false, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true, - true - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " - }, "08d7df347c806ef14b6d0fb031cab041d79ba48528420160e23286369db7af35": { "describe": { "columns": [ @@ -1046,20 +930,6 @@ }, "query": "\n SELECT up.user_password_id\n , up.hashed_password\n , up.version\n , up.upgraded_from_id\n , up.created_at\n FROM user_passwords up\n WHERE up.user_id = $1\n ORDER BY up.created_at DESC\n LIMIT 1\n " }, - "4693f2b9b3d51ff4a05e233b6667161ebc97f331d96bf5f1c61069e1c8492105": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "UuidArray", - "Uuid", - "TextArray" - ] - } - }, - "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " - }, "46c5ae7052504bfd7b94f20e61b9cf92570779a794bccda23dd654fb8523f340": { "describe": { "columns": [ @@ -1432,7 +1302,147 @@ }, "query": "\n SELECT COUNT(*) as \"count!\"\n FROM user_sessions s\n WHERE s.user_id = $1 AND s.finished_at IS NULL\n " }, - "7756a60c36a64a259f7450d6eb77ee92303638ca374a63f23ac4944ccf9f4436": { + "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { + "describe": { + "columns": [ + { + "name": "user_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "user_session_created_at", + "ordinal": 1, + "type_info": "Timestamptz" + }, + { + "name": "user_session_finished_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "user_id", + "ordinal": 3, + "type_info": "Uuid" + }, + { + "name": "user_username", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "user_primary_user_email_id", + "ordinal": 5, + "type_info": "Uuid" + }, + { + "name": "last_authentication_id?", + "ordinal": 6, + "type_info": "Uuid" + }, + { + "name": "last_authd_at?", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n " + }, + "7be139553610ace03193a99fe27fcb4e3d50c90accdaf22ca1cfeefdc9734300": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "UuidArray", + "Uuid", + "TextArray" + ] + } + }, + "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " + }, + "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + } + }, + "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " + }, + "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " + }, + "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { + "describe": { + "columns": [ + { + "name": "user_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "username", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "primary_user_email_id", + "ordinal": 2, + "type_info": "Uuid" + }, + { + "name": "created_at", + "ordinal": 3, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n " + }, + "85499663f1adc7b7439592063f06914089f6243126a177b365bde37db5f6b33d": { "describe": { "columns": [ { @@ -1546,133 +1556,7 @@ ] } }, - "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = ANY($1::uuid[])\n " - }, - "79295f3d3a75f831e9469aabfa720d381a254d00dbe39fef1e9652029d51b89b": { - "describe": { - "columns": [ - { - "name": "user_session_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "user_session_created_at", - "ordinal": 1, - "type_info": "Timestamptz" - }, - { - "name": "user_session_finished_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "user_id", - "ordinal": 3, - "type_info": "Uuid" - }, - { - "name": "user_username", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "user_primary_user_email_id", - "ordinal": 5, - "type_info": "Uuid" - }, - { - "name": "last_authentication_id?", - "ordinal": 6, - "type_info": "Uuid" - }, - { - "name": "last_authd_at?", - "ordinal": 7, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false, - false, - true, - false, - false - ], - "parameters": { - "Left": [ - "Uuid" - ] - } - }, - "query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.primary_user_email_id AS \"user_primary_user_email_id\"\n , a.user_session_authentication_id AS \"last_authentication_id?\"\n , a.created_at AS \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n LEFT JOIN user_session_authentications a\n USING (user_session_id)\n WHERE s.user_session_id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n " - }, - "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid" - ] - } - }, - "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " - }, - "7e3247e35ecf5335f0656c53bcde27264a9efb8dccb6246344950614f487dcaf": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE compat_access_tokens\n SET expires_at = $2\n WHERE compat_access_token_id = $1\n " - }, - "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { - "describe": { - "columns": [ - { - "name": "user_id", - "ordinal": 0, - "type_info": "Uuid" - }, - { - "name": "username", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "primary_user_email_id", - "ordinal": 2, - "type_info": "Uuid" - }, - { - "name": "created_at", - "ordinal": 3, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - true, - false - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT user_id\n , username\n , primary_user_email_id\n , created_at\n FROM users\n WHERE username = $1\n " + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " }, "874e677f82c221c5bb621c12f293bcef4e70c68c87ec003fcd475bcb994b5a4c": { "describe": { @@ -2174,33 +2058,6 @@ }, "query": "\n INSERT INTO compat_sessions (compat_session_id, user_id, device_id, created_at)\n VALUES ($1, $2, $3, $4)\n " }, - "cc9e30678d673546efca336ee8e550083eed71459611fa2db52264e51e175901": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n (oauth2_client_id,\n encrypted_client_secret,\n grant_type_authorization_code,\n grant_type_refresh_token,\n client_name,\n logo_uri,\n client_uri,\n policy_uri,\n tos_uri,\n jwks_uri,\n jwks,\n id_token_signed_response_alg,\n userinfo_signed_response_alg,\n token_endpoint_auth_method,\n token_endpoint_auth_signing_alg,\n initiate_login_uri)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " - }, "d023d7346ec1f32da9459db3c39dffd8a4e3d4e91cdf096928de4517d3f8c622": { "describe": { "columns": [ @@ -2368,6 +2225,122 @@ }, "query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,\n refresh_token, created_at)\n VALUES\n ($1, $2, $3, $4, $5)\n " }, + "db90cbc406a399f5447bd2c1d8018464f83b927dec620353516c0285b76fcf24": { + "describe": { + "columns": [ + { + "name": "oauth2_client_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "encrypted_client_secret", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uris!", + "ordinal": 2, + "type_info": "TextArray" + }, + { + "name": "grant_type_authorization_code", + "ordinal": 3, + "type_info": "Bool" + }, + { + "name": "grant_type_refresh_token", + "ordinal": 4, + "type_info": "Bool" + }, + { + "name": "client_name", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "logo_uri", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "policy_uri", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "tos_uri", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "jwks_uri", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "jwks", + "ordinal": 11, + "type_info": "Jsonb" + }, + { + "name": "id_token_signed_response_alg", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "userinfo_signed_response_alg", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_signing_alg", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "initiate_login_uri", + "ordinal": 16, + "type_info": "Text" + } + ], + "nullable": [ + false, + true, + null, + false, + false, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = $1\n " + }, "dbf4be84eeff9ea51b00185faae2d453ab449017ed492bf6711dc7fceb630880": { "describe": { "columns": [], @@ -2426,6 +2399,33 @@ }, "query": "\n SELECT\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n " }, + "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " + }, "f624e1bdbff4e97b300362d1bbd86035e4a0fdd8ffe16c3bfb9bc451ba60851b": { "describe": { "columns": [ diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index 5c2347d2..e41f4812 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -19,8 +19,8 @@ use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use super::client::OAuth2ClientRepository; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository}; #[tracing::instrument( skip_all, @@ -144,7 +144,9 @@ pub async fn lookup_active_access_token( }; let session_id = res.oauth2_session_id.into(); - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) + let client = conn + .oauth2_client() + .lookup(res.oauth2_client_id.into()) .await? .ok_or_else(|| { DatabaseInconsistencyError::on("oauth2_sessions") diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index b7ffb30d..bfd91860 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -27,8 +27,8 @@ use ulid::Ulid; use url::Url; use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use super::client::OAuth2ClientRepository; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository}; #[tracing::instrument( skip_all, @@ -163,7 +163,7 @@ impl GrantLookup { #[allow(clippy::too_many_lines)] async fn into_authorization_grant( self, - executor: impl PgExecutor<'_>, + conn: &mut PgConnection, ) -> Result { let id = self.oauth2_authorization_grant_id.into(); let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| { @@ -173,8 +173,9 @@ impl GrantLookup { .source(e) })?; - // TODO: don't unwrap - let client = lookup_client(executor, self.oauth2_client_id.into()) + let client = conn + .oauth2_client() + .lookup(self.oauth2_client_id.into()) .await? .ok_or_else(|| { DatabaseInconsistencyError::on("oauth2_authorization_grants") diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 164b0e80..afe789db 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -12,8 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, string::ToString}; +use std::{ + collections::{BTreeMap, BTreeSet}, + string::ToString, +}; +use async_trait::async_trait; use mas_data_model::{Client, JwksOrJwksUri}; use mas_iana::{ jose::JsonWebSignatureAlg, @@ -21,17 +25,83 @@ use mas_iana::{ }; use mas_jose::jwk::PublicJsonWebKeySet; use oauth2_types::requests::GrantType; -use rand::Rng; -use sqlx::{PgConnection, PgExecutor}; +use rand::{Rng, RngCore}; +use sqlx::PgConnection; +use tracing::{info_span, Instrument}; use ulid::Ulid; use url::Url; use uuid::Uuid; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt}; +use crate::{ + tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, +}; + +#[async_trait] +pub trait OAuth2ClientRepository: Send + Sync { + type Error; + + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; + + async fn find_by_client_id(&mut self, client_id: &str) -> Result, Self::Error> { + let Ok(id) = client_id.parse() else { return Ok(None) }; + self.lookup(id).await + } + + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error>; + + #[allow(clippy::too_many_arguments)] + async fn add( + &mut self, + rng: &mut (dyn RngCore + Send), + clock: &Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result; + + #[allow(clippy::too_many_arguments)] + async fn add_from_config( + &mut self, + mut rng: impl Rng + Send, + clock: &Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result; +} + +pub struct PgOAuth2ClientRepository<'c> { + conn: &'c mut PgConnection, +} + +impl<'c> PgOAuth2ClientRepository<'c> { + pub fn new(conn: &'c mut PgConnection) -> Self { + Self { conn } + } +} // XXX: response_types & contacts #[derive(Debug)] -pub struct OAuth2ClientLookup { +struct OAuth2ClientLookup { oauth2_client_id: Uuid, encrypted_client_secret: Option, redirect_uris: Vec, @@ -234,252 +304,305 @@ impl TryInto for OAuth2ClientLookup { } } -#[tracing::instrument(skip_all, err)] -pub async fn lookup_clients( - executor: impl PgExecutor<'_>, - ids: impl IntoIterator + Send, -) -> Result, DatabaseError> { - let ids: Vec = ids.into_iter().map(Uuid::from).collect(); - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT - c.oauth2_client_id, - c.encrypted_client_secret, - ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!", - c.grant_type_authorization_code, - c.grant_type_refresh_token, - c.client_name, - c.logo_uri, - c.client_uri, - c.policy_uri, - c.tos_uri, - c.jwks_uri, - c.jwks, - c.id_token_signed_response_alg, - c.userinfo_signed_response_alg, - c.token_endpoint_auth_method, - c.token_endpoint_auth_signing_alg, - c.initiate_login_uri - FROM oauth2_clients c +#[async_trait] +impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { + type Error = DatabaseError; - WHERE c.oauth2_client_id = ANY($1::uuid[]) - "#, - &ids, - ) - .fetch_all(executor) - .await?; + #[tracing::instrument( + name = "db.oauth2_client.lookup", + skip_all, + fields( + db.statement, + oauth2_client.id = %id, + ), + err, + )] + async fn lookup(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c - res.into_iter() - .map(|r| { - r.try_into() - .map(|c: Client| (c.id, c)) - .map_err(DatabaseError::from) - }) - .collect() -} + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_one(&mut *self.conn) + .await + .to_option()?; -#[tracing::instrument( - skip_all, - fields(client.id = %id), - err, -)] -pub async fn lookup_client( - executor: impl PgExecutor<'_>, - id: Ulid, -) -> Result, DatabaseError> { - let res = sqlx::query_as!( - OAuth2ClientLookup, - r#" - SELECT - c.oauth2_client_id, - c.encrypted_client_secret, - ARRAY( - SELECT redirect_uri - FROM oauth2_client_redirect_uris r - WHERE r.oauth2_client_id = c.oauth2_client_id - ) AS "redirect_uris!", - c.grant_type_authorization_code, - c.grant_type_refresh_token, - c.client_name, - c.logo_uri, - c.client_uri, - c.policy_uri, - c.tos_uri, - c.jwks_uri, - c.jwks, - c.id_token_signed_response_alg, - c.userinfo_signed_response_alg, - c.token_endpoint_auth_method, - c.token_endpoint_auth_signing_alg, - c.initiate_login_uri - FROM oauth2_clients c + let Some(res) = res else { return Ok(None) }; - WHERE c.oauth2_client_id = $1 - "#, - Uuid::from(id), - ) - .fetch_one(executor) - .await - .to_option()?; + Ok(Some(res.try_into()?)) + } - let Some(res) = res else { return Ok(None) }; + #[tracing::instrument( + name = "db.oauth2_client.load_batch", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn load_batch( + &mut self, + ids: BTreeSet, + ) -> Result, Self::Error> { + let ids: Vec = ids.into_iter().map(Uuid::from).collect(); + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c - Ok(Some(res.try_into()?)) -} + WHERE oauth2_client_id = ANY($1::uuid[]) + "#, + &ids, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; -#[tracing::instrument( - skip_all, - fields(client.id = client_id), - err, -)] -pub async fn lookup_client_by_client_id( - executor: impl PgExecutor<'_>, - client_id: &str, -) -> Result, DatabaseError> { - let Ok(id) = client_id.parse() else { return Ok(None) }; - lookup_client(executor, id).await -} + res.into_iter() + .map(|r| { + r.try_into() + .map(|c: Client| (c.id, c)) + .map_err(DatabaseError::from) + }) + .collect() + } -#[tracing::instrument( - skip_all, - fields(client.id = %client_id, client.name = client_name), - err, -)] -#[allow(clippy::too_many_arguments)] -pub async fn insert_client( - conn: &mut PgConnection, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - redirect_uris: &[Url], - encrypted_client_secret: Option<&str>, - grant_types: &[GrantType], - _contacts: &[String], - client_name: Option<&str>, - logo_uri: Option<&Url>, - client_uri: Option<&Url>, - policy_uri: Option<&Url>, - tos_uri: Option<&Url>, - jwks_uri: Option<&Url>, - jwks: Option<&PublicJsonWebKeySet>, - id_token_signed_response_alg: Option<&JsonWebSignatureAlg>, - userinfo_signed_response_alg: Option<&JsonWebSignatureAlg>, - token_endpoint_auth_method: Option<&OAuthClientAuthenticationMethod>, - token_endpoint_auth_signing_alg: Option<&JsonWebSignatureAlg>, - initiate_login_uri: Option<&Url>, -) -> Result<(), sqlx::Error> { - let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode); - let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken); - let logo_uri = logo_uri.map(Url::as_str); - let client_uri = client_uri.map(Url::as_str); - let policy_uri = policy_uri.map(Url::as_str); - let tos_uri = tos_uri.map(Url::as_str); - let jwks = jwks.map(serde_json::to_value).transpose().unwrap(); // TODO - let jwks_uri = jwks_uri.map(Url::as_str); - let id_token_signed_response_alg = id_token_signed_response_alg.map(ToString::to_string); - let userinfo_signed_response_alg = userinfo_signed_response_alg.map(ToString::to_string); - let token_endpoint_auth_method = token_endpoint_auth_method.map(ToString::to_string); - let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(ToString::to_string); - let initiate_login_uri = initiate_login_uri.map(Url::as_str); + #[tracing::instrument( + name = "db.oauth2_client.add", + skip_all, + fields( + db.statement, + client.id, + client.name = client_name + ), + err, + )] + #[allow(clippy::too_many_lines)] + async fn add( + &mut self, + mut rng: &mut (dyn RngCore + Send), + clock: &Clock, + redirect_uris: Vec, + encrypted_client_secret: Option, + grant_types: Vec, + contacts: Vec, + client_name: Option, + logo_uri: Option, + client_uri: Option, + policy_uri: Option, + tos_uri: Option, + jwks_uri: Option, + jwks: Option, + id_token_signed_response_alg: Option, + userinfo_signed_response_alg: Option, + token_endpoint_auth_method: Option, + token_endpoint_auth_signing_alg: Option, + initiate_login_uri: Option, + ) -> Result { + let now = clock.now(); + let id = Ulid::from_datetime_with_source(now.into(), rng); + tracing::Span::current().record("client.id", tracing::field::display(id)); - sqlx::query!( - r#" - INSERT INTO oauth2_clients - (oauth2_client_id, - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) - "#, - Uuid::from(client_id), - encrypted_client_secret, - grant_type_authorization_code, - grant_type_refresh_token, - client_name, - logo_uri, - client_uri, - policy_uri, - tos_uri, - jwks_uri, - jwks, - id_token_signed_response_alg, - userinfo_signed_response_alg, - token_endpoint_auth_method, - token_endpoint_auth_signing_alg, - initiate_login_uri, - ) - .execute(&mut *conn) - .await?; + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), + sqlx::query!( + r#" + INSERT INTO oauth2_clients + ( oauth2_client_id + , encrypted_client_secret + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + "#, + Uuid::from(id), + encrypted_client_secret, + grant_types.contains(&GrantType::AuthorizationCode), + grant_types.contains(&GrantType::RefreshToken), + client_name, + logo_uri.as_ref().map(Url::as_str), + client_uri.as_ref().map(Url::as_str), + policy_uri.as_ref().map(Url::as_str), + tos_uri.as_ref().map(Url::as_str), + jwks_uri.as_ref().map(Url::as_str), + jwks_json, + id_token_signed_response_alg + .as_ref() + .map(ToString::to_string), + userinfo_signed_response_alg + .as_ref() + .map(ToString::to_string), + token_endpoint_auth_method.as_ref().map(ToString::to_string), + token_endpoint_auth_signing_alg + .as_ref() + .map(ToString::to_string), + initiate_login_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + { + let span = info_span!( + "db.oauth2_client.add.redirect_uris", + db.statement = tracing::field::Empty, + client.id = %id, + ); + + let (uri_ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &uri_ids, + Uuid::from(id), + &redirect_uris, ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id, + client_id: id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types, + contacts, + client_name, + logo_uri, + client_uri, + policy_uri, + tos_uri, + jwks, + id_token_signed_response_alg, + userinfo_signed_response_alg, + token_endpoint_auth_method, + token_endpoint_auth_signing_alg, + initiate_login_uri, }) - .unzip(); + } - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .execute(&mut *conn) - .await?; + #[tracing::instrument( + name = "db.oauth2_client.add_from_config", + skip_all, + fields( + db.statement, + client.id = %client_id, + ), + err, + )] + async fn add_from_config( + &mut self, + mut rng: impl Rng + Send, + clock: &Clock, + client_id: Ulid, + client_auth_method: OAuthClientAuthenticationMethod, + encrypted_client_secret: Option, + jwks: Option, + jwks_uri: Option, + redirect_uris: Vec, + ) -> Result { + let jwks_json = jwks + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(DatabaseError::to_invalid_operation)?; - Ok(()) -} + let client_auth_method = client_auth_method.to_string(); -#[allow(clippy::too_many_arguments)] -pub async fn insert_client_from_config( - conn: &mut PgConnection, - mut rng: impl Rng + Send, - clock: &Clock, - client_id: Ulid, - client_auth_method: OAuthClientAuthenticationMethod, - encrypted_client_secret: Option<&str>, - jwks: Option<&PublicJsonWebKeySet>, - jwks_uri: Option<&Url>, - redirect_uris: &[Url], -) -> Result<(), DatabaseError> { - let jwks = jwks - .map(serde_json::to_value) - .transpose() - .map_err(DatabaseError::to_invalid_operation)?; - - let jwks_uri = jwks_uri.map(Url::as_str); - - let client_auth_method = client_auth_method.to_string(); - - sqlx::query!( - r#" + sqlx::query!( + r#" INSERT INTO oauth2_clients ( oauth2_client_id , encrypted_client_secret @@ -500,41 +623,83 @@ pub async fn insert_client_from_config( , jwks = EXCLUDED.jwks , jwks_uri = EXCLUDED.jwks_uri "#, - Uuid::from(client_id), - encrypted_client_secret, - true, - true, - client_auth_method, - jwks, - jwks_uri, - ) - .execute(&mut *conn) - .await?; + Uuid::from(client_id), + encrypted_client_secret, + true, + true, + client_auth_method, + jwks_json, + jwks_uri.as_ref().map(Url::as_str), + ) + .traced() + .execute(&mut *self.conn) + .await?; - let now = clock.now(); - let (ids, redirect_uris): (Vec, Vec) = redirect_uris - .iter() - .map(|uri| { - ( - Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), - uri.as_str().to_owned(), + { + let span = info_span!( + "db.oauth2_client.add_from_config.redirect_uris", + client.id = %client_id, + db.statement = tracing::field::Empty, + ); + + let now = clock.now(); + let (ids, redirect_uris): (Vec, Vec) = redirect_uris + .iter() + .map(|uri| { + ( + Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)), + uri.as_str().to_owned(), + ) + }) + .unzip(); + + sqlx::query!( + r#" + INSERT INTO oauth2_client_redirect_uris + (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) + SELECT id, $2, redirect_uri + FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) + "#, + &ids, + Uuid::from(client_id), + &redirect_uris, ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let jwks = match (jwks, jwks_uri) { + (None, None) => None, + (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)), + (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)), + _ => return Err(DatabaseError::invalid_operation()), + }; + + Ok(Client { + id: client_id, + client_id: client_id.to_string(), + encrypted_client_secret, + redirect_uris, + response_types: vec![ + OAuthAuthorizationEndpointResponseType::Code, + OAuthAuthorizationEndpointResponseType::IdToken, + OAuthAuthorizationEndpointResponseType::None, + ], + grant_types: Vec::new(), + contacts: Vec::new(), + client_name: None, + logo_uri: None, + client_uri: None, + policy_uri: None, + tos_uri: None, + jwks, + id_token_signed_response_alg: None, + userinfo_signed_response_alg: None, + token_endpoint_auth_method: None, + token_endpoint_auth_signing_alg: None, + initiate_login_uri: None, }) - .unzip(); - - sqlx::query!( - r#" - INSERT INTO oauth2_client_redirect_uris - (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri) - SELECT id, $2, redirect_uri - FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri) - "#, - &ids, - Uuid::from(client_id), - &redirect_uris, - ) - .execute(&mut *conn) - .await?; - - Ok(()) + } } diff --git a/crates/storage/src/oauth2/mod.rs b/crates/storage/src/oauth2/mod.rs index bdc9c1b5..c0153a3a 100644 --- a/crates/storage/src/oauth2/mod.rs +++ b/crates/storage/src/oauth2/mod.rs @@ -20,7 +20,7 @@ use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; -use self::client::lookup_clients; +use self::client::OAuth2ClientRepository; use crate::{ pagination::{process_page, QueryBuilderExt}, user::BrowserSessionRepository, @@ -128,7 +128,7 @@ pub async fn get_paginated_user_oauth_sessions( let browser_session_ids: BTreeSet = page.iter().map(|i| Ulid::from(i.user_session_id)).collect(); - let clients = lookup_clients(&mut *conn, client_ids).await?; + let clients = conn.oauth2_client().load_batch(client_ids).await?; // TODO: this can generate N queries instead of batching. This is less than // ideal diff --git a/crates/storage/src/oauth2/refresh_token.rs b/crates/storage/src/oauth2/refresh_token.rs index 5c2b6318..57abf103 100644 --- a/crates/storage/src/oauth2/refresh_token.rs +++ b/crates/storage/src/oauth2/refresh_token.rs @@ -19,8 +19,8 @@ use sqlx::{PgConnection, PgExecutor}; use ulid::Ulid; use uuid::Uuid; -use super::client::lookup_client; -use crate::{Clock, DatabaseError, DatabaseInconsistencyError}; +use super::client::OAuth2ClientRepository; +use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository}; #[tracing::instrument( skip_all, @@ -173,7 +173,9 @@ pub async fn lookup_active_refresh_token( }; let session_id = res.oauth2_session_id.into(); - let client = lookup_client(&mut *conn, res.oauth2_client_id.into()) + let client = conn + .oauth2_client() + .lookup(res.oauth2_client_id.into()) .await? .ok_or_else(|| { DatabaseInconsistencyError::on("oauth2_sessions") diff --git a/crates/storage/src/repository.rs b/crates/storage/src/repository.rs index ef1a567a..4bca2253 100644 --- a/crates/storage/src/repository.rs +++ b/crates/storage/src/repository.rs @@ -15,6 +15,7 @@ use sqlx::{PgConnection, Postgres, Transaction}; use crate::{ + oauth2::client::PgOAuth2ClientRepository, upstream_oauth2::{ PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository, PgUpstreamOAuthSessionRepository, @@ -54,6 +55,10 @@ pub trait Repository { where Self: 'c; + type OAuth2ClientRepository<'c> + where + Self: 'c; + fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>; fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>; fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>; @@ -61,6 +66,7 @@ pub trait Repository { fn user_email(&mut self) -> Self::UserEmailRepository<'_>; fn user_password(&mut self) -> Self::UserPasswordRepository<'_>; fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>; + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>; } impl Repository for PgConnection { @@ -71,6 +77,7 @@ impl Repository for PgConnection { type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; + type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -99,6 +106,10 @@ impl Repository for PgConnection { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { PgBrowserSessionRepository::new(self) } + + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { + PgOAuth2ClientRepository::new(self) + } } impl<'t> Repository for Transaction<'t, Postgres> { @@ -109,6 +120,7 @@ impl<'t> Repository for Transaction<'t, Postgres> { type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c; type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c; type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c; + type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c; fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> { PgUpstreamOAuthLinkRepository::new(self) @@ -137,4 +149,8 @@ impl<'t> Repository for Transaction<'t, Postgres> { fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> { PgBrowserSessionRepository::new(self) } + + fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> { + PgOAuth2ClientRepository::new(self) + } } diff --git a/crates/storage/src/tracing.rs b/crates/storage/src/tracing.rs index 60eb284c..08c62e46 100644 --- a/crates/storage/src/tracing.rs +++ b/crates/storage/src/tracing.rs @@ -12,9 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub trait ExecuteExt<'q, DB> { +use tracing::Span; + +pub trait ExecuteExt<'q, DB>: Sized { /// Records the statement as `db.statement` in the current span - fn traced(self) -> Self; + fn traced(self) -> Self { + self.record(&Span::current()) + } + + /// Records the statement as `db.statement` in the given span + fn record(self, span: &Span) -> Self; } impl<'q, DB, T> ExecuteExt<'q, DB> for T @@ -22,8 +29,8 @@ where T: sqlx::Execute<'q, DB>, DB: sqlx::Database, { - fn traced(self) -> Self { - tracing::Span::current().record("db.statement", self.sql()); + fn record(self, span: &Span) -> Self { + span.record("db.statement", self.sql()); self } }