mirror of
https://github.com/redis/go-redis.git
synced 2025-10-27 18:15:32 +03:00
fix(pool): Pool ReAuth should not interfere with handoff (#3547)
* fix(pool): wip, pool reauth should not interfere with handoff * fix credListeners map * fix race in tests * better conn usable timeout * add design decision comment * few small improvements * update marked as queued * add Used to clarify the state of the conn * rename test * fix(test): fix flaky test * lock inside the listeners collection * address pr comments * Update internal/auth/cred_listeners.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/buffer_size_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * wip refactor entraid * fix maintnotif pool hook * fix mocks * fix nil listener * sync and async reauth based on conn lifecycle * be able to reject connection OnGet * pass hooks so the tests can observe reauth * give some time for the background to execute commands * fix tests * only async reauth * Update internal/pool/pool.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/auth/streaming/pool_hook.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore(redisotel): use metric.WithAttributeSet to avoid copy (#3552) In order to improve performance replace `WithAttributes` with `WithAttributeSet`. This avoids the slice allocation and copy that is done in `WithAttributes`. For more information see https://github.com/open-telemetry/opentelemetry-go/blob/v1.38.0/metric/instrument.go#L357-L376 * chore(docs): explain why MaxRetries is disabled for ClusterClient (#3551) Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> * exponential backoff * address pr comments * address pr comments * remove rlock * add some comments * add comments --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Warnar Boekkooi <wboekkooi@impossiblecloud.com> Co-authored-by: Justin <justindsouza80@gmail.com>
This commit is contained in:
@@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on
|
||||
}
|
||||
|
||||
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
|
||||
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
|
||||
10
error.go
10
error.go
@@ -108,10 +108,12 @@ func isRedisError(err error) bool {
|
||||
|
||||
func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
switch err {
|
||||
case nil:
|
||||
return false
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
return true
|
||||
case nil:
|
||||
return false
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
return true
|
||||
case pool.ErrConnUnusableTimeout:
|
||||
return true
|
||||
}
|
||||
|
||||
if isRedisError(err) {
|
||||
|
||||
100
internal/auth/streaming/conn_reauth_credentials_listener.go
Normal file
100
internal/auth/streaming/conn_reauth_credentials_listener.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ConnReAuthCredentialsListener is a credentials listener for a specific connection
|
||||
// that triggers re-authentication when credentials change.
|
||||
//
|
||||
// This listener implements the auth.CredentialsListener interface and is subscribed
|
||||
// to a StreamingCredentialsProvider. When new credentials are received via OnNext,
|
||||
// it marks the connection for re-authentication through the manager.
|
||||
//
|
||||
// The re-authentication is always performed asynchronously to avoid blocking the
|
||||
// credentials provider and to prevent potential deadlocks with the pool semaphore.
|
||||
// The actual re-auth happens when the connection is returned to the pool in an idle state.
|
||||
//
|
||||
// Lifecycle:
|
||||
// - Created during connection initialization via Manager.Listener()
|
||||
// - Subscribed to the StreamingCredentialsProvider
|
||||
// - Receives credential updates via OnNext()
|
||||
// - Cleaned up when connection is removed from pool via Manager.RemoveListener()
|
||||
type ConnReAuthCredentialsListener struct {
|
||||
// reAuth is the function to re-authenticate the connection with new credentials
|
||||
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
|
||||
|
||||
// onErr is the function to call when re-authentication or acquisition fails
|
||||
onErr func(conn *pool.Conn, err error)
|
||||
|
||||
// conn is the connection this listener is associated with
|
||||
conn *pool.Conn
|
||||
|
||||
// manager is the streaming credentials manager for coordinating re-auth
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// OnNext is called when new credentials are received from the StreamingCredentialsProvider.
|
||||
//
|
||||
// This method marks the connection for asynchronous re-authentication. The actual
|
||||
// re-authentication happens in the background when the connection is returned to the
|
||||
// pool and is in an idle state.
|
||||
//
|
||||
// Asynchronous re-auth is used to:
|
||||
// - Avoid blocking the credentials provider's notification goroutine
|
||||
// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes)
|
||||
// - Ensure re-auth happens when the connection is safe to use (not processing commands)
|
||||
//
|
||||
// The reAuthFn callback receives:
|
||||
// - nil if the connection was successfully acquired for re-auth
|
||||
// - error if acquisition timed out or failed
|
||||
//
|
||||
// Thread-safe: Called by the credentials provider's notification goroutine.
|
||||
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
|
||||
if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Always use async reauth to avoid complex pool semaphore issues
|
||||
// The synchronous path can cause deadlocks in the pool's semaphore mechanism
|
||||
// when called from the Subscribe goroutine, especially with small pool sizes.
|
||||
// The connection pool hook will re-authenticate the connection when it is
|
||||
// returned to the pool in a clean, idle state.
|
||||
c.manager.MarkForReAuth(c.conn, func(err error) {
|
||||
// err is from connection acquisition (timeout, etc.)
|
||||
if err != nil {
|
||||
// Log the error
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
// err is from reauth command execution
|
||||
err = c.reAuth(c.conn, credentials)
|
||||
if err != nil {
|
||||
// Log the error
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnError is called when an error occurs during credential streaming or re-authentication.
|
||||
//
|
||||
// This method can be called from:
|
||||
// - The StreamingCredentialsProvider when there's an error in the credentials stream
|
||||
// - The re-auth process when connection acquisition times out
|
||||
// - The re-auth process when the AUTH command fails
|
||||
//
|
||||
// The error is delegated to the onErr callback provided during listener creation.
|
||||
//
|
||||
// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker).
|
||||
func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
||||
if c.onErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.onErr(c.conn, err)
|
||||
}
|
||||
|
||||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
||||
77
internal/auth/streaming/cred_listeners.go
Normal file
77
internal/auth/streaming/cred_listeners.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
||||
// CredentialsListeners is a thread-safe collection of credentials listeners
|
||||
// indexed by connection ID.
|
||||
//
|
||||
// This collection is used by the Manager to maintain a registry of listeners
|
||||
// for each connection in the pool. Listeners are reused when connections are
|
||||
// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions
|
||||
// to the StreamingCredentialsProvider.
|
||||
//
|
||||
// The collection supports concurrent access from multiple goroutines during
|
||||
// connection initialization, credential updates, and connection removal.
|
||||
type CredentialsListeners struct {
|
||||
// listeners maps connection ID to credentials listener
|
||||
listeners map[uint64]auth.CredentialsListener
|
||||
|
||||
// lock protects concurrent access to the listeners map
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCredentialsListeners creates a new thread-safe credentials listeners collection.
|
||||
func NewCredentialsListeners() *CredentialsListeners {
|
||||
return &CredentialsListeners{
|
||||
listeners: make(map[uint64]auth.CredentialsListener),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds or updates a credentials listener for a connection.
|
||||
//
|
||||
// If a listener already exists for the connection ID, it is replaced.
|
||||
// This is safe because the old listener should have been unsubscribed
|
||||
// before the connection was reinitialized.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.listeners == nil {
|
||||
c.listeners = make(map[uint64]auth.CredentialsListener)
|
||||
}
|
||||
c.listeners[connID] = listener
|
||||
}
|
||||
|
||||
// Get retrieves the credentials listener for a connection.
|
||||
//
|
||||
// Returns:
|
||||
// - listener: The credentials listener for the connection, or nil if not found
|
||||
// - ok: true if a listener exists for the connection ID, false otherwise
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
if len(c.listeners) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
listener, ok := c.listeners[connID]
|
||||
return listener, ok
|
||||
}
|
||||
|
||||
// Remove removes the credentials listener for a connection.
|
||||
//
|
||||
// This is called when a connection is removed from the pool to prevent
|
||||
// memory leaks. If no listener exists for the connection ID, this is a no-op.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Remove(connID uint64) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
delete(c.listeners, connID)
|
||||
}
|
||||
137
internal/auth/streaming/manager.go
Normal file
137
internal/auth/streaming/manager.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// Manager coordinates streaming credentials and re-authentication for a connection pool.
|
||||
//
|
||||
// The manager is responsible for:
|
||||
// - Creating and managing per-connection credentials listeners
|
||||
// - Providing the pool hook for re-authentication
|
||||
// - Coordinating between credentials updates and pool operations
|
||||
//
|
||||
// When credentials change via a StreamingCredentialsProvider:
|
||||
// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update
|
||||
// 2. It calls MarkForReAuth on the manager
|
||||
// 3. The manager delegates to the pool hook
|
||||
// 4. The pool hook schedules background re-authentication
|
||||
//
|
||||
// The manager maintains a registry of credentials listeners indexed by connection ID,
|
||||
// allowing listener reuse when connections are reinitialized (e.g., after handoff).
|
||||
type Manager struct {
|
||||
// credentialsListeners maps connection ID to credentials listener
|
||||
credentialsListeners *CredentialsListeners
|
||||
|
||||
// pool is the connection pool being managed
|
||||
pool pool.Pooler
|
||||
|
||||
// poolHookRef is the re-authentication pool hook
|
||||
poolHookRef *ReAuthPoolHook
|
||||
}
|
||||
|
||||
// NewManager creates a new streaming credentials manager.
|
||||
//
|
||||
// Parameters:
|
||||
// - pl: The connection pool to manage
|
||||
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
|
||||
//
|
||||
// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
|
||||
m := &Manager{
|
||||
pool: pl,
|
||||
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
|
||||
credentialsListeners: NewCredentialsListeners(),
|
||||
}
|
||||
m.poolHookRef.manager = m
|
||||
return m
|
||||
}
|
||||
|
||||
// PoolHook returns the pool hook for re-authentication.
|
||||
//
|
||||
// This hook should be registered with the connection pool to enable
|
||||
// automatic re-authentication when credentials change.
|
||||
func (m *Manager) PoolHook() pool.PoolHook {
|
||||
return m.poolHookRef
|
||||
}
|
||||
|
||||
// Listener returns or creates a credentials listener for a connection.
|
||||
//
|
||||
// This method is called during connection initialization to set up the
|
||||
// credentials listener. If a listener already exists for the connection ID
|
||||
// (e.g., after a handoff), it is reused.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolCn: The connection to create/get a listener for
|
||||
// - reAuth: Function to re-authenticate the connection with new credentials
|
||||
// - onErr: Function to call when re-authentication fails
|
||||
//
|
||||
// Returns:
|
||||
// - auth.CredentialsListener: The listener to subscribe to the credentials provider
|
||||
// - error: Non-nil if poolCn is nil
|
||||
//
|
||||
// Note: The reAuth and onErr callbacks are captured once when the listener is
|
||||
// created and reused for the connection's lifetime. They should not change.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently during connection initialization.
|
||||
func (m *Manager) Listener(
|
||||
poolCn *pool.Conn,
|
||||
reAuth func(*pool.Conn, auth.Credentials) error,
|
||||
onErr func(*pool.Conn, error),
|
||||
) (auth.CredentialsListener, error) {
|
||||
if poolCn == nil {
|
||||
return nil, errors.New("poolCn cannot be nil")
|
||||
}
|
||||
connID := poolCn.GetID()
|
||||
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
|
||||
// so we can get the old listener from the cache and use it.
|
||||
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
|
||||
listener, ok := m.credentialsListeners.Get(connID)
|
||||
if !ok || listener == nil {
|
||||
// Create new listener for this connection
|
||||
// Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime
|
||||
newCredListener := &ConnReAuthCredentialsListener{
|
||||
conn: poolCn,
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
manager: m,
|
||||
}
|
||||
|
||||
m.credentialsListeners.Add(connID, newCredListener)
|
||||
listener = newCredListener
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// MarkForReAuth marks a connection for re-authentication.
|
||||
//
|
||||
// This method is called by the credentials listener when new credentials are
|
||||
// received. It delegates to the pool hook to schedule background re-authentication.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolCn: The connection to re-authenticate
|
||||
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
|
||||
//
|
||||
// Thread-safe: Called by credentials listeners when credentials change.
|
||||
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
|
||||
connID := poolCn.GetID()
|
||||
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
|
||||
}
|
||||
|
||||
// RemoveListener removes the credentials listener for a connection.
|
||||
//
|
||||
// This method is called by the pool hook's OnRemove to clean up listeners
|
||||
// when connections are removed from the pool.
|
||||
//
|
||||
// Parameters:
|
||||
// - connID: The connection ID whose listener should be removed
|
||||
//
|
||||
// Thread-safe: Called during connection removal.
|
||||
func (m *Manager) RemoveListener(connID uint64) {
|
||||
m.credentialsListeners.Remove(connID)
|
||||
}
|
||||
101
internal/auth/streaming/manager_test.go
Normal file
101
internal/auth/streaming/manager_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// Test that Listener returns the newly created listener, not nil
|
||||
func TestManager_Listener_ReturnsNewListener(t *testing.T) {
|
||||
// Create a mock pool
|
||||
mockPool := &mockPooler{}
|
||||
|
||||
// Create manager
|
||||
manager := NewManager(mockPool, time.Second)
|
||||
|
||||
// Create a mock connection
|
||||
conn := &pool.Conn{}
|
||||
|
||||
// Mock functions
|
||||
reAuth := func(cn *pool.Conn, creds auth.Credentials) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
onErr := func(cn *pool.Conn, err error) {
|
||||
}
|
||||
|
||||
// Get listener - this should create a new one
|
||||
listener, err := manager.Listener(conn, reAuth, onErr)
|
||||
|
||||
// Verify no error
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify listener is not nil (this was the bug!)
|
||||
if listener == nil {
|
||||
t.Fatal("Expected listener to be non-nil, but got nil")
|
||||
}
|
||||
|
||||
// Verify it's the correct type
|
||||
if _, ok := listener.(*ConnReAuthCredentialsListener); !ok {
|
||||
t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener)
|
||||
}
|
||||
|
||||
// Get the same listener again - should return the existing one
|
||||
listener2, err := manager.Listener(conn, reAuth, onErr)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error on second call, got: %v", err)
|
||||
}
|
||||
|
||||
if listener2 == nil {
|
||||
t.Fatal("Expected listener2 to be non-nil")
|
||||
}
|
||||
|
||||
// Should be the same instance
|
||||
if listener != listener2 {
|
||||
t.Error("Expected to get the same listener instance on second call")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Listener returns error when conn is nil
|
||||
func TestManager_Listener_NilConn(t *testing.T) {
|
||||
mockPool := &mockPooler{}
|
||||
manager := NewManager(mockPool, time.Second)
|
||||
|
||||
listener, err := manager.Listener(nil, nil, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when conn is nil, got nil")
|
||||
}
|
||||
|
||||
if listener != nil {
|
||||
t.Error("Expected listener to be nil when error occurs")
|
||||
}
|
||||
|
||||
expectedErr := "poolCn cannot be nil"
|
||||
if err.Error() != expectedErr {
|
||||
t.Errorf("Expected error message %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Mock pooler for testing
|
||||
type mockPooler struct{}
|
||||
|
||||
func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil }
|
||||
func (m *mockPooler) CloseConn(*pool.Conn) error { return nil }
|
||||
func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil }
|
||||
func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {}
|
||||
func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {}
|
||||
func (m *mockPooler) Len() int { return 0 }
|
||||
func (m *mockPooler) IdleLen() int { return 0 }
|
||||
func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} }
|
||||
func (m *mockPooler) Size() int { return 10 }
|
||||
func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {}
|
||||
func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {}
|
||||
func (m *mockPooler) Close() error { return nil }
|
||||
|
||||
259
internal/auth/streaming/pool_hook.go
Normal file
259
internal/auth/streaming/pool_hook.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ReAuthPoolHook is a pool hook that manages background re-authentication of connections
|
||||
// when credentials change via a streaming credentials provider.
|
||||
//
|
||||
// The hook uses a semaphore-based worker pool to limit concurrent re-authentication
|
||||
// operations and prevent pool exhaustion. When credentials change, connections are
|
||||
// marked for re-authentication and processed asynchronously in the background.
|
||||
//
|
||||
// The re-authentication process:
|
||||
// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth
|
||||
// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth)
|
||||
// 3. A worker goroutine acquires the connection (waits until it's not in use)
|
||||
// 4. Executes the re-auth function while holding the connection
|
||||
// 5. Releases the connection back to the pool
|
||||
//
|
||||
// The hook ensures that:
|
||||
// - Only one re-auth operation runs per connection at a time
|
||||
// - Connections are not used for commands during re-authentication
|
||||
// - Re-auth operations timeout if they can't acquire the connection
|
||||
// - Resources are properly cleaned up on connection removal
|
||||
type ReAuthPoolHook struct {
|
||||
// shouldReAuth maps connection ID to re-auth function
|
||||
// Connections in this map need re-authentication but haven't been scheduled yet
|
||||
shouldReAuth map[uint64]func(error)
|
||||
shouldReAuthLock sync.RWMutex
|
||||
|
||||
// workers is a semaphore channel limiting concurrent re-auth operations
|
||||
// Initialized with poolSize tokens to prevent pool exhaustion
|
||||
workers chan struct{}
|
||||
|
||||
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
|
||||
reAuthTimeout time.Duration
|
||||
|
||||
// scheduledReAuth maps connection ID to scheduled status
|
||||
// Connections in this map have a background worker attempting re-authentication
|
||||
scheduledReAuth map[uint64]bool
|
||||
scheduledLock sync.RWMutex
|
||||
|
||||
// manager is a back-reference for cleanup operations
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// NewReAuthPoolHook creates a new re-authentication pool hook.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size)
|
||||
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
|
||||
//
|
||||
// The poolSize parameter is used to initialize the worker semaphore, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||
workers := make(chan struct{}, poolSize)
|
||||
// Initialize the workers channel with tokens (semaphore pattern)
|
||||
for i := 0; i < poolSize; i++ {
|
||||
workers <- struct{}{}
|
||||
}
|
||||
|
||||
return &ReAuthPoolHook{
|
||||
shouldReAuth: make(map[uint64]func(error)),
|
||||
scheduledReAuth: make(map[uint64]bool),
|
||||
workers: workers,
|
||||
reAuthTimeout: reAuthTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkForReAuth marks a connection for re-authentication.
|
||||
//
|
||||
// This method is called when credentials change and a connection needs to be
|
||||
// re-authenticated. The actual re-authentication happens asynchronously when
|
||||
// the connection is returned to the pool (in OnPut).
|
||||
//
|
||||
// Parameters:
|
||||
// - connID: The connection ID to mark for re-authentication
|
||||
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
||||
r.shouldReAuthLock.Lock()
|
||||
defer r.shouldReAuthLock.Unlock()
|
||||
r.shouldReAuth[connID] = reAuthFn
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
//
|
||||
// This hook checks if the connection needs re-authentication or has a scheduled
|
||||
// re-auth operation. If so, it rejects the connection (returns accept=false),
|
||||
// causing the pool to try another connection.
|
||||
//
|
||||
// Returns:
|
||||
// - accept: false if connection needs re-auth, true otherwise
|
||||
// - err: always nil (errors are not used in this hook)
|
||||
//
|
||||
// Thread-safe: Called concurrently by multiple goroutines getting connections.
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
connID := conn.GetID()
|
||||
r.shouldReAuthLock.RLock()
|
||||
_, shouldReAuth := r.shouldReAuth[connID]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
// This connection was marked for reauth while in the pool,
|
||||
// reject the connection
|
||||
if shouldReAuth {
|
||||
// simply reject the connection, it will be re-authenticated in OnPut
|
||||
return false, nil
|
||||
}
|
||||
r.scheduledLock.RLock()
|
||||
_, hasScheduled := r.scheduledReAuth[connID]
|
||||
r.scheduledLock.RUnlock()
|
||||
// has scheduled reauth, reject the connection
|
||||
if hasScheduled {
|
||||
// simply reject the connection, it currently has a reauth scheduled
|
||||
// and the worker is waiting for slot to execute the reauth
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
//
|
||||
// This hook checks if the connection needs re-authentication. If so, it schedules
|
||||
// a background goroutine to perform the re-auth asynchronously. The goroutine:
|
||||
// 1. Waits for a worker slot (semaphore)
|
||||
// 2. Acquires the connection (waits until not in use)
|
||||
// 3. Executes the re-auth function
|
||||
// 4. Releases the connection and worker slot
|
||||
//
|
||||
// The connection is always pooled (not removed) since re-auth happens in background.
|
||||
//
|
||||
// Returns:
|
||||
// - shouldPool: always true (connection stays in pool during background re-auth)
|
||||
// - shouldRemove: always false
|
||||
// - err: always nil
|
||||
//
|
||||
// Thread-safe: Called concurrently by multiple goroutines returning connections.
|
||||
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
||||
if conn == nil {
|
||||
// noop
|
||||
return true, false, nil
|
||||
}
|
||||
connID := conn.GetID()
|
||||
// Check if reauth is needed and get the function with proper locking
|
||||
r.shouldReAuthLock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[connID]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
|
||||
if ok {
|
||||
// Acquire both locks to atomically move from shouldReAuth to scheduledReAuth
|
||||
// This prevents race conditions where OnGet might miss the transition
|
||||
r.shouldReAuthLock.Lock()
|
||||
r.scheduledLock.Lock()
|
||||
r.scheduledReAuth[connID] = true
|
||||
delete(r.shouldReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
go func() {
|
||||
<-r.workers
|
||||
// safety first
|
||||
if conn == nil || (conn != nil && conn.IsClosed()) {
|
||||
r.workers <- struct{}{}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
// once again - safety first
|
||||
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
|
||||
}
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.workers <- struct{}{}
|
||||
}()
|
||||
|
||||
var err error
|
||||
timeout := time.After(r.reAuthTimeout)
|
||||
|
||||
// Try to acquire the connection
|
||||
// We need to ensure the connection is both Usable and not Used
|
||||
// to prevent data races with concurrent operations
|
||||
const baseDelay = 10 * time.Microsecond
|
||||
acquired := false
|
||||
attempt := 0
|
||||
for !acquired {
|
||||
select {
|
||||
case <-timeout:
|
||||
// Timeout occurred, cannot acquire connection
|
||||
err = pool.ErrConnUnusableTimeout
|
||||
reAuthFn(err)
|
||||
return
|
||||
default:
|
||||
// Try to acquire: set Usable=false, then check Used
|
||||
if conn.CompareAndSwapUsable(true, false) {
|
||||
if !conn.IsUsed() {
|
||||
acquired = true
|
||||
} else {
|
||||
// Release Usable and retry with exponential backoff
|
||||
// todo(ndyakov): think of a better way to do this without the need
|
||||
// to release the connection, but just wait till it is not used
|
||||
conn.SetUsable(true)
|
||||
}
|
||||
}
|
||||
if !acquired {
|
||||
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safety first
|
||||
if !conn.IsClosed() {
|
||||
// Successfully acquired the connection, perform reauth
|
||||
reAuthFn(nil)
|
||||
}
|
||||
|
||||
// Release the connection
|
||||
conn.SetUsable(true)
|
||||
}()
|
||||
}
|
||||
|
||||
// the reauth will happen in background, as far as the pool is concerned:
|
||||
// pool the connection, don't remove it, no error
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// OnRemove is called when a connection is removed from the pool.
|
||||
//
|
||||
// This hook cleans up all state associated with the connection:
|
||||
// - Removes from shouldReAuth map (pending re-auth)
|
||||
// - Removes from scheduledReAuth map (active re-auth)
|
||||
// - Removes credentials listener from manager
|
||||
//
|
||||
// This prevents memory leaks and ensures that removed connections don't have
|
||||
// lingering re-auth operations or listeners.
|
||||
//
|
||||
// Thread-safe: Called when connections are removed due to errors, timeouts, or pool closure.
|
||||
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
||||
connID := conn.GetID()
|
||||
r.shouldReAuthLock.Lock()
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
delete(r.shouldReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
if r.manager != nil {
|
||||
r.manager.RemoveListener(connID)
|
||||
}
|
||||
}
|
||||
|
||||
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)
|
||||
@@ -3,7 +3,6 @@ package pool_test
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
@@ -124,20 +123,26 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
})
|
||||
|
||||
// Helper functions to extract buffer sizes using unsafe pointers
|
||||
// The struct layout must match pool.Conn exactly to avoid checkptr violations.
|
||||
// checkptr is Go's pointer safety checker, which ensures that unsafe pointer
|
||||
// conversions are valid. If the struct layouts do not match exactly, this can
|
||||
// cause runtime panics or incorrect memory access due to invalid pointer dereferencing.
|
||||
func getWriterBufSizeUnsafe(cn *pool.Conn) int {
|
||||
cnPtr := (*struct {
|
||||
usedAt int64
|
||||
netConn net.Conn
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
// ... other fields
|
||||
id uint64 // First field in pool.Conn
|
||||
usedAt int64 // Second field (atomic)
|
||||
netConnAtomic interface{} // atomic.Value (interface{} has same size)
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
// We only need fields up to bw, so we can stop here
|
||||
})(unsafe.Pointer(cn))
|
||||
|
||||
if cnPtr.bw == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
// bufio.Writer internal structure
|
||||
bwPtr := (*struct {
|
||||
err error
|
||||
buf []byte
|
||||
@@ -150,18 +155,20 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int {
|
||||
|
||||
func getReaderBufSizeUnsafe(cn *pool.Conn) int {
|
||||
cnPtr := (*struct {
|
||||
usedAt int64
|
||||
netConn net.Conn
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
// ... other fields
|
||||
id uint64 // First field in pool.Conn
|
||||
usedAt int64 // Second field (atomic)
|
||||
netConnAtomic interface{} // atomic.Value (interface{} has same size)
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
// We only need fields up to rd, so we can stop here
|
||||
})(unsafe.Pointer(cn))
|
||||
|
||||
if cnPtr.rd == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
// proto.Reader internal structure
|
||||
rdPtr := (*struct {
|
||||
rd *bufio.Reader
|
||||
})(unsafe.Pointer(cnPtr.rd))
|
||||
@@ -170,6 +177,7 @@ func getReaderBufSizeUnsafe(cn *pool.Conn) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// bufio.Reader internal structure
|
||||
bufReaderPtr := (*struct {
|
||||
buf []byte
|
||||
rd interface{}
|
||||
|
||||
@@ -40,6 +40,9 @@ func generateConnID() uint64 {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
// Connection identifier for unique tracking
|
||||
id uint64
|
||||
|
||||
usedAt int64 // atomic
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
@@ -54,7 +57,34 @@ type Conn struct {
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
Inited atomic.Bool
|
||||
// Design note:
|
||||
// Why have both Usable and Used?
|
||||
// _Usable_ is used to mark a connection as safe for use by clients, the connection can still
|
||||
// be in the pool but not Usable at the moment (e.g. handoff in progress).
|
||||
// _Used_ is used to mark a connection as used when a command is going to be processed on that connection.
|
||||
// this is going to happen once the connection is picked from the pool.
|
||||
//
|
||||
// If a background operation needs to use the connection, it will mark it as Not Usable and only use it when it
|
||||
// is not in use. That way, the connection won't be used to send multiple commands at the same time and
|
||||
// potentially corrupt the command stream.
|
||||
|
||||
// usable flag to mark connection as safe for use
|
||||
// It is false before initialization and after a handoff is marked
|
||||
// It will be false during other background operations like re-authentication
|
||||
usable atomic.Bool
|
||||
|
||||
// used flag to mark connection as used when a command is going to be
|
||||
// processed on that connection. This is used to prevent a race condition with
|
||||
// background operations that may execute commands, like re-authentication.
|
||||
used atomic.Bool
|
||||
|
||||
// Inited flag to mark connection as initialized, this is almost the same as usable
|
||||
// but it is used to make sure we don't initialize a network connection twice
|
||||
// On handoff, the network connection is replaced, but the Conn struct is reused
|
||||
// this flag will be set to false when the network connection is replaced and
|
||||
// set to true after the new network connection is initialized
|
||||
Inited atomic.Bool
|
||||
|
||||
pooled bool
|
||||
pubsub bool
|
||||
closed atomic.Bool
|
||||
@@ -75,11 +105,7 @@ type Conn struct {
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
// Connection identifier for unique tracking
|
||||
id uint64 // Unique numeric identifier for this connection
|
||||
|
||||
// Handoff state - using atomic operations for lock-free access
|
||||
usableAtomic atomic.Bool // Connection usability state
|
||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||
|
||||
// Atomic handoff state to prevent race conditions
|
||||
@@ -116,7 +142,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
// Initialize atomic state
|
||||
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
||||
cn.usable.Store(false) // false initially, set to true after initialization
|
||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||
|
||||
// Initialize handoff state atomically
|
||||
@@ -141,6 +167,73 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
}
|
||||
|
||||
// Usable
|
||||
|
||||
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
|
||||
//
|
||||
// This is used by background operations (handoff, re-auth) to acquire exclusive
|
||||
// access to a connection. The operation sets usable to false, preventing the pool
|
||||
// from returning the connection to clients.
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
||||
return cn.usable.CompareAndSwap(old, new)
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
//
|
||||
// A connection is "usable" when it's in a stable state and can be returned to clients.
|
||||
// It becomes unusable during:
|
||||
// - Initialization (before first use)
|
||||
// - Handoff operations (network connection replacement)
|
||||
// - Re-authentication (credential updates)
|
||||
// - Other background operations that need exclusive access
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
return cn.usable.Load()
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
//
|
||||
// This should be called to mark a connection as usable after initialization or
|
||||
// to release it after a background operation completes.
|
||||
//
|
||||
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
cn.usable.Store(usable)
|
||||
}
|
||||
|
||||
// Used
|
||||
|
||||
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
||||
//
|
||||
// This is the preferred method for acquiring a connection from the pool, as it
|
||||
// ensures that only one goroutine marks the connection as used.
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
||||
return cn.used.CompareAndSwap(old, new)
|
||||
}
|
||||
|
||||
// IsUsed returns true if the connection is currently in use (lock-free).
|
||||
//
|
||||
// A connection is "used" when it has been retrieved from the pool and is
|
||||
// actively processing a command. Background operations (like re-auth) should
|
||||
// wait until the connection is not used before executing commands.
|
||||
func (cn *Conn) IsUsed() bool {
|
||||
return cn.used.Load()
|
||||
}
|
||||
|
||||
// SetUsed sets the used flag for the connection (lock-free).
|
||||
//
|
||||
// This should be called when returning a connection to the pool (set to false)
|
||||
// or when a single-connection pool retrieves its connection (set to true).
|
||||
//
|
||||
// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to
|
||||
// avoid race conditions.
|
||||
func (cn *Conn) SetUsed(val bool) {
|
||||
cn.used.Store(val)
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
// This is the fast path for accessing netConn without mutex overhead.
|
||||
func (cn *Conn) getNetConn() net.Conn {
|
||||
@@ -158,18 +251,6 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// Lock-free helper methods for handoff state management
|
||||
|
||||
// isUsable returns true if the connection is safe to use (lock-free).
|
||||
func (cn *Conn) isUsable() bool {
|
||||
return cn.usableAtomic.Load()
|
||||
}
|
||||
|
||||
// setUsable sets the usable flag atomically (lock-free).
|
||||
func (cn *Conn) setUsable(usable bool) {
|
||||
cn.usableAtomic.Store(usable)
|
||||
}
|
||||
|
||||
// getHandoffState returns the current handoff state atomically (lock-free).
|
||||
func (cn *Conn) getHandoffState() *HandoffState {
|
||||
state := cn.handoffStateAtomic.Load()
|
||||
@@ -214,11 +295,6 @@ func (cn *Conn) incrementHandoffRetries(delta int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
return cn.isUsable()
|
||||
}
|
||||
|
||||
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
|
||||
func (cn *Conn) IsPooled() bool {
|
||||
return cn.pooled
|
||||
@@ -233,11 +309,6 @@ func (cn *Conn) IsInited() bool {
|
||||
return cn.Inited.Load()
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
cn.setUsable(usable)
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
@@ -455,9 +526,26 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
const maxRetries = 50
|
||||
const baseDelay = time.Microsecond
|
||||
|
||||
connAcquired := false
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
currentState := cn.getHandoffState()
|
||||
// If CAS failed, add exponential backoff to reduce contention
|
||||
// the delay will be 1, 2, 4... up to 512 microseconds
|
||||
// Moving this to the top of the loop to avoid "continue" without delay
|
||||
if attempt > 0 && attempt < maxRetries-1 {
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
// first we need to mark the connection as not usable
|
||||
// to prevent the pool from returning it to the caller
|
||||
if !connAcquired {
|
||||
if !cn.usable.CompareAndSwap(true, false) {
|
||||
continue
|
||||
}
|
||||
connAcquired = true
|
||||
}
|
||||
|
||||
currentState := cn.getHandoffState()
|
||||
// Check if marked for handoff
|
||||
if !currentState.ShouldHandoff {
|
||||
return errors.New("connection was not marked for handoff")
|
||||
@@ -472,16 +560,12 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
|
||||
// Atomic compare-and-swap to update state
|
||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
cn.setUsable(false)
|
||||
// queue the handoff for processing
|
||||
// the connection is now "acquired" (marked as not usable) by the handoff
|
||||
// and it won't be returned to any other callers until the handoff is complete
|
||||
return nil
|
||||
}
|
||||
|
||||
// If CAS failed, add exponential backoff to reduce contention
|
||||
// the delay will be 1, 2, 4... up to 512 microseconds
|
||||
if attempt < maxRetries-1 {
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
|
||||
@@ -527,7 +611,8 @@ func (cn *Conn) ClearHandoffState() {
|
||||
// Atomically set clean state
|
||||
cn.setHandoffState(cleanState)
|
||||
cn.setHandoffRetries(0)
|
||||
cn.setUsable(true) // Connection is safe to use again after handoff completes
|
||||
// Clearing handoff state also means the connection is usable again
|
||||
cn.SetUsable(true)
|
||||
}
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
|
||||
@@ -9,13 +9,27 @@ import (
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// The accept flag can be used to prevent the connection from being used.
|
||||
// On Accept = false the connection is rejected and returned to the pool.
|
||||
// The error can be used to prevent the connection from being used and returned to the pool.
|
||||
// On Errors, the connection is removed from the pool.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||
|
||||
// OnRemove is called when a connection is removed from the pool.
|
||||
// This happens when:
|
||||
// - Connection fails health check
|
||||
// - Connection exceeds max lifetime
|
||||
// - Pool is being closed
|
||||
// - Connection encounters an error
|
||||
// Implementations should clean up any per-connection state.
|
||||
// The reason parameter indicates why the connection was removed.
|
||||
OnRemove(ctx context.Context, conn *Conn, reason error)
|
||||
}
|
||||
|
||||
// PoolHookManager manages multiple pool hooks.
|
||||
@@ -56,16 +70,21 @@ func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
||||
return err
|
||||
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
@@ -96,6 +115,15 @@ func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shoul
|
||||
return shouldPool, false, nil
|
||||
}
|
||||
|
||||
// ProcessOnRemove calls all OnRemove hooks in order.
|
||||
func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
for _, hook := range phm.hooks {
|
||||
hook.OnRemove(ctx, conn, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHookCount returns the number of registered hooks (for testing).
|
||||
func (phm *PoolHookManager) GetHookCount() int {
|
||||
phm.hooksMu.RLock()
|
||||
|
||||
@@ -10,17 +10,19 @@ import (
|
||||
|
||||
// TestHook for testing hook functionality
|
||||
type TestHook struct {
|
||||
OnGetCalled int
|
||||
OnPutCalled int
|
||||
GetError error
|
||||
PutError error
|
||||
ShouldPool bool
|
||||
ShouldRemove bool
|
||||
OnGetCalled int
|
||||
OnPutCalled int
|
||||
OnRemoveCalled int
|
||||
GetError error
|
||||
PutError error
|
||||
ShouldPool bool
|
||||
ShouldRemove bool
|
||||
ShouldAccept bool
|
||||
}
|
||||
|
||||
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) (bool, error) {
|
||||
th.OnGetCalled++
|
||||
return th.GetError
|
||||
return th.ShouldAccept, th.GetError
|
||||
}
|
||||
|
||||
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
@@ -28,6 +30,10 @@ func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, sho
|
||||
return th.ShouldPool, th.ShouldRemove, th.PutError
|
||||
}
|
||||
|
||||
func (th *TestHook) OnRemove(ctx context.Context, conn *Conn, reason error) {
|
||||
th.OnRemoveCalled++
|
||||
}
|
||||
|
||||
func TestPoolHookManager(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
@@ -37,8 +43,8 @@ func TestPoolHookManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add hooks
|
||||
hook1 := &TestHook{ShouldPool: true}
|
||||
hook2 := &TestHook{ShouldPool: true}
|
||||
hook1 := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
hook2 := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
|
||||
manager.AddHook(hook1)
|
||||
manager.AddHook(hook2)
|
||||
@@ -51,10 +57,13 @@ func TestPoolHookManager(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn := &Conn{} // Mock connection
|
||||
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
accept, err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||
}
|
||||
if !accept {
|
||||
t.Error("Expected accept to be true")
|
||||
}
|
||||
|
||||
if hook1.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||
@@ -99,11 +108,12 @@ func TestHookErrorHandling(t *testing.T) {
|
||||
|
||||
// Hook that returns error on Get
|
||||
errorHook := &TestHook{
|
||||
GetError: errors.New("test error"),
|
||||
ShouldPool: true,
|
||||
GetError: errors.New("test error"),
|
||||
ShouldPool: true,
|
||||
ShouldAccept: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
normalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
|
||||
manager.AddHook(errorHook)
|
||||
manager.AddHook(normalHook)
|
||||
@@ -112,10 +122,13 @@ func TestHookErrorHandling(t *testing.T) {
|
||||
conn := &Conn{}
|
||||
|
||||
// Test that error stops processing
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
accept, err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ProcessOnGet")
|
||||
}
|
||||
if accept {
|
||||
t.Error("Expected accept to be false")
|
||||
}
|
||||
|
||||
if errorHook.OnGetCalled != 1 {
|
||||
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||
@@ -134,9 +147,10 @@ func TestHookShouldRemove(t *testing.T) {
|
||||
removeHook := &TestHook{
|
||||
ShouldPool: false,
|
||||
ShouldRemove: true,
|
||||
ShouldAccept: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
normalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
|
||||
manager.AddHook(removeHook)
|
||||
manager.AddHook(normalHook)
|
||||
@@ -170,7 +184,7 @@ func TestHookShouldRemove(t *testing.T) {
|
||||
func TestPoolWithHooks(t *testing.T) {
|
||||
// Create a pool with hooks
|
||||
hookManager := NewPoolHookManager()
|
||||
testHook := &TestHook{ShouldPool: true}
|
||||
testHook := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
hookManager.AddHook(testHook)
|
||||
|
||||
opt := &Options{
|
||||
@@ -197,7 +211,7 @@ func TestPoolWithHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test adding hook to pool
|
||||
additionalHook := &TestHook{ShouldPool: true}
|
||||
additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
pool.AddPoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 2 {
|
||||
|
||||
@@ -24,6 +24,9 @@ var (
|
||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||
|
||||
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
|
||||
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
|
||||
|
||||
// popAttempts is the maximum number of attempts to find a usable connection
|
||||
// when popping from the idle connection pool. This handles cases where connections
|
||||
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
||||
@@ -78,6 +81,10 @@ type Pooler interface {
|
||||
IdleLen() int
|
||||
Stats() *Stats
|
||||
|
||||
// Size returns the maximum pool size (capacity).
|
||||
// This is used by the streaming credentials manager to size the re-auth worker pool.
|
||||
Size() int
|
||||
|
||||
AddPoolHook(hook PoolHook)
|
||||
RemovePoolHook(hook PoolHook)
|
||||
|
||||
@@ -236,6 +243,7 @@ func (p *ConnPool) addIdleConn() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
@@ -277,6 +285,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
@@ -428,6 +437,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
|
||||
now := time.Now()
|
||||
attempts := 0
|
||||
|
||||
// Get hooks manager once for this getConn call for performance.
|
||||
// Note: Hooks added/removed during this call won't be reflected.
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
||||
@@ -454,17 +470,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
||||
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
// Failed to process connection, discard it
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
if !acceptConn {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||
p.Put(ctx, cn)
|
||||
cn = nil
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint32(&p.stats.Hits, 1)
|
||||
@@ -480,14 +498,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
||||
acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true)
|
||||
// both errors and accept=false mean a hook rejected the connection
|
||||
// this should not happen with a new connection, but we handle it gracefully
|
||||
if err != nil || !acceptConn {
|
||||
// Failed to process connection, discard it
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err)
|
||||
_ = p.CloseConn(newcn)
|
||||
return nil, err
|
||||
}
|
||||
@@ -567,9 +584,12 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
}
|
||||
attempts++
|
||||
|
||||
if cn.IsUsable() {
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
if cn.CompareAndSwapUsed(false, true) {
|
||||
if cn.IsUsable() {
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
cn.SetUsed(false)
|
||||
}
|
||||
|
||||
// Connection is not usable, put it back in the pool
|
||||
@@ -664,6 +684,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
shouldCloseConn = true
|
||||
}
|
||||
|
||||
// if the connection is not going to be closed, mark it as not used
|
||||
if !shouldCloseConn {
|
||||
cn.SetUsed(false)
|
||||
}
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
if shouldCloseConn {
|
||||
@@ -671,7 +696,15 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
hookManager.ProcessOnRemove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
p.removeConnWithLock(cn)
|
||||
|
||||
p.freeTurn()
|
||||
@@ -733,6 +766,14 @@ func (p *ConnPool) IdleLen() int {
|
||||
return int(n)
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size (capacity).
|
||||
//
|
||||
// This is used by the streaming credentials manager to size the re-auth worker pool,
|
||||
// ensuring that re-auth operations don't exhaust the connection pool.
|
||||
func (p *ConnPool) Size() int {
|
||||
return int(p.cfg.PoolSize)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Stats() *Stats {
|
||||
return &Stats{
|
||||
Hits: atomic.LoadUint32(&p.stats.Hits),
|
||||
|
||||
@@ -2,8 +2,12 @@ package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleConnPool is a pool that always returns the same connection.
|
||||
// Note: This pool is not thread-safe.
|
||||
// It is intended to be used by clients that need a single connection.
|
||||
type SingleConnPool struct {
|
||||
pool Pooler
|
||||
cn *Conn
|
||||
@@ -12,6 +16,12 @@ type SingleConnPool struct {
|
||||
|
||||
var _ Pooler = (*SingleConnPool)(nil)
|
||||
|
||||
// NewSingleConnPool creates a new single connection pool.
|
||||
// The pool will always return the same connection.
|
||||
// The pool will not:
|
||||
// - Close the connection
|
||||
// - Reconnect the connection
|
||||
// - Track the connection in any way
|
||||
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
|
||||
return &SingleConnPool{
|
||||
pool: pool,
|
||||
@@ -27,16 +37,30 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
|
||||
return p.pool.CloseConn(cn)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
|
||||
if p.stickyErr != nil {
|
||||
return nil, p.stickyErr
|
||||
}
|
||||
if p.cn == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
p.cn.SetUsed(true)
|
||||
p.cn.SetUsedAt(time.Now())
|
||||
return p.cn, nil
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
|
||||
func (p *SingleConnPool) Put(_ context.Context, cn *Conn) {
|
||||
if p.cn == nil {
|
||||
return
|
||||
}
|
||||
if p.cn != cn {
|
||||
return
|
||||
}
|
||||
p.cn.SetUsed(false)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
cn.SetUsed(false)
|
||||
p.cn = nil
|
||||
p.stickyErr = reason
|
||||
}
|
||||
@@ -55,10 +79,13 @@ func (p *SingleConnPool) IdleLen() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size, which is always 1 for SingleConnPool.
|
||||
func (p *SingleConnPool) Size() int { return 1 }
|
||||
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
|
||||
func (p *SingleConnPool) AddPoolHook(_ PoolHook) {}
|
||||
|
||||
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {}
|
||||
|
||||
@@ -196,6 +196,9 @@ func (p *StickyConnPool) IdleLen() int {
|
||||
return len(p.ch)
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size, which is always 1 for StickyConnPool.
|
||||
func (p *StickyConnPool) Size() int { return 1 }
|
||||
|
||||
func (p *StickyConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
@@ -497,9 +497,14 @@ func TestDialerRetryConfiguration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should have attempted 5 times (default DialerRetries = 5)
|
||||
// Note: There may be one additional attempt from tryDial() goroutine
|
||||
// which is launched when dialErrorsNum reaches PoolSize
|
||||
finalAttempts := atomic.LoadInt64(&attempts)
|
||||
if finalAttempts != 5 {
|
||||
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
|
||||
if finalAttempts < 5 {
|
||||
t.Errorf("Expected at least 5 dial attempts (default), got %d", finalAttempts)
|
||||
}
|
||||
if finalAttempts > 6 {
|
||||
t.Errorf("Expected around 5 dial attempts, got %d (too many)", finalAttempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ type PubSubPool struct {
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
// PubSubPool implements a pool for PubSub connections.
|
||||
// It intentionally does not implement the Pooler interface
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
opt: opt,
|
||||
|
||||
@@ -378,8 +378,12 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
|
||||
}
|
||||
|
||||
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
|
||||
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
|
||||
|
||||
func (hwm *handoffWorkerManager) performHandoffInternal(
|
||||
ctx context.Context,
|
||||
conn *pool.Conn,
|
||||
newEndpoint string,
|
||||
connID uint64,
|
||||
) (shouldRetry bool, err error) {
|
||||
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
|
||||
maxRetries := 3 // Default fallback
|
||||
@@ -438,9 +442,14 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con
|
||||
}
|
||||
}()
|
||||
|
||||
// Clear handoff state will:
|
||||
// - set the connection as usable again
|
||||
// - clear the handoff state (shouldHandoff, endpoint, seqID)
|
||||
// - reset the handoff retries to 0
|
||||
conn.ClearHandoffState()
|
||||
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
|
||||
|
||||
// successfully completed the handoff, no retry needed and no error
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -472,7 +481,10 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque
|
||||
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
|
||||
}
|
||||
} else {
|
||||
conn.Close()
|
||||
err := conn.Close() // Close the connection if no pool provided
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err)
|
||||
}
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
|
||||
}
|
||||
|
||||
@@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
|
||||
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||
// in a handoff state at the moment.
|
||||
|
||||
// Check if connection is usable (not in a handoff state)
|
||||
// Should not happen since the pool will not return a connection that is not usable.
|
||||
if !conn.IsUsable() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||
if conn.ShouldHandoff() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
@@ -174,6 +174,10 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) {
|
||||
// Not used
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||
return ph.workerManager.shutdownWorkers(ctx)
|
||||
|
||||
@@ -92,6 +92,10 @@ func (mp *mockPool) Stats() *pool.Stats {
|
||||
return &pool.Stats{}
|
||||
}
|
||||
|
||||
func (mp *mockPool) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
@@ -356,10 +360,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||
}
|
||||
if !acceptCon {
|
||||
t.Error("Connection should be accepted for normal connection")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||
@@ -381,10 +388,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
if acceptCon {
|
||||
t.Error("Connection should not be accepted when marked for handoff")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
@@ -412,10 +422,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
|
||||
// Test OnGet with pending handoff
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||
}
|
||||
if acceptCon {
|
||||
t.Error("Should not accept connection with pending handoff")
|
||||
}
|
||||
|
||||
// Test removing from pending map and clearing handoff state
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
@@ -428,10 +441,13 @@ func TestConnectionHook(t *testing.T) {
|
||||
conn.SetUsable(true) // Make connection usable again
|
||||
|
||||
// Test OnGet without pending handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
acceptCon, err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||
}
|
||||
if !acceptCon {
|
||||
t.Error("Should accept connection without pending handoff")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||
@@ -624,11 +640,15 @@ func TestConnectionHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// OnGet should succeed for usable connection
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
acceptConn, err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
t.Error("Connection should be accepted when usable")
|
||||
}
|
||||
|
||||
// Mark connection for handoff
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
@@ -648,13 +668,17 @@ func TestConnectionHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// OnGet should fail for connection marked for handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("OnGet should fail for connection marked for handoff")
|
||||
}
|
||||
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
if acceptConn {
|
||||
t.Error("Connection should not be accepted when marked for handoff")
|
||||
}
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
@@ -674,11 +698,15 @@ func TestConnectionHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// OnGet should succeed again
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
acceptConn, err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
t.Error("Connection should be accepted after handoff completion")
|
||||
}
|
||||
|
||||
t.Logf("Usable flag behavior test completed successfully")
|
||||
})
|
||||
|
||||
|
||||
@@ -465,7 +465,6 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
||||
}
|
||||
|
||||
// Don't hold the lock to allow subscriptions and pings.
|
||||
|
||||
cn, err := c.connWithLock(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
60
redis.go
60
redis.go
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/auth/streaming"
|
||||
"github.com/redis/go-redis/v9/internal/hscan"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
@@ -224,6 +225,9 @@ type baseClient struct {
|
||||
// Maintenance notifications manager
|
||||
maintNotificationsManager *maintnotifications.Manager
|
||||
maintNotificationsManagerLock sync.RWMutex
|
||||
|
||||
// streamingCredentialsManager is used to manage streaming credentials
|
||||
streamingCredentialsManager *streaming.Manager
|
||||
}
|
||||
|
||||
func (c *baseClient) clone() *baseClient {
|
||||
@@ -232,11 +236,12 @@ func (c *baseClient) clone() *baseClient {
|
||||
c.maintNotificationsManagerLock.RUnlock()
|
||||
|
||||
clone := &baseClient{
|
||||
opt: c.opt,
|
||||
connPool: c.connPool,
|
||||
onClose: c.onClose,
|
||||
pushProcessor: c.pushProcessor,
|
||||
maintNotificationsManager: maintNotificationsManager,
|
||||
opt: c.opt,
|
||||
connPool: c.connPool,
|
||||
onClose: c.onClose,
|
||||
pushProcessor: c.pushProcessor,
|
||||
maintNotificationsManager: maintNotificationsManager,
|
||||
streamingCredentialsManager: c.streamingCredentialsManager,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -296,32 +301,30 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||
return auth.NewReAuthCredentialsListener(
|
||||
c.reAuthConnection(poolCn),
|
||||
c.onAuthenticationErr(poolCn),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||
return func(credentials auth.Credentials) error {
|
||||
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||
var err error
|
||||
username, password := credentials.BasicAuth()
|
||||
|
||||
// Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter
|
||||
ctx := context.Background()
|
||||
|
||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||
// hooksMixin are intentionally empty here
|
||||
cn := newConn(c.opt, connPool, nil)
|
||||
|
||||
// Pass hooks so that reauth commands are recorded/traced
|
||||
cn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
if username != "" {
|
||||
err = cn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = cn.Auth(ctx, password).Err()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||
return func(err error) {
|
||||
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
|
||||
return func(poolCn *pool.Conn, err error) {
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
// Close the connection to force a reconnection.
|
||||
@@ -372,13 +375,24 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credListener, err := c.streamingCredentialsManager.Listener(
|
||||
cn,
|
||||
c.reAuthConnection(),
|
||||
c.onAuthenticationErr(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create credentials listener: %w", err)
|
||||
}
|
||||
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||
Subscribe(credListener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
}
|
||||
|
||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
||||
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
@@ -496,7 +510,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
}
|
||||
}
|
||||
|
||||
// mark the connection as usable and inited
|
||||
// once returned to the pool as idle, this connection can be used by other clients
|
||||
cn.SetUsable(true)
|
||||
cn.SetUsed(false)
|
||||
cn.Inited.Store(true)
|
||||
|
||||
// Set the connection initialization function for potential reconnections
|
||||
@@ -952,6 +969,11 @@ func NewClient(opt *Options) *Client {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
if opt.StreamingCredentialsProvider != nil {
|
||||
c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout)
|
||||
c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook())
|
||||
}
|
||||
|
||||
// Initialize maintnotifications first if enabled and protocol is RESP3
|
||||
if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 {
|
||||
err := c.enableMaintNotificationsUpgrades()
|
||||
|
||||
@@ -854,24 +854,34 @@ var _ = Describe("Credentials Provider Priority", func() {
|
||||
credentials: initialCreds,
|
||||
updates: updatesChan,
|
||||
},
|
||||
PoolSize: 1, // Force single connection to ensure reauth is tested
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
|
||||
|
||||
// Update credentials
|
||||
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
|
||||
|
||||
// Wait for reauth to complete and verify updated credentials are used
|
||||
// We need to keep trying Ping until we see the updated AUTH command
|
||||
// because the reauth happens asynchronously
|
||||
Eventually(func() bool {
|
||||
// wrongpass
|
||||
_ = client.Ping(context.Background()).Err()
|
||||
return recorder.Contains("AUTH updated_user")
|
||||
}, "1s", "50ms").Should(BeTrue())
|
||||
|
||||
close(updatesChan)
|
||||
})
|
||||
})
|
||||
|
||||
type mockStreamingProvider struct {
|
||||
mu sync.RWMutex
|
||||
credentials auth.Credentials
|
||||
err error
|
||||
updates chan auth.Credentials
|
||||
@@ -882,21 +892,50 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
if listener == nil {
|
||||
return nil, nil, errors.New("listener cannot be nil")
|
||||
}
|
||||
|
||||
// Create a done channel to stop the goroutine
|
||||
done := make(chan struct{})
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
m.credentials = creds
|
||||
listener.OnNext(creds)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// this is just a mock:
|
||||
// allow panics to be caught without crashing
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case creds, ok := <-m.updates:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.credentials = creds
|
||||
m.mu.Unlock()
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return m.credentials, func() (err error) {
|
||||
m.mu.RLock()
|
||||
currentCreds := m.credentials
|
||||
m.mu.RUnlock()
|
||||
|
||||
return currentCreds, func() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// this is just a mock:
|
||||
// allow multiple closes from multiple listeners
|
||||
}
|
||||
}()
|
||||
close(done)
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -410,7 +410,9 @@ var _ = Describe("SentinelAclAuth", func() {
|
||||
})
|
||||
})
|
||||
|
||||
func TestParseFailoverURL(t *testing.T) {
|
||||
// renaming from TestParseFailoverURL to TestParseSentinelURL
|
||||
// to be easier to find Failed tests in the test output
|
||||
func TestParseSentinelURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
url string
|
||||
o *redis.FailoverOptions
|
||||
|
||||
Reference in New Issue
Block a user