You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Sync the OAuth2 clients with CLI and remove redundant CLI tools
This commit is contained in:
@ -50,9 +50,21 @@ fn map_claims_imports(
|
||||
config: &mas_config::UpstreamOAuth2ClaimsImports,
|
||||
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
|
||||
mas_data_model::UpstreamOAuthProviderClaimsImports {
|
||||
localpart: config.localpart.as_ref().map(map_import_preference).unwrap_or_default(),
|
||||
displayname: config.displayname.as_ref().map(map_import_preference).unwrap_or_default(),
|
||||
email: config.email.as_ref().map(map_import_preference).unwrap_or_default(),
|
||||
localpart: config
|
||||
.localpart
|
||||
.as_ref()
|
||||
.map(map_import_preference)
|
||||
.unwrap_or_default(),
|
||||
displayname: config
|
||||
.displayname
|
||||
.as_ref()
|
||||
.map(map_import_preference)
|
||||
.unwrap_or_default(),
|
||||
email: config
|
||||
.email
|
||||
.as_ref()
|
||||
.map(map_import_preference)
|
||||
.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,88 +128,167 @@ impl Options {
|
||||
}
|
||||
|
||||
SC::Sync { prune, dry_run } => {
|
||||
let _span =
|
||||
info_span!("cli.config.sync", prune = prune, dry_run = dry_run).entered();
|
||||
|
||||
let clock = SystemClock::default();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
tracing::info!(
|
||||
prune,
|
||||
dry_run,
|
||||
"Syncing providers and clients defined in config to database"
|
||||
);
|
||||
|
||||
let config_ids = config
|
||||
.upstream_oauth2
|
||||
.providers
|
||||
.iter()
|
||||
.map(|p| p.id)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let existing = repo.upstream_oauth_provider().all().await?;
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
for provider in to_delete {
|
||||
info!(%provider.id, "Deleting provider");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
repo.upstream_oauth_provider().delete(provider).await?;
|
||||
}
|
||||
} else {
|
||||
let len = to_delete.count();
|
||||
match len {
|
||||
0 => {},
|
||||
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
|
||||
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
|
||||
}
|
||||
}
|
||||
|
||||
for provider in config.upstream_oauth2.providers {
|
||||
info!(%provider.id, "Syncing provider");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
let encrypted_client_secret = provider
|
||||
.client_secret()
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
let client_auth_method = provider.client_auth_method();
|
||||
let client_auth_signing_alg = provider.client_auth_signing_alg();
|
||||
|
||||
repo.upstream_oauth_provider()
|
||||
.upsert(
|
||||
&clock,
|
||||
provider.id,
|
||||
provider.issuer,
|
||||
provider.scope.parse()?,
|
||||
client_auth_method,
|
||||
client_auth_signing_alg,
|
||||
provider.client_id,
|
||||
encrypted_client_secret,
|
||||
map_claims_imports(&provider.claims_imports),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
info!("Dry run, rolling back changes");
|
||||
repo.cancel().await?;
|
||||
} else {
|
||||
repo.save().await?;
|
||||
}
|
||||
sync(root, prune, dry_run).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))]
|
||||
async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> {
|
||||
// XXX: we should disallow SeedableRng::from_entropy
|
||||
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
let clock = SystemClock::default();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
tracing::info!(
|
||||
prune,
|
||||
dry_run,
|
||||
"Syncing providers and clients defined in config to database"
|
||||
);
|
||||
|
||||
{
|
||||
let _span = info_span!("cli.config.sync.providers").entered();
|
||||
let config_ids = config
|
||||
.upstream_oauth2
|
||||
.providers
|
||||
.iter()
|
||||
.map(|p| p.id)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let existing = repo.upstream_oauth_provider().all().await?;
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
for provider in to_delete {
|
||||
info!(%provider.id, "Deleting provider");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
repo.upstream_oauth_provider().delete(provider).await?;
|
||||
}
|
||||
} else {
|
||||
let len = to_delete.count();
|
||||
match len {
|
||||
0 => {},
|
||||
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
|
||||
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
|
||||
}
|
||||
}
|
||||
|
||||
for provider in config.upstream_oauth2.providers {
|
||||
if existing_ids.contains(&provider.id) {
|
||||
info!(%provider.id, "Updating provider");
|
||||
} else {
|
||||
info!(%provider.id, "Adding provider");
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
let encrypted_client_secret = provider
|
||||
.client_secret()
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
let client_auth_method = provider.client_auth_method();
|
||||
let client_auth_signing_alg = provider.client_auth_signing_alg();
|
||||
|
||||
repo.upstream_oauth_provider()
|
||||
.upsert(
|
||||
&clock,
|
||||
provider.id,
|
||||
provider.issuer,
|
||||
provider.scope.parse()?,
|
||||
client_auth_method,
|
||||
client_auth_signing_alg,
|
||||
provider.client_id,
|
||||
encrypted_client_secret,
|
||||
map_claims_imports(&provider.claims_imports),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let _span = info_span!("cli.config.sync.clients").entered();
|
||||
let config_ids = config
|
||||
.clients
|
||||
.iter()
|
||||
.map(|c| c.client_id)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let existing = repo.oauth2_client().all_static().await?;
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
|
||||
if prune {
|
||||
for client in to_delete {
|
||||
info!(client.id = %client.client_id, "Deleting client");
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
repo.oauth2_client().delete(client).await?;
|
||||
}
|
||||
} else {
|
||||
let len = to_delete.count();
|
||||
match len {
|
||||
0 => {},
|
||||
1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."),
|
||||
n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."),
|
||||
}
|
||||
}
|
||||
|
||||
for client in config.clients.iter() {
|
||||
if existing_ids.contains(&client.client_id) {
|
||||
info!(client.id = %client.client_id, "Updating client");
|
||||
} else {
|
||||
info!(client.id = %client.client_id, "Adding client");
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
continue;
|
||||
}
|
||||
|
||||
let client_secret = client.client_secret();
|
||||
let client_auth_method = client.client_auth_method();
|
||||
let jwks = client.jwks();
|
||||
let jwks_uri = client.jwks_uri();
|
||||
|
||||
// TODO: should be moved somewhere else
|
||||
let encrypted_client_secret = client_secret
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
repo.oauth2_client()
|
||||
.upsert_static(
|
||||
&mut rng,
|
||||
&clock,
|
||||
client.client_id,
|
||||
client_auth_method,
|
||||
encrypted_client_secret,
|
||||
jwks.cloned(),
|
||||
jwks_uri.cloned(),
|
||||
client.redirect_uris.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
info!("Dry run, rolling back changes");
|
||||
repo.cancel().await?;
|
||||
} else {
|
||||
repo.save().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -13,22 +13,17 @@
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig};
|
||||
use mas_data_model::{Device, TokenType, UpstreamOAuthProviderClaimsImports};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_router::UrlBuilder;
|
||||
use clap::Parser;
|
||||
use mas_config::{DatabaseConfig, PasswordsConfig};
|
||||
use mas_data_model::{Device, TokenType};
|
||||
use mas_storage::{
|
||||
compat::{CompatAccessTokenRepository, CompatSessionRepository},
|
||||
oauth2::OAuth2ClientRepository,
|
||||
upstream_oauth2::UpstreamOAuthProviderRepository,
|
||||
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
|
||||
Repository, RepositoryAccess, SystemClock,
|
||||
};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use oauth2_types::scope::Scope;
|
||||
use rand::SeedableRng;
|
||||
use tracing::{info, info_span, warn};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use crate::util::{database_from_config, password_manager_from_config};
|
||||
|
||||
@ -38,153 +33,14 @@ pub(super) struct Options {
|
||||
subcommand: Subcommand,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||
enum AuthenticationMethod {
|
||||
/// Client doesn't use any authentication
|
||||
None,
|
||||
|
||||
/// Client sends its `client_secret` in the request body
|
||||
ClientSecretPost,
|
||||
|
||||
/// Client sends its `client_secret` in the authorization header
|
||||
ClientSecretBasic,
|
||||
|
||||
/// Client uses its `client_secret` to sign a client assertion JWT
|
||||
ClientSecretJwt,
|
||||
|
||||
/// Client uses its private keys to sign a client assertion JWT
|
||||
PrivateKeyJwt,
|
||||
}
|
||||
|
||||
impl AuthenticationMethod {
|
||||
fn requires_client_secret(self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::ClientSecretJwt | Self::ClientSecretPost | Self::ClientSecretBasic
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AuthenticationMethod> for OAuthClientAuthenticationMethod {
|
||||
fn from(val: AuthenticationMethod) -> Self {
|
||||
(&val).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&AuthenticationMethod> for OAuthClientAuthenticationMethod {
|
||||
fn from(val: &AuthenticationMethod) -> Self {
|
||||
match val {
|
||||
AuthenticationMethod::None => OAuthClientAuthenticationMethod::None,
|
||||
AuthenticationMethod::ClientSecretPost => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost
|
||||
}
|
||||
AuthenticationMethod::ClientSecretBasic => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic
|
||||
}
|
||||
AuthenticationMethod::ClientSecretJwt => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||
}
|
||||
AuthenticationMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||
enum SigningAlgorithm {
|
||||
#[value(name = "HS256")]
|
||||
HS256,
|
||||
#[value(name = "HS384")]
|
||||
HS384,
|
||||
#[value(name = "HS512")]
|
||||
HS512,
|
||||
#[value(name = "RS256")]
|
||||
RS256,
|
||||
#[value(name = "RS384")]
|
||||
RS384,
|
||||
#[value(name = "RS512")]
|
||||
RS512,
|
||||
#[value(name = "PS256")]
|
||||
PS256,
|
||||
#[value(name = "PS384")]
|
||||
PS384,
|
||||
#[value(name = "PS512")]
|
||||
PS512,
|
||||
#[value(name = "ES256")]
|
||||
ES256,
|
||||
#[value(name = "ES384")]
|
||||
ES384,
|
||||
#[value(name = "ES256K")]
|
||||
ES256K,
|
||||
}
|
||||
|
||||
impl From<SigningAlgorithm> for JsonWebSignatureAlg {
|
||||
fn from(val: SigningAlgorithm) -> Self {
|
||||
(&val).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&SigningAlgorithm> for JsonWebSignatureAlg {
|
||||
fn from(val: &SigningAlgorithm) -> Self {
|
||||
match val {
|
||||
SigningAlgorithm::HS256 => Self::Hs256,
|
||||
SigningAlgorithm::HS384 => Self::Hs384,
|
||||
SigningAlgorithm::HS512 => Self::Hs512,
|
||||
SigningAlgorithm::RS256 => Self::Rs256,
|
||||
SigningAlgorithm::RS384 => Self::Rs384,
|
||||
SigningAlgorithm::RS512 => Self::Rs512,
|
||||
SigningAlgorithm::PS256 => Self::Ps256,
|
||||
SigningAlgorithm::PS384 => Self::Ps384,
|
||||
SigningAlgorithm::PS512 => Self::Ps512,
|
||||
SigningAlgorithm::ES256 => Self::Es256,
|
||||
SigningAlgorithm::ES384 => Self::Es384,
|
||||
SigningAlgorithm::ES256K => Self::Es256K,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
enum Subcommand {
|
||||
/// Mark email address as verified
|
||||
VerifyEmail { username: String, email: String },
|
||||
|
||||
/// Import clients from config
|
||||
ImportClients {
|
||||
/// Update existing clients
|
||||
#[arg(long)]
|
||||
update: bool,
|
||||
},
|
||||
|
||||
/// Set a user password
|
||||
SetPassword { username: String, password: String },
|
||||
|
||||
/// Add an OAuth 2.0 upstream
|
||||
#[command(name = "add-oauth-upstream")]
|
||||
AddOAuthUpstream {
|
||||
/// Issuer URL
|
||||
issuer: String,
|
||||
|
||||
/// Scope to ask for when authorizing with this upstream.
|
||||
///
|
||||
/// This should include at least the `openid` scope.
|
||||
scope: Scope,
|
||||
|
||||
/// Client authentication method used when using the token endpoint.
|
||||
#[arg(value_enum)]
|
||||
token_endpoint_auth_method: AuthenticationMethod,
|
||||
|
||||
/// Client ID
|
||||
client_id: String,
|
||||
|
||||
/// JWT signing algorithm used when authenticating for the token
|
||||
/// endpoint.
|
||||
#[arg(long, value_enum)]
|
||||
signing_alg: Option<SigningAlgorithm>,
|
||||
|
||||
/// Client Secret
|
||||
#[arg(long)]
|
||||
client_secret: Option<String>,
|
||||
},
|
||||
|
||||
/// Issue a compatibility token
|
||||
IssueCompatibilityToken {
|
||||
/// User for which to issue the token
|
||||
@ -271,128 +127,6 @@ impl Options {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::ImportClients { update } => {
|
||||
let _span = info_span!("cli.manage.import_clients").entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
for client in config.clients.iter() {
|
||||
let client_id = client.client_id;
|
||||
|
||||
let existing = repo.oauth2_client().lookup(client_id).await?.is_some();
|
||||
if !update && existing {
|
||||
warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients.");
|
||||
continue;
|
||||
}
|
||||
|
||||
if existing {
|
||||
info!(%client_id, "Updating client");
|
||||
} else {
|
||||
info!(%client_id, "Importing client");
|
||||
}
|
||||
|
||||
let client_secret = client.client_secret();
|
||||
let client_auth_method = client.client_auth_method();
|
||||
let jwks = client.jwks();
|
||||
let jwks_uri = client.jwks_uri();
|
||||
|
||||
// TODO: should be moved somewhere else
|
||||
let encrypted_client_secret = client_secret
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
repo.oauth2_client()
|
||||
.add_from_config(
|
||||
&mut rng,
|
||||
&clock,
|
||||
client_id,
|
||||
client_auth_method,
|
||||
encrypted_client_secret,
|
||||
jwks.cloned(),
|
||||
jwks_uri.cloned(),
|
||||
client.redirect_uris.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::AddOAuthUpstream {
|
||||
issuer,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
client_id,
|
||||
client_secret,
|
||||
signing_alg,
|
||||
} => {
|
||||
let _span = info_span!(
|
||||
"cli.manage.add_oauth_upstream",
|
||||
upstream_oauth_provider.issuer = issuer,
|
||||
upstream_oauth_provider.client_id = client_id,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let url_builder = UrlBuilder::new(config.http.public_base);
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
let requires_client_secret = token_endpoint_auth_method.requires_client_secret();
|
||||
|
||||
let token_endpoint_auth_method: OAuthClientAuthenticationMethod =
|
||||
token_endpoint_auth_method.into();
|
||||
|
||||
let token_endpoint_signing_alg: Option<JsonWebSignatureAlg> =
|
||||
signing_alg.as_ref().map(Into::into);
|
||||
|
||||
tracing::info!(%issuer, %scope, %token_endpoint_auth_method, %client_id, "Adding OAuth upstream");
|
||||
|
||||
if client_secret.is_none() && requires_client_secret {
|
||||
tracing::warn!("Token endpoint auth method requires a client secret, but none were provided");
|
||||
}
|
||||
|
||||
let encrypted_client_secret = client_secret
|
||||
.as_deref()
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
let provider = repo
|
||||
.upstream_oauth_provider()
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
issuer,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
UpstreamOAuthProviderClaimsImports::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||
let auth_uri = url_builder.upstream_oauth_authorize(provider.id);
|
||||
tracing::info!(
|
||||
%provider.id,
|
||||
%provider.client_id,
|
||||
provider.redirect_uri = %redirect_uri,
|
||||
"Test authorization by going to {auth_uri}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::IssueCompatibilityToken {
|
||||
username,
|
||||
admin,
|
||||
|
Reference in New Issue
Block a user