1
0
mirror of https://github.com/smallstep/cli.git synced 2025-04-19 10:42:15 +03:00

Simplify ingoring usage of provisioner flag when managing policies

This commit is contained in:
Herman Slatman 2025-03-27 21:48:15 +01:00
parent 064866f86c
commit c153ef3e75
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F
16 changed files with 96 additions and 275 deletions

View File

@ -22,7 +22,6 @@ import (
"github.com/smallstep/cli/command/version"
"github.com/smallstep/cli/internal/plugin"
"github.com/smallstep/cli/internal/provisionerflag"
"github.com/smallstep/cli/utils"
// Enabled cas interfaces.
@ -127,17 +126,11 @@ func newApp(stdout, stderr io.Writer) *cli.App {
app.Copyright = fmt.Sprintf("(c) 2018-%d Smallstep Labs, Inc.", time.Now().Year())
// Flag of custom configuration flag
app.Flags = append(app.Flags, cli.StringFlag{ //nolint:gocritic // intentionally split for documentation
app.Flags = append(app.Flags, cli.StringFlag{
Name: "config",
Usage: "path to the config file to use for CLI flags",
})
// add a hidden flag that can be used to signal that the provisioner
// flag should be ignored in certain commands. By defining it on the
// app level it can be ignored in multiple (sub)commands without having
// to specify the flag in each command.
app.Flags = append(app.Flags, provisionerflag.DisabledSentinelFlag)
// Action runs on `step` or `step <command>` if the command is not enabled.
app.Action = func(ctx *cli.Context) error {
args := ctx.Args()

View File

@ -3,13 +3,9 @@ package main
import (
"bytes"
"regexp"
"slices"
"testing"
"github.com/stretchr/testify/require"
"github.com/urfave/cli"
"github.com/smallstep/cli/internal/provisionerflag"
)
func TestAppHasAllCommands(t *testing.T) {
@ -48,15 +44,3 @@ func TestAppRuns(t *testing.T) {
output := ansiRegex.ReplaceAllString(stdout.String(), "")
require.Contains(t, output, "step -- plumbing for distributed systems")
}
func TestAppHasSentinelFlagForIgnoringProvisionersFlag(t *testing.T) {
app := newApp(nil, nil)
require.NotNil(t, app)
// this test only checks if the flag is present when an app is created
// through [getApp]. This is sufficient for now to proof that the flag
// exists in the actual released CLI binary.
require.True(t, slices.ContainsFunc(app.Flags, func(f cli.Flag) bool {
return f.GetName() == provisionerflag.DisabledSentinelFlagName()
}))
}

View File

@ -76,11 +76,12 @@ $ step ca policy authority x509 deny cn "My Bad CA Name"
}
func commonNamesAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
args = clictx.Args()
)
clictx := command.CLIContextFromContext(ctx)
args := clictx.Args()
if len(args) == 0 {
return errs.TooFewArguments(clictx)
}
@ -90,7 +91,7 @@ func commonNamesAction(ctx context.Context) (err error) {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return fmt.Errorf("error retrieving policy: %w", err)
}
@ -115,7 +116,7 @@ func commonNamesAction(ctx context.Context) (err error) {
panic("no SSH nor X.509 context set")
}
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -94,11 +94,12 @@ $ step ca policy authority ssh host allow dns "badsshhost.local"
}
func dnsAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
args = clictx.Args()
)
clictx := command.CLIContextFromContext(ctx)
args := clictx.Args()
if len(args) == 0 {
return errs.TooFewArguments(clictx)
}
@ -108,7 +109,7 @@ func dnsAction(ctx context.Context) (err error) {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return fmt.Errorf("error retrieving policy: %w", err)
}
@ -140,7 +141,7 @@ func dnsAction(ctx context.Context) (err error) {
panic("no SSH nor X.509 context set")
}
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -81,11 +81,12 @@ $ step ca policy provisioner ssh user deny email @example.com --provisioner my_p
}
func emailAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
args = clictx.Args()
)
clictx := command.CLIContextFromContext(ctx)
args := clictx.Args()
if len(args) == 0 {
return errs.TooFewArguments(clictx)
}
@ -95,7 +96,7 @@ func emailAction(ctx context.Context) (err error) {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return err
}
@ -127,7 +128,7 @@ func emailAction(ctx context.Context) (err error) {
panic("no SSH nor X.509 context set")
}
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -114,11 +114,12 @@ $ step ca policy authority ssh host deny ip 192.168.0.40
}
func ipAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
args = clictx.Args()
)
clictx := command.CLIContextFromContext(ctx)
args := clictx.Args()
if len(args) == 0 {
return errs.TooFewArguments(clictx)
}
@ -128,7 +129,7 @@ func ipAction(ctx context.Context) (err error) {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return err
}
@ -160,7 +161,7 @@ func ipAction(ctx context.Context) (err error) {
panic("no SSH nor X.509 context set")
}
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -16,7 +16,6 @@ import (
"github.com/smallstep/cli/command/ca/policy/policycontext"
"github.com/smallstep/cli/internal/command"
"github.com/smallstep/cli/internal/provisionerflag"
)
var provisionerFilterFlag = cli.StringFlag{
@ -24,28 +23,33 @@ var provisionerFilterFlag = cli.StringFlag{
Usage: `The provisioner <name>`,
}
// ignoreProvisionerFlagIfRequired is a helper function that marks the provisioner
// flag to be ignored when managing a provisioner or ACME account level policy. In
// those cases the provisioner flag is used to filter which provisioner the policy
// applies to, as opposed to its normal usage, where it can be used to select the
// (admin) provisioner to use for authentication.
func ignoreProvisionerFlagIfRequired(ctx context.Context) {
clictx := command.CLIContextFromContext(ctx)
if policycontext.IsProvisionerPolicyLevel(ctx) || policycontext.IsACMEPolicyLevel(ctx) {
provisionerflag.Ignore(clictx)
func retrieveAndUnsetProvisionerFlagIfRequired(ctx context.Context) string {
// when managing policies on the authority level there's no need
// to select a provisioner, so the flag does not need to be unset.
if policycontext.IsAuthorityPolicyLevel(ctx) {
return ""
}
}
func retrieveAndInitializePolicy(ctx context.Context, client *ca.AdminClient) (*linkedca.Policy, error) {
var (
policy *linkedca.Policy
err error
)
clictx := command.CLIContextFromContext(ctx)
provisioner := clictx.String("provisioner")
reference := clictx.String("eab-key-reference")
keyID := clictx.String("eab-key-id")
// unset the provisioner flag value, so that it's not used
// automatically in token flows.
if err := clictx.Set("provisioner", ""); err != nil {
panic(fmt.Errorf("failed unsetting provisioner flag: %w", err))
}
return provisioner
}
func retrieveAndInitializePolicy(ctx context.Context, client *ca.AdminClient, provisioner string) (*linkedca.Policy, error) {
var (
clictx = command.CLIContextFromContext(ctx)
reference = clictx.String("eab-key-reference")
keyID = clictx.String("eab-key-id")
policy *linkedca.Policy
err error
)
switch {
case policycontext.IsAuthorityPolicyLevel(ctx):
@ -160,13 +164,11 @@ func initPolicy(p *linkedca.Policy) *linkedca.Policy {
return p
}
func updatePolicy(ctx context.Context, client *ca.AdminClient, policy *linkedca.Policy) (*linkedca.Policy, error) {
clictx := command.CLIContextFromContext(ctx)
provisioner := clictx.String("provisioner")
reference := clictx.String("eab-key-reference")
keyID := clictx.String("eab-key-id")
func updatePolicy(ctx context.Context, client *ca.AdminClient, policy *linkedca.Policy, provisioner string) (*linkedca.Policy, error) {
var (
clictx = command.CLIContextFromContext(ctx)
reference = clictx.String("eab-key-reference")
keyID = clictx.String("eab-key-id")
updatedPolicy *linkedca.Policy
err error
)

View File

@ -76,11 +76,12 @@ $ step ca policy provisioner ssh host deny principal root --provisioner my_ssh_u
}
func principalAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
args = clictx.Args()
)
clictx := command.CLIContextFromContext(ctx)
args := clictx.Args()
if len(args) == 0 {
return errs.TooFewArguments(clictx)
}
@ -90,7 +91,7 @@ func principalAction(ctx context.Context) (err error) {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return err
}
@ -122,7 +123,7 @@ func principalAction(ctx context.Context) (err error) {
panic("no SSH nor X.509 context set")
}
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -71,12 +71,12 @@ $ step ca policy acme remove --provisioner my_acme_provisioner --eab-key-id "lUO
}
func removeAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
clictx := command.CLIContextFromContext(ctx)
provisioner := clictx.String("provisioner")
reference := clictx.String("eab-key-reference")
keyID := clictx.String("eab-key-id")
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
reference = clictx.String("eab-key-reference")
keyID = clictx.String("eab-key-id")
)
client, err := cautils.NewAdminClient(clictx)
if err != nil {

View File

@ -71,11 +71,12 @@ $ step ca policy provisioner x509 allow uri "*.example.com" --provisioner my_pro
}
func uriAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
args = clictx.Args()
)
clictx := command.CLIContextFromContext(ctx)
args := clictx.Args()
if len(args) == 0 {
return errs.TooFewArguments(clictx)
}
@ -85,7 +86,7 @@ func uriAction(ctx context.Context) (err error) {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return fmt.Errorf("error retrieving policy: %w", err)
}
@ -110,7 +111,7 @@ func uriAction(ctx context.Context) (err error) {
panic("no SSH nor X.509 context set")
}
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -72,22 +72,19 @@ $ step ca policy acme view --provisioner my_acme_provisioner --eab-key-id "lUOTG
}
func viewAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
clictx := command.CLIContextFromContext(ctx)
provisioner := clictx.String("provisioner")
reference := clictx.String("eab-key-reference")
keyID := clictx.String("eab-key-id")
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
reference = clictx.String("eab-key-reference")
keyID = clictx.String("eab-key-id")
policy *linkedca.Policy
)
client, err := cautils.NewAdminClient(clictx)
if err != nil {
return fmt.Errorf("error creating admin client: %w", err)
}
var (
policy *linkedca.Policy
)
switch {
case policycontext.IsAuthorityPolicyLevel(ctx):
policy, err = client.GetAuthorityPolicy()

View File

@ -10,23 +10,24 @@ import (
// AllowWildcardsAction updates the policy to allow wildcard names.
func AllowWildcardsAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
clictx := command.CLIContextFromContext(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
)
client, err := cautils.NewAdminClient(clictx)
if err != nil {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return fmt.Errorf("error retrieving policy: %w", err)
}
policy.X509.AllowWildcardNames = true
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}
@ -36,23 +37,24 @@ func AllowWildcardsAction(ctx context.Context) (err error) {
// DenyWildcardsAction updates the policy to deny wildcard names.
func DenyWildcardsAction(ctx context.Context) (err error) {
ignoreProvisionerFlagIfRequired(ctx)
clictx := command.CLIContextFromContext(ctx)
var (
provisioner = retrieveAndUnsetProvisionerFlagIfRequired(ctx)
clictx = command.CLIContextFromContext(ctx)
)
client, err := cautils.NewAdminClient(clictx)
if err != nil {
return fmt.Errorf("error creating admin client: %w", err)
}
policy, err := retrieveAndInitializePolicy(ctx, client)
policy, err := retrieveAndInitializePolicy(ctx, client, provisioner)
if err != nil {
return fmt.Errorf("error retrieving policy: %w", err)
}
policy.X509.AllowWildcardNames = false
updatedPolicy, err := updatePolicy(ctx, client, policy)
updatedPolicy, err := updatePolicy(ctx, client, policy, provisioner)
if err != nil {
return fmt.Errorf("error updating policy: %w", err)
}

View File

@ -1,58 +0,0 @@
// this package is used for ignoring the provisioner flag in specific
// cli commands.
package provisionerflag
import (
"github.com/urfave/cli"
)
var disabledSentinel = "/x-disable-provisioner-flag"
// DisabledSentinelFlagName returns the name of the sentinel flag
// that can be used to ignore the provisioner flag in specific cli commands.
func DisabledSentinelFlagName() string {
return disabledSentinel
}
// DisabledSentinelFlag is a sentinel flag that can be used to ignore
// the provisioner flag in specific cli commands.
var DisabledSentinelFlag = cli.BoolFlag{
Name: disabledSentinel,
Hidden: true,
}
// Ignore marks the provisioner flag to be ignored. If an error occurs it
// will traverse the [cli.Context] recursively until setting the value
// succeeds or the root context is reached. If the value is not set along
// the way, it will panic.
func Ignore(ctx *cli.Context) {
if ctx == nil {
panic("context is nil")
}
err := ctx.Set(disabledSentinel, "true")
switch {
case err == nil:
return
case ctx.Parent() != nil:
Ignore(ctx.Parent())
default:
panic(err)
}
}
// ShouldBeIgnored returns whether the provisioner flag should be ignored.
// If the [cli.Context] does not contain the sentinel flag value, it will
// recursively look for it up to the root context.
func ShouldBeIgnored(ctx *cli.Context) bool {
if ctx.IsSet(disabledSentinel) && ctx.String(disabledSentinel) == "true" {
return true
}
// recursively look for the sentinel value in the parent context
if ctx.Parent() != nil {
return ShouldBeIgnored(ctx.Parent())
}
return false
}

View File

@ -1,90 +0,0 @@
// this package is used for ignoring the provisioner flag in specific
// cli commands.
package provisionerflag_test
import (
"flag"
"testing"
"github.com/stretchr/testify/require"
"github.com/urfave/cli"
"github.com/smallstep/cli/internal/provisionerflag"
)
func TestProvisionerFlagCanBeIgnored(t *testing.T) {
t.Parallel()
app := cli.NewApp()
t.Run("not-ignored", func(t *testing.T) {
t.Parallel()
parentFlags := flag.NewFlagSet("parent", 0)
parentFlags.String(provisionerflag.DisabledSentinelFlagName(), "", "")
parent := cli.NewContext(app, parentFlags, nil)
ctx := cli.NewContext(app, flag.NewFlagSet("test", 0), parent)
require.False(t, provisionerflag.ShouldBeIgnored(ctx))
})
t.Run("child", func(t *testing.T) {
t.Parallel()
parent := cli.NewContext(app, flag.NewFlagSet("parent", 0), nil)
childFlags := flag.NewFlagSet("test", 0)
childFlags.String(provisionerflag.DisabledSentinelFlagName(), "", "")
ctx := cli.NewContext(app, childFlags, parent)
provisionerflag.Ignore(ctx)
require.True(t, provisionerflag.ShouldBeIgnored(ctx))
})
t.Run("parent", func(t *testing.T) {
t.Parallel()
parentFlags := flag.NewFlagSet("parent", 0)
parentFlags.String(provisionerflag.DisabledSentinelFlagName(), "", "")
parent := cli.NewContext(app, parentFlags, nil)
ctx := cli.NewContext(app, flag.NewFlagSet("test", 0), parent)
provisionerflag.Ignore(ctx)
require.True(t, provisionerflag.ShouldBeIgnored(ctx))
})
t.Run("chain", func(t *testing.T) {
t.Parallel()
parentFlags := flag.NewFlagSet("parent", 0)
parentFlags.String(provisionerflag.DisabledSentinelFlagName(), "", "")
parent := cli.NewContext(app, parentFlags, nil)
ctx := cli.NewContext(app, flag.NewFlagSet("test-1", 0), parent)
ctx = cli.NewContext(app, flag.NewFlagSet("test-2", 0), ctx)
ctx = cli.NewContext(app, flag.NewFlagSet("test-3", 0), ctx)
provisionerflag.Ignore(ctx)
require.True(t, provisionerflag.ShouldBeIgnored(ctx))
})
t.Run("nil-context", func(t *testing.T) {
t.Parallel()
require.Panics(t, func() { provisionerflag.Ignore(nil) })
})
t.Run("flag-undefined", func(t *testing.T) {
t.Parallel()
parent := cli.NewContext(app, flag.NewFlagSet("parent", 0), nil)
ctx := cli.NewContext(app, flag.NewFlagSet("test", 0), parent)
require.Panics(t, func() { provisionerflag.Ignore(ctx) })
})
}

View File

@ -15,7 +15,6 @@ import (
"github.com/smallstep/cli-utils/ui"
"github.com/smallstep/cli/flags"
"github.com/smallstep/cli/internal/provisionerflag"
"github.com/smallstep/cli/utils"
)
@ -352,10 +351,6 @@ func provisionerPrompt(ctx *cli.Context, provisioners provisioner.List) (provisi
// Filter by provisioner / issuer (provisioner name)
if provisionerName, flag := flags.FirstStringOf(ctx, "provisioner", "issuer"); provisionerName != "" {
provisioners = provisionerFilter(provisioners, func(p provisioner.Interface) bool {
if provisionerflag.ShouldBeIgnored(ctx) {
return true // fake match; effectively skipping provisioner flag value for provisioner-dependent policy commands
}
return p.GetName() == provisionerName
})
if len(provisioners) == 0 {

View File

@ -9,8 +9,6 @@ import (
"github.com/urfave/cli"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/internal/provisionerflag"
)
func newContext(t *testing.T) *cli.Context {
@ -19,8 +17,6 @@ func newContext(t *testing.T) *cli.Context {
app := cli.NewApp()
parentFlags := flag.NewFlagSet(fmt.Sprintf("parent-%s", t.Name()), 0)
parentFlags.String(provisionerflag.DisabledSentinelFlagName(), "", "")
parentCtx := cli.NewContext(app, parentFlags, nil)
set := flag.NewFlagSet(fmt.Sprintf("child-%s", t.Name()), 0)
@ -89,13 +85,7 @@ func TestProvisionerPromptPrompts(t *testing.T) {
})
t.Run("ignore-provisioner-flag", func(t *testing.T) {
clictx := newContext(t)
require.NoError(t, clictx.Set("provisioner", "scep"))
// by ignoring the provisioner flag the prompt should fail, because
// there will be multiple provisioners to select from, which it can't do
// if it can't open a tty to get user input.
provisionerflag.Ignore(clictx)
clictx := newContext(t) // provisioner flag is not set; in reality it'll be unset based on policy level
p1 := &provisioner.OIDC{Name: "oidc", ClientID: "client-id"}
p2 := &provisioner.SCEP{Name: "scep"}