mirror of
https://github.com/redis/go-redis.git
synced 2025-12-02 06:22:31 +03:00
fix(pool): pool performance (#3565)
* perf(pool): replace hookManager RWMutex with atomic.Pointer and add predefined state slices
- Replace hookManager RWMutex with atomic.Pointer for lock-free reads in hot paths
- Add predefined state slices to avoid allocations (validFromInUse, validFromCreatedOrIdle, etc.)
- Add Clone() method to PoolHookManager for atomic updates
- Update AddPoolHook/RemovePoolHook to use copy-on-write pattern
- Update all hookManager access points to use atomic Load()
Performance improvements:
- Eliminates RWMutex contention in Get/Put/Remove hot paths
- Reduces allocations by reusing predefined state slices
- Lock-free reads allow better CPU cache utilization
* perf(pool): eliminate mutex overhead in state machine hot path
The state machine was calling notifyWaiters() on EVERY Get/Put operation,
which acquired a mutex even when no waiters were present (the common case).
Fix: Use atomic waiterCount to check for waiters BEFORE acquiring mutex.
This eliminates mutex contention in the hot path (Get/Put operations).
Implementation:
- Added atomic.Int32 waiterCount field to ConnStateMachine
- Increment when adding waiter, decrement when removing
- Check waiterCount atomically before acquiring mutex in notifyWaiters()
Performance impact:
- Before: mutex lock/unlock on every Get/Put (even with no waiters)
- After: lock-free atomic check, only acquire mutex if waiters exist
- Expected improvement: ~30-50% for Get/Put operations
* perf(pool): use predefined state slices to eliminate allocations in hot path
The pool was creating new slice literals on EVERY Get/Put operation:
- popIdle(): []ConnState{StateCreated, StateIdle}
- putConn(): []ConnState{StateInUse}
- CompareAndSwapUsed(): []ConnState{StateIdle} and []ConnState{StateInUse}
- MarkUnusableForHandoff(): []ConnState{StateInUse, StateIdle, StateCreated}
These allocations were happening millions of times per second in the hot path.
Fix: Use predefined global slices defined in conn_state.go:
- validFromInUse
- validFromCreatedOrIdle
- validFromCreatedInUseOrIdle
Performance impact:
- Before: 4 slice allocations per Get/Put cycle
- After: 0 allocations (use predefined slices)
- Expected improvement: ~30-40% reduction in allocations and GC pressure
* perf(pool): optimize TryTransition to reduce atomic operations
Further optimize the hot path by:
1. Remove redundant GetState() call in the loop
2. Only check waiterCount after successful CAS (not before loop)
3. Inline the waiterCount check to avoid notifyWaiters() call overhead
This reduces atomic operations from 4-5 per Get/Put to 2-3:
- Before: GetState() + CAS + waiterCount.Load() + notifyWaiters mutex check
- After: CAS + waiterCount.Load() (only if CAS succeeds)
Performance impact:
- Eliminates 1-2 atomic operations per Get/Put
- Expected improvement: ~10-15% for Get/Put operations
* perf(pool): add fast path for Get/Put to match master performance
Introduced TryTransitionFast() for the hot path (Get/Put operations):
- Single CAS operation (same as master's atomic bool)
- No waiter notification overhead
- No loop through valid states
- No error allocation
Hot path flow:
1. popIdle(): Try IDLE → IN_USE (fast), fallback to CREATED → IN_USE
2. putConn(): Try IN_USE → IDLE (fast)
This matches master's performance while preserving state machine for:
- Background operations (handoff/reauth use UNUSABLE state)
- State validation (TryTransition still available)
- Waiter notification (AwaitAndTransition for blocking)
Performance comparison per Get/Put cycle:
- Master: 2 atomic CAS operations
- State machine (before): 5 atomic operations (2.5x slower)
- State machine (after): 2 atomic CAS operations (same as master!)
Expected improvement: Restore to baseline ~11,373 ops/sec
* combine cas
* fix linter
* try faster approach
* fast semaphore
* better inlining for hot path
* fix linter issues
* use new semaphore in auth as well
* linter should be happy now
* add comments
* Update internal/pool/conn_state.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* address comment
* slight reordering
* try to cache time if for non-critical calculation
* fix wrong benchmark
* add concurrent test
* fix benchmark report
* add additional expect to check output
* comment and variable rename
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@ package redis_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -100,7 +101,82 @@ func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Contex
|
||||
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
|
||||
b.ReportMetric(float64(avgTimePerOp), "ns/op")
|
||||
// report average time in milliseconds from totalTimes
|
||||
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
|
||||
sumTime := time.Duration(0)
|
||||
for _, t := range totalTimes {
|
||||
sumTime += t
|
||||
}
|
||||
avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes))
|
||||
b.ReportMetric(float64(avgTimePerOpMs), "ms")
|
||||
}
|
||||
|
||||
// benchmarkHSETOperationsConcurrent performs the actual HSET benchmark for a given scale
|
||||
func benchmarkHSETOperationsConcurrent(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
|
||||
hashKey := fmt.Sprintf("benchmark_hash_%d", operations)
|
||||
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
totalTimes := []time.Duration{}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
// Clean up the hash before each iteration
|
||||
rdb.Del(ctx, hashKey)
|
||||
b.StartTimer()
|
||||
|
||||
startTime := time.Now()
|
||||
// Perform the specified number of HSET operations
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
timesCh := make(chan time.Duration, operations)
|
||||
errCh := make(chan error, operations)
|
||||
|
||||
for j := 0; j < operations; j++ {
|
||||
wg.Add(1)
|
||||
go func(j int) {
|
||||
defer wg.Done()
|
||||
field := fmt.Sprintf("field_%d", j)
|
||||
value := fmt.Sprintf("value_%d", j)
|
||||
|
||||
err := rdb.HSet(ctx, hashKey, field, value).Err()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
timesCh <- time.Since(startTime)
|
||||
}(j)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(timesCh)
|
||||
close(errCh)
|
||||
|
||||
// Check for errors
|
||||
for err := range errCh {
|
||||
b.Errorf("HSET operation failed: %v", err)
|
||||
}
|
||||
|
||||
for d := range timesCh {
|
||||
totalTimes = append(totalTimes, d)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop the timer to calculate metrics
|
||||
b.StopTimer()
|
||||
|
||||
// Report operations per second
|
||||
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
|
||||
b.ReportMetric(opsPerSec, "ops/sec")
|
||||
|
||||
// Report average time per operation
|
||||
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
|
||||
b.ReportMetric(float64(avgTimePerOp), "ns/op")
|
||||
// report average time in milliseconds from totalTimes
|
||||
|
||||
sumTime := time.Duration(0)
|
||||
for _, t := range totalTimes {
|
||||
sumTime += t
|
||||
}
|
||||
avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes))
|
||||
b.ReportMetric(float64(avgTimePerOpMs), "ms")
|
||||
}
|
||||
|
||||
@@ -134,6 +210,37 @@ func BenchmarkHSETPipelined(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHSET_Concurrent(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup Redis client
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
PoolSize: 100,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
// Test connection
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
b.Skipf("Redis server not available: %v", err)
|
||||
}
|
||||
|
||||
// Clean up before and after tests
|
||||
defer func() {
|
||||
rdb.FlushDB(ctx)
|
||||
}()
|
||||
|
||||
// Reduced scales to avoid overwhelming the system with too many concurrent goroutines
|
||||
scales := []int{1, 10, 100, 1000}
|
||||
|
||||
for _, scale := range scales {
|
||||
b.Run(fmt.Sprintf("HSET_%d_operations_concurrent", scale), func(b *testing.B) {
|
||||
benchmarkHSETOperationsConcurrent(b, rdb, ctx, scale)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkHSETPipelined performs HSET benchmark using pipelining
|
||||
func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
|
||||
hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations)
|
||||
@@ -177,7 +284,11 @@ func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context
|
||||
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
|
||||
b.ReportMetric(float64(avgTimePerOp), "ns/op")
|
||||
// report average time in milliseconds from totalTimes
|
||||
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
|
||||
sumTime := time.Duration(0)
|
||||
for _, t := range totalTimes {
|
||||
sumTime += t
|
||||
}
|
||||
avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes))
|
||||
b.ReportMetric(float64(avgTimePerOpMs), "ms")
|
||||
}
|
||||
|
||||
|
||||
@@ -34,9 +34,10 @@ type ReAuthPoolHook struct {
|
||||
shouldReAuth map[uint64]func(error)
|
||||
shouldReAuthLock sync.RWMutex
|
||||
|
||||
// workers is a semaphore channel limiting concurrent re-auth operations
|
||||
// workers is a semaphore limiting concurrent re-auth operations
|
||||
// Initialized with poolSize tokens to prevent pool exhaustion
|
||||
workers chan struct{}
|
||||
// Uses FastSemaphore for consistency and better performance
|
||||
workers *internal.FastSemaphore
|
||||
|
||||
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
|
||||
reAuthTimeout time.Duration
|
||||
@@ -59,16 +60,10 @@ type ReAuthPoolHook struct {
|
||||
// The poolSize parameter is used to initialize the worker semaphore, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||
workers := make(chan struct{}, poolSize)
|
||||
// Initialize the workers channel with tokens (semaphore pattern)
|
||||
for i := 0; i < poolSize; i++ {
|
||||
workers <- struct{}{}
|
||||
}
|
||||
|
||||
return &ReAuthPoolHook{
|
||||
shouldReAuth: make(map[uint64]func(error)),
|
||||
scheduledReAuth: make(map[uint64]bool),
|
||||
workers: workers,
|
||||
workers: internal.NewFastSemaphore(int32(poolSize)),
|
||||
reAuthTimeout: reAuthTimeout,
|
||||
}
|
||||
}
|
||||
@@ -162,10 +157,10 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
go func() {
|
||||
<-r.workers
|
||||
r.workers.AcquireBlocking()
|
||||
// safety first
|
||||
if conn == nil || (conn != nil && conn.IsClosed()) {
|
||||
r.workers <- struct{}{}
|
||||
r.workers.Release()
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -176,7 +171,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.workers <- struct{}{}
|
||||
r.workers.Release()
|
||||
}()
|
||||
|
||||
// Create timeout context for connection acquisition
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package pool implements the pool management
|
||||
package pool
|
||||
|
||||
import (
|
||||
@@ -17,6 +18,35 @@ import (
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Global time cache updated every 50ms by background goroutine.
|
||||
// This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout.
|
||||
// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds).
|
||||
var globalTimeCache struct {
|
||||
nowNs atomic.Int64
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Initialize immediately
|
||||
globalTimeCache.nowNs.Store(time.Now().UnixNano())
|
||||
|
||||
// Start background updater
|
||||
go func() {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
globalTimeCache.nowNs.Store(time.Now().UnixNano())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// getCachedTimeNs returns the current time in nanoseconds from the global cache.
|
||||
// This is updated every 50ms by a background goroutine, avoiding expensive syscalls.
|
||||
// Max staleness: 50ms.
|
||||
func getCachedTimeNs() int64 {
|
||||
return globalTimeCache.nowNs.Load()
|
||||
}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
@@ -79,6 +109,7 @@ type Conn struct {
|
||||
expiresAt time.Time
|
||||
|
||||
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
|
||||
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
@@ -260,11 +291,13 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
||||
|
||||
if !old && new {
|
||||
// Acquiring: IDLE → IN_USE
|
||||
_, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse)
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse)
|
||||
return err == nil
|
||||
} else {
|
||||
// Releasing: IN_USE → IDLE
|
||||
_, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
@@ -454,7 +487,8 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
@@ -487,7 +521,8 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
@@ -632,7 +667,8 @@ func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
|
||||
// But in some edge cases or tests, it might be in IDLE or CREATED state
|
||||
// The pool will detect this state change and preserve it (not overwrite with IDLE)
|
||||
finalState, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle, StateCreated}, StateUnusable)
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable)
|
||||
if err != nil {
|
||||
// Check if already in UNUSABLE state (race condition or retry)
|
||||
// ShouldHandoff should be false now, but check just in case
|
||||
@@ -658,6 +694,42 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine {
|
||||
return cn.stateMachine
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire the connection for use.
|
||||
// This is an optimized inline method for the hot path (Get operation).
|
||||
//
|
||||
// It tries to transition from IDLE -> IN_USE or CREATED -> IN_USE.
|
||||
// Returns true if the connection was successfully acquired, false otherwise.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast()
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// The IDLE->IN_USE and CREATED->IN_USE transitions don't need
|
||||
// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever
|
||||
// needs to notify waiters on these transitions, update this to use TryTransitionFast().
|
||||
func (cn *Conn) TryAcquire() bool {
|
||||
// The || operator short-circuits, so only 1 CAS in the common case
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) ||
|
||||
cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateInUse))
|
||||
}
|
||||
|
||||
// Release releases the connection back to the pool.
|
||||
// This is an optimized inline method for the hot path (Put operation).
|
||||
//
|
||||
// It tries to transition from IN_USE -> IDLE.
|
||||
// Returns true if the connection was successfully released, false otherwise.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast().
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// If the state machine ever needs to notify waiters
|
||||
// on this transition, update this to use TryTransitionFast().
|
||||
func (cn *Conn) Release() bool {
|
||||
// Inline the hot path - single CAS operation
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle))
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff.
|
||||
// Makes the connection usable again.
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
@@ -800,8 +872,12 @@ func (cn *Conn) MaybeHasData() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// deadline computes the effective deadline time based on context and timeout.
|
||||
// It updates the usedAt timestamp to now.
|
||||
// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation).
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
tm := time.Now()
|
||||
// Use cached time for deadline calculation (called 2x per command: read + write)
|
||||
tm := time.Unix(0, getCachedTimeNs())
|
||||
cn.SetUsedAt(tm)
|
||||
|
||||
if timeout > 0 {
|
||||
|
||||
@@ -41,6 +41,13 @@ const (
|
||||
StateClosed
|
||||
)
|
||||
|
||||
// Predefined state slices to avoid allocations in hot paths
|
||||
var (
|
||||
validFromInUse = []ConnState{StateInUse}
|
||||
validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle}
|
||||
validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle}
|
||||
)
|
||||
|
||||
// String returns a human-readable string representation of the state.
|
||||
func (s ConnState) String() string {
|
||||
switch s {
|
||||
@@ -92,8 +99,9 @@ type ConnStateMachine struct {
|
||||
state atomic.Uint32
|
||||
|
||||
// FIFO queue for waiters - only locked during waiter add/remove/notify
|
||||
mu sync.Mutex
|
||||
waiters *list.List // List of *waiter
|
||||
mu sync.Mutex
|
||||
waiters *list.List // List of *waiter
|
||||
waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path)
|
||||
}
|
||||
|
||||
// NewConnStateMachine creates a new connection state machine.
|
||||
@@ -114,6 +122,23 @@ func (sm *ConnStateMachine) GetState() ConnState {
|
||||
return ConnState(sm.state.Load())
|
||||
}
|
||||
|
||||
// TryTransitionFast is an optimized version for the hot path (Get/Put operations).
|
||||
// It only handles simple state transitions without waiter notification.
|
||||
// This is safe because:
|
||||
// 1. Get/Put don't need to wait for state changes
|
||||
// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match
|
||||
// 3. If a background operation is in progress (state is UNUSABLE), this fails fast
|
||||
//
|
||||
// Returns true if transition succeeded, false otherwise.
|
||||
// Use this for performance-critical paths where you don't need error details.
|
||||
//
|
||||
// Performance: Single CAS operation - as fast as the old atomic bool!
|
||||
// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target)
|
||||
// The || operator short-circuits, so only 1 CAS is executed in the common case.
|
||||
func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool {
|
||||
return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState))
|
||||
}
|
||||
|
||||
// TryTransition attempts an immediate state transition without waiting.
|
||||
// Returns the current state after the transition attempt and an error if the transition failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
@@ -126,17 +151,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta
|
||||
// Try each valid from state with CAS
|
||||
// This ensures only ONE goroutine can successfully transition at a time
|
||||
for _, fromState := range validFromStates {
|
||||
// Fast path: check if we're already in target state
|
||||
if fromState == targetState && sm.GetState() == targetState {
|
||||
return targetState, nil
|
||||
}
|
||||
|
||||
// Try to atomically swap from fromState to targetState
|
||||
// If successful, we won the race and can proceed
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
// Notify any waiters
|
||||
sm.notifyWaiters()
|
||||
// Hot path optimization: only check for waiters if transition succeeded
|
||||
// This avoids atomic load on every Get/Put when no waiters exist
|
||||
if sm.waiterCount.Load() > 0 {
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
@@ -213,6 +236,7 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
// Add to FIFO queue
|
||||
sm.mu.Lock()
|
||||
elem := sm.waiters.PushBack(w)
|
||||
sm.waiterCount.Add(1)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Wait for state change or timeout
|
||||
@@ -221,10 +245,13 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
// Timeout or cancellation - remove from queue
|
||||
sm.mu.Lock()
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
sm.mu.Unlock()
|
||||
return sm.GetState(), ctx.Err()
|
||||
case err := <-w.done:
|
||||
// Transition completed (or failed)
|
||||
// Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed)
|
||||
// or here (on timeout/cancellation).
|
||||
return sm.GetState(), err
|
||||
}
|
||||
}
|
||||
@@ -232,9 +259,16 @@ func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order.
|
||||
// This is called after every state transition.
|
||||
func (sm *ConnStateMachine) notifyWaiters() {
|
||||
// Fast path: check atomic counter without acquiring lock
|
||||
// This eliminates mutex overhead in the common case (no waiters)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock (waiters might have been processed)
|
||||
if sm.waiters.Len() == 0 {
|
||||
return
|
||||
}
|
||||
@@ -255,6 +289,7 @@ func (sm *ConnStateMachine) notifyWaiters() {
|
||||
if _, valid := w.validStates[currentState]; valid {
|
||||
// Remove from queue first
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
|
||||
// Use CAS to ensure state hasn't changed since we checked
|
||||
// This prevents race condition where another thread changes state
|
||||
@@ -267,6 +302,7 @@ func (sm *ConnStateMachine) notifyWaiters() {
|
||||
} else {
|
||||
// State changed - re-add waiter to front of queue and retry
|
||||
sm.waiters.PushFront(w)
|
||||
sm.waiterCount.Add(1)
|
||||
// Continue to next iteration to re-read state
|
||||
processed = true
|
||||
break
|
||||
|
||||
@@ -20,5 +20,5 @@ func (p *ConnPool) CheckMinIdleConns() {
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueueLen() int {
|
||||
return len(p.queue)
|
||||
return int(p.semaphore.Len())
|
||||
}
|
||||
|
||||
@@ -140,3 +140,16 @@ func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
||||
|
||||
// Clone creates a copy of the hook manager with the same hooks.
|
||||
// This is used for lock-free atomic updates of the hook manager.
|
||||
func (phm *PoolHookManager) Clone() *PoolHookManager {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
newManager := &PoolHookManager{
|
||||
hooks: make([]PoolHook, len(phm.hooks)),
|
||||
}
|
||||
copy(newManager.hooks, phm.hooks)
|
||||
return newManager
|
||||
}
|
||||
|
||||
@@ -202,26 +202,29 @@ func TestPoolWithHooks(t *testing.T) {
|
||||
pool.AddPoolHook(testHook)
|
||||
|
||||
// Verify hooks are initialized
|
||||
if pool.hookManager == nil {
|
||||
manager := pool.hookManager.Load()
|
||||
if manager == nil {
|
||||
t.Error("Expected hookManager to be initialized")
|
||||
}
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test adding hook to pool
|
||||
additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
|
||||
pool.AddPoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
|
||||
manager = pool.hookManager.Load()
|
||||
if manager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test removing hook from pool
|
||||
pool.RemovePoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
|
||||
manager = pool.hookManager.Load()
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,12 @@ var (
|
||||
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
|
||||
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
|
||||
|
||||
// errHookRequestedRemoval is returned when a hook requests connection removal.
|
||||
errHookRequestedRemoval = errors.New("hook requested removal")
|
||||
|
||||
// errConnNotPooled is returned when trying to return a non-pooled connection to the pool.
|
||||
errConnNotPooled = errors.New("connection not pooled")
|
||||
|
||||
// popAttempts is the maximum number of attempts to find a usable connection
|
||||
// when popping from the idle connection pool. This handles cases where connections
|
||||
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
|
||||
@@ -45,14 +51,6 @@ var (
|
||||
noExpiration = maxTime
|
||||
)
|
||||
|
||||
var timers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
t := time.NewTimer(time.Hour)
|
||||
t.Stop()
|
||||
return t
|
||||
},
|
||||
}
|
||||
|
||||
// Stats contains pool state information and accumulated stats.
|
||||
type Stats struct {
|
||||
Hits uint32 // number of times free connection was found in the pool
|
||||
@@ -132,7 +130,9 @@ type ConnPool struct {
|
||||
dialErrorsNum uint32 // atomic
|
||||
lastDialError atomic.Value
|
||||
|
||||
queue chan struct{}
|
||||
// Fast atomic semaphore for connection limiting
|
||||
// Replaces the old channel-based queue for better performance
|
||||
semaphore *internal.FastSemaphore
|
||||
|
||||
connsMu sync.Mutex
|
||||
conns map[uint64]*Conn
|
||||
@@ -148,8 +148,8 @@ type ConnPool struct {
|
||||
_closed uint32 // atomic
|
||||
|
||||
// Pool hooks manager for flexible connection processing
|
||||
hookManagerMu sync.RWMutex
|
||||
hookManager *PoolHookManager
|
||||
// Using atomic.Pointer for lock-free reads in hot paths (Get/Put)
|
||||
hookManager atomic.Pointer[PoolHookManager]
|
||||
}
|
||||
|
||||
var _ Pooler = (*ConnPool)(nil)
|
||||
@@ -158,7 +158,7 @@ func NewConnPool(opt *Options) *ConnPool {
|
||||
p := &ConnPool{
|
||||
cfg: opt,
|
||||
|
||||
queue: make(chan struct{}, opt.PoolSize),
|
||||
semaphore: internal.NewFastSemaphore(opt.PoolSize),
|
||||
conns: make(map[uint64]*Conn),
|
||||
idleConns: make([]*Conn, 0, opt.PoolSize),
|
||||
}
|
||||
@@ -176,27 +176,37 @@ func NewConnPool(opt *Options) *ConnPool {
|
||||
|
||||
// initializeHooks sets up the pool hooks system.
|
||||
func (p *ConnPool) initializeHooks() {
|
||||
p.hookManager = NewPoolHookManager()
|
||||
manager := NewPoolHookManager()
|
||||
p.hookManager.Store(manager)
|
||||
}
|
||||
|
||||
// AddPoolHook adds a pool hook to the pool.
|
||||
func (p *ConnPool) AddPoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager == nil {
|
||||
// Lock-free read of current manager
|
||||
manager := p.hookManager.Load()
|
||||
if manager == nil {
|
||||
p.initializeHooks()
|
||||
manager = p.hookManager.Load()
|
||||
}
|
||||
p.hookManager.AddHook(hook)
|
||||
|
||||
// Create new manager with added hook
|
||||
newManager := manager.Clone()
|
||||
newManager.AddHook(hook)
|
||||
|
||||
// Atomically swap to new manager
|
||||
p.hookManager.Store(newManager)
|
||||
}
|
||||
|
||||
// RemovePoolHook removes a pool hook from the pool.
|
||||
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
manager := p.hookManager.Load()
|
||||
if manager != nil {
|
||||
// Create new manager with removed hook
|
||||
newManager := manager.Clone()
|
||||
newManager.RemoveHook(hook)
|
||||
|
||||
if p.hookManager != nil {
|
||||
p.hookManager.RemoveHook(hook)
|
||||
// Atomically swap to new manager
|
||||
p.hookManager.Store(newManager)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,31 +223,32 @@ func (p *ConnPool) checkMinIdleConns() {
|
||||
// Only create idle connections if we haven't reached the total pool size limit
|
||||
// MinIdleConns should be a subset of PoolSize, not additional connections
|
||||
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
p.freeTurn()
|
||||
}()
|
||||
default:
|
||||
// Try to acquire a semaphore token
|
||||
if !p.semaphore.TryAcquire() {
|
||||
// Semaphore is full, can't create more connections
|
||||
return
|
||||
}
|
||||
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
p.freeTurn()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,7 +292,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= p.cfg.MaxActiveConns {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
@@ -296,7 +307,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
// when first used. Do NOT transition to IDLE here - that happens after initialization completes.
|
||||
// The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success)
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > p.cfg.MaxActiveConns {
|
||||
_ = cn.Close()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
@@ -441,14 +452,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
// Use cached time for health checks (max 50ms staleness is acceptable)
|
||||
now := time.Unix(0, getCachedTimeNs())
|
||||
attempts := 0
|
||||
|
||||
// Get hooks manager once for this getConn call for performance.
|
||||
// Note: Hooks added/removed during this call won't be reflected.
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
// Lock-free atomic read - no mutex overhead!
|
||||
hookManager := p.hookManager.Load()
|
||||
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
@@ -476,19 +485,20 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
// Combine error and rejection checks to reduce branches
|
||||
if hookManager != nil {
|
||||
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
if !acceptConn {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||
// Return connection to pool without freeing the turn that this Get() call holds.
|
||||
// We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn.
|
||||
p.putConnWithoutTurn(ctx, cn)
|
||||
cn = nil
|
||||
if err != nil || !acceptConn {
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
_ = p.CloseConn(cn)
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
|
||||
// Return connection to pool without freeing the turn that this Get() call holds.
|
||||
// We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn.
|
||||
p.putConnWithoutTurn(ctx, cn)
|
||||
cn = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -521,44 +531,36 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
}
|
||||
|
||||
func (p *ConnPool) waitTurn(ctx context.Context) error {
|
||||
// Fast path: check context first
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
// Fast path: try to acquire without blocking
|
||||
if p.semaphore.TryAcquire() {
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Slow path: need to wait
|
||||
start := time.Now()
|
||||
timer := timers.Get().(*time.Timer)
|
||||
defer timers.Put(timer)
|
||||
timer.Reset(p.cfg.PoolTimeout)
|
||||
err := p.semaphore.Acquire(ctx, p.cfg.PoolTimeout, ErrPoolTimeout)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
switch err {
|
||||
case nil:
|
||||
// Successfully acquired after waiting
|
||||
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
|
||||
atomic.AddUint32(&p.stats.WaitCount, 1)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
case <-timer.C:
|
||||
case ErrPoolTimeout:
|
||||
atomic.AddUint32(&p.stats.Timeouts, 1)
|
||||
return ErrPoolTimeout
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ConnPool) freeTurn() {
|
||||
<-p.queue
|
||||
p.semaphore.Release()
|
||||
}
|
||||
|
||||
func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
@@ -592,15 +594,16 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
}
|
||||
attempts++
|
||||
|
||||
// Try to atomically transition to IN_USE using state machine
|
||||
// Accept both CREATED (uninitialized) and IDLE (initialized) states
|
||||
_, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse)
|
||||
if err == nil {
|
||||
// Hot path optimization: try IDLE → IN_USE or CREATED → IN_USE transition
|
||||
// Using inline TryAcquire() method for better performance (avoids pointer dereference)
|
||||
if cn.TryAcquire() {
|
||||
// Successfully acquired the connection
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
|
||||
// Connection is in UNUSABLE, INITIALIZING, or other state - skip it
|
||||
|
||||
// Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.)
|
||||
// Put it back in the pool and try the next one
|
||||
if p.cfg.PoolFIFO {
|
||||
@@ -651,9 +654,8 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) {
|
||||
// It's a push notification, allow pooling (client will handle it)
|
||||
}
|
||||
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
// Lock-free atomic read - no mutex overhead!
|
||||
hookManager := p.hookManager.Load()
|
||||
|
||||
if hookManager != nil {
|
||||
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
|
||||
@@ -664,41 +666,35 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// If hooks say to remove the connection, do so
|
||||
if shouldRemove {
|
||||
p.removeConnInternal(ctx, cn, errors.New("hook requested removal"), freeTurn)
|
||||
return
|
||||
}
|
||||
|
||||
// If processor says not to pool the connection, remove it
|
||||
if !shouldPool {
|
||||
p.removeConnInternal(ctx, cn, errors.New("hook requested no pooling"), freeTurn)
|
||||
// Combine all removal checks into one - reduces branches
|
||||
if shouldRemove || !shouldPool {
|
||||
p.removeConnInternal(ctx, cn, errHookRequestedRemoval, freeTurn)
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.pooled {
|
||||
p.removeConnInternal(ctx, cn, errors.New("connection not pooled"), freeTurn)
|
||||
p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn)
|
||||
return
|
||||
}
|
||||
|
||||
var shouldCloseConn bool
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||
// Try to transition to IDLE state BEFORE adding to pool
|
||||
// Only transition if connection is still IN_USE (hooks might have changed state)
|
||||
// This prevents:
|
||||
// 1. Race condition where another goroutine could acquire a connection that's still in IN_USE state
|
||||
// 2. Overwriting state changes made by hooks (e.g., IN_USE → UNUSABLE for handoff)
|
||||
currentState, err := cn.GetStateMachine().TryTransition([]ConnState{StateInUse}, StateIdle)
|
||||
if err != nil {
|
||||
// Hook changed the state (e.g., to UNUSABLE for handoff)
|
||||
// Hot path optimization: try fast IN_USE → IDLE transition
|
||||
// Using inline Release() method for better performance (avoids pointer dereference)
|
||||
transitionedToIdle := cn.Release()
|
||||
|
||||
if !transitionedToIdle {
|
||||
// Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff)
|
||||
// Keep the state set by the hook and pool the connection anyway
|
||||
currentState := cn.GetStateMachine().GetState()
|
||||
internal.Logger.Printf(ctx, "Connection state changed by hook to %v, pooling as-is", currentState)
|
||||
}
|
||||
|
||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||
// put them at the opposite end of the queue
|
||||
if !cn.IsUsable() {
|
||||
// Optimization: if we just transitioned to IDLE, we know it's usable - skip the check
|
||||
if !transitionedToIdle && !cn.IsUsable() {
|
||||
if p.cfg.PoolFIFO {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
@@ -742,9 +738,8 @@ func (p *ConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error
|
||||
|
||||
// removeConnInternal is the internal implementation of Remove that optionally frees a turn.
|
||||
func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) {
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
// Lock-free atomic read - no mutex overhead!
|
||||
hookManager := p.hookManager.Load()
|
||||
|
||||
if hookManager != nil {
|
||||
hookManager.ProcessOnRemove(ctx, cn, reason)
|
||||
@@ -877,36 +872,53 @@ func (p *ConnPool) Close() error {
|
||||
}
|
||||
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
|
||||
// slight optimization, check expiresAt first.
|
||||
if cn.expiresAt.Before(now) {
|
||||
return false
|
||||
// Performance optimization: check conditions from cheapest to most expensive,
|
||||
// and from most likely to fail to least likely to fail.
|
||||
|
||||
// Only fails if ConnMaxLifetime is set AND connection is old.
|
||||
// Most pools don't set ConnMaxLifetime, so this rarely fails.
|
||||
if p.cfg.ConnMaxLifetime > 0 {
|
||||
if cn.expiresAt.Before(now) {
|
||||
return false // Connection has exceeded max lifetime
|
||||
}
|
||||
}
|
||||
|
||||
// Check if connection has exceeded idle timeout
|
||||
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||
return false
|
||||
// Most pools set ConnMaxIdleTime, and idle connections are common.
|
||||
// Checking this first allows us to fail fast without expensive syscalls.
|
||||
if p.cfg.ConnMaxIdleTime > 0 {
|
||||
if now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||
return false // Connection has been idle too long
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsedAt(now)
|
||||
// Check basic connection health
|
||||
// Use GetNetConn() to safely access netConn and avoid data races
|
||||
// Only run this if the cheap checks passed.
|
||||
if err := connCheck(cn.getNetConn()); err != nil {
|
||||
// If there's unexpected data, it might be push notifications (RESP3)
|
||||
// However, push notification processing is now handled by the client
|
||||
// before WithReader to ensure proper context is available to handlers
|
||||
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block
|
||||
// Peek at the reply type to check if it's a push notification
|
||||
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
||||
// For RESP3 connections with push notifications, we allow some buffered data
|
||||
// The client will process these notifications before using the connection
|
||||
internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID())
|
||||
return true // Connection is healthy, client will handle notifications
|
||||
internal.Logger.Printf(
|
||||
context.Background(),
|
||||
"push: conn[%d] has buffered data, likely push notifications - will be processed by client",
|
||||
cn.GetID(),
|
||||
)
|
||||
|
||||
// Update timestamp for healthy connection
|
||||
cn.SetUsedAt(now)
|
||||
|
||||
// Connection is healthy, client will handle notifications
|
||||
return true
|
||||
}
|
||||
return false // Unexpected data, not push notifications, connection is unhealthy
|
||||
} else {
|
||||
// Not a push notification - treat as unhealthy
|
||||
return false
|
||||
}
|
||||
// Connection failed health check
|
||||
return false
|
||||
}
|
||||
|
||||
// Only update UsedAt if connection is healthy (avoids unnecessary atomic store)
|
||||
cn.SetUsedAt(now)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ type PubSubPool struct {
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
// PubSubPool implements a pool for PubSub connections.
|
||||
// NewPubSubPool implements a pool for PubSub connections.
|
||||
// It intentionally does not implement the Pooler interface
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
|
||||
@@ -370,9 +370,17 @@ func BenchmarkPeekPushNotificationName(b *testing.B) {
|
||||
buf := createValidPushNotification(tc.notification, "data")
|
||||
data := buf.Bytes()
|
||||
|
||||
// Reuse both bytes.Reader and proto.Reader to avoid allocations
|
||||
bytesReader := bytes.NewReader(data)
|
||||
reader := NewReader(bytesReader)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader := NewReader(bytes.NewReader(data))
|
||||
// Reset the bytes.Reader to the beginning without allocating
|
||||
bytesReader.Reset(data)
|
||||
// Reset the proto.Reader to reuse the bufio buffer
|
||||
reader.Reset(bytesReader)
|
||||
_, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
b.Errorf("PeekPushNotificationName should not error: %v", err)
|
||||
|
||||
161
internal/semaphore.go
Normal file
161
internal/semaphore.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var semTimers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
t := time.NewTimer(time.Hour)
|
||||
t.Stop()
|
||||
return t
|
||||
},
|
||||
}
|
||||
|
||||
// FastSemaphore is a counting semaphore implementation using atomic operations.
|
||||
// It's optimized for the fast path (no blocking) while still supporting timeouts and context cancellation.
|
||||
//
|
||||
// Performance characteristics:
|
||||
// - Fast path (no blocking): Single atomic CAS operation
|
||||
// - Slow path (blocking): Falls back to channel-based waiting
|
||||
// - Release: Single atomic decrement + optional channel notification
|
||||
//
|
||||
// This is significantly faster than a pure channel-based semaphore because:
|
||||
// 1. The fast path avoids channel operations entirely (no scheduler involvement)
|
||||
// 2. Atomic operations are much cheaper than channel send/receive
|
||||
type FastSemaphore struct {
|
||||
// Current number of acquired tokens (atomic)
|
||||
count atomic.Int32
|
||||
|
||||
// Maximum number of tokens (capacity)
|
||||
max int32
|
||||
|
||||
// Channel for blocking waiters
|
||||
// Only used when fast path fails (semaphore is full)
|
||||
waitCh chan struct{}
|
||||
}
|
||||
|
||||
// NewFastSemaphore creates a new fast semaphore with the given capacity.
|
||||
func NewFastSemaphore(capacity int32) *FastSemaphore {
|
||||
return &FastSemaphore{
|
||||
max: capacity,
|
||||
waitCh: make(chan struct{}, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a token without blocking.
|
||||
// Returns true if successful, false if the semaphore is full.
|
||||
//
|
||||
// This is the fast path - just a single CAS operation.
|
||||
func (s *FastSemaphore) TryAcquire() bool {
|
||||
for {
|
||||
current := s.count.Load()
|
||||
if current >= s.max {
|
||||
return false // Semaphore is full
|
||||
}
|
||||
if s.count.CompareAndSwap(current, current+1) {
|
||||
return true // Successfully acquired
|
||||
}
|
||||
// CAS failed due to concurrent modification, retry
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire acquires a token, blocking if necessary until one is available or the context is cancelled.
|
||||
// Returns an error if the context is cancelled or the timeout expires.
|
||||
// Returns timeoutErr when the timeout expires.
|
||||
//
|
||||
// Performance optimization:
|
||||
// 1. First try fast path (no blocking)
|
||||
// 2. If that fails, fall back to channel-based waiting
|
||||
func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
|
||||
// Fast path: try to acquire without blocking
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Try fast acquire first
|
||||
if s.TryAcquire() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path failed, need to wait
|
||||
// Use timer pool to avoid allocation
|
||||
timer := semTimers.Get().(*time.Timer)
|
||||
defer semTimers.Put(timer)
|
||||
timer.Reset(timeout)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
|
||||
case <-s.waitCh:
|
||||
// Someone released a token, try to acquire it
|
||||
if s.TryAcquire() {
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Failed to acquire (race with another goroutine), continue waiting
|
||||
|
||||
case <-timer.C:
|
||||
return timeoutErr
|
||||
}
|
||||
|
||||
// Periodically check if we can acquire (handles race conditions)
|
||||
if time.Since(start) > timeout {
|
||||
return timeoutErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
|
||||
// This is useful for cases where you don't need timeout or context cancellation.
|
||||
// Returns immediately if a token is available (fast path).
|
||||
func (s *FastSemaphore) AcquireBlocking() {
|
||||
// Try fast path first
|
||||
if s.TryAcquire() {
|
||||
return
|
||||
}
|
||||
|
||||
// Slow path: wait for a token
|
||||
for {
|
||||
<-s.waitCh
|
||||
if s.TryAcquire() {
|
||||
return
|
||||
}
|
||||
// Failed to acquire (race with another goroutine), continue waiting
|
||||
}
|
||||
}
|
||||
|
||||
// Release releases a token back to the semaphore.
|
||||
// This wakes up one waiting goroutine if any are blocked.
|
||||
func (s *FastSemaphore) Release() {
|
||||
s.count.Add(-1)
|
||||
|
||||
// Try to wake up a waiter (non-blocking)
|
||||
// If no one is waiting, this is a no-op
|
||||
select {
|
||||
case s.waitCh <- struct{}{}:
|
||||
// Successfully notified a waiter
|
||||
default:
|
||||
// No waiters, that's fine
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the current number of acquired tokens.
|
||||
// Used by tests to check semaphore state.
|
||||
func (s *FastSemaphore) Len() int32 {
|
||||
return s.count.Load()
|
||||
}
|
||||
@@ -323,6 +323,7 @@ var _ = Describe("Client", func() {
|
||||
cn, err = client.Pool().Get(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
Expect(cn.UsedAt().UnixNano()).To(BeNumerically(">", createdAt.UnixNano()))
|
||||
Expect(cn.UsedAt().After(createdAt)).To(BeTrue())
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user