From dec9310a32936a74801c4a1c5ea2d22fe2fb7d12 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 26 Jun 2023 16:22:49 +0200 Subject: [PATCH] Sync the OAuth2 clients with CLI and remove redundant CLI tools --- crates/cli/src/commands/config.rs | 255 ++++++++++----- crates/cli/src/commands/manage.rs | 274 +--------------- .../20230626130338_oauth_clients_static.sql | 19 ++ crates/storage-pg/sqlx-data.json | 306 ++++++++++++++---- crates/storage-pg/src/oauth2/client.rs | 213 +++++++++++- crates/storage-pg/src/upstream_oauth2/mod.rs | 5 +- .../src/upstream_oauth2/provider.rs | 56 ++-- crates/storage/src/oauth2/client.rs | 45 ++- 8 files changed, 738 insertions(+), 435 deletions(-) create mode 100644 crates/storage-pg/migrations/20230626130338_oauth_clients_static.sql diff --git a/crates/cli/src/commands/config.rs b/crates/cli/src/commands/config.rs index c0667330..9bd13dd6 100644 --- a/crates/cli/src/commands/config.rs +++ b/crates/cli/src/commands/config.rs @@ -50,9 +50,21 @@ fn map_claims_imports( config: &mas_config::UpstreamOAuth2ClaimsImports, ) -> mas_data_model::UpstreamOAuthProviderClaimsImports { mas_data_model::UpstreamOAuthProviderClaimsImports { - localpart: config.localpart.as_ref().map(map_import_preference).unwrap_or_default(), - displayname: config.displayname.as_ref().map(map_import_preference).unwrap_or_default(), - email: config.email.as_ref().map(map_import_preference).unwrap_or_default(), + localpart: config + .localpart + .as_ref() + .map(map_import_preference) + .unwrap_or_default(), + displayname: config + .displayname + .as_ref() + .map(map_import_preference) + .unwrap_or_default(), + email: config + .email + .as_ref() + .map(map_import_preference) + .unwrap_or_default(), } } @@ -116,88 +128,167 @@ impl Options { } SC::Sync { prune, dry_run } => { - let _span = - info_span!("cli.config.sync", prune = prune, dry_run = dry_run).entered(); - - let clock = SystemClock::default(); - - let config: RootConfig = root.load_config()?; - let encrypter = config.secrets.encrypter(); - let pool = database_from_config(&config.database).await?; - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); - - tracing::info!( - prune, - dry_run, - "Syncing providers and clients defined in config to database" - ); - - let config_ids = config - .upstream_oauth2 - .providers - .iter() - .map(|p| p.id) - .collect::>(); - - let existing = repo.upstream_oauth_provider().all().await?; - let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); - if prune { - for provider in to_delete { - info!(%provider.id, "Deleting provider"); - - if dry_run { - continue; - } - - repo.upstream_oauth_provider().delete(provider).await?; - } - } else { - let len = to_delete.count(); - match len { - 0 => {}, - 1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."), - n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."), - } - } - - for provider in config.upstream_oauth2.providers { - info!(%provider.id, "Syncing provider"); - - if dry_run { - continue; - } - - let encrypted_client_secret = provider - .client_secret() - .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) - .transpose()?; - let client_auth_method = provider.client_auth_method(); - let client_auth_signing_alg = provider.client_auth_signing_alg(); - - repo.upstream_oauth_provider() - .upsert( - &clock, - provider.id, - provider.issuer, - provider.scope.parse()?, - client_auth_method, - client_auth_signing_alg, - provider.client_id, - encrypted_client_secret, - map_claims_imports(&provider.claims_imports), - ) - .await?; - } - - if dry_run { - info!("Dry run, rolling back changes"); - repo.cancel().await?; - } else { - repo.save().await?; - } + sync(root, prune, dry_run).await?; } } Ok(()) } } + +#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))] +async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> { + // XXX: we should disallow SeedableRng::from_entropy + let mut rng = rand_chacha::ChaChaRng::from_entropy(); + let clock = SystemClock::default(); + + let config: RootConfig = root.load_config()?; + let encrypter = config.secrets.encrypter(); + let pool = database_from_config(&config.database).await?; + let mut repo = PgRepository::from_pool(&pool).await?.boxed(); + + tracing::info!( + prune, + dry_run, + "Syncing providers and clients defined in config to database" + ); + + { + let _span = info_span!("cli.config.sync.providers").entered(); + let config_ids = config + .upstream_oauth2 + .providers + .iter() + .map(|p| p.id) + .collect::>(); + + let existing = repo.upstream_oauth_provider().all().await?; + let existing_ids = existing.iter().map(|p| p.id).collect::>(); + let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); + if prune { + for provider in to_delete { + info!(%provider.id, "Deleting provider"); + + if dry_run { + continue; + } + + repo.upstream_oauth_provider().delete(provider).await?; + } + } else { + let len = to_delete.count(); + match len { + 0 => {}, + 1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."), + n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."), + } + } + + for provider in config.upstream_oauth2.providers { + if existing_ids.contains(&provider.id) { + info!(%provider.id, "Updating provider"); + } else { + info!(%provider.id, "Adding provider"); + } + + if dry_run { + continue; + } + + let encrypted_client_secret = provider + .client_secret() + .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) + .transpose()?; + let client_auth_method = provider.client_auth_method(); + let client_auth_signing_alg = provider.client_auth_signing_alg(); + + repo.upstream_oauth_provider() + .upsert( + &clock, + provider.id, + provider.issuer, + provider.scope.parse()?, + client_auth_method, + client_auth_signing_alg, + provider.client_id, + encrypted_client_secret, + map_claims_imports(&provider.claims_imports), + ) + .await?; + } + } + + { + let _span = info_span!("cli.config.sync.clients").entered(); + let config_ids = config + .clients + .iter() + .map(|c| c.client_id) + .collect::>(); + + let existing = repo.oauth2_client().all_static().await?; + let existing_ids = existing.iter().map(|p| p.id).collect::>(); + let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); + if prune { + for client in to_delete { + info!(client.id = %client.client_id, "Deleting client"); + + if dry_run { + continue; + } + + repo.oauth2_client().delete(client).await?; + } + } else { + let len = to_delete.count(); + match len { + 0 => {}, + 1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."), + n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."), + } + } + + for client in config.clients.iter() { + if existing_ids.contains(&client.client_id) { + info!(client.id = %client.client_id, "Updating client"); + } else { + info!(client.id = %client.client_id, "Adding client"); + } + + if dry_run { + continue; + } + + let client_secret = client.client_secret(); + let client_auth_method = client.client_auth_method(); + let jwks = client.jwks(); + let jwks_uri = client.jwks_uri(); + + // TODO: should be moved somewhere else + let encrypted_client_secret = client_secret + .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) + .transpose()?; + + repo.oauth2_client() + .upsert_static( + &mut rng, + &clock, + client.client_id, + client_auth_method, + encrypted_client_secret, + jwks.cloned(), + jwks_uri.cloned(), + client.redirect_uris.clone(), + ) + .await?; + } + } + + if dry_run { + info!("Dry run, rolling back changes"); + repo.cancel().await?; + } else { + repo.save().await?; + } + Ok(()) +} diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index eb6e22b7..c59dca5a 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -13,22 +13,17 @@ // limitations under the License. use anyhow::Context; -use clap::{Parser, ValueEnum}; -use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig}; -use mas_data_model::{Device, TokenType, UpstreamOAuthProviderClaimsImports}; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use mas_router::UrlBuilder; +use clap::Parser; +use mas_config::{DatabaseConfig, PasswordsConfig}; +use mas_data_model::{Device, TokenType}; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - oauth2::OAuth2ClientRepository, - upstream_oauth2::UpstreamOAuthProviderRepository, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Repository, RepositoryAccess, SystemClock, }; use mas_storage_pg::PgRepository; -use oauth2_types::scope::Scope; use rand::SeedableRng; -use tracing::{info, info_span, warn}; +use tracing::{info, info_span}; use crate::util::{database_from_config, password_manager_from_config}; @@ -38,153 +33,14 @@ pub(super) struct Options { subcommand: Subcommand, } -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] -enum AuthenticationMethod { - /// Client doesn't use any authentication - None, - - /// Client sends its `client_secret` in the request body - ClientSecretPost, - - /// Client sends its `client_secret` in the authorization header - ClientSecretBasic, - - /// Client uses its `client_secret` to sign a client assertion JWT - ClientSecretJwt, - - /// Client uses its private keys to sign a client assertion JWT - PrivateKeyJwt, -} - -impl AuthenticationMethod { - fn requires_client_secret(self) -> bool { - matches!( - self, - Self::ClientSecretJwt | Self::ClientSecretPost | Self::ClientSecretBasic - ) - } -} - -impl From for OAuthClientAuthenticationMethod { - fn from(val: AuthenticationMethod) -> Self { - (&val).into() - } -} - -impl From<&AuthenticationMethod> for OAuthClientAuthenticationMethod { - fn from(val: &AuthenticationMethod) -> Self { - match val { - AuthenticationMethod::None => OAuthClientAuthenticationMethod::None, - AuthenticationMethod::ClientSecretPost => { - OAuthClientAuthenticationMethod::ClientSecretPost - } - AuthenticationMethod::ClientSecretBasic => { - OAuthClientAuthenticationMethod::ClientSecretBasic - } - AuthenticationMethod::ClientSecretJwt => { - OAuthClientAuthenticationMethod::ClientSecretJwt - } - AuthenticationMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt, - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] -enum SigningAlgorithm { - #[value(name = "HS256")] - HS256, - #[value(name = "HS384")] - HS384, - #[value(name = "HS512")] - HS512, - #[value(name = "RS256")] - RS256, - #[value(name = "RS384")] - RS384, - #[value(name = "RS512")] - RS512, - #[value(name = "PS256")] - PS256, - #[value(name = "PS384")] - PS384, - #[value(name = "PS512")] - PS512, - #[value(name = "ES256")] - ES256, - #[value(name = "ES384")] - ES384, - #[value(name = "ES256K")] - ES256K, -} - -impl From for JsonWebSignatureAlg { - fn from(val: SigningAlgorithm) -> Self { - (&val).into() - } -} - -impl From<&SigningAlgorithm> for JsonWebSignatureAlg { - fn from(val: &SigningAlgorithm) -> Self { - match val { - SigningAlgorithm::HS256 => Self::Hs256, - SigningAlgorithm::HS384 => Self::Hs384, - SigningAlgorithm::HS512 => Self::Hs512, - SigningAlgorithm::RS256 => Self::Rs256, - SigningAlgorithm::RS384 => Self::Rs384, - SigningAlgorithm::RS512 => Self::Rs512, - SigningAlgorithm::PS256 => Self::Ps256, - SigningAlgorithm::PS384 => Self::Ps384, - SigningAlgorithm::PS512 => Self::Ps512, - SigningAlgorithm::ES256 => Self::Es256, - SigningAlgorithm::ES384 => Self::Es384, - SigningAlgorithm::ES256K => Self::Es256K, - } - } -} - #[derive(Parser, Debug)] enum Subcommand { /// Mark email address as verified VerifyEmail { username: String, email: String }, - /// Import clients from config - ImportClients { - /// Update existing clients - #[arg(long)] - update: bool, - }, - /// Set a user password SetPassword { username: String, password: String }, - /// Add an OAuth 2.0 upstream - #[command(name = "add-oauth-upstream")] - AddOAuthUpstream { - /// Issuer URL - issuer: String, - - /// Scope to ask for when authorizing with this upstream. - /// - /// This should include at least the `openid` scope. - scope: Scope, - - /// Client authentication method used when using the token endpoint. - #[arg(value_enum)] - token_endpoint_auth_method: AuthenticationMethod, - - /// Client ID - client_id: String, - - /// JWT signing algorithm used when authenticating for the token - /// endpoint. - #[arg(long, value_enum)] - signing_alg: Option, - - /// Client Secret - #[arg(long)] - client_secret: Option, - }, - /// Issue a compatibility token IssueCompatibilityToken { /// User for which to issue the token @@ -271,128 +127,6 @@ impl Options { Ok(()) } - SC::ImportClients { update } => { - let _span = info_span!("cli.manage.import_clients").entered(); - - let config: RootConfig = root.load_config()?; - let pool = database_from_config(&config.database).await?; - let encrypter = config.secrets.encrypter(); - - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); - - for client in config.clients.iter() { - let client_id = client.client_id; - - let existing = repo.oauth2_client().lookup(client_id).await?.is_some(); - if !update && existing { - warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients."); - continue; - } - - if existing { - info!(%client_id, "Updating client"); - } else { - info!(%client_id, "Importing client"); - } - - let client_secret = client.client_secret(); - let client_auth_method = client.client_auth_method(); - let jwks = client.jwks(); - let jwks_uri = client.jwks_uri(); - - // TODO: should be moved somewhere else - let encrypted_client_secret = client_secret - .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) - .transpose()?; - - repo.oauth2_client() - .add_from_config( - &mut rng, - &clock, - client_id, - client_auth_method, - encrypted_client_secret, - jwks.cloned(), - jwks_uri.cloned(), - client.redirect_uris.clone(), - ) - .await?; - } - - repo.save().await?; - - Ok(()) - } - - SC::AddOAuthUpstream { - issuer, - scope, - token_endpoint_auth_method, - client_id, - client_secret, - signing_alg, - } => { - let _span = info_span!( - "cli.manage.add_oauth_upstream", - upstream_oauth_provider.issuer = issuer, - upstream_oauth_provider.client_id = client_id, - ) - .entered(); - - let config: RootConfig = root.load_config()?; - let encrypter = config.secrets.encrypter(); - let pool = database_from_config(&config.database).await?; - let url_builder = UrlBuilder::new(config.http.public_base); - let mut repo = PgRepository::from_pool(&pool).await?.boxed(); - - let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); - - let token_endpoint_auth_method: OAuthClientAuthenticationMethod = - token_endpoint_auth_method.into(); - - let token_endpoint_signing_alg: Option = - signing_alg.as_ref().map(Into::into); - - tracing::info!(%issuer, %scope, %token_endpoint_auth_method, %client_id, "Adding OAuth upstream"); - - if client_secret.is_none() && requires_client_secret { - tracing::warn!("Token endpoint auth method requires a client secret, but none were provided"); - } - - let encrypted_client_secret = client_secret - .as_deref() - .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) - .transpose()?; - - let provider = repo - .upstream_oauth_provider() - .add( - &mut rng, - &clock, - issuer, - scope, - token_endpoint_auth_method, - token_endpoint_signing_alg, - client_id, - encrypted_client_secret, - UpstreamOAuthProviderClaimsImports::default(), - ) - .await?; - - repo.save().await?; - - let redirect_uri = url_builder.upstream_oauth_callback(provider.id); - let auth_uri = url_builder.upstream_oauth_authorize(provider.id); - tracing::info!( - %provider.id, - %provider.client_id, - provider.redirect_uri = %redirect_uri, - "Test authorization by going to {auth_uri}" - ); - - Ok(()) - } - SC::IssueCompatibilityToken { username, admin, diff --git a/crates/storage-pg/migrations/20230626130338_oauth_clients_static.sql b/crates/storage-pg/migrations/20230626130338_oauth_clients_static.sql new file mode 100644 index 00000000..6df88b33 --- /dev/null +++ b/crates/storage-pg/migrations/20230626130338_oauth_clients_static.sql @@ -0,0 +1,19 @@ +-- Copyright 2023 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- This adds a flag to the OAuth 2.0 clients to indicate whether they are static (i.e. defined in config) or not. +ALTER TABLE oauth2_clients + ADD COLUMN is_static + BOOLEAN NOT NULL + DEFAULT FALSE; \ No newline at end of file diff --git a/crates/storage-pg/sqlx-data.json b/crates/storage-pg/sqlx-data.json index de4cc692..cc4ed673 100644 --- a/crates/storage-pg/sqlx-data.json +++ b/crates/storage-pg/sqlx-data.json @@ -14,6 +14,18 @@ }, "query": "\n UPDATE oauth2_authorization_grants\n SET fulfilled_at = $2\n , oauth2_session_id = $3\n WHERE oauth2_authorization_grant_id = $1\n " }, + "036e9e2cb7271782e48700fecd3fdd80f596ed433f37f2528c7edbdc88b13646": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM oauth2_consents\n WHERE oauth2_client_id = $1\n " + }, "0469c1d3ad11fd96febacad33302709c870ead848d6920cdfdb18912d543488e": { "describe": { "columns": [ @@ -165,6 +177,18 @@ }, "query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n " }, + "1eb829460407fca22b717b88a1a0a9b7b920d807a4b6c235e1bee524cd73b266": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n " + }, "1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": { "describe": { "columns": [], @@ -190,6 +214,57 @@ }, "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, + "2a0d8d70d21afa9a2c9c1c432853361bb85911c48f7db6c3873b0f5abf35940b": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM oauth2_authorization_grants\n WHERE oauth2_client_id = $1\n " + }, + "2ee26886c56f04cd53d4c0968f5cf0963f92b6d15e6af0e69378a6447dee677c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM oauth2_access_tokens\n WHERE oauth2_session_id IN (\n SELECT oauth2_session_id\n FROM oauth2_sessions\n WHERE oauth2_client_id = $1\n )\n " + }, + "31cbbd841029812c6d3500cae04a8e9e5723e4749d339465492b68e072c3a802": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n , is_static\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, FALSE)\n " + }, "3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": { "describe": { "columns": [], @@ -706,6 +781,18 @@ }, "query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5)\n " }, + "5b697dd7834d33ec55972d3ba43d25fe794bc0b69c5938275711faa7a80b811f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM oauth2_refresh_tokens\n WHERE oauth2_session_id IN (\n SELECT oauth2_session_id\n FROM oauth2_sessions\n WHERE oauth2_client_id = $1\n )\n " + }, "5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": { "describe": { "columns": [], @@ -721,6 +808,18 @@ }, "query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n " }, + "5fe1bb569d13a7d3ff22887b3fc5b76ff901c183b314f8ccb5018d70c516abf6": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM oauth2_clients\n WHERE oauth2_client_id = $1\n " + }, "6021c1b9e17b0b2e8b511888f8c6be00683ba0635a13eb7fcd403d3d4a3f90db": { "describe": { "columns": [], @@ -913,6 +1012,24 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n " }, + "68c4cd463e4035ba8384f11818b7be602e2fbc34a5582f31f95b0cc5fa2aeb92": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Bool", + "Bool", + "Text", + "Jsonb", + "Text" + ] + } + }, + "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n , is_static\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, TRUE)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n , is_static = TRUE\n " + }, "6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384": { "describe": { "columns": [ @@ -1213,6 +1330,18 @@ }, "query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n " }, + "7cd0264707100f5b3cb2582f3f840bf66649742374e3643f1902ae69377fc9b6": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM oauth2_client_redirect_uris\n WHERE oauth2_client_id = $1\n " + }, "7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": { "describe": { "columns": [], @@ -1226,17 +1355,119 @@ }, "query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n " }, - "82fec0e13755e7032457d94fe8d212c62f14c80a98b61d82965f1b93f841c014": { + "7e676491b077d4bc8a9cdb4a27ebf119d98cd35ebb52b1064fdb2d9eed78d0e8": { "describe": { - "columns": [], - "nullable": [], + "columns": [ + { + "name": "oauth2_client_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "encrypted_client_secret", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "redirect_uris!", + "ordinal": 2, + "type_info": "TextArray" + }, + { + "name": "grant_type_authorization_code", + "ordinal": 3, + "type_info": "Bool" + }, + { + "name": "grant_type_refresh_token", + "ordinal": 4, + "type_info": "Bool" + }, + { + "name": "client_name", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "logo_uri", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_uri", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "policy_uri", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "tos_uri", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "jwks_uri", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "jwks", + "ordinal": 11, + "type_info": "Jsonb" + }, + { + "name": "id_token_signed_response_alg", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "userinfo_signed_response_alg", + "ordinal": 13, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_signing_alg", + "ordinal": 15, + "type_info": "Text" + }, + { + "name": "initiate_login_uri", + "ordinal": 16, + "type_info": "Text" + } + ], + "nullable": [ + false, + true, + null, + false, + false, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], "parameters": { - "Left": [ - "Uuid" - ] + "Left": [] } }, - "query": "\n DELETE FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_provider_id = $1\n " + "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n WHERE is_static = TRUE\n " }, "836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": { "describe": { @@ -1392,7 +1623,7 @@ }, "query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n " }, - "8a32a39c43147dfd9dbd25ff04686c3cdbc52ea5689ce3454d15e8ed31756f38": { + "8acbdc892d44efb53529da1c2df65bea6b799a43cf4c9264a37d392847e6eff0": { "describe": { "columns": [], "nullable": [], @@ -1402,25 +1633,7 @@ ] } }, - "query": "\n DELETE FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n " - }, - "8a79c7c392dd930628caadec80c9b2645501475ab4feacbac59ca1bc52b16c3f": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Jsonb", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n " + "query": "\n DELETE FROM oauth2_sessions\n WHERE oauth2_client_id = $1\n " }, "8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": { "describe": { @@ -1919,6 +2132,18 @@ }, "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n " }, + "b992283a9b43cbb8f86149f3f55cb47fb628dabd8fadc50e6a5772903f851e1c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_provider_id = $1\n " + }, "bbf62633c561706a762089bbab2f76a9ba3e2ed3539ef16accb601fb609c2ec9": { "describe": { "columns": [], @@ -2515,32 +2740,5 @@ } }, "query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n " - }, - "f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Bool", - "Bool", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text" - ] - } - }, - "query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n " } } \ No newline at end of file diff --git a/crates/storage-pg/src/oauth2/client.rs b/crates/storage-pg/src/oauth2/client.rs index cc2ed8b8..e737645b 100644 --- a/crates/storage-pg/src/oauth2/client.rs +++ b/crates/storage-pg/src/oauth2/client.rs @@ -428,9 +428,10 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { , token_endpoint_auth_method , token_endpoint_auth_signing_alg , initiate_login_uri + , is_static ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, FALSE) "#, Uuid::from(id), encrypted_client_secret, @@ -527,7 +528,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { } #[tracing::instrument( - name = "db.oauth2_client.add_from_config", + name = "db.oauth2_client.upsert_static", skip_all, fields( db.statement, @@ -535,7 +536,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { ), err, )] - async fn add_from_config( + async fn upsert_static( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, @@ -564,9 +565,10 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { , token_endpoint_auth_method , jwks , jwks_uri + , is_static ) VALUES - ($1, $2, $3, $4, $5, $6, $7) + ($1, $2, $3, $4, $5, $6, $7, TRUE) ON CONFLICT (oauth2_client_id) DO UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret @@ -575,6 +577,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method , jwks = EXCLUDED.jwks , jwks_uri = EXCLUDED.jwks_uri + , is_static = TRUE "#, Uuid::from(client_id), encrypted_client_secret, @@ -590,7 +593,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { { let span = info_span!( - "db.oauth2_client.add_from_config.redirect_uris", + "db.oauth2_client.upsert_static.redirect_uris", client.id = %client_id, db.statement = tracing::field::Empty, ); @@ -656,6 +659,52 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { }) } + #[tracing::instrument( + name = "db.oauth2_client.all_static", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn all_static(&mut self) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuth2ClientLookup, + r#" + SELECT oauth2_client_id + , encrypted_client_secret + , ARRAY( + SELECT redirect_uri + FROM oauth2_client_redirect_uris r + WHERE r.oauth2_client_id = c.oauth2_client_id + ) AS "redirect_uris!" + , grant_type_authorization_code + , grant_type_refresh_token + , client_name + , logo_uri + , client_uri + , policy_uri + , tos_uri + , jwks_uri + , jwks + , id_token_signed_response_alg + , userinfo_signed_response_alg + , token_endpoint_auth_method + , token_endpoint_auth_signing_alg + , initiate_login_uri + FROM oauth2_clients c + WHERE is_static = TRUE + "#, + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + res.into_iter() + .map(|r| r.try_into().map_err(DatabaseError::from)) + .collect() + } + #[tracing::instrument( name = "db.oauth2_client.get_consent_for_user", skip_all, @@ -698,6 +747,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { } #[tracing::instrument( + name = "db.oauth2_client.give_consent_for_user", skip_all, fields( db.statement, @@ -739,10 +789,161 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> { &tokens, now, ) - .traced() + .traced() .execute(&mut *self.conn) .await?; Ok(()) } + + #[tracing::instrument( + name = "db.oauth2_client.delete_by_id", + skip_all, + fields( + db.statement, + client.id = %id, + ), + err, + )] + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> { + // Delete the authorization grants + { + let span = info_span!( + "db.oauth2_client.delete_by_id.authorization_grants", + db.statement = tracing::field::Empty, + ); + + sqlx::query!( + r#" + DELETE FROM oauth2_authorization_grants + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + // Delete the user consents + { + let span = info_span!( + "db.oauth2_client.delete_by_id.consents", + db.statement = tracing::field::Empty, + ); + + sqlx::query!( + r#" + DELETE FROM oauth2_consents + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + // Delete the OAuth 2 sessions related data + { + let span = info_span!( + "db.oauth2_client.delete_by_id.access_tokens", + db.statement = tracing::field::Empty, + ); + + sqlx::query!( + r#" + DELETE FROM oauth2_access_tokens + WHERE oauth2_session_id IN ( + SELECT oauth2_session_id + FROM oauth2_sessions + WHERE oauth2_client_id = $1 + ) + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + { + let span = info_span!( + "db.oauth2_client.delete_by_id.refresh_tokens", + db.statement = tracing::field::Empty, + ); + + sqlx::query!( + r#" + DELETE FROM oauth2_refresh_tokens + WHERE oauth2_session_id IN ( + SELECT oauth2_session_id + FROM oauth2_sessions + WHERE oauth2_client_id = $1 + ) + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + { + let span = info_span!( + "db.oauth2_client.delete_by_id.sessions", + db.statement = tracing::field::Empty, + ); + + sqlx::query!( + r#" + DELETE FROM oauth2_sessions + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + // Delete the redirect URIs + { + let span = info_span!( + "db.oauth2_client.delete_by_id.redirect_uris", + db.statement = tracing::field::Empty, + ); + + sqlx::query!( + r#" + DELETE FROM oauth2_client_redirect_uris + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + // Now delete the client itself + let res = sqlx::query!( + r#" + DELETE FROM oauth2_clients + WHERE oauth2_client_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1) + } } diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index fd699db5..48ecab7c 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -189,7 +189,10 @@ mod tests { assert_eq!(links.edges[0].user_id, Some(user.id)); // Try deleting the provider - repo.upstream_oauth_provider().delete(provider).await.unwrap(); + repo.upstream_oauth_provider() + .delete(provider) + .await + .unwrap(); let providers = repo.upstream_oauth_provider().all().await.unwrap(); assert!(providers.is_empty()); } diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index eb187dbd..7400943f 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -20,6 +20,7 @@ use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, use oauth2_types::scope::Scope; use rand::RngCore; use sqlx::{types::Json, PgConnection, QueryBuilder}; +use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; @@ -298,30 +299,47 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' err, )] async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> { - // Delete the authorization sessions first, as they have a foreign key constraint - // on the links and the providers. - sqlx::query!( - r#" - DELETE FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .traced() + // Delete the authorization sessions first, as they have a foreign key + // constraint on the links and the providers. + { + let span = info_span!( + "db.oauth2_client.delete_by_id.authorization_sessions", + upstream_oauth_provider.id = %id, + db.statement = tracing::field::Empty, + ); + sqlx::query!( + r#" + DELETE FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) .execute(&mut *self.conn) + .instrument(span) .await?; + } - // Delete the links next, as they have a foreign key constraint on the providers. - sqlx::query!( - r#" - DELETE FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .traced() + // Delete the links next, as they have a foreign key constraint on the + // providers. + { + let span = info_span!( + "db.oauth2_client.delete_by_id.links", + upstream_oauth_provider.id = %id, + db.statement = tracing::field::Empty, + ); + sqlx::query!( + r#" + DELETE FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) .execute(&mut *self.conn) + .instrument(span) .await?; + } let res = sqlx::query!( r#" diff --git a/crates/storage/src/oauth2/client.rs b/crates/storage/src/oauth2/client.rs index 18f0108b..275ad05e 100644 --- a/crates/storage/src/oauth2/client.rs +++ b/crates/storage/src/oauth2/client.rs @@ -124,7 +124,7 @@ pub trait OAuth2ClientRepository: Send + Sync { initiate_login_uri: Option, ) -> Result; - /// Add or replace a client from the configuration + /// Add or replace a static client /// /// Returns the client that was added or replaced /// @@ -143,7 +143,7 @@ pub trait OAuth2ClientRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails #[allow(clippy::too_many_arguments)] - async fn add_from_config( + async fn upsert_static( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, @@ -155,6 +155,13 @@ pub trait OAuth2ClientRepository: Send + Sync { redirect_uris: Vec, ) -> Result; + /// List all static clients + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn all_static(&mut self) -> Result, Self::Error>; + /// Get the list of scopes that the user has given consent for the given /// client /// @@ -193,6 +200,32 @@ pub trait OAuth2ClientRepository: Send + Sync { user: &User, scope: &Scope, ) -> Result<(), Self::Error>; + + /// Delete a client + /// + /// # Parameters + /// + /// * `client`: The client to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails, or if the + /// client does not exist + async fn delete(&mut self, client: Client) -> Result<(), Self::Error> { + self.delete_by_id(client.id).await + } + + /// Delete a client by ID + /// + /// # Parameters + /// + /// * `id`: The ID of the client to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails, or if the + /// client does not exist + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>; } repository_impl!(OAuth2ClientRepository: @@ -225,7 +258,7 @@ repository_impl!(OAuth2ClientRepository: initiate_login_uri: Option, ) -> Result; - async fn add_from_config( + async fn upsert_static( &mut self, rng: &mut (dyn RngCore + Send), clock: &dyn Clock, @@ -237,6 +270,12 @@ repository_impl!(OAuth2ClientRepository: redirect_uris: Vec, ) -> Result; + async fn all_static(&mut self) -> Result, Self::Error>; + + async fn delete(&mut self, client: Client) -> Result<(), Self::Error>; + + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>; + async fn get_consent_for_user( &mut self, client: &Client,