1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-27 18:15:32 +03:00

fix(pool): Pool ReAuth should not interfere with handoff (#3547)

* fix(pool): wip, pool reauth should not interfere with handoff

* fix credListeners map

* fix race in tests

* better conn usable timeout

* add design decision comment

* few small improvements

* update marked as queued

* add Used to clarify the state of the conn

* rename test

* fix(test): fix flaky test

* lock inside the listeners collection

* address pr comments

* Update internal/auth/cred_listeners.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/buffer_size_test.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* wip refactor entraid

* fix maintnotif pool hook

* fix mocks

* fix nil listener

* sync and async reauth based on conn lifecycle

* be able to reject connection OnGet

* pass hooks so the tests can observe reauth

* give some time for the background to execute commands

* fix tests

* only async reauth

* Update internal/pool/pool.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/auth/streaming/pool_hook.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update internal/pool/conn.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* chore(redisotel): use metric.WithAttributeSet to avoid copy (#3552)

In order to improve performance replace `WithAttributes` with `WithAttributeSet`.
This avoids the slice allocation and copy that is done in `WithAttributes`.

For more information see https://github.com/open-telemetry/opentelemetry-go/blob/v1.38.0/metric/instrument.go#L357-L376

* chore(docs): explain why MaxRetries is disabled for ClusterClient (#3551)

Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com>

* exponential backoff

* address pr comments

* address pr comments

* remove rlock

* add some comments

* add comments

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Warnar Boekkooi <wboekkooi@impossiblecloud.com>
Co-authored-by: Justin <justindsouza80@gmail.com>
This commit is contained in:
Nedyalko Dyakov
2025-10-22 12:45:30 +03:00
committed by GitHub
parent 14a8814540
commit a15e76394c
23 changed files with 1138 additions and 143 deletions

View File

@@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on
}
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)

View File

@@ -108,10 +108,12 @@ func isRedisError(err error) bool {
func isBadConn(err error, allowTimeout bool, addr string) bool {
switch err {
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case pool.ErrConnUnusableTimeout:
return true
}
if isRedisError(err) {

View File

@@ -0,0 +1,100 @@
package streaming
import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)
// ConnReAuthCredentialsListener is a credentials listener for a specific connection
// that triggers re-authentication when credentials change.
//
// This listener implements the auth.CredentialsListener interface and is subscribed
// to a StreamingCredentialsProvider. When new credentials are received via OnNext,
// it marks the connection for re-authentication through the manager.
//
// The re-authentication is always performed asynchronously to avoid blocking the
// credentials provider and to prevent potential deadlocks with the pool semaphore.
// The actual re-auth happens when the connection is returned to the pool in an idle state.
//
// Lifecycle:
// - Created during connection initialization via Manager.Listener()
// - Subscribed to the StreamingCredentialsProvider
// - Receives credential updates via OnNext()
// - Cleaned up when connection is removed from pool via Manager.RemoveListener()
type ConnReAuthCredentialsListener struct {
// reAuth is the function to re-authenticate the connection with new credentials
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
// onErr is the function to call when re-authentication or acquisition fails
onErr func(conn *pool.Conn, err error)
// conn is the connection this listener is associated with
conn *pool.Conn
// manager is the streaming credentials manager for coordinating re-auth
manager *Manager
}
// OnNext is called when new credentials are received from the StreamingCredentialsProvider.
//
// This method marks the connection for asynchronous re-authentication. The actual
// re-authentication happens in the background when the connection is returned to the
// pool and is in an idle state.
//
// Asynchronous re-auth is used to:
// - Avoid blocking the credentials provider's notification goroutine
// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes)
// - Ensure re-auth happens when the connection is safe to use (not processing commands)
//
// The reAuthFn callback receives:
// - nil if the connection was successfully acquired for re-auth
// - error if acquisition timed out or failed
//
// Thread-safe: Called by the credentials provider's notification goroutine.
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
return
}
// Always use async reauth to avoid complex pool semaphore issues
// The synchronous path can cause deadlocks in the pool's semaphore mechanism
// when called from the Subscribe goroutine, especially with small pool sizes.
// The connection pool hook will re-authenticate the connection when it is
// returned to the pool in a clean, idle state.
c.manager.MarkForReAuth(c.conn, func(err error) {
// err is from connection acquisition (timeout, etc.)
if err != nil {
// Log the error
c.OnError(err)
return
}
// err is from reauth command execution
err = c.reAuth(c.conn, credentials)
if err != nil {
// Log the error
c.OnError(err)
return
}
})
}
// OnError is called when an error occurs during credential streaming or re-authentication.
//
// This method can be called from:
// - The StreamingCredentialsProvider when there's an error in the credentials stream
// - The re-auth process when connection acquisition times out
// - The re-auth process when the AUTH command fails
//
// The error is delegated to the onErr callback provided during listener creation.
//
// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker).
func (c *ConnReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}
c.onErr(c.conn, err)
}
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)

View File

@@ -0,0 +1,77 @@
package streaming
import (
"sync"
"github.com/redis/go-redis/v9/auth"
)
// CredentialsListeners is a thread-safe collection of credentials listeners
// indexed by connection ID.
//
// This collection is used by the Manager to maintain a registry of listeners
// for each connection in the pool. Listeners are reused when connections are
// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions
// to the StreamingCredentialsProvider.
//
// The collection supports concurrent access from multiple goroutines during
// connection initialization, credential updates, and connection removal.
type CredentialsListeners struct {
// listeners maps connection ID to credentials listener
listeners map[uint64]auth.CredentialsListener
// lock protects concurrent access to the listeners map
lock sync.RWMutex
}
// NewCredentialsListeners creates a new thread-safe credentials listeners collection.
func NewCredentialsListeners() *CredentialsListeners {
return &CredentialsListeners{
listeners: make(map[uint64]auth.CredentialsListener),
}
}
// Add adds or updates a credentials listener for a connection.
//
// If a listener already exists for the connection ID, it is replaced.
// This is safe because the old listener should have been unsubscribed
// before the connection was reinitialized.
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
c.lock.Lock()
defer c.lock.Unlock()
if c.listeners == nil {
c.listeners = make(map[uint64]auth.CredentialsListener)
}
c.listeners[connID] = listener
}
// Get retrieves the credentials listener for a connection.
//
// Returns:
// - listener: The credentials listener for the connection, or nil if not found
// - ok: true if a listener exists for the connection ID, false otherwise
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
if len(c.listeners) == 0 {
return nil, false
}
listener, ok := c.listeners[connID]
return listener, ok
}
// Remove removes the credentials listener for a connection.
//
// This is called when a connection is removed from the pool to prevent
// memory leaks. If no listener exists for the connection ID, this is a no-op.
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (c *CredentialsListeners) Remove(connID uint64) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.listeners, connID)
}

View File

@@ -0,0 +1,137 @@
package streaming
import (
"errors"
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)
// Manager coordinates streaming credentials and re-authentication for a connection pool.
//
// The manager is responsible for:
// - Creating and managing per-connection credentials listeners
// - Providing the pool hook for re-authentication
// - Coordinating between credentials updates and pool operations
//
// When credentials change via a StreamingCredentialsProvider:
// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update
// 2. It calls MarkForReAuth on the manager
// 3. The manager delegates to the pool hook
// 4. The pool hook schedules background re-authentication
//
// The manager maintains a registry of credentials listeners indexed by connection ID,
// allowing listener reuse when connections are reinitialized (e.g., after handoff).
type Manager struct {
// credentialsListeners maps connection ID to credentials listener
credentialsListeners *CredentialsListeners
// pool is the connection pool being managed
pool pool.Pooler
// poolHookRef is the re-authentication pool hook
poolHookRef *ReAuthPoolHook
}
// NewManager creates a new streaming credentials manager.
//
// Parameters:
// - pl: The connection pool to manage
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
//
// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that
// re-auth operations don't exhaust the connection pool.
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
m := &Manager{
pool: pl,
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
credentialsListeners: NewCredentialsListeners(),
}
m.poolHookRef.manager = m
return m
}
// PoolHook returns the pool hook for re-authentication.
//
// This hook should be registered with the connection pool to enable
// automatic re-authentication when credentials change.
func (m *Manager) PoolHook() pool.PoolHook {
return m.poolHookRef
}
// Listener returns or creates a credentials listener for a connection.
//
// This method is called during connection initialization to set up the
// credentials listener. If a listener already exists for the connection ID
// (e.g., after a handoff), it is reused.
//
// Parameters:
// - poolCn: The connection to create/get a listener for
// - reAuth: Function to re-authenticate the connection with new credentials
// - onErr: Function to call when re-authentication fails
//
// Returns:
// - auth.CredentialsListener: The listener to subscribe to the credentials provider
// - error: Non-nil if poolCn is nil
//
// Note: The reAuth and onErr callbacks are captured once when the listener is
// created and reused for the connection's lifetime. They should not change.
//
// Thread-safe: Can be called concurrently during connection initialization.
func (m *Manager) Listener(
poolCn *pool.Conn,
reAuth func(*pool.Conn, auth.Credentials) error,
onErr func(*pool.Conn, error),
) (auth.CredentialsListener, error) {
if poolCn == nil {
return nil, errors.New("poolCn cannot be nil")
}
connID := poolCn.GetID()
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
// so we can get the old listener from the cache and use it.
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
listener, ok := m.credentialsListeners.Get(connID)
if !ok || listener == nil {
// Create new listener for this connection
// Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime
newCredListener := &ConnReAuthCredentialsListener{
conn: poolCn,
reAuth: reAuth,
onErr: onErr,
manager: m,
}
m.credentialsListeners.Add(connID, newCredListener)
listener = newCredListener
}
return listener, nil
}
// MarkForReAuth marks a connection for re-authentication.
//
// This method is called by the credentials listener when new credentials are
// received. It delegates to the pool hook to schedule background re-authentication.
//
// Parameters:
// - poolCn: The connection to re-authenticate
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
//
// Thread-safe: Called by credentials listeners when credentials change.
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
connID := poolCn.GetID()
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
}
// RemoveListener removes the credentials listener for a connection.
//
// This method is called by the pool hook's OnRemove to clean up listeners
// when connections are removed from the pool.
//
// Parameters:
// - connID: The connection ID whose listener should be removed
//
// Thread-safe: Called during connection removal.
func (m *Manager) RemoveListener(connID uint64) {
m.credentialsListeners.Remove(connID)
}

View File

@@ -0,0 +1,101 @@
package streaming
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)
// Test that Listener returns the newly created listener, not nil
func TestManager_Listener_ReturnsNewListener(t *testing.T) {
// Create a mock pool
mockPool := &mockPooler{}
// Create manager
manager := NewManager(mockPool, time.Second)
// Create a mock connection
conn := &pool.Conn{}
// Mock functions
reAuth := func(cn *pool.Conn, creds auth.Credentials) error {
return nil
}
onErr := func(cn *pool.Conn, err error) {
}
// Get listener - this should create a new one
listener, err := manager.Listener(conn, reAuth, onErr)
// Verify no error
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
// Verify listener is not nil (this was the bug!)
if listener == nil {
t.Fatal("Expected listener to be non-nil, but got nil")
}
// Verify it's the correct type
if _, ok := listener.(*ConnReAuthCredentialsListener); !ok {
t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener)
}
// Get the same listener again - should return the existing one
listener2, err := manager.Listener(conn, reAuth, onErr)
if err != nil {
t.Fatalf("Expected no error on second call, got: %v", err)
}
if listener2 == nil {
t.Fatal("Expected listener2 to be non-nil")
}
// Should be the same instance
if listener != listener2 {
t.Error("Expected to get the same listener instance on second call")
}
}
// Test that Listener returns error when conn is nil
func TestManager_Listener_NilConn(t *testing.T) {
mockPool := &mockPooler{}
manager := NewManager(mockPool, time.Second)
listener, err := manager.Listener(nil, nil, nil)
if err == nil {
t.Fatal("Expected error when conn is nil, got nil")
}
if listener != nil {
t.Error("Expected listener to be nil when error occurs")
}
expectedErr := "poolCn cannot be nil"
if err.Error() != expectedErr {
t.Errorf("Expected error message %q, got %q", expectedErr, err.Error())
}
}
// Mock pooler for testing
type mockPooler struct{}
func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil }
func (m *mockPooler) CloseConn(*pool.Conn) error { return nil }
func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil }
func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {}
func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {}
func (m *mockPooler) Len() int { return 0 }
func (m *mockPooler) IdleLen() int { return 0 }
func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} }
func (m *mockPooler) Size() int { return 10 }
func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {}
func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {}
func (m *mockPooler) Close() error { return nil }

View File

@@ -0,0 +1,259 @@
package streaming
import (
"context"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// ReAuthPoolHook is a pool hook that manages background re-authentication of connections
// when credentials change via a streaming credentials provider.
//
// The hook uses a semaphore-based worker pool to limit concurrent re-authentication
// operations and prevent pool exhaustion. When credentials change, connections are
// marked for re-authentication and processed asynchronously in the background.
//
// The re-authentication process:
// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth
// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth)
// 3. A worker goroutine acquires the connection (waits until it's not in use)
// 4. Executes the re-auth function while holding the connection
// 5. Releases the connection back to the pool
//
// The hook ensures that:
// - Only one re-auth operation runs per connection at a time
// - Connections are not used for commands during re-authentication
// - Re-auth operations timeout if they can't acquire the connection
// - Resources are properly cleaned up on connection removal
type ReAuthPoolHook struct {
// shouldReAuth maps connection ID to re-auth function
// Connections in this map need re-authentication but haven't been scheduled yet
shouldReAuth map[uint64]func(error)
shouldReAuthLock sync.RWMutex
// workers is a semaphore channel limiting concurrent re-auth operations
// Initialized with poolSize tokens to prevent pool exhaustion
workers chan struct{}
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
reAuthTimeout time.Duration
// scheduledReAuth maps connection ID to scheduled status
// Connections in this map have a background worker attempting re-authentication
scheduledReAuth map[uint64]bool
scheduledLock sync.RWMutex
// manager is a back-reference for cleanup operations
manager *Manager
}
// NewReAuthPoolHook creates a new re-authentication pool hook.
//
// Parameters:
// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size)
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
//
// The poolSize parameter is used to initialize the worker semaphore, ensuring that
// re-auth operations don't exhaust the connection pool.
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
workers := make(chan struct{}, poolSize)
// Initialize the workers channel with tokens (semaphore pattern)
for i := 0; i < poolSize; i++ {
workers <- struct{}{}
}
return &ReAuthPoolHook{
shouldReAuth: make(map[uint64]func(error)),
scheduledReAuth: make(map[uint64]bool),
workers: workers,
reAuthTimeout: reAuthTimeout,
}
}
// MarkForReAuth marks a connection for re-authentication.
//
// This method is called when credentials change and a connection needs to be
// re-authenticated. The actual re-authentication happens asynchronously when
// the connection is returned to the pool (in OnPut).
//
// Parameters:
// - connID: The connection ID to mark for re-authentication
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
r.shouldReAuthLock.Lock()
defer r.shouldReAuthLock.Unlock()
r.shouldReAuth[connID] = reAuthFn
}
// OnGet is called when a connection is retrieved from the pool.
//
// This hook checks if the connection needs re-authentication or has a scheduled
// re-auth operation. If so, it rejects the connection (returns accept=false),
// causing the pool to try another connection.
//
// Returns:
// - accept: false if connection needs re-auth, true otherwise
// - err: always nil (errors are not used in this hook)
//
// Thread-safe: Called concurrently by multiple goroutines getting connections.
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
connID := conn.GetID()
r.shouldReAuthLock.RLock()
_, shouldReAuth := r.shouldReAuth[connID]
r.shouldReAuthLock.RUnlock()
// This connection was marked for reauth while in the pool,
// reject the connection
if shouldReAuth {
// simply reject the connection, it will be re-authenticated in OnPut
return false, nil
}
r.scheduledLock.RLock()
_, hasScheduled := r.scheduledReAuth[connID]
r.scheduledLock.RUnlock()
// has scheduled reauth, reject the connection
if hasScheduled {
// simply reject the connection, it currently has a reauth scheduled
// and the worker is waiting for slot to execute the reauth
return false, nil
}
return true, nil
}
// OnPut is called when a connection is returned to the pool.
//
// This hook checks if the connection needs re-authentication. If so, it schedules
// a background goroutine to perform the re-auth asynchronously. The goroutine:
// 1. Waits for a worker slot (semaphore)
// 2. Acquires the connection (waits until not in use)
// 3. Executes the re-auth function
// 4. Releases the connection and worker slot
//
// The connection is always pooled (not removed) since re-auth happens in background.
//
// Returns:
// - shouldPool: always true (connection stays in pool during background re-auth)
// - shouldRemove: always false
// - err: always nil
//
// Thread-safe: Called concurrently by multiple goroutines returning connections.
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
if conn == nil {
// noop
return true, false, nil
}
connID := conn.GetID()
// Check if reauth is needed and get the function with proper locking
r.shouldReAuthLock.RLock()
reAuthFn, ok := r.shouldReAuth[connID]
r.shouldReAuthLock.RUnlock()
if ok {
// Acquire both locks to atomically move from shouldReAuth to scheduledReAuth
// This prevents race conditions where OnGet might miss the transition
r.shouldReAuthLock.Lock()
r.scheduledLock.Lock()
r.scheduledReAuth[connID] = true
delete(r.shouldReAuth, connID)
r.scheduledLock.Unlock()
r.shouldReAuthLock.Unlock()
go func() {
<-r.workers
// safety first
if conn == nil || (conn != nil && conn.IsClosed()) {
r.workers <- struct{}{}
return
}
defer func() {
if rec := recover(); rec != nil {
// once again - safety first
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
}
r.scheduledLock.Lock()
delete(r.scheduledReAuth, connID)
r.scheduledLock.Unlock()
r.workers <- struct{}{}
}()
var err error
timeout := time.After(r.reAuthTimeout)
// Try to acquire the connection
// We need to ensure the connection is both Usable and not Used
// to prevent data races with concurrent operations
const baseDelay = 10 * time.Microsecond
acquired := false
attempt := 0
for !acquired {
select {
case <-timeout:
// Timeout occurred, cannot acquire connection
err = pool.ErrConnUnusableTimeout
reAuthFn(err)
return
default:
// Try to acquire: set Usable=false, then check Used
if conn.CompareAndSwapUsable(true, false) {
if !conn.IsUsed() {
acquired = true
} else {
// Release Usable and retry with exponential backoff
// todo(ndyakov): think of a better way to do this without the need
// to release the connection, but just wait till it is not used
conn.SetUsable(true)
}
}
if !acquired {
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
attempt++
}
}
}
// safety first
if !conn.IsClosed() {
// Successfully acquired the connection, perform reauth
reAuthFn(nil)
}
// Release the connection
conn.SetUsable(true)
}()
}
// the reauth will happen in background, as far as the pool is concerned:
// pool the connection, don't remove it, no error
return true, false, nil
}
// OnRemove is called when a connection is removed from the pool.
//
// This hook cleans up all state associated with the connection:
// - Removes from shouldReAuth map (pending re-auth)
// - Removes from scheduledReAuth map (active re-auth)
// - Removes credentials listener from manager
//
// This prevents memory leaks and ensures that removed connections don't have
// lingering re-auth operations or listeners.
//
// Thread-safe: Called when connections are removed due to errors, timeouts, or pool closure.
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
connID := conn.GetID()
r.shouldReAuthLock.Lock()
r.scheduledLock.Lock()
delete(r.scheduledReAuth, connID)
delete(r.shouldReAuth, connID)
r.scheduledLock.Unlock()
r.shouldReAuthLock.Unlock()
if r.manager != nil {
r.manager.RemoveListener(connID)
}
}
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)

View File

@@ -3,7 +3,6 @@ package pool_test
import (
"bufio"
"context"
"net"
"unsafe"
. "github.com/bsm/ginkgo/v2"
@@ -124,20 +123,26 @@ var _ = Describe("Buffer Size Configuration", func() {
})
// Helper functions to extract buffer sizes using unsafe pointers
// The struct layout must match pool.Conn exactly to avoid checkptr violations.
// checkptr is Go's pointer safety checker, which ensures that unsafe pointer
// conversions are valid. If the struct layouts do not match exactly, this can
// cause runtime panics or incorrect memory access due to invalid pointer dereferencing.
func getWriterBufSizeUnsafe(cn *pool.Conn) int {
cnPtr := (*struct {
usedAt int64
netConn net.Conn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// ... other fields
id uint64 // First field in pool.Conn
usedAt int64 // Second field (atomic)
netConnAtomic interface{} // atomic.Value (interface{} has same size)
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// We only need fields up to bw, so we can stop here
})(unsafe.Pointer(cn))
if cnPtr.bw == nil {
return -1
}
// bufio.Writer internal structure
bwPtr := (*struct {
err error
buf []byte
@@ -150,18 +155,20 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int {
func getReaderBufSizeUnsafe(cn *pool.Conn) int {
cnPtr := (*struct {
usedAt int64
netConn net.Conn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// ... other fields
id uint64 // First field in pool.Conn
usedAt int64 // Second field (atomic)
netConnAtomic interface{} // atomic.Value (interface{} has same size)
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
// We only need fields up to rd, so we can stop here
})(unsafe.Pointer(cn))
if cnPtr.rd == nil {
return -1
}
// proto.Reader internal structure
rdPtr := (*struct {
rd *bufio.Reader
})(unsafe.Pointer(cnPtr.rd))
@@ -170,6 +177,7 @@ func getReaderBufSizeUnsafe(cn *pool.Conn) int {
return -1
}
// bufio.Reader internal structure
bufReaderPtr := (*struct {
buf []byte
rd interface{}

View File

@@ -40,6 +40,9 @@ func generateConnID() uint64 {
}
type Conn struct {
// Connection identifier for unique tracking
id uint64
usedAt int64 // atomic
// Lock-free netConn access using atomic.Value
@@ -54,7 +57,34 @@ type Conn struct {
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
// Design note:
// Why have both Usable and Used?
// _Usable_ is used to mark a connection as safe for use by clients, the connection can still
// be in the pool but not Usable at the moment (e.g. handoff in progress).
// _Used_ is used to mark a connection as used when a command is going to be processed on that connection.
// this is going to happen once the connection is picked from the pool.
//
// If a background operation needs to use the connection, it will mark it as Not Usable and only use it when it
// is not in use. That way, the connection won't be used to send multiple commands at the same time and
// potentially corrupt the command stream.
// usable flag to mark connection as safe for use
// It is false before initialization and after a handoff is marked
// It will be false during other background operations like re-authentication
usable atomic.Bool
// used flag to mark connection as used when a command is going to be
// processed on that connection. This is used to prevent a race condition with
// background operations that may execute commands, like re-authentication.
used atomic.Bool
// Inited flag to mark connection as initialized, this is almost the same as usable
// but it is used to make sure we don't initialize a network connection twice
// On handoff, the network connection is replaced, but the Conn struct is reused
// this flag will be set to false when the network connection is replaced and
// set to true after the new network connection is initialized
Inited atomic.Bool
pooled bool
pubsub bool
closed atomic.Bool
@@ -75,11 +105,7 @@ type Conn struct {
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// Atomic handoff state to prevent race conditions
@@ -116,7 +142,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.usable.Store(false) // false initially, set to true after initialization
cn.handoffRetriesAtomic.Store(0) // 0 initially
// Initialize handoff state atomically
@@ -141,6 +167,73 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
// Usable
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
//
// This is used by background operations (handoff, re-auth) to acquire exclusive
// access to a connection. The operation sets usable to false, preventing the pool
// from returning the connection to clients.
//
// Returns true if the swap was successful (old value matched), false otherwise.
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
return cn.usable.CompareAndSwap(old, new)
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
//
// A connection is "usable" when it's in a stable state and can be returned to clients.
// It becomes unusable during:
// - Initialization (before first use)
// - Handoff operations (network connection replacement)
// - Re-authentication (credential updates)
// - Other background operations that need exclusive access
func (cn *Conn) IsUsable() bool {
return cn.usable.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
//
// This should be called to mark a connection as usable after initialization or
// to release it after a background operation completes.
//
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
func (cn *Conn) SetUsable(usable bool) {
cn.usable.Store(usable)
}
// Used
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
//
// This is the preferred method for acquiring a connection from the pool, as it
// ensures that only one goroutine marks the connection as used.
//
// Returns true if the swap was successful (old value matched), false otherwise.
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
return cn.used.CompareAndSwap(old, new)
}
// IsUsed returns true if the connection is currently in use (lock-free).
//
// A connection is "used" when it has been retrieved from the pool and is
// actively processing a command. Background operations (like re-auth) should
// wait until the connection is not used before executing commands.
func (cn *Conn) IsUsed() bool {
return cn.used.Load()
}
// SetUsed sets the used flag for the connection (lock-free).
//
// This should be called when returning a connection to the pool (set to false)
// or when a single-connection pool retrieves its connection (set to true).
//
// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to
// avoid race conditions.
func (cn *Conn) SetUsed(val bool) {
cn.used.Store(val)
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
@@ -158,18 +251,6 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
}
// getHandoffState returns the current handoff state atomically (lock-free).
func (cn *Conn) getHandoffState() *HandoffState {
state := cn.handoffStateAtomic.Load()
@@ -214,11 +295,6 @@ func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
func (cn *Conn) IsPooled() bool {
return cn.pooled
@@ -233,11 +309,6 @@ func (cn *Conn) IsInited() bool {
return cn.Inited.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
@@ -455,9 +526,26 @@ func (cn *Conn) MarkQueuedForHandoff() error {
const maxRetries = 50
const baseDelay = time.Microsecond
connAcquired := false
for attempt := 0; attempt < maxRetries; attempt++ {
currentState := cn.getHandoffState()
// If CAS failed, add exponential backoff to reduce contention
// the delay will be 1, 2, 4... up to 512 microseconds
// Moving this to the top of the loop to avoid "continue" without delay
if attempt > 0 && attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
// first we need to mark the connection as not usable
// to prevent the pool from returning it to the caller
if !connAcquired {
if !cn.usable.CompareAndSwap(true, false) {
continue
}
connAcquired = true
}
currentState := cn.getHandoffState()
// Check if marked for handoff
if !currentState.ShouldHandoff {
return errors.New("connection was not marked for handoff")
@@ -472,16 +560,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {
// Atomic compare-and-swap to update state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
cn.setUsable(false)
// queue the handoff for processing
// the connection is now "acquired" (marked as not usable) by the handoff
// and it won't be returned to any other callers until the handoff is complete
return nil
}
// If CAS failed, add exponential backoff to reduce contention
// the delay will be 1, 2, 4... up to 512 microseconds
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
}
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
@@ -527,7 +611,8 @@ func (cn *Conn) ClearHandoffState() {
// Atomically set clean state
cn.setHandoffState(cleanState)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
// Clearing handoff state also means the connection is usable again
cn.SetUsable(true)
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).

View File

@@ -9,13 +9,27 @@ import (
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// The accept flag can be used to prevent the connection from being used.
// On Accept = false the connection is rejected and returned to the pool.
// The error can be used to prevent the connection from being used and returned to the pool.
// On Errors, the connection is removed from the pool.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
// OnRemove is called when a connection is removed from the pool.
// This happens when:
// - Connection fails health check
// - Connection exceeds max lifetime
// - Pool is being closed
// - Connection encounters an error
// Implementations should clean up any per-connection state.
// The reason parameter indicates why the connection was removed.
OnRemove(ctx context.Context, conn *Conn, reason error)
}
// PoolHookManager manages multiple pool hooks.
@@ -56,16 +70,21 @@ func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
return err
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
if err != nil {
return false, err
}
if !acceptConn {
return false, nil
}
}
return nil
return true, nil
}
// ProcessOnPut calls all OnPut hooks in order.
@@ -96,6 +115,15 @@ func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shoul
return shouldPool, false, nil
}
// ProcessOnRemove calls all OnRemove hooks in order.
func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
hook.OnRemove(ctx, conn, reason)
}
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()

View File

@@ -10,17 +10,19 @@ import (
// TestHook for testing hook functionality
type TestHook struct {
OnGetCalled int
OnPutCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
OnGetCalled int
OnPutCalled int
OnRemoveCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
ShouldAccept bool
}
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) (bool, error) {
th.OnGetCalled++
return th.GetError
return th.ShouldAccept, th.GetError
}
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
@@ -28,6 +30,10 @@ func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, sho
return th.ShouldPool, th.ShouldRemove, th.PutError
}
func (th *TestHook) OnRemove(ctx context.Context, conn *Conn, reason error) {
th.OnRemoveCalled++
}
func TestPoolHookManager(t *testing.T) {
manager := NewPoolHookManager()
@@ -37,8 +43,8 @@ func TestPoolHookManager(t *testing.T) {
}
// Add hooks
hook1 := &TestHook{ShouldPool: true}
hook2 := &TestHook{ShouldPool: true}
hook1 := &TestHook{ShouldPool: true, ShouldAccept: true}
hook2 := &TestHook{ShouldPool: true, ShouldAccept: true}
manager.AddHook(hook1)
manager.AddHook(hook2)
@@ -51,10 +57,13 @@ func TestPoolHookManager(t *testing.T) {
ctx := context.Background()
conn := &Conn{} // Mock connection
err := manager.ProcessOnGet(ctx, conn, false)
accept, err := manager.ProcessOnGet(ctx, conn, false)
if err != nil {
t.Errorf("ProcessOnGet should not error: %v", err)
}
if !accept {
t.Error("Expected accept to be true")
}
if hook1.OnGetCalled != 1 {
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
@@ -99,11 +108,12 @@ func TestHookErrorHandling(t *testing.T) {
// Hook that returns error on Get
errorHook := &TestHook{
GetError: errors.New("test error"),
ShouldPool: true,
GetError: errors.New("test error"),
ShouldPool: true,
ShouldAccept: true,
}
normalHook := &TestHook{ShouldPool: true}
normalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
manager.AddHook(errorHook)
manager.AddHook(normalHook)
@@ -112,10 +122,13 @@ func TestHookErrorHandling(t *testing.T) {
conn := &Conn{}
// Test that error stops processing
err := manager.ProcessOnGet(ctx, conn, false)
accept, err := manager.ProcessOnGet(ctx, conn, false)
if err == nil {
t.Error("Expected error from ProcessOnGet")
}
if accept {
t.Error("Expected accept to be false")
}
if errorHook.OnGetCalled != 1 {
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
@@ -134,9 +147,10 @@ func TestHookShouldRemove(t *testing.T) {
removeHook := &TestHook{
ShouldPool: false,
ShouldRemove: true,
ShouldAccept: true,
}
normalHook := &TestHook{ShouldPool: true}
normalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
manager.AddHook(removeHook)
manager.AddHook(normalHook)
@@ -170,7 +184,7 @@ func TestHookShouldRemove(t *testing.T) {
func TestPoolWithHooks(t *testing.T) {
// Create a pool with hooks
hookManager := NewPoolHookManager()
testHook := &TestHook{ShouldPool: true}
testHook := &TestHook{ShouldPool: true, ShouldAccept: true}
hookManager.AddHook(testHook)
opt := &Options{
@@ -197,7 +211,7 @@ func TestPoolWithHooks(t *testing.T) {
}
// Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true}
additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 {

View File

@@ -24,6 +24,9 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
// popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
@@ -78,6 +81,10 @@ type Pooler interface {
IdleLen() int
Stats() *Stats
// Size returns the maximum pool size (capacity).
// This is used by the streaming credentials manager to size the re-auth worker pool.
Size() int
AddPoolHook(hook PoolHook)
RemovePoolHook(hook PoolHook)
@@ -236,6 +243,7 @@ func (p *ConnPool) addIdleConn() error {
if err != nil {
return err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
@@ -277,6 +285,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
if err != nil {
return nil, err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
@@ -428,6 +437,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
now := time.Now()
attempts := 0
// Get hooks manager once for this getConn call for performance.
// Note: Hooks added/removed during this call won't be reflected.
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
for {
if attempts >= getAttempts {
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
@@ -454,17 +470,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
if err != nil {
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
// Failed to process connection, discard it
_ = p.CloseConn(cn)
continue
}
if !acceptConn {
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
p.Put(ctx, cn)
cn = nil
continue
}
}
atomic.AddUint32(&p.stats.Hits, 1)
@@ -480,14 +498,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true)
// both errors and accept=false mean a hook rejected the connection
// this should not happen with a new connection, but we handle it gracefully
if err != nil || !acceptConn {
// Failed to process connection, discard it
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err)
_ = p.CloseConn(newcn)
return nil, err
}
@@ -567,9 +584,12 @@ func (p *ConnPool) popIdle() (*Conn, error) {
}
attempts++
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
if cn.CompareAndSwapUsed(false, true) {
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
}
cn.SetUsed(false)
}
// Connection is not usable, put it back in the pool
@@ -664,6 +684,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
shouldCloseConn = true
}
// if the connection is not going to be closed, mark it as not used
if !shouldCloseConn {
cn.SetUsed(false)
}
p.freeTurn()
if shouldCloseConn {
@@ -671,7 +696,15 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
}
}
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
hookManager.ProcessOnRemove(ctx, cn, reason)
}
p.removeConnWithLock(cn)
p.freeTurn()
@@ -733,6 +766,14 @@ func (p *ConnPool) IdleLen() int {
return int(n)
}
// Size returns the maximum pool size (capacity).
//
// This is used by the streaming credentials manager to size the re-auth worker pool,
// ensuring that re-auth operations don't exhaust the connection pool.
func (p *ConnPool) Size() int {
return int(p.cfg.PoolSize)
}
func (p *ConnPool) Stats() *Stats {
return &Stats{
Hits: atomic.LoadUint32(&p.stats.Hits),

View File

@@ -2,8 +2,12 @@ package pool
import (
"context"
"time"
)
// SingleConnPool is a pool that always returns the same connection.
// Note: This pool is not thread-safe.
// It is intended to be used by clients that need a single connection.
type SingleConnPool struct {
pool Pooler
cn *Conn
@@ -12,6 +16,12 @@ type SingleConnPool struct {
var _ Pooler = (*SingleConnPool)(nil)
// NewSingleConnPool creates a new single connection pool.
// The pool will always return the same connection.
// The pool will not:
// - Close the connection
// - Reconnect the connection
// - Track the connection in any way
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
return &SingleConnPool{
pool: pool,
@@ -27,16 +37,30 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
if p.stickyErr != nil {
return nil, p.stickyErr
}
if p.cn == nil {
return nil, ErrClosed
}
p.cn.SetUsed(true)
p.cn.SetUsedAt(time.Now())
return p.cn, nil
}
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
func (p *SingleConnPool) Put(_ context.Context, cn *Conn) {
if p.cn == nil {
return
}
if p.cn != cn {
return
}
p.cn.SetUsed(false)
}
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
cn.SetUsed(false)
p.cn = nil
p.stickyErr = reason
}
@@ -55,10 +79,13 @@ func (p *SingleConnPool) IdleLen() int {
return 0
}
// Size returns the maximum pool size, which is always 1 for SingleConnPool.
func (p *SingleConnPool) Size() int { return 1 }
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
func (p *SingleConnPool) AddPoolHook(_ PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {}

View File

@@ -196,6 +196,9 @@ func (p *StickyConnPool) IdleLen() int {
return len(p.ch)
}
// Size returns the maximum pool size, which is always 1 for StickyConnPool.
func (p *StickyConnPool) Size() int { return 1 }
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}

View File

@@ -497,9 +497,14 @@ func TestDialerRetryConfiguration(t *testing.T) {
}
// Should have attempted 5 times (default DialerRetries = 5)
// Note: There may be one additional attempt from tryDial() goroutine
// which is launched when dialErrorsNum reaches PoolSize
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts != 5 {
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
if finalAttempts < 5 {
t.Errorf("Expected at least 5 dial attempts (default), got %d", finalAttempts)
}
if finalAttempts > 6 {
t.Errorf("Expected around 5 dial attempts, got %d (too many)", finalAttempts)
}
})
}

View File

@@ -24,6 +24,8 @@ type PubSubPool struct {
stats PubSubStats
}
// PubSubPool implements a pool for PubSub connections.
// It intentionally does not implement the Pooler interface
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,

View File

@@ -378,8 +378,12 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
}
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
func (hwm *handoffWorkerManager) performHandoffInternal(
ctx context.Context,
conn *pool.Conn,
newEndpoint string,
connID uint64,
) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
maxRetries := 3 // Default fallback
@@ -438,9 +442,14 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con
}
}()
// Clear handoff state will:
// - set the connection as usable again
// - clear the handoff state (shouldHandoff, endpoint, seqID)
// - reset the handoff retries to 0
conn.ClearHandoffState()
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
// successfully completed the handoff, no retry needed and no error
return false, nil
}
@@ -472,7 +481,10 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
}
} else {
conn.Close()
err := conn.Close() // Close the connection if no pool provided
if err != nil {
internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err)
}
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
}

View File

@@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() {
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
// in a handoff state at the moment.
// Check if connection is usable (not in a handoff state)
// Should not happen since the pool will not return a connection that is not usable.
if !conn.IsUsable() {
return ErrConnectionMarkedForHandoff
return false, ErrConnectionMarkedForHandoff
}
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
if conn.ShouldHandoff() {
return ErrConnectionMarkedForHandoff
return false, ErrConnectionMarkedForHandoff
}
return nil
return true, nil
}
// OnPut is called when a connection is returned to the pool
@@ -174,6 +174,10 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool
return true, false, nil
}
func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) {
// Not used
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
return ph.workerManager.shutdownWorkers(ctx)

View File

@@ -92,6 +92,10 @@ func (mp *mockPool) Stats() *pool.Stats {
return &pool.Stats{}
}
func (mp *mockPool) Size() int {
return 0
}
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
@@ -356,10 +360,13 @@ func TestConnectionHook(t *testing.T) {
conn := createMockPoolConnection()
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
acceptCon, err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should not error for normal connection: %v", err)
}
if !acceptCon {
t.Error("Connection should be accepted for normal connection")
}
})
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
@@ -381,10 +388,13 @@ func TestConnectionHook(t *testing.T) {
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
acceptCon, err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
if acceptCon {
t.Error("Connection should not be accepted when marked for handoff")
}
// Clean up
processor.GetPendingMap().Delete(conn)
@@ -412,10 +422,13 @@ func TestConnectionHook(t *testing.T) {
// Test OnGet with pending handoff
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
acceptCon, err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
}
if acceptCon {
t.Error("Should not accept connection with pending handoff")
}
// Test removing from pending map and clearing handoff state
processor.GetPendingMap().Delete(conn)
@@ -428,10 +441,13 @@ func TestConnectionHook(t *testing.T) {
conn.SetUsable(true) // Make connection usable again
// Test OnGet without pending handoff
err = processor.OnGet(ctx, conn, false)
acceptCon, err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("Should not return error for non-pending connection: %v", err)
}
if !acceptCon {
t.Error("Should accept connection without pending handoff")
}
})
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
@@ -624,11 +640,15 @@ func TestConnectionHook(t *testing.T) {
}
// OnGet should succeed for usable connection
err := processor.OnGet(ctx, conn, false)
acceptConn, err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed for usable connection: %v", err)
}
if !acceptConn {
t.Error("Connection should be accepted when usable")
}
// Mark connection for handoff
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
@@ -648,13 +668,17 @@ func TestConnectionHook(t *testing.T) {
}
// OnGet should fail for connection marked for handoff
err = processor.OnGet(ctx, conn, false)
acceptConn, err = processor.OnGet(ctx, conn, false)
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
if acceptConn {
t.Error("Connection should not be accepted when marked for handoff")
}
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
@@ -674,11 +698,15 @@ func TestConnectionHook(t *testing.T) {
}
// OnGet should succeed again
err = processor.OnGet(ctx, conn, false)
acceptConn, err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed after handoff completion: %v", err)
}
if !acceptConn {
t.Error("Connection should be accepted after handoff completion")
}
t.Logf("Usable flag behavior test completed successfully")
})

View File

@@ -465,7 +465,6 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
}
// Don't hold the lock to allow subscriptions and pings.
cn, err := c.connWithLock(ctx)
if err != nil {
return nil, err

View File

@@ -11,6 +11,7 @@ import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/auth/streaming"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
@@ -224,6 +225,9 @@ type baseClient struct {
// Maintenance notifications manager
maintNotificationsManager *maintnotifications.Manager
maintNotificationsManagerLock sync.RWMutex
// streamingCredentialsManager is used to manage streaming credentials
streamingCredentialsManager *streaming.Manager
}
func (c *baseClient) clone() *baseClient {
@@ -232,11 +236,12 @@ func (c *baseClient) clone() *baseClient {
c.maintNotificationsManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
maintNotificationsManager: maintNotificationsManager,
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
maintNotificationsManager: maintNotificationsManager,
streamingCredentialsManager: c.streamingCredentialsManager,
}
return clone
}
@@ -296,32 +301,30 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
return auth.NewReAuthCredentialsListener(
c.reAuthConnection(poolCn),
c.onAuthenticationErr(poolCn),
)
}
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
return func(credentials auth.Credentials) error {
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
var err error
username, password := credentials.BasicAuth()
// Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter
ctx := context.Background()
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
// hooksMixin are intentionally empty here
cn := newConn(c.opt, connPool, nil)
// Pass hooks so that reauth commands are recorded/traced
cn := newConn(c.opt, connPool, &c.hooksMixin)
if username != "" {
err = cn.AuthACL(ctx, username, password).Err()
} else {
err = cn.Auth(ctx, password).Err()
}
return err
}
}
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
return func(err error) {
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
return func(poolCn *pool.Conn, err error) {
if err != nil {
if isBadConn(err, false, c.opt.Addr) {
// Close the connection to force a reconnection.
@@ -372,13 +375,24 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credListener, err := c.streamingCredentialsManager.Listener(
cn,
c.reAuthConnection(),
c.onAuthenticationErr(),
)
if err != nil {
return fmt.Errorf("failed to create credentials listener: %w", err)
}
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(c.newReAuthCredentialsListener(cn))
Subscribe(credListener)
if err != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
}
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
cn.SetOnClose(unsubscribeFromCredentialsProvider)
username, password = credentials.BasicAuth()
} else if c.opt.CredentialsProviderContext != nil {
username, password, err = c.opt.CredentialsProviderContext(ctx)
@@ -496,7 +510,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
}
}
// mark the connection as usable and inited
// once returned to the pool as idle, this connection can be used by other clients
cn.SetUsable(true)
cn.SetUsed(false)
cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections
@@ -952,6 +969,11 @@ func NewClient(opt *Options) *Client {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
if opt.StreamingCredentialsProvider != nil {
c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout)
c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook())
}
// Initialize maintnotifications first if enabled and protocol is RESP3
if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 {
err := c.enableMaintNotificationsUpgrades()

View File

@@ -854,24 +854,34 @@ var _ = Describe("Credentials Provider Priority", func() {
credentials: initialCreds,
updates: updatesChan,
},
PoolSize: 1, // Force single connection to ensure reauth is tested
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
time.Sleep(10 * time.Millisecond)
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
// Update credentials
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
// Wait for reauth to complete and verify updated credentials are used
// We need to keep trying Ping until we see the updated AUTH command
// because the reauth happens asynchronously
Eventually(func() bool {
// wrongpass
_ = client.Ping(context.Background()).Err()
return recorder.Contains("AUTH updated_user")
}, "1s", "50ms").Should(BeTrue())
close(updatesChan)
})
})
type mockStreamingProvider struct {
mu sync.RWMutex
credentials auth.Credentials
err error
updates chan auth.Credentials
@@ -882,21 +892,50 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
return nil, nil, m.err
}
if listener == nil {
return nil, nil, errors.New("listener cannot be nil")
}
// Create a done channel to stop the goroutine
done := make(chan struct{})
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
m.credentials = creds
listener.OnNext(creds)
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow panics to be caught without crashing
}
}()
for {
select {
case <-done:
return
case creds, ok := <-m.updates:
if !ok {
return
}
m.mu.Lock()
m.credentials = creds
m.mu.Unlock()
listener.OnNext(creds)
}
}
}()
return m.credentials, func() (err error) {
m.mu.RLock()
currentCreds := m.credentials
m.mu.RUnlock()
return currentCreds, func() (err error) {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow multiple closes from multiple listeners
}
}()
close(done)
return
}, nil
}

View File

@@ -410,7 +410,9 @@ var _ = Describe("SentinelAclAuth", func() {
})
})
func TestParseFailoverURL(t *testing.T) {
// renaming from TestParseFailoverURL to TestParseSentinelURL
// to be easier to find Failed tests in the test output
func TestParseSentinelURL(t *testing.T) {
cases := []struct {
url string
o *redis.FailoverOptions