1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

CLI tool to sync the upstream IDPs with the config

This commit is contained in:
Quentin Gliech
2023-06-26 14:21:57 +02:00
parent 4f1b201c74
commit de13d3ef19
9 changed files with 377 additions and 23 deletions

View File

@ -16,13 +16,46 @@ use std::collections::HashSet;
use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig};
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess};
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use rand::SeedableRng;
use tracing::{info, info_span, warn};
use crate::util::database_from_config;
fn map_import_preference(
config: &mas_config::UpstreamOAuth2ImportPreference,
) -> mas_data_model::UpstreamOAuthProviderImportPreference {
let action = match &config.action {
mas_config::UpstreamOAuth2ImportAction::Ignore => {
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
}
mas_config::UpstreamOAuth2ImportAction::Suggest => {
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
}
mas_config::UpstreamOAuth2ImportAction::Force => {
mas_data_model::UpstreamOAuthProviderImportAction::Force
}
mas_config::UpstreamOAuth2ImportAction::Require => {
mas_data_model::UpstreamOAuthProviderImportAction::Require
}
};
mas_data_model::UpstreamOAuthProviderImportPreference { action }
}
fn map_claims_imports(
config: &mas_config::UpstreamOAuth2ClaimsImports,
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
mas_data_model::UpstreamOAuthProviderClaimsImports {
localpart: config.localpart.as_ref().map(map_import_preference).unwrap_or_default(),
displayname: config.displayname.as_ref().map(map_import_preference).unwrap_or_default(),
email: config.email.as_ref().map(map_import_preference).unwrap_or_default(),
}
}
#[derive(Parser, Debug)]
pub(super) struct Options {
#[command(subcommand)]
@ -64,12 +97,14 @@ impl Options {
serde_yaml::to_writer(std::io::stdout(), &config)?;
}
SC::Check => {
let _span = info_span!("cli.config.check").entered();
let _config: RootConfig = root.load_config()?;
info!(path = ?root.config, "Configuration file looks good");
}
SC::Generate => {
let _span = info_span!("cli.config.generate").entered();
@ -81,9 +116,13 @@ impl Options {
}
SC::Sync { prune, dry_run } => {
let _span = info_span!("cli.config.sync").entered();
let _span =
info_span!("cli.config.sync", prune = prune, dry_run = dry_run).entered();
let clock = SystemClock::default();
let config: RootConfig = root.load_config()?;
let encrypter = config.secrets.encrypter();
let pool = database_from_config(&config.database).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
@ -93,9 +132,6 @@ impl Options {
"Syncing providers and clients defined in config to database"
);
let existing = repo.upstream_oauth_provider().all().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let config_ids = config
.upstream_oauth2
.providers
@ -103,24 +139,54 @@ impl Options {
.map(|p| p.id)
.collect::<HashSet<_>>();
let needs_pruning = existing_ids.difference(&config_ids).collect::<Vec<_>>();
let existing = repo.upstream_oauth_provider().all().await?;
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for id in needs_pruning {
info!(provider.id = %id, "Deleting provider");
for provider in to_delete {
info!(%provider.id, "Deleting provider");
if dry_run {
continue;
}
repo.upstream_oauth_provider().delete(provider).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {},
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
}
} else if !needs_pruning.is_empty() {
warn!(
"{} provider(s) in the database are not in the config. Run with `--prune` to delete them.",
needs_pruning.len()
);
}
for provider in config.upstream_oauth2.providers {
if existing_ids.contains(&provider.id) {
info!(%provider.id, "Updating provider");
} else {
info!(%provider.id, "Adding provider");
info!(%provider.id, "Syncing provider");
if dry_run {
continue;
}
let encrypted_client_secret = provider
.client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let client_auth_method = provider.client_auth_method();
let client_auth_signing_alg = provider.client_auth_signing_alg();
repo.upstream_oauth_provider()
.upsert(
&clock,
provider.id,
provider.issuer,
provider.scope.parse()?,
client_auth_method,
client_auth_signing_alg,
provider.client_id,
encrypted_client_secret,
map_claims_imports(&provider.claims_imports),
)
.await?;
}
if dry_run {

View File

@ -48,7 +48,10 @@ pub use self::{
TelemetryConfig, TracingConfig, TracingExporterConfig,
},
templates::TemplatesConfig,
upstream_oauth2::UpstreamOAuth2Config,
upstream_oauth2::{
ClaimsImports as UpstreamOAuth2ClaimsImports, ImportAction as UpstreamOAuth2ImportAction,
ImportPreference as UpstreamOAuth2ImportPreference, UpstreamOAuth2Config,
},
};
use crate::util::ConfigurationSection;

View File

@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use async_trait::async_trait;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use rand::Rng;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -65,16 +68,17 @@ pub enum TokenAuthMethod {
/// signed using the `client_secret`
ClientSecretJwt {
client_secret: String,
token_endpoint_auth_signing_alg: Option<String>,
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
},
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// signed by an asymmetric key
PrivateKeyJwt {
token_endpoint_auth_signing_alg: Option<String>,
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
},
}
/// How to handle a claim
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum ImportAction {
@ -92,6 +96,7 @@ pub enum ImportAction {
Require,
}
/// What should be done with a claim
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ImportPreference {
/// How to handle the claim
@ -99,6 +104,7 @@ pub struct ImportPreference {
pub action: ImportAction,
}
/// How claims should be imported
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ClaimsImports {
/// Import the localpart of the MXID based on the `preferred_username` claim
@ -142,3 +148,58 @@ pub struct Provider {
/// provider
pub claims_imports: ClaimsImports,
}
impl Deref for Provider {
type Target = TokenAuthMethod;
fn deref(&self) -> &Self::Target {
&self.token_auth_method
}
}
impl TokenAuthMethod {
#[doc(hidden)]
#[must_use]
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
match self {
TokenAuthMethod::None => OAuthClientAuthenticationMethod::None,
TokenAuthMethod::ClientSecretBasic { .. } => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
TokenAuthMethod::ClientSecretPost { .. } => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
TokenAuthMethod::ClientSecretJwt { .. } => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
TokenAuthMethod::PrivateKeyJwt { .. } => OAuthClientAuthenticationMethod::PrivateKeyJwt,
}
}
#[doc(hidden)]
#[must_use]
pub fn client_secret(&self) -> Option<&str> {
match self {
TokenAuthMethod::None | TokenAuthMethod::PrivateKeyJwt { .. } => None,
TokenAuthMethod::ClientSecretBasic { client_secret }
| TokenAuthMethod::ClientSecretPost { client_secret }
| TokenAuthMethod::ClientSecretJwt { client_secret, .. } => Some(client_secret),
}
}
#[doc(hidden)]
#[must_use]
pub fn client_auth_signing_alg(&self) -> Option<JsonWebSignatureAlg> {
match self {
TokenAuthMethod::ClientSecretJwt {
token_endpoint_auth_signing_alg,
..
}
| TokenAuthMethod::PrivateKeyJwt {
token_endpoint_auth_signing_alg,
..
} => token_endpoint_auth_signing_alg.clone(),
_ => None,
}
}
}

View File

@ -48,7 +48,7 @@ pub use self::{
upstream_oauth2::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
UpstreamOAuthProviderImportPreference,
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
},
users::{
Authentication, BrowserSession, Password, User, UserEmail, UserEmailVerification,

View File

@ -75,14 +75,17 @@ pub enum ImportAction {
}
impl ImportAction {
#[must_use]
pub fn is_forced(&self) -> bool {
matches!(self, Self::Force | Self::Require)
}
#[must_use]
pub fn ignore(&self) -> bool {
matches!(self, Self::Ignore)
}
#[must_use]
pub fn is_required(&self) -> bool {
matches!(self, Self::Require)
}

View File

@ -1447,6 +1447,18 @@
},
"query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n "
},
"91a3ee5ad64a947b7807a590f6b014c6856229918b972b98946f98b75686ab6c": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n "
},
"921d77c194609615a7e9a6fd806e9cc17a7927e3e5deb58f3917ceeb9ab4dede": {
"describe": {
"columns": [],
@ -2402,6 +2414,34 @@
},
"query": "\n SELECT oauth2_refresh_token_id\n , refresh_token\n , created_at\n , consumed_at\n , oauth2_access_token_id\n , oauth2_session_id\n FROM oauth2_refresh_tokens\n\n WHERE refresh_token = $1\n "
},
"e7ce95415bb6b57cd601393c6abe5febfec2a963ce6eac7b099b761594b1dfaf": {
"describe": {
"columns": [
{
"name": "created_at",
"ordinal": 0,
"type_info": "Timestamptz"
}
],
"nullable": [
false
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz",
"Jsonb"
]
}
},
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at,\n claims_imports\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n ON CONFLICT (upstream_oauth_provider_id) \n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports\n RETURNING created_at\n "
},
"f0ace1af3775192a555c4ebb59b81183f359771f9f77e5fad759d38d872541d1": {
"describe": {
"columns": [

View File

@ -210,6 +210,108 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
})
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.add",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
upstream_oauth_provider.issuer = %issuer,
upstream_oauth_provider.client_id = %client_id,
),
err,
)]
#[allow(clippy::too_many_arguments)]
async fn upsert(
&mut self,
clock: &dyn Clock,
id: Ulid,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error> {
let created_at = clock.now();
let created_at = sqlx::query_scalar!(
r#"
INSERT INTO upstream_oauth_providers (
upstream_oauth_provider_id,
issuer,
scope,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
created_at,
claims_imports
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (upstream_oauth_provider_id)
DO UPDATE
SET
issuer = EXCLUDED.issuer,
scope = EXCLUDED.scope,
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
client_id = EXCLUDED.client_id,
encrypted_client_secret = EXCLUDED.encrypted_client_secret,
claims_imports = EXCLUDED.claims_imports
RETURNING created_at
"#,
Uuid::from(id),
&issuer,
scope.to_string(),
token_endpoint_auth_method.to_string(),
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
&client_id,
encrypted_client_secret.as_deref(),
created_at,
Json(&claims_imports) as _,
)
.traced()
.fetch_one(&mut *self.conn)
.await?;
Ok(UpstreamOAuthProvider {
id,
issuer,
scope,
client_id,
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
created_at,
claims_imports,
})
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.delete_by_id",
skip_all,
fields(
db.statement,
upstream_oauth_provider.id = %id,
),
err,
)]
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
sqlx::query!(
r#"
DELETE FROM upstream_oauth_providers
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.upstream_oauth_provider.list_paginated",
skip_all,

View File

@ -78,6 +78,65 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync {
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Delete an upstream OAuth provider
///
/// # Parameters
///
/// * `provider`: The provider to delete
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
self.delete_by_id(provider.id).await
}
/// Delete an upstream OAuth provider by its ID
///
/// # Parameters
///
/// * `id`: The ID of the provider to delete
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
/// Insert or update an upstream OAuth provider
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `id`: The ID of the provider to update
/// * `issuer`: The OIDC issuer of the provider
/// * `scope`: The scope to request during the authorization flow
/// * `token_endpoint_auth_method`: The token endpoint authentication method
/// * `token_endpoint_auth_signing_alg`: The JWT signing algorithm to use
/// when then `client_secret_jwt` or `private_key_jwt` authentication
/// methods are used
/// * `client_id`: The client ID to use when authenticating to the upstream
/// * `encrypted_client_secret`: The encrypted client secret to use when
/// authenticating to the upstream
/// * `claims_imports`: How claims should be imported from the upstream
/// provider
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
#[allow(clippy::too_many_arguments)]
async fn upsert(
&mut self,
clock: &dyn Clock,
id: Ulid,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error>;
/// Get a paginated list of upstream OAuth providers
///
/// # Parameters
@ -116,6 +175,23 @@ repository_impl!(UpstreamOAuthProviderRepository:
claims_imports: UpstreamOAuthProviderClaimsImports
) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn upsert(
&mut self,
clock: &dyn Clock,
id: Ulid,
issuer: String,
scope: Scope,
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
client_id: String,
encrypted_client_secret: Option<String>,
claims_imports: UpstreamOAuthProviderClaimsImports,
) -> Result<UpstreamOAuthProvider, Self::Error>;
async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
async fn list_paginated(
&mut self,
pagination: Pagination

View File

@ -290,6 +290,7 @@
]
},
"ClaimsImports": {
"description": "How claims should be imported",
"type": "object",
"properties": {
"displayname": {
@ -802,6 +803,7 @@
}
},
"ImportAction": {
"description": "How to handle a claim",
"oneOf": [
{
"description": "Ignore the claim",
@ -834,6 +836,7 @@
]
},
"ImportPreference": {
"description": "What should be done with a claim",
"type": "object",
"properties": {
"action": {
@ -1469,7 +1472,7 @@
]
},
"token_endpoint_auth_signing_alg": {
"type": "string"
"$ref": "#/definitions/JsonWebSignatureAlg"
}
}
},
@ -1487,7 +1490,7 @@
]
},
"token_endpoint_auth_signing_alg": {
"type": "string"
"$ref": "#/definitions/JsonWebSignatureAlg"
}
}
}