From 5fe0bfa0ffe820a9c44281c15b2725dc7fb80139 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 14 Oct 2025 21:02:33 +0300 Subject: [PATCH] fix(pool): wip, pool reauth should not interfere with handoff --- auth/conn_reauth_credentials_listener.go | 87 ++++++++++++++++++++++++ auth/reauth_credentials_listener.go | 2 +- error.go | 2 + internal/pool/conn.go | 26 ++++--- internal/pool/pool.go | 3 + redis.go | 60 +++++++++++++--- 6 files changed, 158 insertions(+), 22 deletions(-) create mode 100644 auth/conn_reauth_credentials_listener.go diff --git a/auth/conn_reauth_credentials_listener.go b/auth/conn_reauth_credentials_listener.go new file mode 100644 index 00000000..c111c7c4 --- /dev/null +++ b/auth/conn_reauth_credentials_listener.go @@ -0,0 +1,87 @@ +package auth + +import ( + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface. +// It is used to re-authenticate the credentials when they are updated. +// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks. +// It contains: +// - reAuth: a function that takes the new credentials and returns an error if any. +// - onErr: a function that takes an error and handles it. +// - conn: the connection to re-authenticate. +type ConnReAuthCredentialsListener struct { + reAuth func(conn *pool.Conn, credentials Credentials) error + onErr func(conn *pool.Conn, err error) + conn *pool.Conn +} + +// OnNext is called when the credentials are updated. +// It calls the reAuth function with the new credentials. +// If the reAuth function returns an error, it calls the onErr function with the error. +func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) { + if c.conn.IsClosed() { + return + } + + if c.reAuth == nil { + return + } + + var err error + timeout := time.After(1 * time.Second) + // wait for the connection to be usable + // this is important because the connection pool may be in the process of reconnecting the connection + // and we don't want to interfere with that process + // but we also don't want to block for too long, so incorporate a timeout + for { + // we were able to mark the connection as unusable + if c.conn.Usable.CompareAndSwap(true, false) { + break + } + + select { + case <-timeout: + err = pool.ErrConnUnusableTimeout + break + default: + } + } + if err != nil { + c.OnError(err) + return + } + // we set the usable flag, so restore it back to usable after we're done + defer c.conn.SetUsable(true) + + err = c.reAuth(c.conn, credentials) + if err != nil { + c.OnError(err) + } +} + +// OnError is called when an error occurs. +// It can be called from both the credentials provider and the reAuth function. +func (c *ConnReAuthCredentialsListener) OnError(err error) { + if c.onErr == nil { + return + } + + c.onErr(c.conn, err) +} + +// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener. +// Implements the auth.CredentialsListener interface. +func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener { + return &ConnReAuthCredentialsListener{ + conn: conn, + reAuth: reAuth, + onErr: onErr, + } +} + +// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. +var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil) diff --git a/auth/reauth_credentials_listener.go b/auth/reauth_credentials_listener.go index 40076a0b..f4b31983 100644 --- a/auth/reauth_credentials_listener.go +++ b/auth/reauth_credentials_listener.go @@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on } // Ensure ReAuthCredentialsListener implements the CredentialsListener interface. -var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) \ No newline at end of file diff --git a/error.go b/error.go index 8013de44..be9cf1a2 100644 --- a/error.go +++ b/error.go @@ -112,6 +112,8 @@ func isBadConn(err error, allowTimeout bool, addr string) bool { return false case context.Canceled, context.DeadlineExceeded: return true + case pool.ErrConnUnusableTimeout: + return true } if isRedisError(err) { diff --git a/internal/pool/conn.go b/internal/pool/conn.go index e4780546..e2f6f8f3 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -40,6 +40,9 @@ func generateConnID() uint64 { } type Conn struct { + // Connection identifier for unique tracking + id uint64 // Unique numeric identifier for this connection + usedAt int64 // atomic // Lock-free netConn access using atomic.Value @@ -54,7 +57,9 @@ type Conn struct { // Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe readerMu sync.RWMutex - Inited atomic.Bool + Usable atomic.Bool + Inited atomic.Bool + pooled bool pubsub bool closed atomic.Bool @@ -75,18 +80,14 @@ type Conn struct { // Connection initialization function for reconnections initConnFunc func(context.Context, *Conn) error - // Connection identifier for unique tracking - id uint64 // Unique numeric identifier for this connection - // Handoff state - using atomic operations for lock-free access - usableAtomic atomic.Bool // Connection usability state handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts // Atomic handoff state to prevent race conditions // Stores *HandoffState to ensure atomic updates of all handoff-related fields handoffStateAtomic atomic.Value // stores *HandoffState - onClose func() error + onClose func() error } func NewConn(netConn net.Conn) *Conn { @@ -116,7 +117,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) // Initialize atomic state - cn.usableAtomic.Store(false) // false initially, set to true after initialization + cn.Usable.Store(false) // false initially, set to true after initialization cn.handoffRetriesAtomic.Store(0) // 0 initially // Initialize handoff state atomically @@ -162,12 +163,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) { // isUsable returns true if the connection is safe to use (lock-free). func (cn *Conn) isUsable() bool { - return cn.usableAtomic.Load() + return cn.Usable.Load() } // setUsable sets the usable flag atomically (lock-free). func (cn *Conn) setUsable(usable bool) { - cn.usableAtomic.Store(usable) + cn.Usable.Store(usable) } // getHandoffState returns the current handoff state atomically (lock-free). @@ -456,6 +457,12 @@ func (cn *Conn) MarkQueuedForHandoff() error { const baseDelay = time.Microsecond for attempt := 0; attempt < maxRetries; attempt++ { + // first we need to mark the connection as not usable + // to prevent the pool from returning it to the caller + if !cn.Usable.CompareAndSwap(true, false) { + continue + } + currentState := cn.getHandoffState() // Check if marked for handoff @@ -472,7 +479,6 @@ func (cn *Conn) MarkQueuedForHandoff() error { // Atomic compare-and-swap to update state if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) { - cn.setUsable(false) return nil } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 88d8105e..83acb6ff 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -24,6 +24,9 @@ var ( // ErrPoolTimeout timed out waiting to get a connection from the connection pool. ErrPoolTimeout = errors.New("redis: connection pool timeout") + // ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable. + ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable") + // popAttempts is the maximum number of attempts to find a usable connection // when popping from the idle connection pool. This handles cases where connections // are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues). diff --git a/redis.go b/redis.go index b308263e..cf522a06 100644 --- a/redis.go +++ b/redis.go @@ -224,6 +224,9 @@ type baseClient struct { // Maintenance notifications manager maintNotificationsManager *maintnotifications.Manager maintNotificationsManagerLock sync.RWMutex + + credListeners map[uint64]auth.CredentialsListener + credListenersLock sync.RWMutex } func (c *baseClient) clone() *baseClient { @@ -237,6 +240,7 @@ func (c *baseClient) clone() *baseClient { onClose: c.onClose, pushProcessor: c.pushProcessor, maintNotificationsManager: maintNotificationsManager, + credListeners: c.credListeners, } return clone } @@ -296,18 +300,43 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener { - return auth.NewReAuthCredentialsListener( - c.reAuthConnection(poolCn), - c.onAuthenticationErr(poolCn), +// connReAuthCredentialsListener returns a credentials listener that can be used to re-authenticate the connection. +// The credentials listener is stored in a map, so that it can be reused for multiple connections. +// The credentials listener is removed from the map when the connection is closed. +func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.CredentialsListener, func()) { + c.credListenersLock.RLock() + credListener, ok := c.credListeners[poolCn.GetID()] + c.credListenersLock.RUnlock() + if ok { + return credListener.(auth.CredentialsListener), func() { + c.removeCredListener(poolCn) + } + } + c.credListenersLock.Lock() + defer c.credListenersLock.Unlock() + newCredListener := auth.NewConnReAuthCredentialsListener( + poolCn, + c.reAuthConnection(), + c.onAuthenticationErr(), ) + c.credListeners[poolCn.GetID()] = newCredListener + return newCredListener, func() { + c.removeCredListener(poolCn) + } } -func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error { - return func(credentials auth.Credentials) error { +func (c *baseClient) removeCredListener(poolCn *pool.Conn) { + c.credListenersLock.Lock() + defer c.credListenersLock.Unlock() + delete(c.credListeners, poolCn.GetID()) +} + +func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error { + return func(poolCn *pool.Conn, credentials auth.Credentials) error { var err error username, password := credentials.BasicAuth() ctx := context.Background() + connPool := pool.NewSingleConnPool(c.connPool, poolCn) // hooksMixin are intentionally empty here cn := newConn(c.opt, connPool, nil) @@ -320,8 +349,8 @@ func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.C return err } } -func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) { - return func(err error) { +func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { + return func(poolCn *pool.Conn, err error) { if err != nil { if isBadConn(err, false, c.opt.Addr) { // Close the connection to force a reconnection. @@ -372,13 +401,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { + credListener, removeCredListener := c.connReAuthCredentialsListener(cn) credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. - Subscribe(c.newReAuthCredentialsListener(cn)) + Subscribe(credListener) if err != nil { return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) } - c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider) - cn.SetOnClose(unsubscribeFromCredentialsProvider) + + unsubscribe := func() error { + removeCredListener() + return unsubscribeFromCredentialsProvider() + } + c.onClose = c.wrappedOnClose(unsubscribe) + cn.SetOnClose(unsubscribe) + username, password = credentials.BasicAuth() } else if c.opt.CredentialsProviderContext != nil { username, password, err = c.opt.CredentialsProviderContext(ctx) @@ -496,6 +532,8 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + // mark the connection as usable and inited + // once returned to the pool as idle, this connection can be used by other clients cn.SetUsable(true) cn.Inited.Store(true)