mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)
* wip * update documentation * add streamingcredentialsprovider in options * fix: put back option in pool creation * add package level comment * Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested * Change function type name Change CancelProviderFunc to UnsubscribeFunc * add tests * fix race in tests * fix example tests * wip, hooks refactor * fix build * update README.md * update wordlist * update README.md * refactor(auth): early returns in cred listener * fix(doctest): simulate some delay * feat(conn): add close hook on conn * fix(tests): simulate start/stop in mock credentials provider * fix(auth): don't double close the conn * docs(README): mark streaming credentials provider as experimental * fix(auth): streamline auth err proccess * fix(auth): check err on close conn * chore(entraid): use the repo under redis org
This commit is contained in:
152
redis.go
152
redis.go
@ -9,6 +9,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/hscan"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
@ -203,6 +204,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
||||
type baseClient struct {
|
||||
opt *Options
|
||||
connPool pool.Pooler
|
||||
hooksMixin
|
||||
|
||||
onClose func() error // hook called when client is closed
|
||||
}
|
||||
@ -282,30 +284,107 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||
return auth.NewReAuthCredentialsListener(
|
||||
c.reAuthConnection(poolCn),
|
||||
c.onAuthenticationErr(poolCn),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||
return func(credentials auth.Credentials) error {
|
||||
var err error
|
||||
username, password := credentials.BasicAuth()
|
||||
ctx := context.Background()
|
||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||
// hooksMixin are intentionally empty here
|
||||
cn := newConn(c.opt, connPool, nil)
|
||||
|
||||
if username != "" {
|
||||
err = cn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = cn.Auth(ctx, password).Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
// Close the connection to force a reconnection.
|
||||
err := c.connPool.CloseConn(poolCn)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err)
|
||||
// try to close the network connection directly
|
||||
// so that no resource is leaked
|
||||
err := poolCn.Close()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||
onClose := c.onClose
|
||||
return func() error {
|
||||
var firstErr error
|
||||
err := newOnClose()
|
||||
// Even if we have an error we would like to execute the onClose hook
|
||||
// if it exists. We will return the first error that occurred.
|
||||
// This is to keep error handling consistent with the rest of the code.
|
||||
if err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
if onClose != nil {
|
||||
err = onClose()
|
||||
if err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if cn.Inited {
|
||||
return nil
|
||||
}
|
||||
cn.Inited = true
|
||||
|
||||
var err error
|
||||
username, password := c.opt.Username, c.opt.Password
|
||||
if c.opt.CredentialsProviderContext != nil {
|
||||
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
||||
return err
|
||||
cn.Inited = true
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
}
|
||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||
}
|
||||
} else if c.opt.CredentialsProvider != nil {
|
||||
username, password = c.opt.CredentialsProvider()
|
||||
} else if c.opt.Username != "" || c.opt.Password != "" {
|
||||
username, password = c.opt.Username, c.opt.Password
|
||||
}
|
||||
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool)
|
||||
|
||||
var auth bool
|
||||
// for redis-server versions that do not support the HELLO command,
|
||||
// RESP2 will continue to be used.
|
||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
auth = true
|
||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
// Authentication successful with HELLO command
|
||||
} else if !isRedisError(err) {
|
||||
// When the server responds with the RESP protocol and the result is not a normal
|
||||
// execution result of the HELLO command, we consider it to be an indication that
|
||||
@ -315,17 +394,19 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
// with different error string results for unsupported commands, making it
|
||||
// difficult to rely on error strings to determine all results.
|
||||
return err
|
||||
} else if password != "" {
|
||||
// Try legacy AUTH command if HELLO failed
|
||||
if username != "" {
|
||||
err = conn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = conn.Auth(ctx, password).Err()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
if !auth && password != "" {
|
||||
if username != "" {
|
||||
pipe.AuthACL(ctx, username, password)
|
||||
} else {
|
||||
pipe.Auth(ctx, password)
|
||||
}
|
||||
}
|
||||
|
||||
if c.opt.DB > 0 {
|
||||
pipe.Select(ctx, c.opt.DB)
|
||||
}
|
||||
@ -341,7 +422,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
}
|
||||
|
||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||
@ -363,6 +444,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -481,6 +563,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
||||
return c.opt.ReadTimeout
|
||||
}
|
||||
|
||||
// context returns the context for the current connection.
|
||||
// If the context timeout is enabled, it returns the original context.
|
||||
// Otherwise, it returns a new background context.
|
||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
if c.opt.ContextTimeoutEnabled {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// Close closes the client, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Client, as the Client is meant to be
|
||||
@ -633,13 +725,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
if c.opt.ContextTimeoutEnabled {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Client is a Redis client representing a pool of zero or more underlying connections.
|
||||
@ -650,7 +735,6 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
type Client struct {
|
||||
*baseClient
|
||||
cmdable
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
// NewClient returns a client to the Redis Server specified by Options.
|
||||
@ -689,7 +773,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Conn() *Conn {
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
|
||||
}
|
||||
|
||||
// Do create a Cmd from the args and processes the cmd.
|
||||
@ -822,10 +906,12 @@ type Conn struct {
|
||||
baseClient
|
||||
cmdable
|
||||
statefulCmdable
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
||||
// newConn is a helper func to create a new Conn instance.
|
||||
// the Conn instance is not thread-safe and should not be shared between goroutines.
|
||||
// the parentHooks will be cloned, no need to clone before passing it.
|
||||
func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
|
||||
c := Conn{
|
||||
baseClient: baseClient{
|
||||
opt: opt,
|
||||
@ -833,6 +919,10 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
||||
},
|
||||
}
|
||||
|
||||
if parentHooks != nil {
|
||||
c.hooksMixin = parentHooks.clone()
|
||||
}
|
||||
|
||||
c.cmdable = c.Process
|
||||
c.statefulCmdable = c.Process
|
||||
c.initHooks(hooks{
|
||||
|
Reference in New Issue
Block a user