From dfbbf953b55dcaddbac44bd5b38afd73ebc036d8 Mon Sep 17 00:00:00 2001 From: max furman Date: Sun, 26 Jun 2022 21:37:49 -0700 Subject: [PATCH] A few more fixes - set maximum poll time to 5 minutes - avoid panics for missing or invalid values in Authz discover response --- command/oauth/cmd.go | 128 ++++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 55 deletions(-) diff --git a/command/oauth/cmd.go b/command/oauth/cmd.go index 651e49d9..2b90107b 100644 --- a/command/oauth/cmd.go +++ b/command/oauth/cmd.go @@ -50,6 +50,7 @@ const ( defaultDeviceAuthzClientID = "1087160488420-1u0jqoulmv3mfomfh6fhkfs4vk4bdjih.apps.googleusercontent.com" defaultDeviceAuthzClientNotSoSecret = "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" defaultDeviceAuthzInterval = 5 + defaultDeviceAuthzExpiresIn = time.Minute * 5 // The URN for getting verification token offline oobCallbackUrn = "urn:ietf:wg:oauth:2.0:oob" @@ -306,6 +307,48 @@ OpenID standard defines the following values, but your provider may support some command.Register(cmd) } +type consoleFlow int + +const ( + oobConsoleFlow consoleFlow = iota + deviceConsoleFlow +) + +type options struct { + Provider string + Email string + Console bool + ConsoleFlow consoleFlow + Implicit bool + CallbackListener string + CallbackListenerURL string + CallbackPath string + TerminalRedirect string + Browser string +} + +// Validate validates the options. +func (o *options) Validate() error { + if o.Provider != "google" && !strings.HasPrefix(o.Provider, "https://") { + return errors.New("use a valid provider: google") + } + if o.CallbackListener != "" { + if _, _, err := net.SplitHostPort(o.CallbackListener); err != nil { + return errors.Wrapf(err, "invalid value '%s' for flag '--listen'", o.CallbackListener) + } + } + if o.CallbackListenerURL != "" { + u, err := url.Parse(o.CallbackListenerURL) + if err != nil || u.Scheme == "" { + return errors.Wrapf(err, "invalid value '%s' for flag '--listen-url'", o.CallbackListenerURL) + } + if u.Path != "" { + o.CallbackPath = u.Path + } + } + return nil +} + func oauthCmd(c *cli.Context) error { opts := &options{ Provider: c.String("provider"), @@ -325,24 +368,26 @@ func oauthCmd(c *cli.Context) error { return errors.New("flag '--client-id' required with '--provider'") } - if c.Bool("oob") && c.Bool("device") { - return errs.MutuallyExclusiveFlags(c, "oob", "device") - } - isOOBFlow, isDeviceFlow := false, false consoleFlowInput := c.String("console-flow") switch { case strings.EqualFold(consoleFlowInput, "device"): + opts.Console = true + opts.ConsoleFlow = deviceConsoleFlow isDeviceFlow = true case strings.EqualFold(consoleFlowInput, "oob"): + opts.Console = true + opts.ConsoleFlow = oobConsoleFlow isOOBFlow = true case c.IsSet("console-flow"): return errs.InvalidFlagValue(c, "console-flow", consoleFlowInput, "device, oob") case c.Bool("console"): if opts.Provider == "google" || strings.HasPrefix(opts.Provider, "https://accounts.google.com") { isOOBFlow = true + opts.ConsoleFlow = oobConsoleFlow } else { isDeviceFlow = true + opts.ConsoleFlow = deviceConsoleFlow } } @@ -498,41 +543,6 @@ func oauthCmd(c *cli.Context) error { return nil } -type options struct { - Provider string - Email string - Console bool - Device bool - Implicit bool - CallbackListener string - CallbackListenerURL string - CallbackPath string - TerminalRedirect string - Browser string -} - -// Validate validates the options. -func (o *options) Validate() error { - if o.Provider != "google" && !strings.HasPrefix(o.Provider, "https://") { - return errors.New("use a valid provider: google") - } - if o.CallbackListener != "" { - if _, _, err := net.SplitHostPort(o.CallbackListener); err != nil { - return errors.Wrapf(err, "invalid value '%s' for flag '--listen'", o.CallbackListener) - } - } - if o.CallbackListenerURL != "" { - u, err := url.Parse(o.CallbackListenerURL) - if err != nil || u.Scheme == "" { - return errors.Wrapf(err, "invalid value '%s' for flag '--listen-url'", o.CallbackListenerURL) - } - if u.Path != "" { - o.CallbackPath = u.Path - } - } - return nil -} - type oauth struct { provider string clientID string @@ -604,28 +614,30 @@ func newOauth(provider, clientID, clientSecret, authzEp, deviceAuthzEp, tokenEp, default: userinfoEp := "" - if (opts.Device && deviceAuthzEp == "" && tokenEp == "") || - (!opts.Device && authzEp == "" && tokenEp == "") { + isDeviceFlow := opts.Console && opts.ConsoleFlow == deviceConsoleFlow + + if (isDeviceFlow && deviceAuthzEp == "" && tokenEp == "") || + (!isDeviceFlow && authzEp == "" && tokenEp == "") { d, err := disco(provider) if err != nil { return nil, err } - if _, ok := d["device_authorization_endpoint"]; !ok && opts.Device { - return nil, errors.New("missing 'device_authorization_endpoint' in provider metadata") + if v, ok := d["device_authorization_endpoint"].(string); !ok && isDeviceFlow { + return nil, errors.New("missing or invalid 'device_authorization_endpoint' in provider metadata") } else if ok { - deviceAuthzEp = d["device_authorization_endpoint"].(string) + deviceAuthzEp = v } - if _, ok := d["authorization_endpoint"]; !ok && !opts.Device { - return nil, errors.New("missing 'authorization_endpoint' in provider metadata") + if v, ok := d["authorization_endpoint"].(string); !ok && !isDeviceFlow { + return nil, errors.New("missing or invalid 'authorization_endpoint' in provider metadata") } else if ok { - authzEp = d["authorization_endpoint"].(string) + authzEp = v } - if _, ok := d["token_endpoint"]; !ok { - return nil, errors.New("missing 'token_endpoint' in provider metadata") + v, ok := d["token_endpoint"].(string) + if !ok { + return nil, errors.New("missing or invalid 'token_endpoint' in provider metadata") } - tokenEp = d["token_endpoint"].(string) - userinfoEp = d["token_endpoint"].(string) + tokenEp, userinfoEp = v, v } return &oauth{ @@ -868,14 +880,20 @@ func (o *oauth) DoDeviceAuthorization() (*token, error) { data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") data.Set("device_code", idr.DeviceCode) - var tok *token - t := time.NewTimer(time.Duration(idr.ExpiresIn) * time.Second) + endPollIn := defaultDeviceAuthzExpiresIn + if idr.ExpiresIn > 0 { + expiresIn := time.Duration(idr.ExpiresIn) * time.Second + if expiresIn < endPollIn { + endPollIn = expiresIn + } + } + + t := time.NewTimer(endPollIn) defer t.Stop() for { select { case <-time.After(time.Duration(idr.Interval) * time.Second): - tok, err = o.deviceAuthzTokenPoll(data) - if err != nil { + if tok, err := o.deviceAuthzTokenPoll(data); err != nil { return nil, err } else if tok != nil { return tok, nil