1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Sync the OAuth2 clients with CLI and remove redundant CLI tools

This commit is contained in:
Quentin Gliech
2023-06-26 16:22:49 +02:00
parent 9caf6251b5
commit dec9310a32
8 changed files with 738 additions and 435 deletions

View File

@ -50,9 +50,21 @@ fn map_claims_imports(
config: &mas_config::UpstreamOAuth2ClaimsImports,
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
mas_data_model::UpstreamOAuthProviderClaimsImports {
localpart: config.localpart.as_ref().map(map_import_preference).unwrap_or_default(),
displayname: config.displayname.as_ref().map(map_import_preference).unwrap_or_default(),
email: config.email.as_ref().map(map_import_preference).unwrap_or_default(),
localpart: config
.localpart
.as_ref()
.map(map_import_preference)
.unwrap_or_default(),
displayname: config
.displayname
.as_ref()
.map(map_import_preference)
.unwrap_or_default(),
email: config
.email
.as_ref()
.map(map_import_preference)
.unwrap_or_default(),
}
}
@ -116,88 +128,167 @@ impl Options {
}
SC::Sync { prune, dry_run } => {
let _span =
info_span!("cli.config.sync", prune = prune, dry_run = dry_run).entered();
let clock = SystemClock::default();
let config: RootConfig = root.load_config()?;
let encrypter = config.secrets.encrypter();
let pool = database_from_config(&config.database).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
tracing::info!(
prune,
dry_run,
"Syncing providers and clients defined in config to database"
);
let config_ids = config
.upstream_oauth2
.providers
.iter()
.map(|p| p.id)
.collect::<HashSet<_>>();
let existing = repo.upstream_oauth_provider().all().await?;
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for provider in to_delete {
info!(%provider.id, "Deleting provider");
if dry_run {
continue;
}
repo.upstream_oauth_provider().delete(provider).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {},
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
}
}
for provider in config.upstream_oauth2.providers {
info!(%provider.id, "Syncing provider");
if dry_run {
continue;
}
let encrypted_client_secret = provider
.client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let client_auth_method = provider.client_auth_method();
let client_auth_signing_alg = provider.client_auth_signing_alg();
repo.upstream_oauth_provider()
.upsert(
&clock,
provider.id,
provider.issuer,
provider.scope.parse()?,
client_auth_method,
client_auth_signing_alg,
provider.client_id,
encrypted_client_secret,
map_claims_imports(&provider.claims_imports),
)
.await?;
}
if dry_run {
info!("Dry run, rolling back changes");
repo.cancel().await?;
} else {
repo.save().await?;
}
sync(root, prune, dry_run).await?;
}
}
Ok(())
}
}
#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))]
async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> {
// XXX: we should disallow SeedableRng::from_entropy
let mut rng = rand_chacha::ChaChaRng::from_entropy();
let clock = SystemClock::default();
let config: RootConfig = root.load_config()?;
let encrypter = config.secrets.encrypter();
let pool = database_from_config(&config.database).await?;
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
tracing::info!(
prune,
dry_run,
"Syncing providers and clients defined in config to database"
);
{
let _span = info_span!("cli.config.sync.providers").entered();
let config_ids = config
.upstream_oauth2
.providers
.iter()
.map(|p| p.id)
.collect::<HashSet<_>>();
let existing = repo.upstream_oauth_provider().all().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for provider in to_delete {
info!(%provider.id, "Deleting provider");
if dry_run {
continue;
}
repo.upstream_oauth_provider().delete(provider).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {},
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
}
}
for provider in config.upstream_oauth2.providers {
if existing_ids.contains(&provider.id) {
info!(%provider.id, "Updating provider");
} else {
info!(%provider.id, "Adding provider");
}
if dry_run {
continue;
}
let encrypted_client_secret = provider
.client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let client_auth_method = provider.client_auth_method();
let client_auth_signing_alg = provider.client_auth_signing_alg();
repo.upstream_oauth_provider()
.upsert(
&clock,
provider.id,
provider.issuer,
provider.scope.parse()?,
client_auth_method,
client_auth_signing_alg,
provider.client_id,
encrypted_client_secret,
map_claims_imports(&provider.claims_imports),
)
.await?;
}
}
{
let _span = info_span!("cli.config.sync.clients").entered();
let config_ids = config
.clients
.iter()
.map(|c| c.client_id)
.collect::<HashSet<_>>();
let existing = repo.oauth2_client().all_static().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for client in to_delete {
info!(client.id = %client.client_id, "Deleting client");
if dry_run {
continue;
}
repo.oauth2_client().delete(client).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {},
1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."),
n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."),
}
}
for client in config.clients.iter() {
if existing_ids.contains(&client.client_id) {
info!(client.id = %client.client_id, "Updating client");
} else {
info!(client.id = %client.client_id, "Adding client");
}
if dry_run {
continue;
}
let client_secret = client.client_secret();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks();
let jwks_uri = client.jwks_uri();
// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
repo.oauth2_client()
.upsert_static(
&mut rng,
&clock,
client.client_id,
client_auth_method,
encrypted_client_secret,
jwks.cloned(),
jwks_uri.cloned(),
client.redirect_uris.clone(),
)
.await?;
}
}
if dry_run {
info!("Dry run, rolling back changes");
repo.cancel().await?;
} else {
repo.save().await?;
}
Ok(())
}

View File

@ -13,22 +13,17 @@
// limitations under the License.
use anyhow::Context;
use clap::{Parser, ValueEnum};
use mas_config::{DatabaseConfig, PasswordsConfig, RootConfig};
use mas_data_model::{Device, TokenType, UpstreamOAuthProviderClaimsImports};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_router::UrlBuilder;
use clap::Parser;
use mas_config::{DatabaseConfig, PasswordsConfig};
use mas_data_model::{Device, TokenType};
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
oauth2::OAuth2ClientRepository,
upstream_oauth2::UpstreamOAuthProviderRepository,
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Repository, RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use oauth2_types::scope::Scope;
use rand::SeedableRng;
use tracing::{info, info_span, warn};
use tracing::{info, info_span};
use crate::util::{database_from_config, password_manager_from_config};
@ -38,153 +33,14 @@ pub(super) struct Options {
subcommand: Subcommand,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
enum AuthenticationMethod {
/// Client doesn't use any authentication
None,
/// Client sends its `client_secret` in the request body
ClientSecretPost,
/// Client sends its `client_secret` in the authorization header
ClientSecretBasic,
/// Client uses its `client_secret` to sign a client assertion JWT
ClientSecretJwt,
/// Client uses its private keys to sign a client assertion JWT
PrivateKeyJwt,
}
impl AuthenticationMethod {
fn requires_client_secret(self) -> bool {
matches!(
self,
Self::ClientSecretJwt | Self::ClientSecretPost | Self::ClientSecretBasic
)
}
}
impl From<AuthenticationMethod> for OAuthClientAuthenticationMethod {
fn from(val: AuthenticationMethod) -> Self {
(&val).into()
}
}
impl From<&AuthenticationMethod> for OAuthClientAuthenticationMethod {
fn from(val: &AuthenticationMethod) -> Self {
match val {
AuthenticationMethod::None => OAuthClientAuthenticationMethod::None,
AuthenticationMethod::ClientSecretPost => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
AuthenticationMethod::ClientSecretBasic => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
AuthenticationMethod::ClientSecretJwt => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
AuthenticationMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
enum SigningAlgorithm {
#[value(name = "HS256")]
HS256,
#[value(name = "HS384")]
HS384,
#[value(name = "HS512")]
HS512,
#[value(name = "RS256")]
RS256,
#[value(name = "RS384")]
RS384,
#[value(name = "RS512")]
RS512,
#[value(name = "PS256")]
PS256,
#[value(name = "PS384")]
PS384,
#[value(name = "PS512")]
PS512,
#[value(name = "ES256")]
ES256,
#[value(name = "ES384")]
ES384,
#[value(name = "ES256K")]
ES256K,
}
impl From<SigningAlgorithm> for JsonWebSignatureAlg {
fn from(val: SigningAlgorithm) -> Self {
(&val).into()
}
}
impl From<&SigningAlgorithm> for JsonWebSignatureAlg {
fn from(val: &SigningAlgorithm) -> Self {
match val {
SigningAlgorithm::HS256 => Self::Hs256,
SigningAlgorithm::HS384 => Self::Hs384,
SigningAlgorithm::HS512 => Self::Hs512,
SigningAlgorithm::RS256 => Self::Rs256,
SigningAlgorithm::RS384 => Self::Rs384,
SigningAlgorithm::RS512 => Self::Rs512,
SigningAlgorithm::PS256 => Self::Ps256,
SigningAlgorithm::PS384 => Self::Ps384,
SigningAlgorithm::PS512 => Self::Ps512,
SigningAlgorithm::ES256 => Self::Es256,
SigningAlgorithm::ES384 => Self::Es384,
SigningAlgorithm::ES256K => Self::Es256K,
}
}
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Mark email address as verified
VerifyEmail { username: String, email: String },
/// Import clients from config
ImportClients {
/// Update existing clients
#[arg(long)]
update: bool,
},
/// Set a user password
SetPassword { username: String, password: String },
/// Add an OAuth 2.0 upstream
#[command(name = "add-oauth-upstream")]
AddOAuthUpstream {
/// Issuer URL
issuer: String,
/// Scope to ask for when authorizing with this upstream.
///
/// This should include at least the `openid` scope.
scope: Scope,
/// Client authentication method used when using the token endpoint.
#[arg(value_enum)]
token_endpoint_auth_method: AuthenticationMethod,
/// Client ID
client_id: String,
/// JWT signing algorithm used when authenticating for the token
/// endpoint.
#[arg(long, value_enum)]
signing_alg: Option<SigningAlgorithm>,
/// Client Secret
#[arg(long)]
client_secret: Option<String>,
},
/// Issue a compatibility token
IssueCompatibilityToken {
/// User for which to issue the token
@ -271,128 +127,6 @@ impl Options {
Ok(())
}
SC::ImportClients { update } => {
let _span = info_span!("cli.manage.import_clients").entered();
let config: RootConfig = root.load_config()?;
let pool = database_from_config(&config.database).await?;
let encrypter = config.secrets.encrypter();
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
for client in config.clients.iter() {
let client_id = client.client_id;
let existing = repo.oauth2_client().lookup(client_id).await?.is_some();
if !update && existing {
warn!(%client_id, "Skipping already imported client. Run with --update to update existing clients.");
continue;
}
if existing {
info!(%client_id, "Updating client");
} else {
info!(%client_id, "Importing client");
}
let client_secret = client.client_secret();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks();
let jwks_uri = client.jwks_uri();
// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
repo.oauth2_client()
.add_from_config(
&mut rng,
&clock,
client_id,
client_auth_method,
encrypted_client_secret,
jwks.cloned(),
jwks_uri.cloned(),
client.redirect_uris.clone(),
)
.await?;
}
repo.save().await?;
Ok(())
}
SC::AddOAuthUpstream {
issuer,
scope,
token_endpoint_auth_method,
client_id,
client_secret,
signing_alg,
} => {
let _span = info_span!(
"cli.manage.add_oauth_upstream",
upstream_oauth_provider.issuer = issuer,
upstream_oauth_provider.client_id = client_id,
)
.entered();
let config: RootConfig = root.load_config()?;
let encrypter = config.secrets.encrypter();
let pool = database_from_config(&config.database).await?;
let url_builder = UrlBuilder::new(config.http.public_base);
let mut repo = PgRepository::from_pool(&pool).await?.boxed();
let requires_client_secret = token_endpoint_auth_method.requires_client_secret();
let token_endpoint_auth_method: OAuthClientAuthenticationMethod =
token_endpoint_auth_method.into();
let token_endpoint_signing_alg: Option<JsonWebSignatureAlg> =
signing_alg.as_ref().map(Into::into);
tracing::info!(%issuer, %scope, %token_endpoint_auth_method, %client_id, "Adding OAuth upstream");
if client_secret.is_none() && requires_client_secret {
tracing::warn!("Token endpoint auth method requires a client secret, but none were provided");
}
let encrypted_client_secret = client_secret
.as_deref()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
issuer,
scope,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
UpstreamOAuthProviderClaimsImports::default(),
)
.await?;
repo.save().await?;
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
let auth_uri = url_builder.upstream_oauth_authorize(provider.id);
tracing::info!(
%provider.id,
%provider.client_id,
provider.redirect_uri = %redirect_uri,
"Test authorization by going to {auth_uri}"
);
Ok(())
}
SC::IssueCompatibilityToken {
username,
admin,

View File

@ -0,0 +1,19 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- This adds a flag to the OAuth 2.0 clients to indicate whether they are static (i.e. defined in config) or not.
ALTER TABLE oauth2_clients
ADD COLUMN is_static
BOOLEAN NOT NULL
DEFAULT FALSE;

View File

@ -14,6 +14,18 @@
},
"query": "\n UPDATE oauth2_authorization_grants\n SET fulfilled_at = $2\n , oauth2_session_id = $3\n WHERE oauth2_authorization_grant_id = $1\n "
},
"036e9e2cb7271782e48700fecd3fdd80f596ed433f37f2528c7edbdc88b13646": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM oauth2_consents\n WHERE oauth2_client_id = $1\n "
},
"0469c1d3ad11fd96febacad33302709c870ead848d6920cdfdb18912d543488e": {
"describe": {
"columns": [
@ -165,6 +177,18 @@
},
"query": "\n SELECT user_email_confirmation_code_id\n , user_email_id\n , code\n , created_at\n , expires_at\n , consumed_at\n FROM user_email_confirmation_codes\n WHERE code = $1\n AND user_email_id = $2\n "
},
"1eb829460407fca22b717b88a1a0a9b7b920d807a4b6c235e1bee524cd73b266": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n "
},
"1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc": {
"describe": {
"columns": [],
@ -190,6 +214,57 @@
},
"query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n "
},
"2a0d8d70d21afa9a2c9c1c432853361bb85911c48f7db6c3873b0f5abf35940b": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM oauth2_authorization_grants\n WHERE oauth2_client_id = $1\n "
},
"2ee26886c56f04cd53d4c0968f5cf0963f92b6d15e6af0e69378a6447dee677c": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE oauth2_session_id IN (\n SELECT oauth2_session_id\n FROM oauth2_sessions\n WHERE oauth2_client_id = $1\n )\n "
},
"31cbbd841029812c6d3500cae04a8e9e5723e4749d339465492b68e072c3a802": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Bool",
"Bool",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text"
]
}
},
"query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n , is_static\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, FALSE)\n "
},
"3d66f3121b11ce923b9c60609b510a8ca899640e78cc8f5b03168622928ffe94": {
"describe": {
"columns": [],
@ -706,6 +781,18 @@
},
"query": "\n INSERT INTO oauth2_sessions\n ( oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n )\n VALUES ($1, $2, $3, $4, $5)\n "
},
"5b697dd7834d33ec55972d3ba43d25fe794bc0b69c5938275711faa7a80b811f": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM oauth2_refresh_tokens\n WHERE oauth2_session_id IN (\n SELECT oauth2_session_id\n FROM oauth2_sessions\n WHERE oauth2_client_id = $1\n )\n "
},
"5f6b7e38ef9bc3b39deabba277d0255fb8cfb2adaa65f47b78a8fac11d8c91c3": {
"describe": {
"columns": [],
@ -721,6 +808,18 @@
},
"query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n "
},
"5fe1bb569d13a7d3ff22887b3fc5b76ff901c183b314f8ccb5018d70c516abf6": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM oauth2_clients\n WHERE oauth2_client_id = $1\n "
},
"6021c1b9e17b0b2e8b511888f8c6be00683ba0635a13eb7fcd403d3d4a3f90db": {
"describe": {
"columns": [],
@ -913,6 +1012,24 @@
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET consumed_at = $1\n WHERE upstream_oauth_authorization_session_id = $2\n "
},
"68c4cd463e4035ba8384f11818b7be602e2fbc34a5582f31f95b0cc5fa2aeb92": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Bool",
"Bool",
"Text",
"Jsonb",
"Text"
]
}
},
"query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n , is_static\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, TRUE)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n , is_static = TRUE\n "
},
"6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384": {
"describe": {
"columns": [
@ -1213,6 +1330,18 @@
},
"query": "\n INSERT INTO oauth2_client_redirect_uris\n (oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)\n SELECT id, $2, redirect_uri\n FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)\n "
},
"7cd0264707100f5b3cb2582f3f840bf66649742374e3643f1902ae69377fc9b6": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM oauth2_client_redirect_uris\n WHERE oauth2_client_id = $1\n "
},
"7ce387b1b0aaf10e72adde667b19521b66eaafa51f73bf2f95e38b8f3b64a229": {
"describe": {
"columns": [],
@ -1226,17 +1355,119 @@
},
"query": "\n UPDATE upstream_oauth_links\n SET user_id = $1\n WHERE upstream_oauth_link_id = $2\n "
},
"82fec0e13755e7032457d94fe8d212c62f14c80a98b61d82965f1b93f841c014": {
"7e676491b077d4bc8a9cdb4a27ebf119d98cd35ebb52b1064fdb2d9eed78d0e8": {
"describe": {
"columns": [],
"nullable": [],
"columns": [
{
"name": "oauth2_client_id",
"ordinal": 0,
"type_info": "Uuid"
},
{
"name": "encrypted_client_secret",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "redirect_uris!",
"ordinal": 2,
"type_info": "TextArray"
},
{
"name": "grant_type_authorization_code",
"ordinal": 3,
"type_info": "Bool"
},
{
"name": "grant_type_refresh_token",
"ordinal": 4,
"type_info": "Bool"
},
{
"name": "client_name",
"ordinal": 5,
"type_info": "Text"
},
{
"name": "logo_uri",
"ordinal": 6,
"type_info": "Text"
},
{
"name": "client_uri",
"ordinal": 7,
"type_info": "Text"
},
{
"name": "policy_uri",
"ordinal": 8,
"type_info": "Text"
},
{
"name": "tos_uri",
"ordinal": 9,
"type_info": "Text"
},
{
"name": "jwks_uri",
"ordinal": 10,
"type_info": "Text"
},
{
"name": "jwks",
"ordinal": 11,
"type_info": "Jsonb"
},
{
"name": "id_token_signed_response_alg",
"ordinal": 12,
"type_info": "Text"
},
{
"name": "userinfo_signed_response_alg",
"ordinal": 13,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_method",
"ordinal": 14,
"type_info": "Text"
},
{
"name": "token_endpoint_auth_signing_alg",
"ordinal": 15,
"type_info": "Text"
},
{
"name": "initiate_login_uri",
"ordinal": 16,
"type_info": "Text"
}
],
"nullable": [
false,
true,
null,
false,
false,
true,
true,
true,
true,
true,
true,
true,
true,
true,
true,
true,
true
],
"parameters": {
"Left": [
"Uuid"
]
"Left": []
}
},
"query": "\n DELETE FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_provider_id = $1\n "
"query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n WHERE is_static = TRUE\n "
},
"836fb7567d84057fa7f1edaab834c21a158a5762fe220b6bfacd6576be6c613c": {
"describe": {
@ -1392,7 +1623,7 @@
},
"query": "\n SELECT oauth2_client_id\n , encrypted_client_secret\n , ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\"\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n FROM oauth2_clients c\n\n WHERE oauth2_client_id = ANY($1::uuid[])\n "
},
"8a32a39c43147dfd9dbd25ff04686c3cdbc52ea5689ce3454d15e8ed31756f38": {
"8acbdc892d44efb53529da1c2df65bea6b799a43cf4c9264a37d392847e6eff0": {
"describe": {
"columns": [],
"nullable": [],
@ -1402,25 +1633,7 @@
]
}
},
"query": "\n DELETE FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n "
},
"8a79c7c392dd930628caadec80c9b2645501475ab4feacbac59ca1bc52b16c3f": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Bool",
"Bool",
"Text",
"Jsonb",
"Text"
]
}
},
"query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , token_endpoint_auth_method\n , jwks\n , jwks_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7)\n ON CONFLICT (oauth2_client_id)\n DO\n UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret\n , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code\n , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token\n , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method\n , jwks = EXCLUDED.jwks\n , jwks_uri = EXCLUDED.jwks_uri\n "
"query": "\n DELETE FROM oauth2_sessions\n WHERE oauth2_client_id = $1\n "
},
"8b7297c263336d70c2b647212b16f7ae39bc5cb1572e3a2dcfcd67f196a1fa39": {
"describe": {
@ -1919,6 +2132,18 @@
},
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n "
},
"b992283a9b43cbb8f86149f3f55cb47fb628dabd8fadc50e6a5772903f851e1c": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid"
]
}
},
"query": "\n DELETE FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_provider_id = $1\n "
},
"bbf62633c561706a762089bbab2f76a9ba3e2ed3539ef16accb601fb609c2ec9": {
"describe": {
"columns": [],
@ -2515,32 +2740,5 @@
}
},
"query": "\n SELECT oauth2_session_id\n , user_session_id\n , oauth2_client_id\n , scope\n , created_at\n , finished_at\n FROM oauth2_sessions\n\n WHERE oauth2_session_id = $1\n "
},
"f5edcd4c306ca8179cdf9d4aab59fbba971b54611c91345849920954dd8089b3": {
"describe": {
"columns": [],
"nullable": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Bool",
"Bool",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text"
]
}
},
"query": "\n INSERT INTO oauth2_clients\n ( oauth2_client_id\n , encrypted_client_secret\n , grant_type_authorization_code\n , grant_type_refresh_token\n , client_name\n , logo_uri\n , client_uri\n , policy_uri\n , tos_uri\n , jwks_uri\n , jwks\n , id_token_signed_response_alg\n , userinfo_signed_response_alg\n , token_endpoint_auth_method\n , token_endpoint_auth_signing_alg\n , initiate_login_uri\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n "
}
}

View File

@ -428,9 +428,10 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
, is_static
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, FALSE)
"#,
Uuid::from(id),
encrypted_client_secret,
@ -527,7 +528,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
}
#[tracing::instrument(
name = "db.oauth2_client.add_from_config",
name = "db.oauth2_client.upsert_static",
skip_all,
fields(
db.statement,
@ -535,7 +536,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
),
err,
)]
async fn add_from_config(
async fn upsert_static(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
@ -564,9 +565,10 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
, token_endpoint_auth_method
, jwks
, jwks_uri
, is_static
)
VALUES
($1, $2, $3, $4, $5, $6, $7)
($1, $2, $3, $4, $5, $6, $7, TRUE)
ON CONFLICT (oauth2_client_id)
DO
UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
@ -575,6 +577,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
, jwks = EXCLUDED.jwks
, jwks_uri = EXCLUDED.jwks_uri
, is_static = TRUE
"#,
Uuid::from(client_id),
encrypted_client_secret,
@ -590,7 +593,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
{
let span = info_span!(
"db.oauth2_client.add_from_config.redirect_uris",
"db.oauth2_client.upsert_static.redirect_uris",
client.id = %client_id,
db.statement = tracing::field::Empty,
);
@ -656,6 +659,52 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
})
}
#[tracing::instrument(
name = "db.oauth2_client.all_static",
skip_all,
fields(
db.statement,
),
err,
)]
async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT oauth2_client_id
, encrypted_client_secret
, ARRAY(
SELECT redirect_uri
FROM oauth2_client_redirect_uris r
WHERE r.oauth2_client_id = c.oauth2_client_id
) AS "redirect_uris!"
, grant_type_authorization_code
, grant_type_refresh_token
, client_name
, logo_uri
, client_uri
, policy_uri
, tos_uri
, jwks_uri
, jwks
, id_token_signed_response_alg
, userinfo_signed_response_alg
, token_endpoint_auth_method
, token_endpoint_auth_signing_alg
, initiate_login_uri
FROM oauth2_clients c
WHERE is_static = TRUE
"#,
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
res.into_iter()
.map(|r| r.try_into().map_err(DatabaseError::from))
.collect()
}
#[tracing::instrument(
name = "db.oauth2_client.get_consent_for_user",
skip_all,
@ -698,6 +747,7 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
}
#[tracing::instrument(
name = "db.oauth2_client.give_consent_for_user",
skip_all,
fields(
db.statement,
@ -739,10 +789,161 @@ impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
&tokens,
now,
)
.traced()
.traced()
.execute(&mut *self.conn)
.await?;
Ok(())
}
#[tracing::instrument(
name = "db.oauth2_client.delete_by_id",
skip_all,
fields(
db.statement,
client.id = %id,
),
err,
)]
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
// Delete the authorization grants
{
let span = info_span!(
"db.oauth2_client.delete_by_id.authorization_grants",
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM oauth2_authorization_grants
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
// Delete the user consents
{
let span = info_span!(
"db.oauth2_client.delete_by_id.consents",
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM oauth2_consents
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
// Delete the OAuth 2 sessions related data
{
let span = info_span!(
"db.oauth2_client.delete_by_id.access_tokens",
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE oauth2_session_id IN (
SELECT oauth2_session_id
FROM oauth2_sessions
WHERE oauth2_client_id = $1
)
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
{
let span = info_span!(
"db.oauth2_client.delete_by_id.refresh_tokens",
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM oauth2_refresh_tokens
WHERE oauth2_session_id IN (
SELECT oauth2_session_id
FROM oauth2_sessions
WHERE oauth2_client_id = $1
)
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
{
let span = info_span!(
"db.oauth2_client.delete_by_id.sessions",
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM oauth2_sessions
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
// Delete the redirect URIs
{
let span = info_span!(
"db.oauth2_client.delete_by_id.redirect_uris",
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM oauth2_client_redirect_uris
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
// Now delete the client itself
let res = sqlx::query!(
r#"
DELETE FROM oauth2_clients
WHERE oauth2_client_id = $1
"#,
Uuid::from(id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)
}
}

View File

@ -189,7 +189,10 @@ mod tests {
assert_eq!(links.edges[0].user_id, Some(user.id));
// Try deleting the provider
repo.upstream_oauth_provider().delete(provider).await.unwrap();
repo.upstream_oauth_provider()
.delete(provider)
.await
.unwrap();
let providers = repo.upstream_oauth_provider().all().await.unwrap();
assert!(providers.is_empty());
}

View File

@ -20,6 +20,7 @@ use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page,
use oauth2_types::scope::Scope;
use rand::RngCore;
use sqlx::{types::Json, PgConnection, QueryBuilder};
use tracing::{info_span, Instrument};
use ulid::Ulid;
use uuid::Uuid;
@ -298,30 +299,47 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
err,
)]
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
// Delete the authorization sessions first, as they have a foreign key constraint
// on the links and the providers.
sqlx::query!(
r#"
DELETE FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.traced()
// Delete the authorization sessions first, as they have a foreign key
// constraint on the links and the providers.
{
let span = info_span!(
"db.oauth2_client.delete_by_id.authorization_sessions",
upstream_oauth_provider.id = %id,
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM upstream_oauth_authorization_sessions
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
// Delete the links next, as they have a foreign key constraint on the providers.
sqlx::query!(
r#"
DELETE FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.traced()
// Delete the links next, as they have a foreign key constraint on the
// providers.
{
let span = info_span!(
"db.oauth2_client.delete_by_id.links",
upstream_oauth_provider.id = %id,
db.statement = tracing::field::Empty,
);
sqlx::query!(
r#"
DELETE FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
"#,
Uuid::from(id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
}
let res = sqlx::query!(
r#"

View File

@ -124,7 +124,7 @@ pub trait OAuth2ClientRepository: Send + Sync {
initiate_login_uri: Option<Url>,
) -> Result<Client, Self::Error>;
/// Add or replace a client from the configuration
/// Add or replace a static client
///
/// Returns the client that was added or replaced
///
@ -143,7 +143,7 @@ pub trait OAuth2ClientRepository: Send + Sync {
///
/// Returns [`Self::Error`] if the underlying repository fails
#[allow(clippy::too_many_arguments)]
async fn add_from_config(
async fn upsert_static(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
@ -155,6 +155,13 @@ pub trait OAuth2ClientRepository: Send + Sync {
redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error>;
/// List all static clients
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
/// Get the list of scopes that the user has given consent for the given
/// client
///
@ -193,6 +200,32 @@ pub trait OAuth2ClientRepository: Send + Sync {
user: &User,
scope: &Scope,
) -> Result<(), Self::Error>;
/// Delete a client
///
/// # Parameters
///
/// * `client`: The client to delete
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails, or if the
/// client does not exist
async fn delete(&mut self, client: Client) -> Result<(), Self::Error> {
self.delete_by_id(client.id).await
}
/// Delete a client by ID
///
/// # Parameters
///
/// * `id`: The ID of the client to delete
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails, or if the
/// client does not exist
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
}
repository_impl!(OAuth2ClientRepository:
@ -225,7 +258,7 @@ repository_impl!(OAuth2ClientRepository:
initiate_login_uri: Option<Url>,
) -> Result<Client, Self::Error>;
async fn add_from_config(
async fn upsert_static(
&mut self,
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
@ -237,6 +270,12 @@ repository_impl!(OAuth2ClientRepository:
redirect_uris: Vec<Url>,
) -> Result<Client, Self::Error>;
async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
async fn delete(&mut self, client: Client) -> Result<(), Self::Error>;
async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
async fn get_consent_for_user(
&mut self,
client: &Client,