mirror of
https://github.com/redis/go-redis.git
synced 2025-11-26 06:23:09 +03:00
fix(conn): conn to have state machine (#3559)
* wip * wip, used and unusable states * polish state machine * correct handling OnPut * better errors for tests, hook should work now * fix linter * improve reauth state management. fix tests * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * better timeouts * empty endpoint handoff case * fix handoff state when queued for handoff * try to detect the deadlock * try to detect the deadlock x2 * delete should be called * improve tests * fix mark on uninitialized connection * Update internal/pool/conn_state_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn_state_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/pool.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn_state.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update internal/pool/conn.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix error from copilot * address copilot comment * fix(pool): pool performance (#3565) * perf(pool): replace hookManager RWMutex with atomic.Pointer and add predefined state slices - Replace hookManager RWMutex with atomic.Pointer for lock-free reads in hot paths - Add predefined state slices to avoid allocations (validFromInUse, validFromCreatedOrIdle, etc.) - Add Clone() method to PoolHookManager for atomic updates - Update AddPoolHook/RemovePoolHook to use copy-on-write pattern - Update all hookManager access points to use atomic Load() Performance improvements: - Eliminates RWMutex contention in Get/Put/Remove hot paths - Reduces allocations by reusing predefined state slices - Lock-free reads allow better CPU cache utilization * perf(pool): eliminate mutex overhead in state machine hot path The state machine was calling notifyWaiters() on EVERY Get/Put operation, which acquired a mutex even when no waiters were present (the common case). Fix: Use atomic waiterCount to check for waiters BEFORE acquiring mutex. This eliminates mutex contention in the hot path (Get/Put operations). Implementation: - Added atomic.Int32 waiterCount field to ConnStateMachine - Increment when adding waiter, decrement when removing - Check waiterCount atomically before acquiring mutex in notifyWaiters() Performance impact: - Before: mutex lock/unlock on every Get/Put (even with no waiters) - After: lock-free atomic check, only acquire mutex if waiters exist - Expected improvement: ~30-50% for Get/Put operations * perf(pool): use predefined state slices to eliminate allocations in hot path The pool was creating new slice literals on EVERY Get/Put operation: - popIdle(): []ConnState{StateCreated, StateIdle} - putConn(): []ConnState{StateInUse} - CompareAndSwapUsed(): []ConnState{StateIdle} and []ConnState{StateInUse} - MarkUnusableForHandoff(): []ConnState{StateInUse, StateIdle, StateCreated} These allocations were happening millions of times per second in the hot path. Fix: Use predefined global slices defined in conn_state.go: - validFromInUse - validFromCreatedOrIdle - validFromCreatedInUseOrIdle Performance impact: - Before: 4 slice allocations per Get/Put cycle - After: 0 allocations (use predefined slices) - Expected improvement: ~30-40% reduction in allocations and GC pressure * perf(pool): optimize TryTransition to reduce atomic operations Further optimize the hot path by: 1. Remove redundant GetState() call in the loop 2. Only check waiterCount after successful CAS (not before loop) 3. Inline the waiterCount check to avoid notifyWaiters() call overhead This reduces atomic operations from 4-5 per Get/Put to 2-3: - Before: GetState() + CAS + waiterCount.Load() + notifyWaiters mutex check - After: CAS + waiterCount.Load() (only if CAS succeeds) Performance impact: - Eliminates 1-2 atomic operations per Get/Put - Expected improvement: ~10-15% for Get/Put operations * perf(pool): add fast path for Get/Put to match master performance Introduced TryTransitionFast() for the hot path (Get/Put operations): - Single CAS operation (same as master's atomic bool) - No waiter notification overhead - No loop through valid states - No error allocation Hot path flow: 1. popIdle(): Try IDLE → IN_USE (fast), fallback to CREATED → IN_USE 2. putConn(): Try IN_USE → IDLE (fast) This matches master's performance while preserving state machine for: - Background operations (handoff/reauth use UNUSABLE state) - State validation (TryTransition still available) - Waiter notification (AwaitAndTransition for blocking) Performance comparison per Get/Put cycle: - Master: 2 atomic CAS operations - State machine (before): 5 atomic operations (2.5x slower) - State machine (after): 2 atomic CAS operations (same as master!) Expected improvement: Restore to baseline ~11,373 ops/sec * combine cas * fix linter * try faster approach * fast semaphore * better inlining for hot path * fix linter issues * use new semaphore in auth as well * linter should be happy now * add comments * Update internal/pool/conn_state.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * address comment * slight reordering * try to cache time if for non-critical calculation * fix wrong benchmark * add concurrent test * fix benchmark report * add additional expect to check output * comment and variable rename --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * initConn sets IDLE state - Handle unexpected conn state changes * fix precision of time cache and usedAt * allow e2e tests to run longer * Fix broken initialization of idle connections * optimize push notif * 100ms -> 50ms * use correct timer for last health check * verify pass auth on conn creation * fix assertion * fix unsafe test * fix benchmark test * improve remove conn * re doesn't support requirepass * wait more in e2e test * flaky test * add missed method in interface * fix test assertions * silence logs and faster hooks manager * address linter comment * fix flaky test * use read instad of control * use pool size for semsize * CAS instead of reading the state * preallocate errors and states * preallocate state slices * fix flaky test * fix fast semaphore that could have been starved * try to fix the semaphore * should properly notify the waiters - this way a waiter that timesout at the same time a releaser is releasing, won't throw token. the releaser will fail to notify and will pick another waiter. this hybrid approach should be faster than channels and maintains FIFO * waiter may double-release (if closed/times out) * priority of operations * use simple approach of fifo waiters * use simple channel based semaphores * address linter and tests * remove unused benchs * change log message * address pr comments * address pr comments * fix data race --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -85,14 +85,14 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
rmvErr := errors.New("Bench test remove")
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
|
||||
connPool.Remove(ctx, cn, rmvErr)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ package pool_test
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
@@ -133,9 +134,10 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
// cause runtime panics or incorrect memory access due to invalid pointer dereferencing.
|
||||
func getWriterBufSizeUnsafe(cn *pool.Conn) int {
|
||||
cnPtr := (*struct {
|
||||
id uint64 // First field in pool.Conn
|
||||
usedAt int64 // Second field (atomic)
|
||||
netConnAtomic interface{} // atomic.Value (interface{} has same size)
|
||||
id uint64 // First field in pool.Conn
|
||||
usedAt atomic.Int64 // Second field (atomic)
|
||||
lastPutAt atomic.Int64 // Third field (atomic)
|
||||
netConnAtomic interface{} // atomic.Value (interface{} has same size)
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
@@ -159,9 +161,10 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int {
|
||||
|
||||
func getReaderBufSizeUnsafe(cn *pool.Conn) int {
|
||||
cnPtr := (*struct {
|
||||
id uint64 // First field in pool.Conn
|
||||
usedAt int64 // Second field (atomic)
|
||||
netConnAtomic interface{} // atomic.Value (interface{} has same size)
|
||||
id uint64 // First field in pool.Conn
|
||||
usedAt atomic.Int64 // Second field (atomic)
|
||||
lastPutAt atomic.Int64 // Third field (atomic)
|
||||
netConnAtomic interface{} // atomic.Value (interface{} has same size)
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package pool implements the pool management
|
||||
package pool
|
||||
|
||||
import (
|
||||
@@ -17,6 +18,30 @@ import (
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Preallocated errors for hot paths to avoid allocations
|
||||
var (
|
||||
errAlreadyMarkedForHandoff = errors.New("connection is already marked for handoff")
|
||||
errNotMarkedForHandoff = errors.New("connection was not marked for handoff")
|
||||
errHandoffStateChanged = errors.New("handoff state changed during marking")
|
||||
errConnectionNotAvailable = errors.New("redis: connection not available")
|
||||
errConnNotAvailableForWrite = errors.New("redis: connection not available for write operation")
|
||||
)
|
||||
|
||||
// getCachedTimeNs returns the current time in nanoseconds from the global cache.
|
||||
// This is updated every 50ms by a background goroutine, avoiding expensive syscalls.
|
||||
// Max staleness: 50ms.
|
||||
func getCachedTimeNs() int64 {
|
||||
return globalTimeCache.nowNs.Load()
|
||||
}
|
||||
|
||||
// GetCachedTimeNs returns the current time in nanoseconds from the global cache.
|
||||
// This is updated every 50ms by a background goroutine, avoiding expensive syscalls.
|
||||
// Max staleness: 50ms.
|
||||
// Exported for use by other packages that need fast time access.
|
||||
func GetCachedTimeNs() int64 {
|
||||
return getCachedTimeNs()
|
||||
}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
@@ -43,7 +68,8 @@ type Conn struct {
|
||||
// Connection identifier for unique tracking
|
||||
id uint64
|
||||
|
||||
usedAt int64 // atomic
|
||||
usedAt atomic.Int64
|
||||
lastPutAt atomic.Int64
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||
@@ -57,33 +83,20 @@ type Conn struct {
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
// Design note:
|
||||
// Why have both Usable and Used?
|
||||
// _Usable_ is used to mark a connection as safe for use by clients, the connection can still
|
||||
// be in the pool but not Usable at the moment (e.g. handoff in progress).
|
||||
// _Used_ is used to mark a connection as used when a command is going to be processed on that connection.
|
||||
// this is going to happen once the connection is picked from the pool.
|
||||
//
|
||||
// If a background operation needs to use the connection, it will mark it as Not Usable and only use it when it
|
||||
// is not in use. That way, the connection won't be used to send multiple commands at the same time and
|
||||
// potentially corrupt the command stream.
|
||||
// State machine for connection state management
|
||||
// Replaces: usable, Inited, used
|
||||
// Provides thread-safe state transitions with FIFO waiting queue
|
||||
// States: CREATED → INITIALIZING → IDLE ⇄ IN_USE
|
||||
// ↓
|
||||
// UNUSABLE (handoff/reauth)
|
||||
// ↓
|
||||
// IDLE/CLOSED
|
||||
stateMachine *ConnStateMachine
|
||||
|
||||
// usable flag to mark connection as safe for use
|
||||
// It is false before initialization and after a handoff is marked
|
||||
// It will be false during other background operations like re-authentication
|
||||
usable atomic.Bool
|
||||
|
||||
// used flag to mark connection as used when a command is going to be
|
||||
// processed on that connection. This is used to prevent a race condition with
|
||||
// background operations that may execute commands, like re-authentication.
|
||||
used atomic.Bool
|
||||
|
||||
// Inited flag to mark connection as initialized, this is almost the same as usable
|
||||
// but it is used to make sure we don't initialize a network connection twice
|
||||
// On handoff, the network connection is replaced, but the Conn struct is reused
|
||||
// this flag will be set to false when the network connection is replaced and
|
||||
// set to true after the new network connection is initialized
|
||||
Inited atomic.Bool
|
||||
// Handoff metadata - managed separately from state machine
|
||||
// These are atomic for lock-free access during handoff operations
|
||||
handoffStateAtomic atomic.Value // stores *HandoffState
|
||||
handoffRetriesAtomic atomic.Uint32 // retry counter
|
||||
|
||||
pooled bool
|
||||
pubsub bool
|
||||
@@ -92,6 +105,7 @@ type Conn struct {
|
||||
expiresAt time.Time
|
||||
|
||||
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
|
||||
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
@@ -105,13 +119,6 @@ type Conn struct {
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
// Handoff state - using atomic operations for lock-free access
|
||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||
|
||||
// Atomic handoff state to prevent race conditions
|
||||
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
|
||||
handoffStateAtomic atomic.Value // stores *HandoffState
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
|
||||
@@ -120,9 +127,11 @@ func NewConn(netConn net.Conn) *Conn {
|
||||
}
|
||||
|
||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||
now := time.Now()
|
||||
cn := &Conn{
|
||||
createdAt: time.Now(),
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
createdAt: now,
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
stateMachine: NewConnStateMachine(),
|
||||
}
|
||||
|
||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||
@@ -141,10 +150,8 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
// Store netConn atomically for lock-free access using wrapper
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
// Initialize atomic state
|
||||
cn.usable.Store(false) // false initially, set to true after initialization
|
||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(now)
|
||||
// Initialize handoff state atomically
|
||||
initialHandoffState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
@@ -152,22 +159,32 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
SeqID: 0,
|
||||
}
|
||||
cn.handoffStateAtomic.Store(initialHandoffState)
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(time.Now())
|
||||
return cn
|
||||
}
|
||||
|
||||
func (cn *Conn) UsedAt() time.Time {
|
||||
unix := atomic.LoadInt64(&cn.usedAt)
|
||||
return time.Unix(unix, 0)
|
||||
return time.Unix(0, cn.usedAt.Load())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
cn.usedAt.Store(tm.UnixNano())
|
||||
}
|
||||
|
||||
// Usable
|
||||
func (cn *Conn) UsedAtNs() int64 {
|
||||
return cn.usedAt.Load()
|
||||
}
|
||||
func (cn *Conn) SetUsedAtNs(ns int64) {
|
||||
cn.usedAt.Store(ns)
|
||||
}
|
||||
|
||||
func (cn *Conn) LastPutAtNs() int64 {
|
||||
return cn.lastPutAt.Load()
|
||||
}
|
||||
func (cn *Conn) SetLastPutAtNs(ns int64) {
|
||||
cn.lastPutAt.Store(ns)
|
||||
}
|
||||
|
||||
// Backward-compatible wrapper methods for state machine
|
||||
// These maintain the existing API while using the new state machine internally
|
||||
|
||||
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
|
||||
//
|
||||
@@ -176,51 +193,135 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
// from returning the connection to clients.
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
//
|
||||
// Implementation note: This is a compatibility wrapper around the state machine.
|
||||
// It checks if the current state is "usable" (IDLE or IN_USE) and transitions accordingly.
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
||||
return cn.usable.CompareAndSwap(old, new)
|
||||
currentState := cn.stateMachine.GetState()
|
||||
|
||||
// Check if current state matches the "old" usable value
|
||||
currentUsable := (currentState == StateIdle || currentState == StateInUse)
|
||||
if currentUsable != old {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're trying to set to the same value, succeed immediately
|
||||
if old == new {
|
||||
return true
|
||||
}
|
||||
|
||||
// Transition based on new value
|
||||
if new {
|
||||
// Trying to make usable - transition from UNUSABLE to IDLE
|
||||
// This should only work from UNUSABLE or INITIALIZING states
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
validFromInitializingOrUnusable,
|
||||
StateIdle,
|
||||
)
|
||||
return err == nil
|
||||
}
|
||||
// Trying to make unusable - transition from IDLE to UNUSABLE
|
||||
// This is typically for acquiring the connection for background operations
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
validFromIdle,
|
||||
StateUnusable,
|
||||
)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
//
|
||||
// A connection is "usable" when it's in a stable state and can be returned to clients.
|
||||
// It becomes unusable during:
|
||||
// - Initialization (before first use)
|
||||
// - Handoff operations (network connection replacement)
|
||||
// - Re-authentication (credential updates)
|
||||
// - Other background operations that need exclusive access
|
||||
//
|
||||
// Note: CREATED state is considered usable because new connections need to pass OnGet() hook
|
||||
// before initialization. The initialization happens after OnGet() in the client code.
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
return cn.usable.Load()
|
||||
state := cn.stateMachine.GetState()
|
||||
// CREATED, IDLE, and IN_USE states are considered usable
|
||||
// CREATED: new connection, not yet initialized (will be initialized by client)
|
||||
// IDLE: initialized and ready to be acquired
|
||||
// IN_USE: usable but currently acquired by someone
|
||||
return state == StateCreated || state == StateIdle || state == StateInUse
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This should be called to mark a connection as usable after initialization or
|
||||
// to release it after a background operation completes.
|
||||
//
|
||||
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
cn.usable.Store(usable)
|
||||
if usable {
|
||||
// Transition to IDLE state (ready to be acquired)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
} else {
|
||||
// Transition to UNUSABLE state (for background operations)
|
||||
cn.stateMachine.Transition(StateUnusable)
|
||||
}
|
||||
}
|
||||
|
||||
// Used
|
||||
// IsInited returns true if the connection has been initialized.
|
||||
// This is a backward-compatible wrapper around the state machine.
|
||||
func (cn *Conn) IsInited() bool {
|
||||
state := cn.stateMachine.GetState()
|
||||
// Connection is initialized if it's in IDLE or any post-initialization state
|
||||
return state != StateCreated && state != StateInitializing && state != StateClosed
|
||||
}
|
||||
|
||||
// Used - State machine based implementation
|
||||
|
||||
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This is the preferred method for acquiring a connection from the pool, as it
|
||||
// ensures that only one goroutine marks the connection as used.
|
||||
//
|
||||
// Implementation: Uses state machine transitions IDLE ⇄ IN_USE
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
||||
return cn.used.CompareAndSwap(old, new)
|
||||
if old == new {
|
||||
// No change needed
|
||||
currentState := cn.stateMachine.GetState()
|
||||
currentUsed := (currentState == StateInUse)
|
||||
return currentUsed == old
|
||||
}
|
||||
|
||||
if !old && new {
|
||||
// Acquiring: IDLE → IN_USE
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse)
|
||||
return err == nil
|
||||
} else {
|
||||
// Releasing: IN_USE → IDLE
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsUsed returns true if the connection is currently in use (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// A connection is "used" when it has been retrieved from the pool and is
|
||||
// actively processing a command. Background operations (like re-auth) should
|
||||
// wait until the connection is not used before executing commands.
|
||||
func (cn *Conn) IsUsed() bool {
|
||||
return cn.used.Load()
|
||||
return cn.stateMachine.GetState() == StateInUse
|
||||
}
|
||||
|
||||
// SetUsed sets the used flag for the connection (lock-free).
|
||||
@@ -230,8 +331,13 @@ func (cn *Conn) IsUsed() bool {
|
||||
//
|
||||
// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to
|
||||
// avoid race conditions.
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
func (cn *Conn) SetUsed(val bool) {
|
||||
cn.used.Store(val)
|
||||
if val {
|
||||
cn.stateMachine.Transition(StateInUse)
|
||||
} else {
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
@@ -251,48 +357,51 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// getHandoffState returns the current handoff state atomically (lock-free).
|
||||
func (cn *Conn) getHandoffState() *HandoffState {
|
||||
state := cn.handoffStateAtomic.Load()
|
||||
if state == nil {
|
||||
// Return default state if not initialized
|
||||
return &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
// Handoff state management - atomic access to handoff metadata
|
||||
|
||||
// ShouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).ShouldHandoff
|
||||
}
|
||||
return state.(*HandoffState)
|
||||
return false
|
||||
}
|
||||
|
||||
// setHandoffState sets the handoff state atomically (lock-free).
|
||||
func (cn *Conn) setHandoffState(state *HandoffState) {
|
||||
cn.handoffStateAtomic.Store(state)
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).Endpoint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// shouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) shouldHandoff() bool {
|
||||
return cn.getHandoffState().ShouldHandoff
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).SeqID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// getMovingSeqID returns the sequence ID atomically (lock-free).
|
||||
func (cn *Conn) getMovingSeqID() int64 {
|
||||
return cn.getHandoffState().SeqID
|
||||
// GetHandoffInfo returns all handoff information atomically (lock-free).
|
||||
// This method prevents race conditions by returning all handoff state in a single atomic operation.
|
||||
// Returns (shouldHandoff, endpoint, seqID).
|
||||
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
state := v.(*HandoffState)
|
||||
return state.ShouldHandoff, state.Endpoint, state.SeqID
|
||||
}
|
||||
return false, "", 0
|
||||
}
|
||||
|
||||
// getNewEndpoint returns the new endpoint atomically (lock-free).
|
||||
func (cn *Conn) getNewEndpoint() string {
|
||||
return cn.getHandoffState().Endpoint
|
||||
// HandoffRetries returns the current handoff retry count (lock-free).
|
||||
func (cn *Conn) HandoffRetries() int {
|
||||
return int(cn.handoffRetriesAtomic.Load())
|
||||
}
|
||||
|
||||
// setHandoffRetries sets the retry count atomically (lock-free).
|
||||
func (cn *Conn) setHandoffRetries(retries int) {
|
||||
cn.handoffRetriesAtomic.Store(uint32(retries))
|
||||
}
|
||||
|
||||
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
|
||||
func (cn *Conn) incrementHandoffRetries(delta int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(n)))
|
||||
}
|
||||
|
||||
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
|
||||
@@ -305,10 +414,6 @@ func (cn *Conn) IsPubSub() bool {
|
||||
return cn.pubsub
|
||||
}
|
||||
|
||||
func (cn *Conn) IsInited() bool {
|
||||
return cn.Inited.Load()
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
@@ -392,7 +497,8 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
@@ -425,7 +531,8 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
@@ -477,121 +584,115 @@ func (cn *Conn) GetNetConn() net.Conn {
|
||||
}
|
||||
|
||||
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||
// This method ensures only one initialization can happen at a time by using atomic state transitions.
|
||||
// If another goroutine is currently initializing, this will wait for it to complete.
|
||||
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||
// New connection is not initialized yet
|
||||
cn.Inited.Store(false)
|
||||
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
|
||||
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
|
||||
// If another goroutine is initializing, we'll wait for it to finish
|
||||
// if the context has a deadline, use that, otherwise use the connection read (relaxed) timeout
|
||||
// which should be set during handoff. If it is not set, use a 5 second default
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
deadline = time.Now().Add(cn.getEffectiveReadTimeout(5 * time.Second))
|
||||
}
|
||||
waitCtx, cancel := context.WithDeadline(ctx, deadline)
|
||||
defer cancel()
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.AwaitAndTransition(
|
||||
waitCtx,
|
||||
validFromCreatedIdleOrUnusable,
|
||||
StateInitializing,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
|
||||
}
|
||||
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
return cn.ExecuteInitConn(ctx)
|
||||
|
||||
// Execute initialization
|
||||
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
|
||||
// or CLOSED on failure. We don't need to do it here.
|
||||
// NOTE: Initconn returns conn in IDLE state
|
||||
initErr := cn.ExecuteInitConn(ctx)
|
||||
if initErr != nil {
|
||||
// ExecuteInitConn already transitioned to CLOSED, just return the error
|
||||
return initErr
|
||||
}
|
||||
|
||||
// ExecuteInitConn already transitioned to IDLE
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification.
|
||||
// Returns an error if the connection is already marked for handoff.
|
||||
// This method uses atomic compare-and-swap to ensure all handoff state is updated atomically.
|
||||
// Note: This only sets metadata - the connection state is not changed until OnPut.
|
||||
// This allows the current user to finish using the connection before handoff.
|
||||
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||
const maxRetries = 50
|
||||
const baseDelay = time.Microsecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
currentState := cn.getHandoffState()
|
||||
|
||||
// Check if already marked for handoff
|
||||
if currentState.ShouldHandoff {
|
||||
return errors.New("connection is already marked for handoff")
|
||||
}
|
||||
|
||||
// Create new state with handoff enabled
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: true,
|
||||
Endpoint: newEndpoint,
|
||||
SeqID: seqID,
|
||||
}
|
||||
|
||||
// Atomic compare-and-swap to update entire state
|
||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If CAS failed, add exponential backoff to reduce contention
|
||||
if attempt < maxRetries-1 {
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
}
|
||||
// Check if already marked for handoff
|
||||
if cn.ShouldHandoff() {
|
||||
return errAlreadyMarkedForHandoff
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to mark connection for handoff after %d attempts due to high contention", maxRetries)
|
||||
// Set handoff metadata atomically
|
||||
cn.handoffStateAtomic.Store(&HandoffState{
|
||||
ShouldHandoff: true,
|
||||
Endpoint: newEndpoint,
|
||||
SeqID: seqID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkQueuedForHandoff marks the connection as queued for handoff processing.
|
||||
// This makes the connection unusable until handoff completes.
|
||||
// This is called from OnPut hook, where the connection is typically in IN_USE state.
|
||||
// The pool will preserve the UNUSABLE state and not overwrite it with IDLE.
|
||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
const maxRetries = 50
|
||||
const baseDelay = time.Microsecond
|
||||
|
||||
connAcquired := false
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// If CAS failed, add exponential backoff to reduce contention
|
||||
// the delay will be 1, 2, 4... up to 512 microseconds
|
||||
// Moving this to the top of the loop to avoid "continue" without delay
|
||||
if attempt > 0 && attempt < maxRetries-1 {
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
// first we need to mark the connection as not usable
|
||||
// to prevent the pool from returning it to the caller
|
||||
if !connAcquired {
|
||||
if !cn.usable.CompareAndSwap(true, false) {
|
||||
continue
|
||||
}
|
||||
connAcquired = true
|
||||
}
|
||||
|
||||
currentState := cn.getHandoffState()
|
||||
// Check if marked for handoff
|
||||
if !currentState.ShouldHandoff {
|
||||
return errors.New("connection was not marked for handoff")
|
||||
}
|
||||
|
||||
// Create new state with handoff disabled (queued)
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: currentState.Endpoint, // Preserve endpoint for handoff processing
|
||||
SeqID: currentState.SeqID, // Preserve seqID for handoff processing
|
||||
}
|
||||
|
||||
// Atomic compare-and-swap to update state
|
||||
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
// queue the handoff for processing
|
||||
// the connection is now "acquired" (marked as not usable) by the handoff
|
||||
// and it won't be returned to any other callers until the handoff is complete
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current handoff state
|
||||
currentState := cn.handoffStateAtomic.Load()
|
||||
if currentState == nil {
|
||||
return errNotMarkedForHandoff
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
|
||||
}
|
||||
state := currentState.(*HandoffState)
|
||||
if !state.ShouldHandoff {
|
||||
return errNotMarkedForHandoff
|
||||
}
|
||||
|
||||
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
return cn.shouldHandoff()
|
||||
}
|
||||
// Create new state with ShouldHandoff=false but preserve endpoint and seqID
|
||||
// This prevents the connection from being queued multiple times while still
|
||||
// allowing the worker to access the handoff metadata
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: state.Endpoint, // Preserve endpoint for handoff processing
|
||||
SeqID: state.SeqID, // Preserve seqID for handoff processing
|
||||
}
|
||||
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
return cn.getNewEndpoint()
|
||||
}
|
||||
// Atomic compare-and-swap to update state
|
||||
if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
// State changed between load and CAS - retry or return error
|
||||
return errHandoffStateChanged
|
||||
}
|
||||
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
return cn.getMovingSeqID()
|
||||
}
|
||||
|
||||
// GetHandoffInfo returns all handoff information atomically (lock-free).
|
||||
// This method prevents race conditions by returning all handoff state in a single atomic operation.
|
||||
// Returns (shouldHandoff, endpoint, seqID).
|
||||
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
|
||||
state := cn.getHandoffState()
|
||||
return state.ShouldHandoff, state.Endpoint, state.SeqID
|
||||
// Transition to UNUSABLE from IN_USE (normal flow), IDLE (edge cases), or CREATED (tests/uninitialized)
|
||||
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
|
||||
// But in some edge cases or tests, it might be in IDLE or CREATED state
|
||||
// The pool will detect this state change and preserve it (not overwrite with IDLE)
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable)
|
||||
if err != nil {
|
||||
// Check if already in UNUSABLE state (race condition or retry)
|
||||
// ShouldHandoff should be false now, but check just in case
|
||||
if finalState == StateUnusable && !cn.ShouldHandoff() {
|
||||
// Already unusable - this is fine, keep the new handoff state
|
||||
return nil
|
||||
}
|
||||
// Restore the original state if transition fails for other reasons
|
||||
cn.handoffStateAtomic.Store(currentState)
|
||||
return fmt.Errorf("failed to mark connection as unusable: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetID returns the unique identifier for this connection.
|
||||
@@ -599,30 +700,67 @@ func (cn *Conn) GetID() uint64 {
|
||||
return cn.id
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
|
||||
// GetStateMachine returns the connection's state machine for advanced state management.
|
||||
// This is primarily used by internal packages like maintnotifications for handoff processing.
|
||||
func (cn *Conn) GetStateMachine() *ConnStateMachine {
|
||||
return cn.stateMachine
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire the connection for use.
|
||||
// This is an optimized inline method for the hot path (Get operation).
|
||||
//
|
||||
// It tries to transition from IDLE -> IN_USE or CREATED -> CREATED.
|
||||
// Returns true if the connection was successfully acquired, false otherwise.
|
||||
// The CREATED->CREATED is done so we can keep the state correct for later
|
||||
// initialization of the connection in initConn.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast()
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// The IDLE->IN_USE and CREATED->CREATED transitions don't need
|
||||
// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever
|
||||
// needs to notify waiters on these transitions, update this to use TryTransitionFast().
|
||||
func (cn *Conn) TryAcquire() bool {
|
||||
// The || operator short-circuits, so only 1 CAS in the common case
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) ||
|
||||
cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated))
|
||||
}
|
||||
|
||||
// Release releases the connection back to the pool.
|
||||
// This is an optimized inline method for the hot path (Put operation).
|
||||
//
|
||||
// It tries to transition from IN_USE -> IDLE.
|
||||
// Returns true if the connection was successfully released, false otherwise.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast().
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// If the state machine ever needs to notify waiters
|
||||
// on this transition, update this to use TryTransitionFast().
|
||||
func (cn *Conn) Release() bool {
|
||||
// Inline the hot path - single CAS operation
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle))
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff.
|
||||
// Makes the connection usable again.
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
// Create clean state
|
||||
cleanState := &HandoffState{
|
||||
// Clear handoff metadata
|
||||
cn.handoffStateAtomic.Store(&HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
})
|
||||
|
||||
// Atomically set clean state
|
||||
cn.setHandoffState(cleanState)
|
||||
cn.setHandoffRetries(0)
|
||||
// Clearing handoff state also means the connection is usable again
|
||||
cn.SetUsable(true)
|
||||
}
|
||||
// Reset retry counter
|
||||
cn.handoffRetriesAtomic.Store(0)
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return cn.incrementHandoffRetries(n)
|
||||
}
|
||||
|
||||
// GetHandoffRetries returns the current handoff retry count (lock-free).
|
||||
func (cn *Conn) HandoffRetries() int {
|
||||
return int(cn.handoffRetriesAtomic.Load())
|
||||
// Mark connection as usable again
|
||||
// Use state machine directly instead of deprecated SetUsable
|
||||
// probably done by initConn
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
@@ -673,7 +811,7 @@ func (cn *Conn) WithReader(
|
||||
// Get the connection directly from atomic storage
|
||||
netConn := cn.getNetConn()
|
||||
if netConn == nil {
|
||||
return fmt.Errorf("redis: connection not available")
|
||||
return errConnectionNotAvailable
|
||||
}
|
||||
|
||||
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
@@ -690,19 +828,18 @@ func (cn *Conn) WithWriter(
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||
|
||||
// Always set write deadline, even if getNetConn() returns nil
|
||||
// This prevents write operations from hanging indefinitely
|
||||
// Set write deadline on the connection
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// If getNetConn() returns nil, we still need to respect the timeout
|
||||
// Return an error to prevent indefinite blocking
|
||||
return fmt.Errorf("redis: conn[%d] not available for write operation", cn.GetID())
|
||||
// Connection is not available - return preallocated error
|
||||
return errConnNotAvailableForWrite
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the buffered writer if needed, should not happen
|
||||
if cn.bw.Buffered() > 0 {
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
cn.bw.Reset(netConn)
|
||||
@@ -717,11 +854,15 @@ func (cn *Conn) WithWriter(
|
||||
}
|
||||
|
||||
func (cn *Conn) IsClosed() bool {
|
||||
return cn.closed.Load()
|
||||
return cn.closed.Load() || cn.stateMachine.GetState() == StateClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
cn.closed.Store(true)
|
||||
|
||||
// Transition to CLOSED state
|
||||
cn.stateMachine.Transition(StateClosed)
|
||||
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
@@ -745,9 +886,14 @@ func (cn *Conn) MaybeHasData() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// deadline computes the effective deadline time based on context and timeout.
|
||||
// It updates the usedAt timestamp to now.
|
||||
// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation).
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
tm := time.Now()
|
||||
cn.SetUsedAt(tm)
|
||||
// Use cached time for deadline calculation (called 2x per command: read + write)
|
||||
nowNs := getCachedTimeNs()
|
||||
cn.SetUsedAtNs(nowNs)
|
||||
tm := time.Unix(0, nowNs)
|
||||
|
||||
if timeout > 0 {
|
||||
tm = tm.Add(timeout)
|
||||
|
||||
343
internal/pool/conn_state.go
Normal file
343
internal/pool/conn_state.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ConnState represents the connection state in the state machine.
|
||||
// States are designed to be lightweight and fast to check.
|
||||
//
|
||||
// State Transitions:
|
||||
// CREATED → INITIALIZING → IDLE ⇄ IN_USE
|
||||
// ↓
|
||||
// UNUSABLE (handoff/reauth)
|
||||
// ↓
|
||||
// IDLE/CLOSED
|
||||
type ConnState uint32
|
||||
|
||||
const (
|
||||
// StateCreated - Connection just created, not yet initialized
|
||||
StateCreated ConnState = iota
|
||||
|
||||
// StateInitializing - Connection initialization in progress
|
||||
StateInitializing
|
||||
|
||||
// StateIdle - Connection initialized and idle in pool, ready to be acquired
|
||||
StateIdle
|
||||
|
||||
// StateInUse - Connection actively processing a command (retrieved from pool)
|
||||
StateInUse
|
||||
|
||||
// StateUnusable - Connection temporarily unusable due to background operation
|
||||
// (handoff, reauth, etc.). Cannot be acquired from pool.
|
||||
StateUnusable
|
||||
|
||||
// StateClosed - Connection closed
|
||||
StateClosed
|
||||
)
|
||||
|
||||
// Predefined state slices to avoid allocations in hot paths
|
||||
var (
|
||||
validFromInUse = []ConnState{StateInUse}
|
||||
validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle}
|
||||
validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle}
|
||||
// For AwaitAndTransition calls
|
||||
validFromCreatedIdleOrUnusable = []ConnState{StateCreated, StateIdle, StateUnusable}
|
||||
validFromIdle = []ConnState{StateIdle}
|
||||
// For CompareAndSwapUsable
|
||||
validFromInitializingOrUnusable = []ConnState{StateInitializing, StateUnusable}
|
||||
)
|
||||
|
||||
// Accessor functions for predefined slices to avoid allocations in external packages
|
||||
// These return the same slice instance, so they're zero-allocation
|
||||
|
||||
// ValidFromIdle returns a predefined slice containing only StateIdle.
|
||||
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
|
||||
func ValidFromIdle() []ConnState {
|
||||
return validFromIdle
|
||||
}
|
||||
|
||||
// ValidFromCreatedIdleOrUnusable returns a predefined slice for initialization transitions.
|
||||
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
|
||||
func ValidFromCreatedIdleOrUnusable() []ConnState {
|
||||
return validFromCreatedIdleOrUnusable
|
||||
}
|
||||
|
||||
// String returns a human-readable string representation of the state.
|
||||
func (s ConnState) String() string {
|
||||
switch s {
|
||||
case StateCreated:
|
||||
return "CREATED"
|
||||
case StateInitializing:
|
||||
return "INITIALIZING"
|
||||
case StateIdle:
|
||||
return "IDLE"
|
||||
case StateInUse:
|
||||
return "IN_USE"
|
||||
case StateUnusable:
|
||||
return "UNUSABLE"
|
||||
case StateClosed:
|
||||
return "CLOSED"
|
||||
default:
|
||||
return fmt.Sprintf("UNKNOWN(%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrInvalidStateTransition is returned when a state transition is not allowed
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
|
||||
// ErrStateMachineClosed is returned when operating on a closed state machine
|
||||
ErrStateMachineClosed = errors.New("state machine is closed")
|
||||
|
||||
// ErrTimeout is returned when a state transition times out
|
||||
ErrTimeout = errors.New("state transition timeout")
|
||||
)
|
||||
|
||||
// waiter represents a goroutine waiting for a state transition.
|
||||
// Designed for minimal allocations and fast processing.
|
||||
type waiter struct {
|
||||
validStates map[ConnState]struct{} // States we're waiting for
|
||||
targetState ConnState // State to transition to
|
||||
done chan error // Signaled when transition completes or times out
|
||||
}
|
||||
|
||||
// ConnStateMachine manages connection state transitions with FIFO waiting queue.
|
||||
// Optimized for:
|
||||
// - Lock-free reads (hot path)
|
||||
// - Minimal allocations
|
||||
// - Fast state transitions
|
||||
// - FIFO fairness for waiters
|
||||
// Note: Handoff metadata (endpoint, seqID, retries) is managed separately in the Conn struct.
|
||||
type ConnStateMachine struct {
|
||||
// Current state - atomic for lock-free reads
|
||||
state atomic.Uint32
|
||||
|
||||
// FIFO queue for waiters - only locked during waiter add/remove/notify
|
||||
mu sync.Mutex
|
||||
waiters *list.List // List of *waiter
|
||||
waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path)
|
||||
}
|
||||
|
||||
// NewConnStateMachine creates a new connection state machine.
|
||||
// Initial state is StateCreated.
|
||||
func NewConnStateMachine() *ConnStateMachine {
|
||||
sm := &ConnStateMachine{
|
||||
waiters: list.New(),
|
||||
}
|
||||
sm.state.Store(uint32(StateCreated))
|
||||
return sm
|
||||
}
|
||||
|
||||
// GetState returns the current state (lock-free read).
|
||||
// This is the hot path - optimized for zero allocations and minimal overhead.
|
||||
// Note: Zero allocations applies to state reads; converting the returned state to a string
|
||||
// (via String()) may allocate if the state is unknown.
|
||||
func (sm *ConnStateMachine) GetState() ConnState {
|
||||
return ConnState(sm.state.Load())
|
||||
}
|
||||
|
||||
// TryTransitionFast is an optimized version for the hot path (Get/Put operations).
|
||||
// It only handles simple state transitions without waiter notification.
|
||||
// This is safe because:
|
||||
// 1. Get/Put don't need to wait for state changes
|
||||
// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match
|
||||
// 3. If a background operation is in progress (state is UNUSABLE), this fails fast
|
||||
//
|
||||
// Returns true if transition succeeded, false otherwise.
|
||||
// Use this for performance-critical paths where you don't need error details.
|
||||
//
|
||||
// Performance: Single CAS operation - as fast as the old atomic bool!
|
||||
// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target)
|
||||
// The || operator short-circuits, so only 1 CAS is executed in the common case.
|
||||
func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool {
|
||||
return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState))
|
||||
}
|
||||
|
||||
// TryTransition attempts an immediate state transition without waiting.
|
||||
// Returns the current state after the transition attempt and an error if the transition failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// This is faster than AwaitAndTransition when you don't need to wait.
|
||||
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
|
||||
// This method does NOT wait - it fails immediately if the transition cannot be performed.
|
||||
//
|
||||
// Performance: Zero allocations on success path (hot path).
|
||||
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
|
||||
// Try each valid from state with CAS
|
||||
// This ensures only ONE goroutine can successfully transition at a time
|
||||
for _, fromState := range validFromStates {
|
||||
// Try to atomically swap from fromState to targetState
|
||||
// If successful, we won the race and can proceed
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
// Hot path optimization: only check for waiters if transition succeeded
|
||||
// This avoids atomic load on every Get/Put when no waiters exist
|
||||
if sm.waiterCount.Load() > 0 {
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All CAS attempts failed - state is not valid for this transition
|
||||
// Return the current state so caller can decide what to do
|
||||
// Note: This error path allocates, but it's the exceptional case
|
||||
currentState := sm.GetState()
|
||||
return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
||||
ErrInvalidStateTransition, currentState, targetState, validFromStates)
|
||||
}
|
||||
|
||||
// Transition unconditionally transitions to the target state.
|
||||
// Use with caution - prefer AwaitAndTransition or TryTransition for safety.
|
||||
// This is useful for error paths or when you know the transition is valid.
|
||||
func (sm *ConnStateMachine) Transition(targetState ConnState) {
|
||||
sm.state.Store(uint32(targetState))
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
|
||||
// AwaitAndTransition waits for the connection to reach one of the valid states,
|
||||
// then atomically transitions to the target state.
|
||||
// Returns the current state after the transition attempt and an error if the operation failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// Returns error if timeout expires or context is cancelled.
|
||||
//
|
||||
// This method implements FIFO fairness - the first caller to wait gets priority
|
||||
// when the state becomes available.
|
||||
//
|
||||
// Performance notes:
|
||||
// - If already in a valid state, this is very fast (no allocation, no waiting)
|
||||
// - If waiting is required, allocates one waiter struct and one channel
|
||||
func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
ctx context.Context,
|
||||
validFromStates []ConnState,
|
||||
targetState ConnState,
|
||||
) (ConnState, error) {
|
||||
// Fast path: try immediate transition with CAS to prevent race conditions
|
||||
// BUT: only if there are no waiters in the queue (to maintain FIFO ordering)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
for _, fromState := range validFromStates {
|
||||
// Check if we're already in target state
|
||||
if fromState == targetState && sm.GetState() == targetState {
|
||||
return targetState, nil
|
||||
}
|
||||
|
||||
// Try to atomically swap from fromState to targetState
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
sm.notifyWaiters()
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path failed - check if we should wait or fail
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if closed
|
||||
if currentState == StateClosed {
|
||||
return currentState, ErrStateMachineClosed
|
||||
}
|
||||
|
||||
// Slow path: need to wait for state change
|
||||
// Create waiter with valid states map for fast lookup
|
||||
validStatesMap := make(map[ConnState]struct{}, len(validFromStates))
|
||||
for _, s := range validFromStates {
|
||||
validStatesMap[s] = struct{}{}
|
||||
}
|
||||
|
||||
w := &waiter{
|
||||
validStates: validStatesMap,
|
||||
targetState: targetState,
|
||||
done: make(chan error, 1), // Buffered to avoid goroutine leak
|
||||
}
|
||||
|
||||
// Add to FIFO queue
|
||||
sm.mu.Lock()
|
||||
elem := sm.waiters.PushBack(w)
|
||||
sm.waiterCount.Add(1)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Wait for state change or timeout
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Timeout or cancellation - remove from queue
|
||||
sm.mu.Lock()
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
sm.mu.Unlock()
|
||||
return sm.GetState(), ctx.Err()
|
||||
case err := <-w.done:
|
||||
// Transition completed (or failed)
|
||||
// Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed)
|
||||
// or here (on timeout/cancellation).
|
||||
return sm.GetState(), err
|
||||
}
|
||||
}
|
||||
|
||||
// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order.
|
||||
// This is called after every state transition.
|
||||
func (sm *ConnStateMachine) notifyWaiters() {
|
||||
// Fast path: check atomic counter without acquiring lock
|
||||
// This eliminates mutex overhead in the common case (no waiters)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock (waiters might have been processed)
|
||||
if sm.waiters.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Process waiters in FIFO order until no more can be processed
|
||||
// We loop instead of recursing to avoid stack overflow and mutex issues
|
||||
for {
|
||||
processed := false
|
||||
|
||||
// Find the first waiter that can proceed
|
||||
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
|
||||
w := elem.Value.(*waiter)
|
||||
|
||||
// Read current state inside the loop to get the latest value
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if current state is valid for this waiter
|
||||
if _, valid := w.validStates[currentState]; valid {
|
||||
// Remove from queue first
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
|
||||
// Use CAS to ensure state hasn't changed since we checked
|
||||
// This prevents race condition where another thread changes state
|
||||
// between our check and our transition
|
||||
if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
|
||||
// Successfully transitioned - notify waiter
|
||||
w.done <- nil
|
||||
processed = true
|
||||
break
|
||||
} else {
|
||||
// State changed - re-add waiter to front of queue to maintain FIFO ordering
|
||||
// This waiter was first in line and should retain priority
|
||||
sm.waiters.PushFront(w)
|
||||
sm.waiterCount.Add(1)
|
||||
// Continue to next iteration to re-read state
|
||||
processed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't process any waiter, we're done
|
||||
if !processed {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
169
internal/pool/conn_state_alloc_test.go
Normal file
169
internal/pool/conn_state_alloc_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPredefinedSlicesAvoidAllocations verifies that using predefined slices
|
||||
// avoids allocations in AwaitAndTransition calls
|
||||
func TestPredefinedSlicesAvoidAllocations(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with predefined slice - should have 0 allocations on fast path
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
_, _ = sm.AwaitAndTransition(ctx, validFromIdle, StateUnusable)
|
||||
sm.Transition(StateIdle)
|
||||
})
|
||||
|
||||
if allocs > 0 {
|
||||
t.Errorf("Expected 0 allocations with predefined slice, got %.2f", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInlineSliceAllocations shows that inline slices cause allocations
|
||||
func TestInlineSliceAllocations(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with inline slice - will allocate
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
_, _ = sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
sm.Transition(StateIdle)
|
||||
})
|
||||
|
||||
if allocs == 0 {
|
||||
t.Logf("Inline slice had 0 allocations (compiler optimization)")
|
||||
} else {
|
||||
t.Logf("Inline slice caused %.2f allocations per run (expected)", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAwaitAndTransition_PredefinedSlice benchmarks with predefined slice
|
||||
func BenchmarkAwaitAndTransition_PredefinedSlice(b *testing.B) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = sm.AwaitAndTransition(ctx, validFromIdle, StateUnusable)
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAwaitAndTransition_InlineSlice benchmarks with inline slice
|
||||
func BenchmarkAwaitAndTransition_InlineSlice(b *testing.B) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAwaitAndTransition_MultipleStates_Predefined benchmarks with predefined multi-state slice
|
||||
func BenchmarkAwaitAndTransition_MultipleStates_Predefined(b *testing.B) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = sm.AwaitAndTransition(ctx, validFromCreatedIdleOrUnusable, StateInitializing)
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAwaitAndTransition_MultipleStates_Inline benchmarks with inline multi-state slice
|
||||
func BenchmarkAwaitAndTransition_MultipleStates_Inline(b *testing.B) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = sm.AwaitAndTransition(ctx, []ConnState{StateCreated, StateIdle, StateUnusable}, StateInitializing)
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreallocatedErrorsAvoidAllocations verifies that preallocated errors
|
||||
// avoid allocations in hot paths
|
||||
func TestPreallocatedErrorsAvoidAllocations(t *testing.T) {
|
||||
cn := NewConn(nil)
|
||||
|
||||
// Test MarkForHandoff - first call should succeed
|
||||
err := cn.MarkForHandoff("localhost:6379", 123)
|
||||
if err != nil {
|
||||
t.Fatalf("First MarkForHandoff should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Second call should return preallocated error with 0 allocations
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
_ = cn.MarkForHandoff("localhost:6380", 124)
|
||||
})
|
||||
|
||||
if allocs > 0 {
|
||||
t.Errorf("Expected 0 allocations for preallocated error, got %.2f", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkHandoffErrors_Preallocated benchmarks handoff errors with preallocated errors
|
||||
func BenchmarkHandoffErrors_Preallocated(b *testing.B) {
|
||||
cn := NewConn(nil)
|
||||
cn.MarkForHandoff("localhost:6379", 123)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = cn.MarkForHandoff("localhost:6380", 124)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompareAndSwapUsable_Preallocated benchmarks with preallocated slices
|
||||
func BenchmarkCompareAndSwapUsable_Preallocated(b *testing.B) {
|
||||
cn := NewConn(nil)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cn.CompareAndSwapUsable(true, false) // IDLE -> UNUSABLE
|
||||
cn.CompareAndSwapUsable(false, true) // UNUSABLE -> IDLE
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllTryTransitionUsePredefinedSlices verifies all TryTransition calls use predefined slices
|
||||
func TestAllTryTransitionUsePredefinedSlices(t *testing.T) {
|
||||
cn := NewConn(nil)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
|
||||
// Test CompareAndSwapUsable - should have minimal allocations
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
cn.CompareAndSwapUsable(true, false) // IDLE -> UNUSABLE
|
||||
cn.CompareAndSwapUsable(false, true) // UNUSABLE -> IDLE
|
||||
})
|
||||
|
||||
// Allow some allocations for error objects, but should be minimal
|
||||
if allocs > 2 {
|
||||
t.Errorf("Expected <= 2 allocations with predefined slices, got %.2f", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
742
internal/pool/conn_state_test.go
Normal file
742
internal/pool/conn_state_test.go
Normal file
@@ -0,0 +1,742 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConnStateMachine_GetState(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
|
||||
if state := sm.GetState(); state != StateCreated {
|
||||
t.Errorf("expected initial state to be CREATED, got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_Transition(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
|
||||
// Unconditional transition
|
||||
sm.Transition(StateInitializing)
|
||||
if state := sm.GetState(); state != StateInitializing {
|
||||
t.Errorf("expected state to be INITIALIZING, got %s", state)
|
||||
}
|
||||
|
||||
sm.Transition(StateIdle)
|
||||
if state := sm.GetState(); state != StateIdle {
|
||||
t.Errorf("expected state to be IDLE, got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_TryTransition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState ConnState
|
||||
validStates []ConnState
|
||||
targetState ConnState
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid transition from CREATED to INITIALIZING",
|
||||
initialState: StateCreated,
|
||||
validStates: []ConnState{StateCreated},
|
||||
targetState: StateInitializing,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid transition from CREATED to IDLE",
|
||||
initialState: StateCreated,
|
||||
validStates: []ConnState{StateInitializing},
|
||||
targetState: StateIdle,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "transition to same state",
|
||||
initialState: StateIdle,
|
||||
validStates: []ConnState{StateIdle},
|
||||
targetState: StateIdle,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple valid from states",
|
||||
initialState: StateIdle,
|
||||
validStates: []ConnState{StateInitializing, StateIdle, StateUnusable},
|
||||
targetState: StateUnusable,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(tt.initialState)
|
||||
|
||||
_, err := sm.TryTransition(tt.validStates, tt.targetState)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !tt.expectError {
|
||||
if state := sm.GetState(); state != tt.targetState {
|
||||
t.Errorf("expected state %s, got %s", tt.targetState, state)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_AwaitAndTransition_FastPath(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fast path: already in valid state
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if state := sm.GetState(); state != StateUnusable {
|
||||
t.Errorf("expected state UNUSABLE, got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_AwaitAndTransition_Timeout(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateCreated)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Wait for a state that will never come
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateUnusable)
|
||||
if err == nil {
|
||||
t.Error("expected timeout error but got none")
|
||||
}
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_AwaitAndTransition_FIFO(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateCreated)
|
||||
|
||||
const numWaiters = 10
|
||||
order := make([]int, 0, numWaiters)
|
||||
var orderMu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var startBarrier sync.WaitGroup
|
||||
startBarrier.Add(numWaiters)
|
||||
|
||||
// Start multiple waiters
|
||||
for i := 0; i < numWaiters; i++ {
|
||||
wg.Add(1)
|
||||
waiterID := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Signal that this goroutine is ready
|
||||
startBarrier.Done()
|
||||
// Wait for all goroutines to be ready before starting
|
||||
startBarrier.Wait()
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateIdle)
|
||||
if err != nil {
|
||||
t.Errorf("waiter %d got error: %v", waiterID, err)
|
||||
return
|
||||
}
|
||||
|
||||
orderMu.Lock()
|
||||
order = append(order, waiterID)
|
||||
orderMu.Unlock()
|
||||
|
||||
// Transition back to READY for next waiter
|
||||
sm.Transition(StateIdle)
|
||||
}()
|
||||
}
|
||||
|
||||
// Give waiters time to queue up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Transition to READY to start processing waiters
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
// Wait for all waiters to complete
|
||||
wg.Wait()
|
||||
|
||||
// Verify all waiters completed (FIFO order is not guaranteed due to goroutine scheduling)
|
||||
if len(order) != numWaiters {
|
||||
t.Errorf("expected %d waiters to complete, got %d", numWaiters, len(order))
|
||||
}
|
||||
|
||||
// Verify no duplicates
|
||||
seen := make(map[int]bool)
|
||||
for _, id := range order {
|
||||
if seen[id] {
|
||||
t.Errorf("duplicate waiter ID %d in order", id)
|
||||
}
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_ConcurrentAccess(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
const numGoroutines = 100
|
||||
const numIterations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount atomic.Int32
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numIterations; j++ {
|
||||
// Try to transition from READY to REAUTH_IN_PROGRESS
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
// Transition back to READY
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
|
||||
// Read state (hot path)
|
||||
_ = sm.GetState()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// At least some transitions should have succeeded
|
||||
if successCount.Load() == 0 {
|
||||
t.Error("expected at least some successful transitions")
|
||||
}
|
||||
|
||||
t.Logf("Successful transitions: %d out of %d attempts", successCount.Load(), numGoroutines*numIterations)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestConnStateMachine_StateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
state ConnState
|
||||
expected string
|
||||
}{
|
||||
{StateCreated, "CREATED"},
|
||||
{StateInitializing, "INITIALIZING"},
|
||||
{StateIdle, "IDLE"},
|
||||
{StateInUse, "IN_USE"},
|
||||
{StateUnusable, "UNUSABLE"},
|
||||
{StateClosed, "CLOSED"},
|
||||
{ConnState(999), "UNKNOWN(999)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
if got := tt.state.String(); got != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnStateMachine_GetState(b *testing.B) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = sm.GetState()
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_PreventsConcurrentInitialization(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
const numGoroutines = 10
|
||||
var inInitializing atomic.Int32
|
||||
var maxConcurrent atomic.Int32
|
||||
var successCount atomic.Int32
|
||||
var wg sync.WaitGroup
|
||||
var startBarrier sync.WaitGroup
|
||||
startBarrier.Add(numGoroutines)
|
||||
|
||||
// Try to initialize concurrently from multiple goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for all goroutines to be ready
|
||||
startBarrier.Done()
|
||||
startBarrier.Wait()
|
||||
|
||||
// Try to transition to INITIALIZING
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInitializing)
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
|
||||
// We successfully transitioned - increment concurrent count
|
||||
current := inInitializing.Add(1)
|
||||
|
||||
// Track maximum concurrent initializations
|
||||
for {
|
||||
max := maxConcurrent.Load()
|
||||
if current <= max || maxConcurrent.CompareAndSwap(max, current) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Goroutine %d: entered INITIALIZING (concurrent=%d)", id, current)
|
||||
|
||||
// Simulate initialization work
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Decrement before transitioning back
|
||||
inInitializing.Add(-1)
|
||||
|
||||
// Transition back to READY
|
||||
sm.Transition(StateIdle)
|
||||
} else {
|
||||
t.Logf("Goroutine %d: failed to enter INITIALIZING - %v", id, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
t.Logf("Total successful transitions: %d, Max concurrent: %d", successCount.Load(), maxConcurrent.Load())
|
||||
|
||||
// The maximum number of concurrent initializations should be 1
|
||||
if maxConcurrent.Load() != 1 {
|
||||
t.Errorf("expected max 1 concurrent initialization, got %d", maxConcurrent.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_AwaitAndTransitionWaitsForInitialization(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
const numGoroutines = 5
|
||||
var completedCount atomic.Int32
|
||||
var executionOrder []int
|
||||
var orderMu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var startBarrier sync.WaitGroup
|
||||
startBarrier.Add(numGoroutines)
|
||||
|
||||
// All goroutines try to initialize concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for all goroutines to be ready
|
||||
startBarrier.Done()
|
||||
startBarrier.Wait()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to transition to INITIALIZING - should wait if another is initializing
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Record execution order
|
||||
orderMu.Lock()
|
||||
executionOrder = append(executionOrder, id)
|
||||
orderMu.Unlock()
|
||||
|
||||
t.Logf("Goroutine %d: entered INITIALIZING (position %d)", id, len(executionOrder))
|
||||
|
||||
// Simulate initialization work
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Transition back to READY
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
completedCount.Add(1)
|
||||
t.Logf("Goroutine %d: completed initialization (total=%d)", id, completedCount.Load())
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All goroutines should have completed successfully
|
||||
if completedCount.Load() != numGoroutines {
|
||||
t.Errorf("expected %d completions, got %d", numGoroutines, completedCount.Load())
|
||||
}
|
||||
|
||||
// Final state should be IDLE
|
||||
if sm.GetState() != StateIdle {
|
||||
t.Errorf("expected final state IDLE, got %s", sm.GetState())
|
||||
}
|
||||
|
||||
t.Logf("Execution order: %v", executionOrder)
|
||||
}
|
||||
|
||||
func TestConnStateMachine_FIFOOrdering(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateInitializing) // Start in INITIALIZING so all waiters must queue
|
||||
|
||||
const numGoroutines = 10
|
||||
var executionOrder []int
|
||||
var orderMu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Launch goroutines one at a time, ensuring each is queued before launching the next
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
expectedWaiters := int32(i + 1)
|
||||
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// This should queue in FIFO order
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Record execution order
|
||||
orderMu.Lock()
|
||||
executionOrder = append(executionOrder, id)
|
||||
orderMu.Unlock()
|
||||
|
||||
t.Logf("Goroutine %d: executed (position %d)", id, len(executionOrder))
|
||||
|
||||
// Transition back to IDLE to allow next waiter
|
||||
sm.Transition(StateIdle)
|
||||
}(i)
|
||||
|
||||
// Wait until this goroutine has been queued before launching the next
|
||||
// Poll the waiter count to ensure the goroutine is actually queued
|
||||
timeout := time.After(100 * time.Millisecond)
|
||||
for {
|
||||
if sm.waiterCount.Load() >= expectedWaiters {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("Timeout waiting for goroutine %d to queue", i)
|
||||
case <-time.After(1 * time.Millisecond):
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Give all goroutines time to fully settle in the queue
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Transition to IDLE to start processing the queue
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
t.Logf("Execution order: %v", executionOrder)
|
||||
|
||||
// Verify FIFO ordering - should be [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
if executionOrder[i] != i {
|
||||
t.Errorf("FIFO violation: expected goroutine %d at position %d, got %d", i, i, executionOrder[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStateMachine_FIFOWithFastPath(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle) // Start in READY so fast path is available
|
||||
|
||||
const numGoroutines = 10
|
||||
var executionOrder []int
|
||||
var orderMu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var startBarrier sync.WaitGroup
|
||||
startBarrier.Add(numGoroutines)
|
||||
|
||||
// Launch goroutines that will all try the fast path
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for all goroutines to be ready
|
||||
startBarrier.Done()
|
||||
startBarrier.Wait()
|
||||
|
||||
// Small stagger to establish arrival order
|
||||
time.Sleep(time.Duration(id) * 100 * time.Microsecond)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// This might use fast path (CAS) or slow path (queue)
|
||||
_, err := sm.AwaitAndTransition(ctx, []ConnState{StateIdle}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: failed to transition: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Record execution order
|
||||
orderMu.Lock()
|
||||
executionOrder = append(executionOrder, id)
|
||||
orderMu.Unlock()
|
||||
|
||||
t.Logf("Goroutine %d: executed (position %d)", id, len(executionOrder))
|
||||
|
||||
// Simulate work
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Transition back to READY to allow next waiter
|
||||
sm.Transition(StateIdle)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
t.Logf("Execution order: %v", executionOrder)
|
||||
|
||||
// Check if FIFO was maintained
|
||||
// With the current fast-path implementation, this might NOT be FIFO
|
||||
fifoViolations := 0
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
if executionOrder[i] != i {
|
||||
fifoViolations++
|
||||
}
|
||||
}
|
||||
|
||||
if fifoViolations > 0 {
|
||||
t.Logf("WARNING: %d FIFO violations detected (fast path bypasses queue)", fifoViolations)
|
||||
t.Logf("This is expected with current implementation - fast path uses CAS race")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnStateMachine_TryTransition(b *testing.B) {
|
||||
sm := NewConnStateMachine()
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
sm.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestConnStateMachine_IdleInUseTransitions(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
|
||||
// Initialize to IDLE state
|
||||
sm.Transition(StateInitializing)
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
// Test IDLE → IN_USE transition
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IDLE to IN_USE: %v", err)
|
||||
}
|
||||
if state := sm.GetState(); state != StateInUse {
|
||||
t.Errorf("expected state IN_USE, got %s", state)
|
||||
}
|
||||
|
||||
// Test IN_USE → IDLE transition
|
||||
_, err = sm.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IN_USE to IDLE: %v", err)
|
||||
}
|
||||
if state := sm.GetState(); state != StateIdle {
|
||||
t.Errorf("expected state IDLE, got %s", state)
|
||||
}
|
||||
|
||||
// Test concurrent acquisition (only one should succeed)
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
var successCount atomic.Int32
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
if err == nil {
|
||||
successCount.Add(1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if count := successCount.Load(); count != 1 {
|
||||
t.Errorf("expected exactly 1 successful transition, got %d", count)
|
||||
}
|
||||
|
||||
if state := sm.GetState(); state != StateInUse {
|
||||
t.Errorf("expected final state IN_USE, got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_UsedMethods(t *testing.T) {
|
||||
cn := NewConn(nil)
|
||||
|
||||
// Initialize connection to IDLE state
|
||||
cn.stateMachine.Transition(StateInitializing)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
|
||||
// Test IsUsed - should be false when IDLE
|
||||
if cn.IsUsed() {
|
||||
t.Error("expected IsUsed to be false for IDLE connection")
|
||||
}
|
||||
|
||||
// Test CompareAndSwapUsed - acquire connection
|
||||
if !cn.CompareAndSwapUsed(false, true) {
|
||||
t.Error("failed to acquire connection with CompareAndSwapUsed")
|
||||
}
|
||||
|
||||
// Test IsUsed - should be true when IN_USE
|
||||
if !cn.IsUsed() {
|
||||
t.Error("expected IsUsed to be true for IN_USE connection")
|
||||
}
|
||||
|
||||
// Test CompareAndSwapUsed - release connection
|
||||
if !cn.CompareAndSwapUsed(true, false) {
|
||||
t.Error("failed to release connection with CompareAndSwapUsed")
|
||||
}
|
||||
|
||||
// Test IsUsed - should be false again
|
||||
if cn.IsUsed() {
|
||||
t.Error("expected IsUsed to be false after release")
|
||||
}
|
||||
|
||||
// Test SetUsed
|
||||
cn.SetUsed(true)
|
||||
if !cn.IsUsed() {
|
||||
t.Error("expected IsUsed to be true after SetUsed(true)")
|
||||
}
|
||||
|
||||
cn.SetUsed(false)
|
||||
if cn.IsUsed() {
|
||||
t.Error("expected IsUsed to be false after SetUsed(false)")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestConnStateMachine_UnusableState(t *testing.T) {
|
||||
sm := NewConnStateMachine()
|
||||
|
||||
// Initialize to IDLE state
|
||||
sm.Transition(StateInitializing)
|
||||
sm.Transition(StateIdle)
|
||||
|
||||
// Test IDLE → UNUSABLE transition (for background operations)
|
||||
_, err := sm.TryTransition([]ConnState{StateIdle}, StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IDLE to UNUSABLE: %v", err)
|
||||
}
|
||||
if state := sm.GetState(); state != StateUnusable {
|
||||
t.Errorf("expected state UNUSABLE, got %s", state)
|
||||
}
|
||||
|
||||
// Test UNUSABLE → IDLE transition (after background operation completes)
|
||||
_, err = sm.TryTransition([]ConnState{StateUnusable}, StateIdle)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from UNUSABLE to IDLE: %v", err)
|
||||
}
|
||||
if state := sm.GetState(); state != StateIdle {
|
||||
t.Errorf("expected state IDLE, got %s", state)
|
||||
}
|
||||
|
||||
// Test that we can transition from IN_USE to UNUSABLE if needed
|
||||
// (e.g., for urgent handoff while connection is in use)
|
||||
sm.Transition(StateInUse)
|
||||
_, err = sm.TryTransition([]ConnState{StateInUse}, StateUnusable)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from IN_USE to UNUSABLE: %v", err)
|
||||
}
|
||||
if state := sm.GetState(); state != StateUnusable {
|
||||
t.Errorf("expected state UNUSABLE, got %s", state)
|
||||
}
|
||||
|
||||
// Test UNUSABLE → INITIALIZING transition (for handoff)
|
||||
sm.Transition(StateIdle)
|
||||
sm.Transition(StateUnusable)
|
||||
_, err = sm.TryTransition([]ConnState{StateUnusable}, StateInitializing)
|
||||
if err != nil {
|
||||
t.Errorf("failed to transition from UNUSABLE to INITIALIZING: %v", err)
|
||||
}
|
||||
if state := sm.GetState(); state != StateInitializing {
|
||||
t.Errorf("expected state INITIALIZING, got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_UsableUnusable(t *testing.T) {
|
||||
cn := NewConn(nil)
|
||||
|
||||
// Initialize connection to IDLE state
|
||||
cn.stateMachine.Transition(StateInitializing)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
|
||||
// Test IsUsable - should be true when IDLE
|
||||
if !cn.IsUsable() {
|
||||
t.Error("expected IsUsable to be true for IDLE connection")
|
||||
}
|
||||
|
||||
// Test CompareAndSwapUsable - make unusable for background operation
|
||||
if !cn.CompareAndSwapUsable(true, false) {
|
||||
t.Error("failed to make connection unusable with CompareAndSwapUsable")
|
||||
}
|
||||
|
||||
// Verify state is UNUSABLE
|
||||
if state := cn.stateMachine.GetState(); state != StateUnusable {
|
||||
t.Errorf("expected state UNUSABLE, got %s", state)
|
||||
}
|
||||
|
||||
// Test IsUsable - should be false when UNUSABLE
|
||||
if cn.IsUsable() {
|
||||
t.Error("expected IsUsable to be false for UNUSABLE connection")
|
||||
}
|
||||
|
||||
// Test CompareAndSwapUsable - make usable again
|
||||
if !cn.CompareAndSwapUsable(false, true) {
|
||||
t.Error("failed to make connection usable with CompareAndSwapUsable")
|
||||
}
|
||||
|
||||
// Verify state is IDLE
|
||||
if state := cn.stateMachine.GetState(); state != StateIdle {
|
||||
t.Errorf("expected state IDLE, got %s", state)
|
||||
}
|
||||
|
||||
// Test SetUsable(false)
|
||||
cn.SetUsable(false)
|
||||
if state := cn.stateMachine.GetState(); state != StateUnusable {
|
||||
t.Errorf("expected state UNUSABLE after SetUsable(false), got %s", state)
|
||||
}
|
||||
|
||||
// Test SetUsable(true)
|
||||
cn.SetUsable(true)
|
||||
if state := cn.stateMachine.GetState(); state != StateIdle {
|
||||
t.Errorf("expected state IDLE after SetUsable(true), got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
259
internal/pool/conn_used_at_test.go
Normal file
259
internal/pool/conn_used_at_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
)
|
||||
|
||||
// TestConn_UsedAtUpdatedOnRead verifies that usedAt is updated when reading from connection
|
||||
func TestConn_UsedAtUpdatedOnRead(t *testing.T) {
|
||||
// Create a mock connection
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
cn := NewConn(client)
|
||||
defer cn.Close()
|
||||
|
||||
// Get initial usedAt time
|
||||
initialUsedAt := cn.UsedAt()
|
||||
|
||||
// Wait 100ms to ensure time difference (usedAt has ~50ms precision from cached time)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Simulate a read operation by calling WithReader
|
||||
ctx := context.Background()
|
||||
err := cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error {
|
||||
// Don't actually read anything, just trigger the deadline update
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("WithReader failed: %v", err)
|
||||
}
|
||||
|
||||
// Get updated usedAt time
|
||||
updatedUsedAt := cn.UsedAt()
|
||||
|
||||
// Verify that usedAt was updated
|
||||
if !updatedUsedAt.After(initialUsedAt) {
|
||||
t.Errorf("Expected usedAt to be updated after read. Initial: %v, Updated: %v",
|
||||
initialUsedAt, updatedUsedAt)
|
||||
}
|
||||
|
||||
// Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision and ~5ms sleep precision)
|
||||
diff := updatedUsedAt.Sub(initialUsedAt)
|
||||
if diff < 45*time.Millisecond || diff > 155*time.Millisecond {
|
||||
t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache, ±5ms for sleep), got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConn_UsedAtUpdatedOnWrite verifies that usedAt is updated when writing to connection
|
||||
func TestConn_UsedAtUpdatedOnWrite(t *testing.T) {
|
||||
// Create a mock connection
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
cn := NewConn(client)
|
||||
defer cn.Close()
|
||||
|
||||
// Get initial usedAt time
|
||||
initialUsedAt := cn.UsedAt()
|
||||
|
||||
// Wait at least 100ms to ensure time difference (usedAt has ~50ms precision from cached time)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Simulate a write operation by calling WithWriter
|
||||
ctx := context.Background()
|
||||
err := cn.WithWriter(ctx, time.Second, func(wr *proto.Writer) error {
|
||||
// Don't actually write anything, just trigger the deadline update
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("WithWriter failed: %v", err)
|
||||
}
|
||||
|
||||
// Get updated usedAt time
|
||||
updatedUsedAt := cn.UsedAt()
|
||||
|
||||
// Verify that usedAt was updated
|
||||
if !updatedUsedAt.After(initialUsedAt) {
|
||||
t.Errorf("Expected usedAt to be updated after write. Initial: %v, Updated: %v",
|
||||
initialUsedAt, updatedUsedAt)
|
||||
}
|
||||
|
||||
// Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision)
|
||||
diff := updatedUsedAt.Sub(initialUsedAt)
|
||||
|
||||
// 50 ms is the cache precision, so we allow up to 110ms difference
|
||||
if diff < 45*time.Millisecond || diff > 155*time.Millisecond {
|
||||
t.Errorf("Expected usedAt difference to be around 100 (±50ms for cache) (+-5ms for sleep precision), got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConn_UsedAtUpdatedOnMultipleOperations verifies that usedAt is updated on each operation
|
||||
func TestConn_UsedAtUpdatedOnMultipleOperations(t *testing.T) {
|
||||
// Create a mock connection
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
cn := NewConn(client)
|
||||
defer cn.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var previousUsedAt time.Time
|
||||
|
||||
// Perform multiple operations and verify usedAt is updated each time
|
||||
// Note: usedAt has ~50ms precision from cached time
|
||||
for i := 0; i < 5; i++ {
|
||||
currentUsedAt := cn.UsedAt()
|
||||
|
||||
if i > 0 {
|
||||
// Verify usedAt was updated from previous iteration
|
||||
if !currentUsedAt.After(previousUsedAt) {
|
||||
t.Errorf("Iteration %d: Expected usedAt to be updated. Previous: %v, Current: %v",
|
||||
i, previousUsedAt, currentUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
previousUsedAt = currentUsedAt
|
||||
|
||||
// Wait at least 100ms (accounting for ~50ms cache precision)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Perform a read operation
|
||||
err := cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Iteration %d: WithReader failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final usedAt is significantly later than initial
|
||||
finalUsedAt := cn.UsedAt()
|
||||
if !finalUsedAt.After(previousUsedAt) {
|
||||
t.Errorf("Expected final usedAt to be updated. Previous: %v, Final: %v",
|
||||
previousUsedAt, finalUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConn_UsedAtNotUpdatedWithoutOperation verifies that usedAt is NOT updated without operations
|
||||
func TestConn_UsedAtNotUpdatedWithoutOperation(t *testing.T) {
|
||||
// Create a mock connection
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
cn := NewConn(client)
|
||||
defer cn.Close()
|
||||
|
||||
// Get initial usedAt time
|
||||
initialUsedAt := cn.UsedAt()
|
||||
|
||||
// Wait without performing any operations
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get usedAt time again
|
||||
currentUsedAt := cn.UsedAt()
|
||||
|
||||
// Verify that usedAt was NOT updated (should be the same)
|
||||
if !currentUsedAt.Equal(initialUsedAt) {
|
||||
t.Errorf("Expected usedAt to remain unchanged without operations. Initial: %v, Current: %v",
|
||||
initialUsedAt, currentUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConn_UsedAtConcurrentUpdates verifies that usedAt updates are thread-safe
|
||||
func TestConn_UsedAtConcurrentUpdates(t *testing.T) {
|
||||
// Create a mock connection
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
cn := NewConn(client)
|
||||
defer cn.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
const numGoroutines = 10
|
||||
const numIterations = 10
|
||||
|
||||
// Launch multiple goroutines that perform operations concurrently
|
||||
done := make(chan bool, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
for j := 0; j < numIterations; j++ {
|
||||
// Alternate between read and write operations
|
||||
if j%2 == 0 {
|
||||
_ = cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error {
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
_ = cn.WithWriter(ctx, time.Second, func(wr *proto.Writer) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify that usedAt was updated (should be recent)
|
||||
usedAt := cn.UsedAt()
|
||||
timeSinceUsed := time.Since(usedAt)
|
||||
|
||||
// Should be very recent (within last second)
|
||||
if timeSinceUsed > time.Second {
|
||||
t.Errorf("Expected usedAt to be recent, but it was %v ago", timeSinceUsed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConn_UsedAtPrecision verifies that usedAt has 50ms precision (not nanosecond)
|
||||
func TestConn_UsedAtPrecision(t *testing.T) {
|
||||
// Create a mock connection
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
cn := NewConn(client)
|
||||
defer cn.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Perform an operation
|
||||
err := cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("WithReader failed: %v", err)
|
||||
}
|
||||
|
||||
// Get usedAt time
|
||||
usedAt := cn.UsedAt()
|
||||
|
||||
// Verify that usedAt has nanosecond precision (from the cached time which updates every 50ms)
|
||||
// The value should be reasonable (not year 1970 or something)
|
||||
if usedAt.Year() < 2020 {
|
||||
t.Errorf("Expected usedAt to be a recent time, got %v", usedAt)
|
||||
}
|
||||
|
||||
// The nanoseconds might be non-zero depending on when the cache was updated
|
||||
// We just verify the time is stored with full precision (not truncated to seconds)
|
||||
initialNanos := usedAt.UnixNano()
|
||||
if initialNanos == 0 {
|
||||
t.Error("Expected usedAt to have nanosecond precision, got 0")
|
||||
}
|
||||
}
|
||||
@@ -20,5 +20,5 @@ func (p *ConnPool) CheckMinIdleConns() {
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueueLen() int {
|
||||
return len(p.queue)
|
||||
return int(p.semaphore.Len())
|
||||
}
|
||||
|
||||
74
internal/pool/global_time_cache.go
Normal file
74
internal/pool/global_time_cache.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Global time cache updated every 50ms by background goroutine.
|
||||
// This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout.
|
||||
// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds).
|
||||
var globalTimeCache struct {
|
||||
nowNs atomic.Int64
|
||||
lock sync.Mutex
|
||||
started bool
|
||||
stop chan struct{}
|
||||
subscribers int32
|
||||
}
|
||||
|
||||
func subscribeToGlobalTimeCache() {
|
||||
globalTimeCache.lock.Lock()
|
||||
globalTimeCache.subscribers += 1
|
||||
globalTimeCache.lock.Unlock()
|
||||
}
|
||||
|
||||
func unsubscribeFromGlobalTimeCache() {
|
||||
globalTimeCache.lock.Lock()
|
||||
globalTimeCache.subscribers -= 1
|
||||
globalTimeCache.lock.Unlock()
|
||||
}
|
||||
|
||||
func startGlobalTimeCache() {
|
||||
globalTimeCache.lock.Lock()
|
||||
if globalTimeCache.started {
|
||||
globalTimeCache.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
globalTimeCache.started = true
|
||||
globalTimeCache.nowNs.Store(time.Now().UnixNano())
|
||||
globalTimeCache.stop = make(chan struct{})
|
||||
globalTimeCache.lock.Unlock()
|
||||
// Start background updater
|
||||
go func(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
globalTimeCache.nowNs.Store(time.Now().UnixNano())
|
||||
}
|
||||
}(globalTimeCache.stop)
|
||||
}
|
||||
|
||||
// stopGlobalTimeCache stops the global time cache if there are no subscribers.
|
||||
// This should only be called when the last subscriber is removed.
|
||||
func stopGlobalTimeCache() {
|
||||
globalTimeCache.lock.Lock()
|
||||
if !globalTimeCache.started || globalTimeCache.subscribers > 0 {
|
||||
globalTimeCache.lock.Unlock()
|
||||
return
|
||||
}
|
||||
globalTimeCache.started = false
|
||||
close(globalTimeCache.stop)
|
||||
globalTimeCache.lock.Unlock()
|
||||
}
|
||||
|
||||
func init() {
|
||||
startGlobalTimeCache()
|
||||
}
|
||||
@@ -71,10 +71,13 @@ func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -90,12 +93,15 @@ func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewC
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
shouldPool = true // Default to pooling the connection
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||
|
||||
if hookErr != nil {
|
||||
@@ -117,9 +123,13 @@ func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shoul
|
||||
|
||||
// ProcessOnRemove calls all OnRemove hooks in order.
|
||||
func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
for _, hook := range phm.hooks {
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
hook.OnRemove(ctx, conn, reason)
|
||||
}
|
||||
}
|
||||
@@ -140,3 +150,16 @@ func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
||||
|
||||
// Clone creates a copy of the hook manager with the same hooks.
|
||||
// This is used for lock-free atomic updates of the hook manager.
|
||||
func (phm *PoolHookManager) Clone() *PoolHookManager {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
newManager := &PoolHookManager{
|
||||
hooks: make([]PoolHook, len(phm.hooks)),
|
||||
}
|
||||
copy(newManager.hooks, phm.hooks)
|
||||
return newManager
|
||||
}
|
||||
|
||||
@@ -203,26 +203,29 @@ func TestPoolWithHooks(t *testing.T) {
|
||||
pool.AddPoolHook(testHook)
|
||||
|
||||
// Verify hooks are initialized
|
||||
if pool.hookManager == nil {
|
||||
manager := pool.hookManager.Load()
|
||||
if manager == nil {
|
||||
t.Error("Expected hookManager to be initialized")
|
||||
}
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test adding hook to pool
|
||||
additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
pool.AddPoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
|
||||
manager = pool.hookManager.Load()
|
||||
if manager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test removing hook from pool
|
||||
pool.RemovePoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
|
||||
manager = pool.hookManager.Load()
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,12 @@ var (
|
||||
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
|
||||
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
|
||||
|
||||
// errHookRequestedRemoval is returned when a hook requests connection removal.
|
||||
errHookRequestedRemoval = errors.New("hook requested removal")
|
||||
|
||||
// errConnNotPooled is returned when trying to return a non-pooled connection to the pool.
|
||||
errConnNotPooled = errors.New("connection not pooled")
|
||||
|
||||
// popAttempts is the maximum number of attempts to find a usable connection
|
||||
// when popping from the idle connection pool. This handles cases where connections
|
||||
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
||||
@@ -45,14 +51,6 @@ var (
|
||||
noExpiration = maxTime
|
||||
)
|
||||
|
||||
var timers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
t := time.NewTimer(time.Hour)
|
||||
t.Stop()
|
||||
return t
|
||||
},
|
||||
}
|
||||
|
||||
// Stats contains pool state information and accumulated stats.
|
||||
type Stats struct {
|
||||
Hits uint32 // number of times free connection was found in the pool
|
||||
@@ -88,6 +86,12 @@ type Pooler interface {
|
||||
AddPoolHook(hook PoolHook)
|
||||
RemovePoolHook(hook PoolHook)
|
||||
|
||||
// RemoveWithoutTurn removes a connection from the pool without freeing a turn.
|
||||
// This should be used when removing a connection from a context that didn't acquire
|
||||
// a turn via Get() (e.g., background workers, cleanup tasks).
|
||||
// For normal removal after Get(), use Remove() instead.
|
||||
RemoveWithoutTurn(context.Context, *Conn, error)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -130,6 +134,9 @@ type ConnPool struct {
|
||||
queue chan struct{}
|
||||
dialsInProgress chan struct{}
|
||||
dialsQueue *wantConnQueue
|
||||
// Fast semaphore for connection limiting with eventual fairness
|
||||
// Uses fast path optimization to avoid timer allocation when tokens are available
|
||||
semaphore *internal.FastSemaphore
|
||||
|
||||
connsMu sync.Mutex
|
||||
conns map[uint64]*Conn
|
||||
@@ -145,16 +152,16 @@ type ConnPool struct {
|
||||
_closed uint32 // atomic
|
||||
|
||||
// Pool hooks manager for flexible connection processing
|
||||
hookManagerMu sync.RWMutex
|
||||
hookManager *PoolHookManager
|
||||
// Using atomic.Pointer for lock-free reads in hot paths (Get/Put)
|
||||
hookManager atomic.Pointer[PoolHookManager]
|
||||
}
|
||||
|
||||
var _ Pooler = (*ConnPool)(nil)
|
||||
|
||||
func NewConnPool(opt *Options) *ConnPool {
|
||||
p := &ConnPool{
|
||||
cfg: opt,
|
||||
|
||||
cfg: opt,
|
||||
semaphore: internal.NewFastSemaphore(opt.PoolSize),
|
||||
queue: make(chan struct{}, opt.PoolSize),
|
||||
conns: make(map[uint64]*Conn),
|
||||
dialsInProgress: make(chan struct{}, opt.MaxConcurrentDials),
|
||||
@@ -170,32 +177,45 @@ func NewConnPool(opt *Options) *ConnPool {
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
|
||||
startGlobalTimeCache()
|
||||
subscribeToGlobalTimeCache()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// initializeHooks sets up the pool hooks system.
|
||||
func (p *ConnPool) initializeHooks() {
|
||||
p.hookManager = NewPoolHookManager()
|
||||
manager := NewPoolHookManager()
|
||||
p.hookManager.Store(manager)
|
||||
}
|
||||
|
||||
// AddPoolHook adds a pool hook to the pool.
|
||||
func (p *ConnPool) AddPoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager == nil {
|
||||
// Lock-free read of current manager
|
||||
manager := p.hookManager.Load()
|
||||
if manager == nil {
|
||||
p.initializeHooks()
|
||||
manager = p.hookManager.Load()
|
||||
}
|
||||
p.hookManager.AddHook(hook)
|
||||
|
||||
// Create new manager with added hook
|
||||
newManager := manager.Clone()
|
||||
newManager.AddHook(hook)
|
||||
|
||||
// Atomically swap to new manager
|
||||
p.hookManager.Store(newManager)
|
||||
}
|
||||
|
||||
// RemovePoolHook removes a pool hook from the pool.
|
||||
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
manager := p.hookManager.Load()
|
||||
if manager != nil {
|
||||
// Create new manager with removed hook
|
||||
newManager := manager.Clone()
|
||||
newManager.RemoveHook(hook)
|
||||
|
||||
if p.hookManager != nil {
|
||||
p.hookManager.RemoveHook(hook)
|
||||
// Atomically swap to new manager
|
||||
p.hookManager.Store(newManager)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,33 +232,33 @@ func (p *ConnPool) checkMinIdleConns() {
|
||||
// Only create idle connections if we haven't reached the total pool size limit
|
||||
// MinIdleConns should be a subset of PoolSize, not additional connections
|
||||
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
p.freeTurn()
|
||||
}()
|
||||
default:
|
||||
// Try to acquire a semaphore token
|
||||
if !p.semaphore.TryAcquire() {
|
||||
// Semaphore is full, can't create more connections
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
p.freeTurn()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) addIdleConn() error {
|
||||
@@ -250,9 +270,9 @@ func (p *ConnPool) addIdleConn() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
// NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
|
||||
// when first acquired from the pool. Do NOT transition to IDLE here - that happens
|
||||
// after initialization completes.
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
@@ -281,7 +301,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= p.cfg.MaxActiveConns {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
@@ -292,11 +312,11 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
// NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn()
|
||||
// when first used. Do NOT transition to IDLE here - that happens after initialization completes.
|
||||
// The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success)
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > p.cfg.MaxActiveConns {
|
||||
_ = cn.Close()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
@@ -352,7 +372,8 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
// when the timeout is reached, we should stop retrying
|
||||
// but keep the lastErr to return to the caller
|
||||
// instead of a generic context deadline exceeded error
|
||||
for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ {
|
||||
attempt := 0
|
||||
for attempt = 0; (attempt < maxRetries) && shouldLoop; attempt++ {
|
||||
netConn, err := p.cfg.Dialer(ctx)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -379,7 +400,7 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr)
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", attempt, lastErr)
|
||||
// All retries failed - handle error tracking
|
||||
p.setLastDialError(lastErr)
|
||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
|
||||
@@ -441,21 +462,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
attempts := 0
|
||||
// Use cached time for health checks (max 50ms staleness is acceptable)
|
||||
nowNs := getCachedTimeNs()
|
||||
|
||||
// Get hooks manager once for this getConn call for performance.
|
||||
// Note: Hooks added/removed during this call won't be reflected.
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
// Lock-free atomic read - no mutex overhead!
|
||||
hookManager := p.hookManager.Load()
|
||||
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
||||
break
|
||||
}
|
||||
attempts++
|
||||
for attempts := 0; attempts < getAttempts; attempts++ {
|
||||
|
||||
p.connsMu.Lock()
|
||||
cn, err = p.popIdle()
|
||||
@@ -470,23 +483,26 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
break
|
||||
}
|
||||
|
||||
if !p.isHealthyConn(cn, now) {
|
||||
if !p.isHealthyConn(cn, nowNs) {
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
// Combine error and rejection checks to reduce branches
|
||||
if hookManager != nil {
|
||||
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
if !acceptConn {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||
p.Put(ctx, cn)
|
||||
cn = nil
|
||||
if err != nil || !acceptConn {
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
_ = p.CloseConn(cn)
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||
// Return connection to pool without freeing the turn that this Get() call holds.
|
||||
// We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn.
|
||||
p.putConnWithoutTurn(ctx, cn)
|
||||
cn = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -595,8 +611,6 @@ func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
|
||||
@@ -611,44 +625,36 @@ func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) {
|
||||
}
|
||||
|
||||
func (p *ConnPool) waitTurn(ctx context.Context) error {
|
||||
// Fast path: check context first
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
// Fast path: try to acquire without blocking
|
||||
if p.semaphore.TryAcquire() {
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Slow path: need to wait
|
||||
start := time.Now()
|
||||
timer := timers.Get().(*time.Timer)
|
||||
defer timers.Put(timer)
|
||||
timer.Reset(p.cfg.PoolTimeout)
|
||||
err := p.semaphore.Acquire(ctx, p.cfg.PoolTimeout, ErrPoolTimeout)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
switch err {
|
||||
case nil:
|
||||
// Successfully acquired after waiting
|
||||
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
|
||||
atomic.AddUint32(&p.stats.WaitCount, 1)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
case <-timer.C:
|
||||
case ErrPoolTimeout:
|
||||
atomic.AddUint32(&p.stats.Timeouts, 1)
|
||||
return ErrPoolTimeout
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ConnPool) freeTurn() {
|
||||
<-p.queue
|
||||
p.semaphore.Release()
|
||||
}
|
||||
|
||||
func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
@@ -682,15 +688,18 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
}
|
||||
attempts++
|
||||
|
||||
if cn.CompareAndSwapUsed(false, true) {
|
||||
if cn.IsUsable() {
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
cn.SetUsed(false)
|
||||
// Hot path optimization: try IDLE → IN_USE or CREATED → IN_USE transition
|
||||
// Using inline TryAcquire() method for better performance (avoids pointer dereference)
|
||||
if cn.TryAcquire() {
|
||||
// Successfully acquired the connection
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
|
||||
// Connection is not usable, put it back in the pool
|
||||
// Connection is in UNUSABLE, INITIALIZING, or other state - skip it
|
||||
|
||||
// Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.)
|
||||
// Put it back in the pool and try the next one
|
||||
if p.cfg.PoolFIFO {
|
||||
// FIFO: put at end (will be picked up last since we pop from front)
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
@@ -711,6 +720,18 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
}
|
||||
|
||||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
p.putConn(ctx, cn, true)
|
||||
}
|
||||
|
||||
// putConnWithoutTurn is an internal method that puts a connection back to the pool
|
||||
// without freeing a turn. This is used when returning a rejected connection from
|
||||
// within Get(), where the turn is still held by the Get() call.
|
||||
func (p *ConnPool) putConnWithoutTurn(ctx context.Context, cn *Conn) {
|
||||
p.putConn(ctx, cn, false)
|
||||
}
|
||||
|
||||
// putConn is the internal implementation of Put that optionally frees a turn.
|
||||
func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) {
|
||||
// Process connection using the hooks system
|
||||
shouldPool := true
|
||||
shouldRemove := false
|
||||
@@ -721,47 +742,64 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
|
||||
// Not a push notification or error peeking, remove connection
|
||||
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
|
||||
p.Remove(ctx, cn, err)
|
||||
p.removeConnInternal(ctx, cn, err, freeTurn)
|
||||
return
|
||||
}
|
||||
// It's a push notification, allow pooling (client will handle it)
|
||||
}
|
||||
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
// Lock-free atomic read - no mutex overhead!
|
||||
hookManager := p.hookManager.Load()
|
||||
|
||||
if hookManager != nil {
|
||||
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
|
||||
p.Remove(ctx, cn, err)
|
||||
p.removeConnInternal(ctx, cn, err, freeTurn)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If hooks say to remove the connection, do so
|
||||
if shouldRemove {
|
||||
p.Remove(ctx, cn, errors.New("hook requested removal"))
|
||||
return
|
||||
}
|
||||
|
||||
// If processor says not to pool the connection, remove it
|
||||
if !shouldPool {
|
||||
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
|
||||
// Combine all removal checks into one - reduces branches
|
||||
if shouldRemove || !shouldPool {
|
||||
p.removeConnInternal(ctx, cn, errHookRequestedRemoval, freeTurn)
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.pooled {
|
||||
p.Remove(ctx, cn, errors.New("connection not pooled"))
|
||||
p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn)
|
||||
return
|
||||
}
|
||||
|
||||
var shouldCloseConn bool
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||
// Hot path optimization: try fast IN_USE → IDLE transition
|
||||
// Using inline Release() method for better performance (avoids pointer dereference)
|
||||
transitionedToIdle := cn.Release()
|
||||
|
||||
// Handle unexpected state changes
|
||||
if !transitionedToIdle {
|
||||
// Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff)
|
||||
// Keep the state set by the hook and pool the connection anyway
|
||||
currentState := cn.GetStateMachine().GetState()
|
||||
switch currentState {
|
||||
case StateUnusable:
|
||||
// expected state, don't log it
|
||||
case StateClosed:
|
||||
internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, closing it", cn.GetID(), currentState)
|
||||
shouldCloseConn = true
|
||||
p.removeConnWithLock(cn)
|
||||
default:
|
||||
// Pool as-is
|
||||
internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, pooling as-is", cn.GetID(), currentState)
|
||||
}
|
||||
}
|
||||
|
||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||
// put them at the opposite end of the queue
|
||||
if !cn.IsUsable() {
|
||||
// Optimization: if we just transitioned to IDLE, we know it's usable - skip the check
|
||||
if !transitionedToIdle && !cn.IsUsable() {
|
||||
if p.cfg.PoolFIFO {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
@@ -771,33 +809,45 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
} else {
|
||||
p.idleConnsLen.Add(1)
|
||||
} else if !shouldCloseConn {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.connsMu.Unlock()
|
||||
p.idleConnsLen.Add(1)
|
||||
}
|
||||
p.idleConnsLen.Add(1)
|
||||
} else {
|
||||
p.removeConnWithLock(cn)
|
||||
shouldCloseConn = true
|
||||
p.removeConnWithLock(cn)
|
||||
}
|
||||
|
||||
// if the connection is not going to be closed, mark it as not used
|
||||
if !shouldCloseConn {
|
||||
cn.SetUsed(false)
|
||||
if freeTurn {
|
||||
p.freeTurn()
|
||||
}
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
if shouldCloseConn {
|
||||
_ = p.closeConn(cn)
|
||||
}
|
||||
|
||||
cn.SetLastPutAtNs(getCachedTimeNs())
|
||||
}
|
||||
|
||||
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
p.removeConnInternal(ctx, cn, reason, true)
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn removes a connection from the pool without freeing a turn.
|
||||
// This should be used when removing a connection from a context that didn't acquire
|
||||
// a turn via Get() (e.g., background workers, cleanup tasks).
|
||||
// For normal removal after Get(), use Remove() instead.
|
||||
func (p *ConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.removeConnInternal(ctx, cn, reason, false)
|
||||
}
|
||||
|
||||
// removeConnInternal is the internal implementation of Remove that optionally frees a turn.
|
||||
func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) {
|
||||
// Lock-free atomic read - no mutex overhead!
|
||||
hookManager := p.hookManager.Load()
|
||||
|
||||
if hookManager != nil {
|
||||
hookManager.ProcessOnRemove(ctx, cn, reason)
|
||||
@@ -805,7 +855,9 @@ func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
|
||||
p.removeConnWithLock(cn)
|
||||
|
||||
p.freeTurn()
|
||||
if freeTurn {
|
||||
p.freeTurn()
|
||||
}
|
||||
|
||||
_ = p.closeConn(cn)
|
||||
|
||||
@@ -834,8 +886,7 @@ func (p *ConnPool) removeConn(cn *Conn) {
|
||||
p.poolSize.Add(-1)
|
||||
// this can be idle conn
|
||||
for idx, ic := range p.idleConns {
|
||||
if ic.GetID() == cid {
|
||||
internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid)
|
||||
if ic == cn {
|
||||
p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...)
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
@@ -911,6 +962,9 @@ func (p *ConnPool) Close() error {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
unsubscribeFromGlobalTimeCache()
|
||||
stopGlobalTimeCache()
|
||||
|
||||
var firstErr error
|
||||
p.connsMu.Lock()
|
||||
for _, cn := range p.conns {
|
||||
@@ -927,37 +981,54 @@ func (p *ConnPool) Close() error {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
|
||||
// slight optimization, check expiresAt first.
|
||||
if cn.expiresAt.Before(now) {
|
||||
return false
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn, nowNs int64) bool {
|
||||
// Performance optimization: check conditions from cheapest to most expensive,
|
||||
// and from most likely to fail to least likely to fail.
|
||||
|
||||
// Only fails if ConnMaxLifetime is set AND connection is old.
|
||||
// Most pools don't set ConnMaxLifetime, so this rarely fails.
|
||||
if p.cfg.ConnMaxLifetime > 0 {
|
||||
if cn.expiresAt.UnixNano() < nowNs {
|
||||
return false // Connection has exceeded max lifetime
|
||||
}
|
||||
}
|
||||
|
||||
// Check if connection has exceeded idle timeout
|
||||
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||
return false
|
||||
// Most pools set ConnMaxIdleTime, and idle connections are common.
|
||||
// Checking this first allows us to fail fast without expensive syscalls.
|
||||
if p.cfg.ConnMaxIdleTime > 0 {
|
||||
if nowNs-cn.UsedAtNs() >= int64(p.cfg.ConnMaxIdleTime) {
|
||||
return false // Connection has been idle too long
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsedAt(now)
|
||||
// Check basic connection health
|
||||
// Use GetNetConn() to safely access netConn and avoid data races
|
||||
// Only run this if the cheap checks passed.
|
||||
if err := connCheck(cn.getNetConn()); err != nil {
|
||||
// If there's unexpected data, it might be push notifications (RESP3)
|
||||
// However, push notification processing is now handled by the client
|
||||
// before WithReader to ensure proper context is available to handlers
|
||||
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block
|
||||
// Peek at the reply type to check if it's a push notification
|
||||
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
||||
// For RESP3 connections with push notifications, we allow some buffered data
|
||||
// The client will process these notifications before using the connection
|
||||
internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID())
|
||||
return true // Connection is healthy, client will handle notifications
|
||||
internal.Logger.Printf(
|
||||
context.Background(),
|
||||
"push: conn[%d] has buffered data, likely push notifications - will be processed by client",
|
||||
cn.GetID(),
|
||||
)
|
||||
|
||||
// Update timestamp for healthy connection
|
||||
cn.SetUsedAtNs(nowNs)
|
||||
|
||||
// Connection is healthy, client will handle notifications
|
||||
return true
|
||||
}
|
||||
return false // Unexpected data, not push notifications, connection is unhealthy
|
||||
} else {
|
||||
// Not a push notification - treat as unhealthy
|
||||
return false
|
||||
}
|
||||
// Connection failed health check
|
||||
return false
|
||||
}
|
||||
|
||||
// Only update UsedAt if connection is healthy (avoids unnecessary atomic store)
|
||||
cn.SetUsedAtNs(nowNs)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -44,6 +44,13 @@ func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
|
||||
if p.cn == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
|
||||
// - During initialization (connection is in INITIALIZING state)
|
||||
// - During re-authentication (connection is in UNUSABLE state)
|
||||
// - For transactions (connection might be in various states)
|
||||
// We use SetUsed() which forces the transition, rather than TryTransition() which
|
||||
// would fail if the connection is not in IDLE/CREATED state.
|
||||
p.cn.SetUsed(true)
|
||||
p.cn.SetUsedAt(time.Now())
|
||||
return p.cn, nil
|
||||
@@ -65,6 +72,12 @@ func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
p.stickyErr = reason
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn has the same behavior as Remove for SingleConnPool
|
||||
// since SingleConnPool doesn't use a turn-based queue system.
|
||||
func (p *SingleConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.Remove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Close() error {
|
||||
p.cn = nil
|
||||
p.stickyErr = ErrClosed
|
||||
|
||||
@@ -123,6 +123,12 @@ func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
p.ch <- cn
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn has the same behavior as Remove for StickyConnPool
|
||||
// since StickyConnPool doesn't use a turn-based queue system.
|
||||
func (p *StickyConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.Remove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Close() error {
|
||||
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
|
||||
return nil
|
||||
|
||||
@@ -24,7 +24,7 @@ type PubSubPool struct {
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
// PubSubPool implements a pool for PubSub connections.
|
||||
// NewPubSubPool implements a pool for PubSub connections.
|
||||
// It intentionally does not implement the Pooler interface
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
|
||||
Reference in New Issue
Block a user