diff --git a/auth/auth.go b/auth/auth.go index dcfd09eb..ae9310e0 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -9,6 +9,7 @@ type StreamingCredentialsProvider interface { // Subscribe subscribes to the credentials provider for updates. // It returns the current credentials, a cancel function to unsubscribe from the provider, // and an error if any. + // TODO(ndyakov): Should we add context to the Subscribe method? Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error) } diff --git a/auth/reauth_credentials_listener.go b/auth/reauth_credentials_listener.go new file mode 100644 index 00000000..12eb2956 --- /dev/null +++ b/auth/reauth_credentials_listener.go @@ -0,0 +1,45 @@ +package auth + +// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface. +// It is used to re-authenticate the credentials when they are updated. +// It contains: +// - reAuth: a function that takes the new credentials and returns an error if any. +// - onErr: a function that takes an error and handles it. +type ReAuthCredentialsListener struct { + reAuth func(credentials Credentials) error + onErr func(err error) +} + +// OnNext is called when the credentials are updated. +// It calls the reAuth function with the new credentials. +// If the reAuth function returns an error, it calls the onErr function with the error. +func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) { + if c.reAuth != nil { + err := c.reAuth(credentials) + if err != nil { + if c.onErr != nil { + c.onErr(err) + } + } + } +} + +// OnError is called when an error occurs. +// It can be called from both the credentials provider and the reAuth function. +func (c *ReAuthCredentialsListener) OnError(err error) { + if c.onErr != nil { + c.onErr(err) + } +} + +// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener. +// Implements the auth.CredentialsListener interface. +func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener { + return &ReAuthCredentialsListener{ + reAuth: reAuth, + onErr: onErr, + } +} + +// Ensure ReAuthCredentialsListener implements the CredentialsListener interface. +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) diff --git a/internal_test.go b/internal_test.go index a6317196..c2cbff70 100644 --- a/internal_test.go +++ b/internal_test.go @@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) { }, NewClient: func(opt *Options) *Client { c := NewClient(opt) - c.baseClient.onClose = func() error { + c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error { closeCounter.increment(opt.Addr) return nil - } + }) return c }, }) @@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) { } createCounter.increment(opt.Addr) c := NewClient(opt) - c.baseClient.onClose = func() error { + c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error { closeCounter.increment(opt.Addr) return nil - } + }) return c }, }) diff --git a/redis.go b/redis.go index 2026ff9d..94de3fc7 100644 --- a/redis.go +++ b/redis.go @@ -283,15 +283,57 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) reAuth(ctx context.Context, cn *Conn, credentials auth.Credentials) error { - var err error - username, password := credentials.BasicAuth() - if username != "" { - err = cn.AuthACL(ctx, username, password).Err() - } else { - err = cn.Auth(ctx, password).Err() +func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, conn *Conn) auth.CredentialsListener { + return auth.NewReAuthCredentialsListener( + c.reAuthConnection(c.context(ctx), conn), + c.onAuthenticationErr(c.context(ctx), conn), + ) +} + +func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error { + return func(credentials auth.Credentials) error { + var err error + username, password := credentials.BasicAuth() + if username != "" { + err = cn.AuthACL(ctx, username, password).Err() + } else { + err = cn.Auth(ctx, password).Err() + } + return err + } +} +func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) { + return func(err error) { + // since the connection pool of the *Conn will actually return us the underlying pool.Conn, + // we can get it from the *Conn and remove it from the clients pool. + if err != nil { + if isBadConn(err, false, c.opt.Addr) { + poolCn, _ := cn.connPool.Get(ctx) + c.connPool.Remove(ctx, poolCn, err) + } + } + } +} + +func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { + onClose := c.onClose + return func() error { + var firstErr error + err := newOnClose() + // Even if we have an error we would like to execute the onClose hook + // if it exists. We will return the first error that occurred. + // This is to keep error handling consistent with the rest of the code. + if err != nil { + firstErr = err + } + if onClose != nil { + err = onClose() + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr } - return err } func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { @@ -312,7 +354,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { var authenticated bool username, password := c.opt.Username, c.opt.Password - if c.opt.CredentialsProviderContext != nil { + if c.opt.StreamingCredentialsProvider != nil { + credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider. + Subscribe(c.newReAuthCredentialsListener(ctx, conn)) + if err != nil { + return err + } + c.onClose = c.wrappedOnClose(cancelCredentialsProvider) + username, password = credentials.BasicAuth() + } else if c.opt.CredentialsProviderContext != nil { if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { return err } @@ -336,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } if !authenticated && password != "" { - err = c.reAuth(ctx, conn, auth.NewBasicCredentials(username, password)) + err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password)) if err != nil { return err } diff --git a/sentinel.go b/sentinel.go index a4c9f53c..55346735 100644 --- a/sentinel.go +++ b/sentinel.go @@ -257,7 +257,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { connPool = newConnPool(opt, rdb.dialHook) rdb.connPool = connPool - rdb.onClose = failover.Close + rdb.onClose = rdb.wrappedOnClose(failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) {