1
0
mirror of https://github.com/smallstep/cli.git synced 2025-08-09 03:22:43 +03:00

A few more fixes

- set maximum poll time to 5 minutes
- avoid panics for missing or invalid values in Authz discover response
This commit is contained in:
max furman
2022-06-26 21:37:49 -07:00
parent 8d1b0cb938
commit dfbbf953b5

View File

@@ -50,6 +50,7 @@ const (
defaultDeviceAuthzClientID = "1087160488420-1u0jqoulmv3mfomfh6fhkfs4vk4bdjih.apps.googleusercontent.com" defaultDeviceAuthzClientID = "1087160488420-1u0jqoulmv3mfomfh6fhkfs4vk4bdjih.apps.googleusercontent.com"
defaultDeviceAuthzClientNotSoSecret = "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" defaultDeviceAuthzClientNotSoSecret = "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"
defaultDeviceAuthzInterval = 5 defaultDeviceAuthzInterval = 5
defaultDeviceAuthzExpiresIn = time.Minute * 5
// The URN for getting verification token offline // The URN for getting verification token offline
oobCallbackUrn = "urn:ietf:wg:oauth:2.0:oob" oobCallbackUrn = "urn:ietf:wg:oauth:2.0:oob"
@@ -306,6 +307,48 @@ OpenID standard defines the following values, but your provider may support some
command.Register(cmd) command.Register(cmd)
} }
type consoleFlow int
const (
oobConsoleFlow consoleFlow = iota
deviceConsoleFlow
)
type options struct {
Provider string
Email string
Console bool
ConsoleFlow consoleFlow
Implicit bool
CallbackListener string
CallbackListenerURL string
CallbackPath string
TerminalRedirect string
Browser string
}
// Validate validates the options.
func (o *options) Validate() error {
if o.Provider != "google" && !strings.HasPrefix(o.Provider, "https://") {
return errors.New("use a valid provider: google")
}
if o.CallbackListener != "" {
if _, _, err := net.SplitHostPort(o.CallbackListener); err != nil {
return errors.Wrapf(err, "invalid value '%s' for flag '--listen'", o.CallbackListener)
}
}
if o.CallbackListenerURL != "" {
u, err := url.Parse(o.CallbackListenerURL)
if err != nil || u.Scheme == "" {
return errors.Wrapf(err, "invalid value '%s' for flag '--listen-url'", o.CallbackListenerURL)
}
if u.Path != "" {
o.CallbackPath = u.Path
}
}
return nil
}
func oauthCmd(c *cli.Context) error { func oauthCmd(c *cli.Context) error {
opts := &options{ opts := &options{
Provider: c.String("provider"), Provider: c.String("provider"),
@@ -325,24 +368,26 @@ func oauthCmd(c *cli.Context) error {
return errors.New("flag '--client-id' required with '--provider'") return errors.New("flag '--client-id' required with '--provider'")
} }
if c.Bool("oob") && c.Bool("device") {
return errs.MutuallyExclusiveFlags(c, "oob", "device")
}
isOOBFlow, isDeviceFlow := false, false isOOBFlow, isDeviceFlow := false, false
consoleFlowInput := c.String("console-flow") consoleFlowInput := c.String("console-flow")
switch { switch {
case strings.EqualFold(consoleFlowInput, "device"): case strings.EqualFold(consoleFlowInput, "device"):
opts.Console = true
opts.ConsoleFlow = deviceConsoleFlow
isDeviceFlow = true isDeviceFlow = true
case strings.EqualFold(consoleFlowInput, "oob"): case strings.EqualFold(consoleFlowInput, "oob"):
opts.Console = true
opts.ConsoleFlow = oobConsoleFlow
isOOBFlow = true isOOBFlow = true
case c.IsSet("console-flow"): case c.IsSet("console-flow"):
return errs.InvalidFlagValue(c, "console-flow", consoleFlowInput, "device, oob") return errs.InvalidFlagValue(c, "console-flow", consoleFlowInput, "device, oob")
case c.Bool("console"): case c.Bool("console"):
if opts.Provider == "google" || strings.HasPrefix(opts.Provider, "https://accounts.google.com") { if opts.Provider == "google" || strings.HasPrefix(opts.Provider, "https://accounts.google.com") {
isOOBFlow = true isOOBFlow = true
opts.ConsoleFlow = oobConsoleFlow
} else { } else {
isDeviceFlow = true isDeviceFlow = true
opts.ConsoleFlow = deviceConsoleFlow
} }
} }
@@ -498,41 +543,6 @@ func oauthCmd(c *cli.Context) error {
return nil return nil
} }
type options struct {
Provider string
Email string
Console bool
Device bool
Implicit bool
CallbackListener string
CallbackListenerURL string
CallbackPath string
TerminalRedirect string
Browser string
}
// Validate validates the options.
func (o *options) Validate() error {
if o.Provider != "google" && !strings.HasPrefix(o.Provider, "https://") {
return errors.New("use a valid provider: google")
}
if o.CallbackListener != "" {
if _, _, err := net.SplitHostPort(o.CallbackListener); err != nil {
return errors.Wrapf(err, "invalid value '%s' for flag '--listen'", o.CallbackListener)
}
}
if o.CallbackListenerURL != "" {
u, err := url.Parse(o.CallbackListenerURL)
if err != nil || u.Scheme == "" {
return errors.Wrapf(err, "invalid value '%s' for flag '--listen-url'", o.CallbackListenerURL)
}
if u.Path != "" {
o.CallbackPath = u.Path
}
}
return nil
}
type oauth struct { type oauth struct {
provider string provider string
clientID string clientID string
@@ -604,28 +614,30 @@ func newOauth(provider, clientID, clientSecret, authzEp, deviceAuthzEp, tokenEp,
default: default:
userinfoEp := "" userinfoEp := ""
if (opts.Device && deviceAuthzEp == "" && tokenEp == "") || isDeviceFlow := opts.Console && opts.ConsoleFlow == deviceConsoleFlow
(!opts.Device && authzEp == "" && tokenEp == "") {
if (isDeviceFlow && deviceAuthzEp == "" && tokenEp == "") ||
(!isDeviceFlow && authzEp == "" && tokenEp == "") {
d, err := disco(provider) d, err := disco(provider)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, ok := d["device_authorization_endpoint"]; !ok && opts.Device { if v, ok := d["device_authorization_endpoint"].(string); !ok && isDeviceFlow {
return nil, errors.New("missing 'device_authorization_endpoint' in provider metadata") return nil, errors.New("missing or invalid 'device_authorization_endpoint' in provider metadata")
} else if ok { } else if ok {
deviceAuthzEp = d["device_authorization_endpoint"].(string) deviceAuthzEp = v
} }
if _, ok := d["authorization_endpoint"]; !ok && !opts.Device { if v, ok := d["authorization_endpoint"].(string); !ok && !isDeviceFlow {
return nil, errors.New("missing 'authorization_endpoint' in provider metadata") return nil, errors.New("missing or invalid 'authorization_endpoint' in provider metadata")
} else if ok { } else if ok {
authzEp = d["authorization_endpoint"].(string) authzEp = v
} }
if _, ok := d["token_endpoint"]; !ok { v, ok := d["token_endpoint"].(string)
return nil, errors.New("missing 'token_endpoint' in provider metadata") if !ok {
return nil, errors.New("missing or invalid 'token_endpoint' in provider metadata")
} }
tokenEp = d["token_endpoint"].(string) tokenEp, userinfoEp = v, v
userinfoEp = d["token_endpoint"].(string)
} }
return &oauth{ return &oauth{
@@ -868,14 +880,20 @@ func (o *oauth) DoDeviceAuthorization() (*token, error) {
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
data.Set("device_code", idr.DeviceCode) data.Set("device_code", idr.DeviceCode)
var tok *token endPollIn := defaultDeviceAuthzExpiresIn
t := time.NewTimer(time.Duration(idr.ExpiresIn) * time.Second) if idr.ExpiresIn > 0 {
expiresIn := time.Duration(idr.ExpiresIn) * time.Second
if expiresIn < endPollIn {
endPollIn = expiresIn
}
}
t := time.NewTimer(endPollIn)
defer t.Stop() defer t.Stop()
for { for {
select { select {
case <-time.After(time.Duration(idr.Interval) * time.Second): case <-time.After(time.Duration(idr.Interval) * time.Second):
tok, err = o.deviceAuthzTokenPoll(data) if tok, err := o.deviceAuthzTokenPoll(data); err != nil {
if err != nil {
return nil, err return nil, err
} else if tok != nil { } else if tok != nil {
return tok, nil return tok, nil