diff --git a/crates/cli/src/commands/config.rs b/crates/cli/src/commands/config.rs index 8d622cdf..c0667330 100644 --- a/crates/cli/src/commands/config.rs +++ b/crates/cli/src/commands/config.rs @@ -16,13 +16,46 @@ use std::collections::HashSet; use clap::Parser; use mas_config::{ConfigurationSection, RootConfig}; -use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess}; +use mas_storage::{ + upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess, SystemClock, +}; use mas_storage_pg::PgRepository; use rand::SeedableRng; use tracing::{info, info_span, warn}; use crate::util::database_from_config; +fn map_import_preference( + config: &mas_config::UpstreamOAuth2ImportPreference, +) -> mas_data_model::UpstreamOAuthProviderImportPreference { + let action = match &config.action { + mas_config::UpstreamOAuth2ImportAction::Ignore => { + mas_data_model::UpstreamOAuthProviderImportAction::Ignore + } + mas_config::UpstreamOAuth2ImportAction::Suggest => { + mas_data_model::UpstreamOAuthProviderImportAction::Suggest + } + mas_config::UpstreamOAuth2ImportAction::Force => { + mas_data_model::UpstreamOAuthProviderImportAction::Force + } + mas_config::UpstreamOAuth2ImportAction::Require => { + mas_data_model::UpstreamOAuthProviderImportAction::Require + } + }; + + mas_data_model::UpstreamOAuthProviderImportPreference { action } +} + +fn map_claims_imports( + config: &mas_config::UpstreamOAuth2ClaimsImports, +) -> mas_data_model::UpstreamOAuthProviderClaimsImports { + mas_data_model::UpstreamOAuthProviderClaimsImports { + localpart: config.localpart.as_ref().map(map_import_preference).unwrap_or_default(), + displayname: config.displayname.as_ref().map(map_import_preference).unwrap_or_default(), + email: config.email.as_ref().map(map_import_preference).unwrap_or_default(), + } +} + #[derive(Parser, Debug)] pub(super) struct Options { #[command(subcommand)] @@ -64,12 +97,14 @@ impl Options { serde_yaml::to_writer(std::io::stdout(), &config)?; } + SC::Check => { let _span = info_span!("cli.config.check").entered(); let _config: RootConfig = root.load_config()?; info!(path = ?root.config, "Configuration file looks good"); } + SC::Generate => { let _span = info_span!("cli.config.generate").entered(); @@ -81,9 +116,13 @@ impl Options { } SC::Sync { prune, dry_run } => { - let _span = info_span!("cli.config.sync").entered(); + let _span = + info_span!("cli.config.sync", prune = prune, dry_run = dry_run).entered(); + + let clock = SystemClock::default(); let config: RootConfig = root.load_config()?; + let encrypter = config.secrets.encrypter(); let pool = database_from_config(&config.database).await?; let mut repo = PgRepository::from_pool(&pool).await?.boxed(); @@ -93,9 +132,6 @@ impl Options { "Syncing providers and clients defined in config to database" ); - let existing = repo.upstream_oauth_provider().all().await?; - - let existing_ids = existing.iter().map(|p| p.id).collect::>(); let config_ids = config .upstream_oauth2 .providers @@ -103,24 +139,54 @@ impl Options { .map(|p| p.id) .collect::>(); - let needs_pruning = existing_ids.difference(&config_ids).collect::>(); + let existing = repo.upstream_oauth_provider().all().await?; + let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id)); if prune { - for id in needs_pruning { - info!(provider.id = %id, "Deleting provider"); + for provider in to_delete { + info!(%provider.id, "Deleting provider"); + + if dry_run { + continue; + } + + repo.upstream_oauth_provider().delete(provider).await?; + } + } else { + let len = to_delete.count(); + match len { + 0 => {}, + 1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."), + n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."), } - } else if !needs_pruning.is_empty() { - warn!( - "{} provider(s) in the database are not in the config. Run with `--prune` to delete them.", - needs_pruning.len() - ); } for provider in config.upstream_oauth2.providers { - if existing_ids.contains(&provider.id) { - info!(%provider.id, "Updating provider"); - } else { - info!(%provider.id, "Adding provider"); + info!(%provider.id, "Syncing provider"); + + if dry_run { + continue; } + + let encrypted_client_secret = provider + .client_secret() + .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) + .transpose()?; + let client_auth_method = provider.client_auth_method(); + let client_auth_signing_alg = provider.client_auth_signing_alg(); + + repo.upstream_oauth_provider() + .upsert( + &clock, + provider.id, + provider.issuer, + provider.scope.parse()?, + client_auth_method, + client_auth_signing_alg, + provider.client_id, + encrypted_client_secret, + map_claims_imports(&provider.claims_imports), + ) + .await?; } if dry_run { diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index 91a033b5..49a2dc68 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -48,7 +48,10 @@ pub use self::{ TelemetryConfig, TracingConfig, TracingExporterConfig, }, templates::TemplatesConfig, - upstream_oauth2::UpstreamOAuth2Config, + upstream_oauth2::{ + ClaimsImports as UpstreamOAuth2ClaimsImports, ImportAction as UpstreamOAuth2ImportAction, + ImportPreference as UpstreamOAuth2ImportPreference, UpstreamOAuth2Config, + }, }; use crate::util::ConfigurationSection; diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index 5e7d954e..a50c5034 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use async_trait::async_trait; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use rand::Rng; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -65,16 +68,17 @@ pub enum TokenAuthMethod { /// signed using the `client_secret` ClientSecretJwt { client_secret: String, - token_endpoint_auth_signing_alg: Option, + token_endpoint_auth_signing_alg: Option, }, /// `client_secret_basic`: a `client_assertion` sent in the request body and /// signed by an asymmetric key PrivateKeyJwt { - token_endpoint_auth_signing_alg: Option, + token_endpoint_auth_signing_alg: Option, }, } +/// How to handle a claim #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum ImportAction { @@ -92,6 +96,7 @@ pub enum ImportAction { Require, } +/// What should be done with a claim #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] pub struct ImportPreference { /// How to handle the claim @@ -99,6 +104,7 @@ pub struct ImportPreference { pub action: ImportAction, } +/// How claims should be imported #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] pub struct ClaimsImports { /// Import the localpart of the MXID based on the `preferred_username` claim @@ -142,3 +148,58 @@ pub struct Provider { /// provider pub claims_imports: ClaimsImports, } + +impl Deref for Provider { + type Target = TokenAuthMethod; + + fn deref(&self) -> &Self::Target { + &self.token_auth_method + } +} + +impl TokenAuthMethod { + #[doc(hidden)] + #[must_use] + pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod { + match self { + TokenAuthMethod::None => OAuthClientAuthenticationMethod::None, + TokenAuthMethod::ClientSecretBasic { .. } => { + OAuthClientAuthenticationMethod::ClientSecretBasic + } + TokenAuthMethod::ClientSecretPost { .. } => { + OAuthClientAuthenticationMethod::ClientSecretPost + } + TokenAuthMethod::ClientSecretJwt { .. } => { + OAuthClientAuthenticationMethod::ClientSecretJwt + } + TokenAuthMethod::PrivateKeyJwt { .. } => OAuthClientAuthenticationMethod::PrivateKeyJwt, + } + } + + #[doc(hidden)] + #[must_use] + pub fn client_secret(&self) -> Option<&str> { + match self { + TokenAuthMethod::None | TokenAuthMethod::PrivateKeyJwt { .. } => None, + TokenAuthMethod::ClientSecretBasic { client_secret } + | TokenAuthMethod::ClientSecretPost { client_secret } + | TokenAuthMethod::ClientSecretJwt { client_secret, .. } => Some(client_secret), + } + } + + #[doc(hidden)] + #[must_use] + pub fn client_auth_signing_alg(&self) -> Option { + match self { + TokenAuthMethod::ClientSecretJwt { + token_endpoint_auth_signing_alg, + .. + } + | TokenAuthMethod::PrivateKeyJwt { + token_endpoint_auth_signing_alg, + .. + } => token_endpoint_auth_signing_alg.clone(), + _ => None, + } + } +} diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 02b5b296..2e69fb9e 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -48,7 +48,7 @@ pub use self::{ upstream_oauth2::{ UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, - UpstreamOAuthProviderImportPreference, + UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference, }, users::{ Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification, diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index c5259008..97cc1f81 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -75,14 +75,17 @@ pub enum ImportAction { } impl ImportAction { + #[must_use] pub fn is_forced(&self) -> bool { matches!(self, Self::Force | Self::Require) } + #[must_use] pub fn ignore(&self) -> bool { matches!(self, Self::Ignore) } + #[must_use] pub fn is_required(&self) -> bool { matches!(self, Self::Require) } diff --git a/crates/storage-pg/sqlx-data.json b/crates/storage-pg/sqlx-data.json index 82a82ac1..01ee4539 100644 --- a/crates/storage-pg/sqlx-data.json +++ b/crates/storage-pg/sqlx-data.json @@ -1447,6 +1447,18 @@ }, "query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n " }, + "91a3ee5ad64a947b7807a590f6b014c6856229918b972b98946f98b75686ab6c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n DELETE FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " + }, "921d77c194609615a7e9a6fd806e9cc17a7927e3e5deb58f3917ceeb9ab4dede": { "describe": { "columns": [], @@ -2402,6 +2414,34 @@ }, "query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n " }, + "e7ce95415bb6b57cd601393c6abe5febfec2a963ce6eac7b099b761594b1dfaf": { + "describe": { + "columns": [ + { + "name": "created_at", + "ordinal": 0, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false + ], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Timestamptz", + "Jsonb" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports\n RETURNING created_at\n " + }, "f0ace1af3775192a555c4ebb59b81183f359771f9f77e5fad759d38d872541d1": { "describe": { "columns": [ diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 60dc60a3..e3051ec1 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -210,6 +210,108 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' }) } + #[tracing::instrument( + name = "db.upstream_oauth_provider.add", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, + )] + #[allow(clippy::too_many_arguments)] + async fn upsert( + &mut self, + clock: &dyn Clock, + id: Ulid, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + claims_imports: UpstreamOAuthProviderClaimsImports, + ) -> Result { + let created_at = clock.now(); + + let created_at = sqlx::query_scalar!( + r#" + INSERT INTO upstream_oauth_providers ( + upstream_oauth_provider_id, + issuer, + scope, + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id, + encrypted_client_secret, + created_at, + claims_imports + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (upstream_oauth_provider_id) + DO UPDATE + SET + issuer = EXCLUDED.issuer, + scope = EXCLUDED.scope, + token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method, + token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg, + client_id = EXCLUDED.client_id, + encrypted_client_secret = EXCLUDED.encrypted_client_secret, + claims_imports = EXCLUDED.claims_imports + RETURNING created_at + "#, + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + Json(&claims_imports) as _, + ) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + claims_imports, + }) + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.delete_by_id", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> { + sqlx::query!( + r#" + DELETE FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } + #[tracing::instrument( name = "db.upstream_oauth_provider.list_paginated", skip_all, diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index f1d05e6c..9624a40a 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -78,6 +78,65 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { claims_imports: UpstreamOAuthProviderClaimsImports, ) -> Result; + /// Delete an upstream OAuth provider + /// + /// # Parameters + /// + /// * `provider`: The provider to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> { + self.delete_by_id(provider.id).await + } + + /// Delete an upstream OAuth provider by its ID + /// + /// # Parameters + /// + /// * `id`: The ID of the provider to delete + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>; + + /// Insert or update an upstream OAuth provider + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `id`: The ID of the provider to update + /// * `issuer`: The OIDC issuer of the provider + /// * `scope`: The scope to request during the authorization flow + /// * `token_endpoint_auth_method`: The token endpoint authentication method + /// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use + /// when then `client_secret_jwt` or `private_key_jwt` authentication + /// methods are used + /// * `client_id`: The client ID to use when authenticating to the upstream + /// * `encrypted_client_secret`: The encrypted client secret to use when + /// authenticating to the upstream + /// * `claims_imports`: How claims should be imported from the upstream + /// provider + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + #[allow(clippy::too_many_arguments)] + async fn upsert( + &mut self, + clock: &dyn Clock, + id: Ulid, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + claims_imports: UpstreamOAuthProviderClaimsImports, + ) -> Result; + /// Get a paginated list of upstream OAuth providers /// /// # Parameters @@ -116,6 +175,23 @@ repository_impl!(UpstreamOAuthProviderRepository: claims_imports: UpstreamOAuthProviderClaimsImports ) -> Result; + async fn upsert( + &mut self, + clock: &dyn Clock, + id: Ulid, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, + claims_imports: UpstreamOAuthProviderClaimsImports, + ) -> Result; + + async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>; + + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>; + async fn list_paginated( &mut self, pagination: Pagination diff --git a/docs/config.schema.json b/docs/config.schema.json index 443f12dc..e98681ae 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -290,6 +290,7 @@ ] }, "ClaimsImports": { + "description": "How claims should be imported", "type": "object", "properties": { "displayname": { @@ -802,6 +803,7 @@ } }, "ImportAction": { + "description": "How to handle a claim", "oneOf": [ { "description": "Ignore the claim", @@ -834,6 +836,7 @@ ] }, "ImportPreference": { + "description": "What should be done with a claim", "type": "object", "properties": { "action": { @@ -1469,7 +1472,7 @@ ] }, "token_endpoint_auth_signing_alg": { - "type": "string" + "$ref": "#/definitions/JsonWebSignatureAlg" } } }, @@ -1487,7 +1490,7 @@ ] }, "token_endpoint_auth_signing_alg": { - "type": "string" + "$ref": "#/definitions/JsonWebSignatureAlg" } } }