1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-24 19:32:57 +03:00

address pr comments

This commit is contained in:
Nedyalko Dyakov
2025-10-21 18:01:14 +03:00
parent 528f2e92a9
commit 1ee52937e2
9 changed files with 142 additions and 93 deletions

View File

@@ -27,11 +27,7 @@ type ConnReAuthCredentialsListener struct {
// It calls the reAuth function with the new credentials. // It calls the reAuth function with the new credentials.
// If the reAuth function returns an error, it calls the onErr function with the error. // If the reAuth function returns an error, it calls the onErr function with the error.
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
if c.conn.IsClosed() { if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
return
}
if c.reAuth == nil {
return return
} }
@@ -41,17 +37,20 @@ func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
// The connection pool hook will re-authenticate the connection when it is // The connection pool hook will re-authenticate the connection when it is
// returned to the pool in a clean, idle state. // returned to the pool in a clean, idle state.
c.manager.MarkForReAuth(c.conn, func(err error) { c.manager.MarkForReAuth(c.conn, func(err error) {
// err is from connection acquisition (timeout, etc.)
if err != nil { if err != nil {
// Log the error
c.OnError(err) c.OnError(err)
return return
} }
// err is from reauth command execution
err = c.reAuth(c.conn, credentials) err = c.reAuth(c.conn, credentials)
if err != nil { if err != nil {
// Log the error
c.OnError(err) c.OnError(err)
return return
} }
}) })
} }
// OnError is called when an error occurs. // OnError is called when an error occurs.

View File

@@ -15,11 +15,13 @@ type Manager struct {
} }
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager { func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
return &Manager{ m := &Manager{
pool: pl, pool: pl,
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout), poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
credentialsListeners: NewCredentialsListeners(), credentialsListeners: NewCredentialsListeners(),
} }
m.poolHookRef.manager = m
return m
} }
func (m *Manager) PoolHook() pool.PoolHook { func (m *Manager) PoolHook() pool.PoolHook {
@@ -35,6 +37,10 @@ func (m *Manager) Listener(
return nil, errors.New("poolCn cannot be nil") return nil, errors.New("poolCn cannot be nil")
} }
connID := poolCn.GetID() connID := poolCn.GetID()
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
// so we can get the old listener from the cache and use it.
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
listener, ok := m.credentialsListeners.Get(connID) listener, ok := m.credentialsListeners.Get(connID)
if !ok || listener == nil { if !ok || listener == nil {
newCredListener := &ConnReAuthCredentialsListener{ newCredListener := &ConnReAuthCredentialsListener{
@@ -54,3 +60,7 @@ func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
connID := poolCn.GetID() connID := poolCn.GetID()
m.poolHookRef.MarkForReAuth(connID, reAuthFn) m.poolHookRef.MarkForReAuth(connID, reAuthFn)
} }
func (m *Manager) RemoveListener(connID uint64) {
m.credentialsListeners.Remove(connID)
}

View File

@@ -5,6 +5,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/pool"
) )
@@ -17,6 +18,9 @@ type ReAuthPoolHook struct {
// conn id -> bool // conn id -> bool
scheduledReAuth map[uint64]bool scheduledReAuth map[uint64]bool
scheduledLock sync.RWMutex scheduledLock sync.RWMutex
// for cleanup
manager *Manager
} }
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
@@ -32,7 +36,6 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
workers: workers, workers: workers,
reAuthTimeout: reAuthTimeout, reAuthTimeout: reAuthTimeout,
} }
} }
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
@@ -41,27 +44,22 @@ func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
r.shouldReAuth[connID] = reAuthFn r.shouldReAuth[connID] = reAuthFn
} }
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
r.shouldReAuthLock.Lock()
defer r.shouldReAuthLock.Unlock()
delete(r.shouldReAuth, connID)
}
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
connID := conn.GetID()
r.shouldReAuthLock.RLock() r.shouldReAuthLock.RLock()
_, ok := r.shouldReAuth[conn.GetID()] _, shouldReAuth := r.shouldReAuth[connID]
r.shouldReAuthLock.RUnlock() r.shouldReAuthLock.RUnlock()
// This connection was marked for reauth while in the pool, // This connection was marked for reauth while in the pool,
// reject the connection // reject the connection
if ok { if shouldReAuth {
// simply reject the connection, it will be re-authenticated in OnPut // simply reject the connection, it will be re-authenticated in OnPut
return false, nil return false, nil
} }
r.scheduledLock.RLock() r.scheduledLock.RLock()
hasScheduled, ok := r.scheduledReAuth[conn.GetID()] _, hasScheduled := r.scheduledReAuth[connID]
r.scheduledLock.RUnlock() r.scheduledLock.RUnlock()
// has scheduled reauth, reject the connection // has scheduled reauth, reject the connection
if ok && hasScheduled { if hasScheduled {
// simply reject the connection, it currently has a reauth scheduled // simply reject the connection, it currently has a reauth scheduled
// and the worker is waiting for slot to execute the reauth // and the worker is waiting for slot to execute the reauth
return false, nil return false, nil
@@ -70,22 +68,38 @@ func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (acce
} }
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
if conn == nil {
// noop
return true, false, nil
}
connID := conn.GetID()
// Check if reauth is needed and get the function with proper locking // Check if reauth is needed and get the function with proper locking
r.shouldReAuthLock.RLock() r.shouldReAuthLock.RLock()
reAuthFn, ok := r.shouldReAuth[conn.GetID()] r.scheduledLock.RLock()
reAuthFn, ok := r.shouldReAuth[connID]
r.shouldReAuthLock.RUnlock() r.shouldReAuthLock.RUnlock()
if ok { if ok {
r.shouldReAuthLock.Lock()
r.scheduledLock.Lock() r.scheduledLock.Lock()
r.scheduledReAuth[conn.GetID()] = true r.scheduledReAuth[connID] = true
delete(r.shouldReAuth, connID)
r.scheduledLock.Unlock() r.scheduledLock.Unlock()
// Clear the mark immediately to prevent duplicate reauth attempts r.shouldReAuthLock.Unlock()
r.ClearReAuthMark(conn.GetID())
go func() { go func() {
<-r.workers <-r.workers
// safety first
if conn == nil || (conn != nil && conn.IsClosed()) {
r.workers <- struct{}{}
return
}
defer func() { defer func() {
if rec := recover(); rec != nil {
// once again - safety first
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
}
r.scheduledLock.Lock() r.scheduledLock.Lock()
delete(r.scheduledReAuth, conn.GetID()) delete(r.scheduledReAuth, connID)
r.scheduledLock.Unlock() r.scheduledLock.Unlock()
r.workers <- struct{}{} r.workers <- struct{}{}
}() }()
@@ -96,7 +110,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
// Try to acquire the connection // Try to acquire the connection
// We need to ensure the connection is both Usable and not Used // We need to ensure the connection is both Usable and not Used
// to prevent data races with concurrent operations // to prevent data races with concurrent operations
const baseDelay = time.Microsecond const baseDelay = 10 * time.Microsecond
acquired := false acquired := false
attempt := 0 attempt := 0
for !acquired { for !acquired {
@@ -108,36 +122,33 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
return return
default: default:
// Try to acquire: set Usable=false, then check Used // Try to acquire: set Usable=false, then check Used
if conn.Usable.CompareAndSwap(true, false) { if conn.CompareAndSwapUsable(true, false) {
if !conn.Used.Load() { if !conn.IsUsed() {
acquired = true acquired = true
} else { } else {
// Release Usable and retry with exponential backoff // Release Usable and retry with exponential backoff
conn.Usable.Store(true) // todo(ndyakov): think of a better way to do this without the need
if attempt > 0 { // to release the connection, but just wait till it is not used
// Exponential backoff: 1, 2, 4, 8... up to 512 microseconds conn.SetUsable(true)
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
attempt++
}
} else {
// Connection not usable, retry with exponential backoff
if attempt > 0 {
// Exponential backoff: 1, 2, 4, 8... up to 512 microseconds
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
} }
}
if !acquired {
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
attempt++ attempt++
} }
} }
} }
// Successfully acquired the connection, perform reauth // safety first
reAuthFn(nil) if !conn.IsClosed() {
// Successfully acquired the connection, perform reauth
reAuthFn(nil)
}
// Release the connection // Release the connection
conn.Usable.Store(true) conn.SetUsable(true)
}() }()
} }
@@ -147,10 +158,16 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
} }
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) { func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
connID := conn.GetID()
r.shouldReAuthLock.Lock()
r.scheduledLock.Lock() r.scheduledLock.Lock()
delete(r.scheduledReAuth, conn.GetID()) delete(r.scheduledReAuth, connID)
delete(r.shouldReAuth, connID)
r.scheduledLock.Unlock() r.scheduledLock.Unlock()
r.ClearReAuthMark(conn.GetID()) r.shouldReAuthLock.Unlock()
if r.manager != nil {
r.manager.RemoveListener(connID)
}
} }
var _ pool.PoolHook = (*ReAuthPoolHook)(nil) var _ pool.PoolHook = (*ReAuthPoolHook)(nil)

View File

@@ -68,15 +68,15 @@ type Conn struct {
// is not in use. That way, the connection won't be used to send multiple commands at the same time and // is not in use. That way, the connection won't be used to send multiple commands at the same time and
// potentially corrupt the command stream. // potentially corrupt the command stream.
// Usable flag to mark connection as safe for use // usable flag to mark connection as safe for use
// It is false before initialization and after a handoff is marked // It is false before initialization and after a handoff is marked
// It will be false during other background operations like re-authentication // It will be false during other background operations like re-authentication
Usable atomic.Bool usable atomic.Bool
// Used flag to mark connection as used when a command is going to be // used flag to mark connection as used when a command is going to be
// processed on that connection. This is used to prevent a race condition with // processed on that connection. This is used to prevent a race condition with
// background operations that may execute commands, like re-authentication. // background operations that may execute commands, like re-authentication.
Used atomic.Bool used atomic.Bool
// Inited flag to mark connection as initialized, this is almost the same as usable // Inited flag to mark connection as initialized, this is almost the same as usable
// but it is used to make sure we don't initialize a network connection twice // but it is used to make sure we don't initialize a network connection twice
@@ -142,7 +142,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic state // Initialize atomic state
cn.Usable.Store(false) // false initially, set to true after initialization cn.usable.Store(false) // false initially, set to true after initialization
cn.handoffRetriesAtomic.Store(0) // 0 initially cn.handoffRetriesAtomic.Store(0) // 0 initially
// Initialize handoff state atomically // Initialize handoff state atomically
@@ -167,6 +167,42 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix()) atomic.StoreInt64(&cn.usedAt, tm.Unix())
} }
// Usable
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
return cn.usable.CompareAndSwap(old, new)
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.usable.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
// prefer CompareAndSwapUsable() when possible
func (cn *Conn) SetUsable(usable bool) {
cn.usable.Store(usable)
}
// Used
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
return cn.used.CompareAndSwap(old, new)
}
// IsUsed returns true if the connection is currently in use (lock-free).
func (cn *Conn) IsUsed() bool {
return cn.used.Load()
}
// SetUsed sets the used flag for the connection (lock-free).
// prefer CompareAndSwapUsed() when possible
func (cn *Conn) SetUsed(val bool) {
cn.used.Store(val)
}
// getNetConn returns the current network connection using atomic load (lock-free). // getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead. // This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn { func (cn *Conn) getNetConn() net.Conn {
@@ -184,18 +220,6 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
} }
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.Usable.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.Usable.Store(usable)
}
// getHandoffState returns the current handoff state atomically (lock-free). // getHandoffState returns the current handoff state atomically (lock-free).
func (cn *Conn) getHandoffState() *HandoffState { func (cn *Conn) getHandoffState() *HandoffState {
state := cn.handoffStateAtomic.Load() state := cn.handoffStateAtomic.Load()
@@ -240,11 +264,6 @@ func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta))) return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
} }
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put. // IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
func (cn *Conn) IsPooled() bool { func (cn *Conn) IsPooled() bool {
return cn.pooled return cn.pooled
@@ -259,11 +278,6 @@ func (cn *Conn) IsInited() bool {
return cn.Inited.Load() return cn.Inited.Load()
} }
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades. // SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires. // These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access. // Uses atomic operations for lock-free access.
@@ -494,11 +508,10 @@ func (cn *Conn) MarkQueuedForHandoff() error {
// first we need to mark the connection as not usable // first we need to mark the connection as not usable
// to prevent the pool from returning it to the caller // to prevent the pool from returning it to the caller
if !connAcquired { if !connAcquired {
if cn.Usable.CompareAndSwap(true, false) { if !cn.usable.CompareAndSwap(true, false) {
connAcquired = true
} else {
continue continue
} }
connAcquired = true
} }
currentState := cn.getHandoffState() currentState := cn.getHandoffState()
@@ -568,7 +581,7 @@ func (cn *Conn) ClearHandoffState() {
cn.setHandoffState(cleanState) cn.setHandoffState(cleanState)
cn.setHandoffRetries(0) cn.setHandoffRetries(0)
// Clearing handoff state also means the connection is usable again // Clearing handoff state also means the connection is usable again
cn.setUsable(true) cn.SetUsable(true)
} }
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free). // IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).

View File

@@ -22,7 +22,13 @@ type PoolHook interface {
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
// OnRemove is called when a connection is removed from the pool. // OnRemove is called when a connection is removed from the pool.
// It can be used for cleanup or logging purposes. // This happens when:
// - Connection fails health check
// - Connection exceeds max lifetime
// - Pool is being closed
// - Connection encounters an error
// Implementations should clean up any per-connection state.
// The reason parameter indicates why the connection was removed.
OnRemove(ctx context.Context, conn *Conn, reason error) OnRemove(ctx context.Context, conn *Conn, reason error)
} }

View File

@@ -435,7 +435,8 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
now := time.Now() now := time.Now()
attempts := 0 attempts := 0
// get hooks manager // Get hooks manager once for this getConn call for performance.
// Note: Hooks added/removed during this call won't be reflected.
p.hookManagerMu.RLock() p.hookManagerMu.RLock()
hookManager := p.hookManager hookManager := p.hookManager
p.hookManagerMu.RUnlock() p.hookManagerMu.RUnlock()
@@ -580,11 +581,12 @@ func (p *ConnPool) popIdle() (*Conn, error) {
} }
attempts++ attempts++
if cn.IsUsable() { if cn.CompareAndSwapUsed(false, true) {
if cn.Used.CompareAndSwap(false, true) { if cn.IsUsable() {
p.idleConnsLen.Add(-1) p.idleConnsLen.Add(-1)
break break
} }
cn.SetUsed(false)
} }
// Connection is not usable, put it back in the pool // Connection is not usable, put it back in the pool
@@ -679,10 +681,9 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
shouldCloseConn = true shouldCloseConn = true
} }
// Mark connection as not used only // if the connection is not going to be closed, mark it as not used
// if it's not being closed
if !shouldCloseConn { if !shouldCloseConn {
cn.Used.Store(false) cn.SetUsed(false)
} }
p.freeTurn() p.freeTurn()

View File

@@ -28,30 +28,30 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn) return p.pool.CloseConn(cn)
} }
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
if p.stickyErr != nil { if p.stickyErr != nil {
return nil, p.stickyErr return nil, p.stickyErr
} }
if p.cn == nil { if p.cn == nil {
return nil, ErrClosed return nil, ErrClosed
} }
p.cn.Used.Store(true) p.cn.SetUsed(true)
p.cn.SetUsedAt(time.Now()) p.cn.SetUsedAt(time.Now())
return p.cn, nil return p.cn, nil
} }
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) { func (p *SingleConnPool) Put(_ context.Context, cn *Conn) {
if p.cn == nil { if p.cn == nil {
return return
} }
if p.cn != cn { if p.cn != cn {
return return
} }
p.cn.Used.Store(false) p.cn.SetUsed(false)
} }
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
cn.Used.Store(false) cn.SetUsed(false)
p.cn = nil p.cn = nil
p.stickyErr = reason p.stickyErr = reason
} }
@@ -76,6 +76,6 @@ func (p *SingleConnPool) Stats() *Stats {
return &Stats{} return &Stats{}
} }
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {} func (p *SingleConnPool) AddPoolHook(_ PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {} func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {}

View File

@@ -481,7 +481,10 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
} }
} else { } else {
conn.Close() err := conn.Close() // Close the connection if no pool provided
if err != nil {
internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err)
}
if internal.LogLevel.WarnOrAbove() { if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
} }

View File

@@ -513,7 +513,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// mark the connection as usable and inited // mark the connection as usable and inited
// once returned to the pool as idle, this connection can be used by other clients // once returned to the pool as idle, this connection can be used by other clients
cn.SetUsable(true) cn.SetUsable(true)
cn.Used.Store(false) cn.SetUsed(false)
cn.Inited.Store(true) cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections // Set the connection initialization function for potential reconnections