You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Define upstream OAuth providers in the config
And adds CLI tool to sync them with the database (WIP)
This commit is contained in:
@@ -12,10 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use clap::Parser;
|
||||
use mas_config::{ConfigurationSection, RootConfig};
|
||||
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Repository, RepositoryAccess};
|
||||
use mas_storage_pg::PgRepository;
|
||||
use rand::SeedableRng;
|
||||
use tracing::{info, info_span};
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use crate::util::database_from_config;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub(super) struct Options {
|
||||
@@ -33,27 +39,36 @@ enum Subcommand {
|
||||
|
||||
/// Generate a new config file
|
||||
Generate,
|
||||
|
||||
/// Sync the clients and providers from the config file to the database
|
||||
Sync {
|
||||
/// Prune elements that are in the database but not in the config file
|
||||
/// anymore
|
||||
#[clap(long)]
|
||||
prune: bool,
|
||||
|
||||
/// Do not actually write to the database
|
||||
#[clap(long)]
|
||||
dry_run: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Options {
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
match &self.subcommand {
|
||||
match self.subcommand {
|
||||
SC::Dump => {
|
||||
let _span = info_span!("cli.config.dump").entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
SC::Check => {
|
||||
let _span = info_span!("cli.config.check").entered();
|
||||
|
||||
let _config: RootConfig = root.load_config()?;
|
||||
info!(path = ?root.config, "Configuration file looks good");
|
||||
Ok(())
|
||||
}
|
||||
SC::Generate => {
|
||||
let _span = info_span!("cli.config.generate").entered();
|
||||
@@ -63,9 +78,60 @@ impl Options {
|
||||
let config = RootConfig::load_and_generate(rng).await?;
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
SC::Sync { prune, dry_run } => {
|
||||
let _span = info_span!("cli.config.sync").entered();
|
||||
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config.database).await?;
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
|
||||
tracing::info!(
|
||||
prune,
|
||||
dry_run,
|
||||
"Syncing providers and clients defined in config to database"
|
||||
);
|
||||
|
||||
let existing = repo.upstream_oauth_provider().all().await?;
|
||||
|
||||
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
|
||||
let config_ids = config
|
||||
.upstream_oauth2
|
||||
.providers
|
||||
.iter()
|
||||
.map(|p| p.id)
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let needs_pruning = existing_ids.difference(&config_ids).collect::<Vec<_>>();
|
||||
if prune {
|
||||
for id in needs_pruning {
|
||||
info!(provider.id = %id, "Deleting provider");
|
||||
}
|
||||
} else if !needs_pruning.is_empty() {
|
||||
warn!(
|
||||
"{} provider(s) in the database are not in the config. Run with `--prune` to delete them.",
|
||||
needs_pruning.len()
|
||||
);
|
||||
}
|
||||
|
||||
for provider in config.upstream_oauth2.providers {
|
||||
if existing_ids.contains(&provider.id) {
|
||||
info!(%provider.id, "Updating provider");
|
||||
} else {
|
||||
info!(%provider.id, "Adding provider");
|
||||
}
|
||||
}
|
||||
|
||||
if dry_run {
|
||||
info!("Dry run, rolling back changes");
|
||||
repo.cancel().await?;
|
||||
} else {
|
||||
repo.save().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -33,7 +33,7 @@ enum Subcommand {
|
||||
}
|
||||
|
||||
impl Options {
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
let _span = info_span!("cli.database.migrate").entered();
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = database_from_config(&config).await?;
|
||||
|
@@ -65,10 +65,10 @@ fn print_headers(parts: &hyper::http::response::Parts) {
|
||||
|
||||
impl Options {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
let http_client_factory = HttpClientFactory::new(10);
|
||||
match &self.subcommand {
|
||||
match self.subcommand {
|
||||
SC::Http {
|
||||
show_headers,
|
||||
json: false,
|
||||
@@ -83,15 +83,13 @@ impl Options {
|
||||
let response = client.ready().await?.call(request).await?;
|
||||
let (parts, body) = response.into_parts();
|
||||
|
||||
if *show_headers {
|
||||
if show_headers {
|
||||
print_headers(&parts);
|
||||
}
|
||||
|
||||
let mut body = hyper::body::aggregate(body).await?;
|
||||
let mut stdout = tokio::io::stdout();
|
||||
stdout.write_all_buf(&mut body).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::Http {
|
||||
@@ -113,14 +111,12 @@ impl Options {
|
||||
client.ready().await?.call(request).await?;
|
||||
let (parts, body) = response.into_parts();
|
||||
|
||||
if *show_headers {
|
||||
if show_headers {
|
||||
print_headers(&parts);
|
||||
}
|
||||
|
||||
let body = serde_json::to_string_pretty(&body)?;
|
||||
println!("{body}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SC::Policy => {
|
||||
@@ -130,8 +126,9 @@ impl Options {
|
||||
let policy_factory = policy_factory_from_config(&config).await?;
|
||||
|
||||
let _instance = policy_factory.instantiate().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -202,13 +202,13 @@ enum Subcommand {
|
||||
|
||||
impl Options {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
let clock = SystemClock::default();
|
||||
// XXX: we should disallow SeedableRng::from_entropy
|
||||
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
|
||||
match &self.subcommand {
|
||||
match self.subcommand {
|
||||
SC::SetPassword { username, password } => {
|
||||
let _span =
|
||||
info_span!("cli.manage.set_password", user.username = %username).entered();
|
||||
@@ -222,11 +222,11 @@ impl Options {
|
||||
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(username)
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
let password = password.as_bytes().to_vec().into();
|
||||
let password = password.into_bytes().into();
|
||||
|
||||
let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
|
||||
|
||||
@@ -254,13 +254,13 @@ impl Options {
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(username)
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
let email = repo
|
||||
.user_email()
|
||||
.find(&user, email)
|
||||
.find(&user, &email)
|
||||
.await?
|
||||
.context("Email not found")?;
|
||||
let email = repo.user_email().mark_as_verified(&clock, email).await?;
|
||||
@@ -302,7 +302,7 @@ impl Options {
|
||||
|
||||
// TODO: should be moved somewhere else
|
||||
let encrypted_client_secret = client_secret
|
||||
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
repo.oauth2_client()
|
||||
@@ -361,7 +361,7 @@ impl Options {
|
||||
|
||||
let encrypted_client_secret = client_secret
|
||||
.as_deref()
|
||||
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
|
||||
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
|
||||
.transpose()?;
|
||||
|
||||
let provider = repo
|
||||
@@ -369,11 +369,11 @@ impl Options {
|
||||
.add(
|
||||
&mut rng,
|
||||
&clock,
|
||||
issuer.clone(),
|
||||
scope.clone(),
|
||||
issuer,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_signing_alg,
|
||||
client_id.clone(),
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
UpstreamOAuthProviderClaimsImports::default(),
|
||||
)
|
||||
@@ -404,19 +404,19 @@ impl Options {
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.find_by_username(username)
|
||||
.find_by_username(&username)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
let device = if let Some(device_id) = device_id {
|
||||
device_id.clone().try_into()?
|
||||
device_id.try_into()?
|
||||
} else {
|
||||
Device::generate(&mut rng)
|
||||
};
|
||||
|
||||
let compat_session = repo
|
||||
.compat_session()
|
||||
.add(&mut rng, &clock, &user, device, *admin)
|
||||
.add(&mut rng, &clock, &user, device, admin)
|
||||
.await?;
|
||||
|
||||
let token = TokenType::CompatAccessToken.generate(&mut rng);
|
||||
|
@@ -60,17 +60,17 @@ pub struct Options {
|
||||
}
|
||||
|
||||
impl Options {
|
||||
pub async fn run(&self) -> anyhow::Result<()> {
|
||||
pub async fn run(mut self) -> anyhow::Result<()> {
|
||||
use Subcommand as S;
|
||||
match &self.subcommand {
|
||||
Some(S::Config(c)) => c.run(self).await,
|
||||
Some(S::Database(c)) => c.run(self).await,
|
||||
Some(S::Server(c)) => c.run(self).await,
|
||||
Some(S::Worker(c)) => c.run(self).await,
|
||||
Some(S::Manage(c)) => c.run(self).await,
|
||||
Some(S::Templates(c)) => c.run(self).await,
|
||||
Some(S::Debug(c)) => c.run(self).await,
|
||||
None => self::server::Options::default().run(self).await,
|
||||
match self.subcommand.take() {
|
||||
Some(S::Config(c)) => c.run(&self).await,
|
||||
Some(S::Database(c)) => c.run(&self).await,
|
||||
Some(S::Server(c)) => c.run(&self).await,
|
||||
Some(S::Worker(c)) => c.run(&self).await,
|
||||
Some(S::Manage(c)) => c.run(&self).await,
|
||||
Some(S::Templates(c)) => c.run(&self).await,
|
||||
Some(S::Debug(c)) => c.run(&self).await,
|
||||
None => self::server::Options::default().run(&self).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -52,7 +52,7 @@ pub(super) struct Options {
|
||||
|
||||
impl Options {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
let span = info_span!("cli.run.init").entered();
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
|
@@ -35,9 +35,9 @@ enum Subcommand {
|
||||
}
|
||||
|
||||
impl Options {
|
||||
pub async fn run(&self, _root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, _root: &super::Options) -> anyhow::Result<()> {
|
||||
use Subcommand as SC;
|
||||
match &self.subcommand {
|
||||
match self.subcommand {
|
||||
SC::Check { path } => {
|
||||
let _span = info_span!("cli.templates.check").entered();
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Options {
|
||||
// XXX: we should disallow SeedableRng::from_entropy
|
||||
let mut rng = rand_chacha::ChaChaRng::from_entropy();
|
||||
let url_builder = mas_router::UrlBuilder::new("https://example.com/".parse()?);
|
||||
let templates = Templates::load(path.clone(), url_builder).await?;
|
||||
let templates = Templates::load(path, url_builder).await?;
|
||||
templates.check_render(clock.now(), &mut rng).await?;
|
||||
|
||||
Ok(())
|
||||
|
@@ -29,7 +29,7 @@ use crate::util::{database_from_config, mailer_from_config, templates_from_confi
|
||||
pub(super) struct Options {}
|
||||
|
||||
impl Options {
|
||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||
pub async fn run(self, root: &super::Options) -> anyhow::Result<()> {
|
||||
let span = info_span!("cli.worker.init").entered();
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
|
Reference in New Issue
Block a user