1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-20 09:52:25 +03:00

fix(pool): wip, pool reauth should not interfere with handoff

This commit is contained in:
Nedyalko Dyakov
2025-10-14 21:02:33 +03:00
parent 3ad9f9cb23
commit 5fe0bfa0ff
6 changed files with 158 additions and 22 deletions

View File

@@ -0,0 +1,87 @@
package auth
import (
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
// It is used to re-authenticate the credentials when they are updated.
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
// It contains:
// - reAuth: a function that takes the new credentials and returns an error if any.
// - onErr: a function that takes an error and handles it.
// - conn: the connection to re-authenticate.
type ConnReAuthCredentialsListener struct {
reAuth func(conn *pool.Conn, credentials Credentials) error
onErr func(conn *pool.Conn, err error)
conn *pool.Conn
}
// OnNext is called when the credentials are updated.
// It calls the reAuth function with the new credentials.
// If the reAuth function returns an error, it calls the onErr function with the error.
func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
if c.conn.IsClosed() {
return
}
if c.reAuth == nil {
return
}
var err error
timeout := time.After(1 * time.Second)
// wait for the connection to be usable
// this is important because the connection pool may be in the process of reconnecting the connection
// and we don't want to interfere with that process
// but we also don't want to block for too long, so incorporate a timeout
for {
// we were able to mark the connection as unusable
if c.conn.Usable.CompareAndSwap(true, false) {
break
}
select {
case <-timeout:
err = pool.ErrConnUnusableTimeout
break
default:
}
}
if err != nil {
c.OnError(err)
return
}
// we set the usable flag, so restore it back to usable after we're done
defer c.conn.SetUsable(true)
err = c.reAuth(c.conn, credentials)
if err != nil {
c.OnError(err)
}
}
// OnError is called when an error occurs.
// It can be called from both the credentials provider and the reAuth function.
func (c *ConnReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}
c.onErr(c.conn, err)
}
// NewConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
// Implements the auth.CredentialsListener interface.
func NewConnReAuthCredentialsListener(conn *pool.Conn, reAuth func(conn *pool.Conn, credentials Credentials) error, onErr func(conn *pool.Conn, err error)) *ConnReAuthCredentialsListener {
return &ConnReAuthCredentialsListener{
conn: conn,
reAuth: reAuth,
onErr: onErr,
}
}
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ConnReAuthCredentialsListener)(nil)

View File

@@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on
}
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)

View File

@@ -112,6 +112,8 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
return false
case context.Canceled, context.DeadlineExceeded:
return true
case pool.ErrConnUnusableTimeout:
return true
}
if isRedisError(err) {

View File

@@ -40,6 +40,9 @@ func generateConnID() uint64 {
}
type Conn struct {
// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection
usedAt int64 // atomic
// Lock-free netConn access using atomic.Value
@@ -54,7 +57,9 @@ type Conn struct {
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
Usable atomic.Bool
Inited atomic.Bool
pooled bool
pubsub bool
closed atomic.Bool
@@ -75,18 +80,14 @@ type Conn struct {
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// Atomic handoff state to prevent race conditions
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
handoffStateAtomic atomic.Value // stores *HandoffState
onClose func() error
onClose func() error
}
func NewConn(netConn net.Conn) *Conn {
@@ -116,7 +117,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.Usable.Store(false) // false initially, set to true after initialization
cn.handoffRetriesAtomic.Store(0) // 0 initially
// Initialize handoff state atomically
@@ -162,12 +163,12 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
return cn.Usable.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
cn.Usable.Store(usable)
}
// getHandoffState returns the current handoff state atomically (lock-free).
@@ -456,6 +457,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {
const baseDelay = time.Microsecond
for attempt := 0; attempt < maxRetries; attempt++ {
// first we need to mark the connection as not usable
// to prevent the pool from returning it to the caller
if !cn.Usable.CompareAndSwap(true, false) {
continue
}
currentState := cn.getHandoffState()
// Check if marked for handoff
@@ -472,7 +479,6 @@ func (cn *Conn) MarkQueuedForHandoff() error {
// Atomic compare-and-swap to update state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
cn.setUsable(false)
return nil
}

View File

@@ -24,6 +24,9 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
// popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).

View File

@@ -224,6 +224,9 @@ type baseClient struct {
// Maintenance notifications manager
maintNotificationsManager *maintnotifications.Manager
maintNotificationsManagerLock sync.RWMutex
credListeners map[uint64]auth.CredentialsListener
credListenersLock sync.RWMutex
}
func (c *baseClient) clone() *baseClient {
@@ -237,6 +240,7 @@ func (c *baseClient) clone() *baseClient {
onClose: c.onClose,
pushProcessor: c.pushProcessor,
maintNotificationsManager: maintNotificationsManager,
credListeners: c.credListeners,
}
return clone
}
@@ -296,18 +300,43 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
return auth.NewReAuthCredentialsListener(
c.reAuthConnection(poolCn),
c.onAuthenticationErr(poolCn),
// connReAuthCredentialsListener returns a credentials listener that can be used to re-authenticate the connection.
// The credentials listener is stored in a map, so that it can be reused for multiple connections.
// The credentials listener is removed from the map when the connection is closed.
func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.CredentialsListener, func()) {
c.credListenersLock.RLock()
credListener, ok := c.credListeners[poolCn.GetID()]
c.credListenersLock.RUnlock()
if ok {
return credListener.(auth.CredentialsListener), func() {
c.removeCredListener(poolCn)
}
}
c.credListenersLock.Lock()
defer c.credListenersLock.Unlock()
newCredListener := auth.NewConnReAuthCredentialsListener(
poolCn,
c.reAuthConnection(),
c.onAuthenticationErr(),
)
c.credListeners[poolCn.GetID()] = newCredListener
return newCredListener, func() {
c.removeCredListener(poolCn)
}
}
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
return func(credentials auth.Credentials) error {
func (c *baseClient) removeCredListener(poolCn *pool.Conn) {
c.credListenersLock.Lock()
defer c.credListenersLock.Unlock()
delete(c.credListeners, poolCn.GetID())
}
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
var err error
username, password := credentials.BasicAuth()
ctx := context.Background()
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
// hooksMixin are intentionally empty here
cn := newConn(c.opt, connPool, nil)
@@ -320,8 +349,8 @@ func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.C
return err
}
}
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
return func(err error) {
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
return func(poolCn *pool.Conn, err error) {
if err != nil {
if isBadConn(err, false, c.opt.Addr) {
// Close the connection to force a reconnection.
@@ -372,13 +401,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credListener, removeCredListener := c.connReAuthCredentialsListener(cn)
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(c.newReAuthCredentialsListener(cn))
Subscribe(credListener)
if err != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
}
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
cn.SetOnClose(unsubscribeFromCredentialsProvider)
unsubscribe := func() error {
removeCredListener()
return unsubscribeFromCredentialsProvider()
}
c.onClose = c.wrappedOnClose(unsubscribe)
cn.SetOnClose(unsubscribe)
username, password = credentials.BasicAuth()
} else if c.opt.CredentialsProviderContext != nil {
username, password, err = c.opt.CredentialsProviderContext(ctx)
@@ -496,6 +532,8 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
}
}
// mark the connection as usable and inited
// once returned to the pool as idle, this connection can be used by other clients
cn.SetUsable(true)
cn.Inited.Store(true)