1
0
mirror of https://github.com/redis/go-redis.git synced 2025-12-02 06:22:31 +03:00

fix(pool): pool performance (#3565)

* perf(pool): replace hookManager RWMutex with atomic.Pointer and add predefined state slices

- Replace hookManager RWMutex with atomic.Pointer for lock-free reads in hot paths
- Add predefined state slices to avoid allocations (validFromInUse, validFromCreatedOrIdle, etc.)
- Add Clone() method to PoolHookManager for atomic updates
- Update AddPoolHook/RemovePoolHook to use copy-on-write pattern
- Update all hookManager access points to use atomic Load()

Performance improvements:
- Eliminates RWMutex contention in Get/Put/Remove hot paths
- Reduces allocations by reusing predefined state slices
- Lock-free reads allow better CPU cache utilization

* perf(pool): eliminate mutex overhead in state machine hot path

The state machine was calling notifyWaiters() on EVERY Get/Put operation,
which acquired a mutex even when no waiters were present (the common case).

Fix: Use atomic waiterCount to check for waiters BEFORE acquiring mutex.
This eliminates mutex contention in the hot path (Get/Put operations).

Implementation:
- Added atomic.Int32 waiterCount field to ConnStateMachine
- Increment when adding waiter, decrement when removing
- Check waiterCount atomically before acquiring mutex in notifyWaiters()

Performance impact:
- Before: mutex lock/unlock on every Get/Put (even with no waiters)
- After: lock-free atomic check, only acquire mutex if waiters exist
- Expected improvement: ~30-50% for Get/Put operations

* perf(pool): use predefined state slices to eliminate allocations in hot path

The pool was creating new slice literals on EVERY Get/Put operation:
- popIdle(): []ConnState{StateCreated, StateIdle}
- putConn(): []ConnState{StateInUse}
- CompareAndSwapUsed(): []ConnState{StateIdle} and []ConnState{StateInUse}
- MarkUnusableForHandoff(): []ConnState{StateInUse, StateIdle, StateCreated}

These allocations were happening millions of times per second in the hot path.

Fix: Use predefined global slices defined in conn_state.go:
- validFromInUse
- validFromCreatedOrIdle
- validFromCreatedInUseOrIdle

Performance impact:
- Before: 4 slice allocations per Get/Put cycle
- After: 0 allocations (use predefined slices)
- Expected improvement: ~30-40% reduction in allocations and GC pressure

* perf(pool): optimize TryTransition to reduce atomic operations

Further optimize the hot path by:
1. Remove redundant GetState() call in the loop
2. Only check waiterCount after successful CAS (not before loop)
3. Inline the waiterCount check to avoid notifyWaiters() call overhead

This reduces atomic operations from 4-5 per Get/Put to 2-3:
- Before: GetState() + CAS + waiterCount.Load() + notifyWaiters mutex check
- After: CAS + waiterCount.Load() (only if CAS succeeds)

Performance impact:
- Eliminates 1-2 atomic operations per Get/Put
- Expected improvement: ~10-15% for Get/Put operations

* perf(pool): add fast path for Get/Put to match master performance

Introduced TryTransitionFast() for the hot path (Get/Put operations):
- Single CAS operation (same as master's atomic bool)
- No waiter notification overhead
- No loop through valid states
- No error allocation

Hot path flow:
1. popIdle(): Try IDLE → IN_USE (fast), fallback to CREATED → IN_USE
2. putConn(): Try IN_USE → IDLE (fast)

This matches master's performance while preserving state machine for:
- Background operations (handoff/reauth use UNUSABLE state)
- State validation (TryTransition still available)
- Waiter notification (AwaitAndTransition for blocking)

Performance comparison per Get/Put cycle:
- Master: 2 atomic CAS operations
- State machine (before): 5 atomic operations (2.5x slower)
- State machine (after): 2 atomic CAS operations (same as master!)

Expected improvement: Restore to baseline ~11,373 ops/sec

* combine cas

* fix linter

* try faster approach

* fast semaphore

* better inlining for hot path

* fix linter issues

* use new semaphore in auth as well

* linter should be happy now

* add comments

* Update internal/pool/conn_state.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* address comment

* slight reordering

* try to cache time if for non-critical calculation

* fix wrong benchmark

* add concurrent test

* fix benchmark report

* add additional expect to check output

* comment and variable rename

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nedyalko Dyakov
2025-10-27 15:06:30 +02:00
committed by GitHub
parent 07e665f7af
commit 080a33c3a8
12 changed files with 585 additions and 169 deletions

View File

@@ -3,6 +3,7 @@ package redis_test
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"testing" "testing"
"time" "time"
@@ -100,7 +101,82 @@ func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Contex
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op") b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes // report average time in milliseconds from totalTimes
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) sumTime := time.Duration(0)
for _, t := range totalTimes {
sumTime += t
}
avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms")
}
// benchmarkHSETOperationsConcurrent performs the actual HSET benchmark for a given scale
func benchmarkHSETOperationsConcurrent(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
hashKey := fmt.Sprintf("benchmark_hash_%d", operations)
b.ResetTimer()
b.StartTimer()
totalTimes := []time.Duration{}
for i := 0; i < b.N; i++ {
b.StopTimer()
// Clean up the hash before each iteration
rdb.Del(ctx, hashKey)
b.StartTimer()
startTime := time.Now()
// Perform the specified number of HSET operations
wg := sync.WaitGroup{}
timesCh := make(chan time.Duration, operations)
errCh := make(chan error, operations)
for j := 0; j < operations; j++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
field := fmt.Sprintf("field_%d", j)
value := fmt.Sprintf("value_%d", j)
err := rdb.HSet(ctx, hashKey, field, value).Err()
if err != nil {
errCh <- err
return
}
timesCh <- time.Since(startTime)
}(j)
}
wg.Wait()
close(timesCh)
close(errCh)
// Check for errors
for err := range errCh {
b.Errorf("HSET operation failed: %v", err)
}
for d := range timesCh {
totalTimes = append(totalTimes, d)
}
}
// Stop the timer to calculate metrics
b.StopTimer()
// Report operations per second
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
b.ReportMetric(opsPerSec, "ops/sec")
// Report average time per operation
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes
sumTime := time.Duration(0)
for _, t := range totalTimes {
sumTime += t
}
avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms") b.ReportMetric(float64(avgTimePerOpMs), "ms")
} }
@@ -134,6 +210,37 @@ func BenchmarkHSETPipelined(b *testing.B) {
} }
} }
func BenchmarkHSET_Concurrent(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 0,
PoolSize: 100,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
// Reduced scales to avoid overwhelming the system with too many concurrent goroutines
scales := []int{1, 10, 100, 1000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_%d_operations_concurrent", scale), func(b *testing.B) {
benchmarkHSETOperationsConcurrent(b, rdb, ctx, scale)
})
}
}
// benchmarkHSETPipelined performs HSET benchmark using pipelining // benchmarkHSETPipelined performs HSET benchmark using pipelining
func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations) hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations)
@@ -177,7 +284,11 @@ func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op") b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes // report average time in milliseconds from totalTimes
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) sumTime := time.Duration(0)
for _, t := range totalTimes {
sumTime += t
}
avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms") b.ReportMetric(float64(avgTimePerOpMs), "ms")
} }

View File

@@ -34,9 +34,10 @@ type ReAuthPoolHook struct {
shouldReAuth map[uint64]func(error) shouldReAuth map[uint64]func(error)
shouldReAuthLock sync.RWMutex shouldReAuthLock sync.RWMutex
// workers is a semaphore channel limiting concurrent re-auth operations // workers is a semaphore limiting concurrent re-auth operations
// Initialized with poolSize tokens to prevent pool exhaustion // Initialized with poolSize tokens to prevent pool exhaustion
workers chan struct{} // Uses FastSemaphore for consistency and better performance
workers *internal.FastSemaphore
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth // reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
reAuthTimeout time.Duration reAuthTimeout time.Duration
@@ -59,16 +60,10 @@ type ReAuthPoolHook struct {
// The poolSize parameter is used to initialize the worker semaphore, ensuring that // The poolSize parameter is used to initialize the worker semaphore, ensuring that
// re-auth operations don't exhaust the connection pool. // re-auth operations don't exhaust the connection pool.
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
workers := make(chan struct{}, poolSize)
// Initialize the workers channel with tokens (semaphore pattern)
for i := 0; i < poolSize; i++ {
workers <- struct{}{}
}
return &ReAuthPoolHook{ return &ReAuthPoolHook{
shouldReAuth: make(map[uint64]func(error)), shouldReAuth: make(map[uint64]func(error)),
scheduledReAuth: make(map[uint64]bool), scheduledReAuth: make(map[uint64]bool),
workers: workers, workers: internal.NewFastSemaphore(int32(poolSize)),
reAuthTimeout: reAuthTimeout, reAuthTimeout: reAuthTimeout,
} }
} }
@@ -162,10 +157,10 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
r.scheduledLock.Unlock() r.scheduledLock.Unlock()
r.shouldReAuthLock.Unlock() r.shouldReAuthLock.Unlock()
go func() { go func() {
<-r.workers r.workers.AcquireBlocking()
// safety first // safety first
if conn == nil || (conn != nil && conn.IsClosed()) { if conn == nil || (conn != nil && conn.IsClosed()) {
r.workers <- struct{}{} r.workers.Release()
return return
} }
defer func() { defer func() {
@@ -176,7 +171,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
r.scheduledLock.Lock() r.scheduledLock.Lock()
delete(r.scheduledReAuth, connID) delete(r.scheduledReAuth, connID)
r.scheduledLock.Unlock() r.scheduledLock.Unlock()
r.workers <- struct{}{} r.workers.Release()
}() }()
// Create timeout context for connection acquisition // Create timeout context for connection acquisition

View File

@@ -1,3 +1,4 @@
// Package pool implements the pool management
package pool package pool
import ( import (
@@ -17,6 +18,35 @@ import (
var noDeadline = time.Time{} var noDeadline = time.Time{}
// Global time cache updated every 50ms by background goroutine.
// This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout.
// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds).
var globalTimeCache struct {
nowNs atomic.Int64
}
func init() {
// Initialize immediately
globalTimeCache.nowNs.Store(time.Now().UnixNano())
// Start background updater
go func() {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
globalTimeCache.nowNs.Store(time.Now().UnixNano())
}
}()
}
// getCachedTimeNs returns the current time in nanoseconds from the global cache.
// This is updated every 50ms by a background goroutine, avoiding expensive syscalls.
// Max staleness: 50ms.
func getCachedTimeNs() int64 {
return globalTimeCache.nowNs.Load()
}
// Global atomic counter for connection IDs // Global atomic counter for connection IDs
var connIDCounter uint64 var connIDCounter uint64
@@ -79,6 +109,7 @@ type Conn struct {
expiresAt time.Time expiresAt time.Time
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers // maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention // Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
@@ -260,11 +291,13 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
if !old && new { if !old && new {
// Acquiring: IDLE → IN_USE // Acquiring: IDLE → IN_USE
_, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse) // Use predefined slice to avoid allocation
_, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse)
return err == nil return err == nil
} else { } else {
// Releasing: IN_USE → IDLE // Releasing: IN_USE → IDLE
_, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle) // Use predefined slice to avoid allocation
_, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle)
return err == nil return err == nil
} }
} }
@@ -454,7 +487,8 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati
return time.Duration(readTimeoutNs) return time.Duration(readTimeoutNs)
} }
nowNs := time.Now().UnixNano() // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
nowNs := getCachedTimeNs()
// Check if deadline has passed // Check if deadline has passed
if nowNs < deadlineNs { if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout // Deadline is in the future, use relaxed timeout
@@ -487,7 +521,8 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat
return time.Duration(writeTimeoutNs) return time.Duration(writeTimeoutNs)
} }
nowNs := time.Now().UnixNano() // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
nowNs := getCachedTimeNs()
// Check if deadline has passed // Check if deadline has passed
if nowNs < deadlineNs { if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout // Deadline is in the future, use relaxed timeout
@@ -632,7 +667,8 @@ func (cn *Conn) MarkQueuedForHandoff() error {
// The connection is typically in IN_USE state when OnPut is called (normal Put flow) // The connection is typically in IN_USE state when OnPut is called (normal Put flow)
// But in some edge cases or tests, it might be in IDLE or CREATED state // But in some edge cases or tests, it might be in IDLE or CREATED state
// The pool will detect this state change and preserve it (not overwrite with IDLE) // The pool will detect this state change and preserve it (not overwrite with IDLE)
finalState, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle, StateCreated}, StateUnusable) // Use predefined slice to avoid allocation
finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable)
if err != nil { if err != nil {
// Check if already in UNUSABLE state (race condition or retry) // Check if already in UNUSABLE state (race condition or retry)
// ShouldHandoff should be false now, but check just in case // ShouldHandoff should be false now, but check just in case
@@ -658,6 +694,42 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine {
return cn.stateMachine return cn.stateMachine
} }
// TryAcquire attempts to acquire the connection for use.
// This is an optimized inline method for the hot path (Get operation).
//
// It tries to transition from IDLE -> IN_USE or CREATED -> IN_USE.
// Returns true if the connection was successfully acquired, false otherwise.
//
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast()
//
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
// methods. This breaks encapsulation but is necessary for performance.
// The IDLE->IN_USE and CREATED->IN_USE transitions don't need
// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever
// needs to notify waiters on these transitions, update this to use TryTransitionFast().
func (cn *Conn) TryAcquire() bool {
// The || operator short-circuits, so only 1 CAS in the common case
return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) ||
cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateInUse))
}
// Release releases the connection back to the pool.
// This is an optimized inline method for the hot path (Put operation).
//
// It tries to transition from IN_USE -> IDLE.
// Returns true if the connection was successfully released, false otherwise.
//
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast().
//
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
// methods. This breaks encapsulation but is necessary for performance.
// If the state machine ever needs to notify waiters
// on this transition, update this to use TryTransitionFast().
func (cn *Conn) Release() bool {
// Inline the hot path - single CAS operation
return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle))
}
// ClearHandoffState clears the handoff state after successful handoff. // ClearHandoffState clears the handoff state after successful handoff.
// Makes the connection usable again. // Makes the connection usable again.
func (cn *Conn) ClearHandoffState() { func (cn *Conn) ClearHandoffState() {
@@ -800,8 +872,12 @@ func (cn *Conn) MaybeHasData() bool {
return false return false
} }
// deadline computes the effective deadline time based on context and timeout.
// It updates the usedAt timestamp to now.
// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation).
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
tm := time.Now() // Use cached time for deadline calculation (called 2x per command: read + write)
tm := time.Unix(0, getCachedTimeNs())
cn.SetUsedAt(tm) cn.SetUsedAt(tm)
if timeout > 0 { if timeout > 0 {

View File

@@ -41,6 +41,13 @@ const (
StateClosed StateClosed
) )
// Predefined state slices to avoid allocations in hot paths
var (
validFromInUse = []ConnState{StateInUse}
validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle}
validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle}
)
// String returns a human-readable string representation of the state. // String returns a human-readable string representation of the state.
func (s ConnState) String() string { func (s ConnState) String() string {
switch s { switch s {
@@ -92,8 +99,9 @@ type ConnStateMachine struct {
state atomic.Uint32 state atomic.Uint32
// FIFO queue for waiters - only locked during waiter add/remove/notify // FIFO queue for waiters - only locked during waiter add/remove/notify
mu sync.Mutex mu sync.Mutex
waiters *list.List // List of *waiter waiters *list.List // List of *waiter
waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path)
} }
// NewConnStateMachine creates a new connection state machine. // NewConnStateMachine creates a new connection state machine.
@@ -114,6 +122,23 @@ func (sm *ConnStateMachine) GetState() ConnState {
return ConnState(sm.state.Load()) return ConnState(sm.state.Load())
} }
// TryTransitionFast is an optimized version for the hot path (Get/Put operations).
// It only handles simple state transitions without waiter notification.
// This is safe because:
// 1. Get/Put don't need to wait for state changes
// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match
// 3. If a background operation is in progress (state is UNUSABLE), this fails fast
//
// Returns true if transition succeeded, false otherwise.
// Use this for performance-critical paths where you don't need error details.
//
// Performance: Single CAS operation - as fast as the old atomic bool!
// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target)
// The || operator short-circuits, so only 1 CAS is executed in the common case.
func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool {
return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState))
}
// TryTransition attempts an immediate state transition without waiting. // TryTransition attempts an immediate state transition without waiting.
// Returns the current state after the transition attempt and an error if the transition failed. // Returns the current state after the transition attempt and an error if the transition failed.
// The returned state is the CURRENT state (after the attempt), not the previous state. // The returned state is the CURRENT state (after the attempt), not the previous state.
@@ -126,17 +151,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta
// Try each valid from state with CAS // Try each valid from state with CAS
// This ensures only ONE goroutine can successfully transition at a time // This ensures only ONE goroutine can successfully transition at a time
for _, fromState := range validFromStates { for _, fromState := range validFromStates {
// Fast path: check if we're already in target state
if fromState == targetState && sm.GetState() == targetState {
return targetState, nil
}
// Try to atomically swap from fromState to targetState // Try to atomically swap from fromState to targetState
// If successful, we won the race and can proceed // If successful, we won the race and can proceed
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
// Success! We transitioned atomically // Success! We transitioned atomically
// Notify any waiters // Hot path optimization: only check for waiters if transition succeeded
sm.notifyWaiters() // This avoids atomic load on every Get/Put when no waiters exist
if sm.waiterCount.Load() > 0 {
sm.notifyWaiters()
}
return targetState, nil return targetState, nil
} }
} }
@@ -213,6 +236,7 @@ func (sm *ConnStateMachine) AwaitAndTransition(
// Add to FIFO queue // Add to FIFO queue
sm.mu.Lock() sm.mu.Lock()
elem := sm.waiters.PushBack(w) elem := sm.waiters.PushBack(w)
sm.waiterCount.Add(1)
sm.mu.Unlock() sm.mu.Unlock()
// Wait for state change or timeout // Wait for state change or timeout
@@ -221,10 +245,13 @@ func (sm *ConnStateMachine) AwaitAndTransition(
// Timeout or cancellation - remove from queue // Timeout or cancellation - remove from queue
sm.mu.Lock() sm.mu.Lock()
sm.waiters.Remove(elem) sm.waiters.Remove(elem)
sm.waiterCount.Add(-1)
sm.mu.Unlock() sm.mu.Unlock()
return sm.GetState(), ctx.Err() return sm.GetState(), ctx.Err()
case err := <-w.done: case err := <-w.done:
// Transition completed (or failed) // Transition completed (or failed)
// Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed)
// or here (on timeout/cancellation).
return sm.GetState(), err return sm.GetState(), err
} }
} }
@@ -232,9 +259,16 @@ func (sm *ConnStateMachine) AwaitAndTransition(
// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order. // notifyWaiters checks if any waiters can proceed and notifies them in FIFO order.
// This is called after every state transition. // This is called after every state transition.
func (sm *ConnStateMachine) notifyWaiters() { func (sm *ConnStateMachine) notifyWaiters() {
// Fast path: check atomic counter without acquiring lock
// This eliminates mutex overhead in the common case (no waiters)
if sm.waiterCount.Load() == 0 {
return
}
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() defer sm.mu.Unlock()
// Double-check after acquiring lock (waiters might have been processed)
if sm.waiters.Len() == 0 { if sm.waiters.Len() == 0 {
return return
} }
@@ -255,6 +289,7 @@ func (sm *ConnStateMachine) notifyWaiters() {
if _, valid := w.validStates[currentState]; valid { if _, valid := w.validStates[currentState]; valid {
// Remove from queue first // Remove from queue first
sm.waiters.Remove(elem) sm.waiters.Remove(elem)
sm.waiterCount.Add(-1)
// Use CAS to ensure state hasn't changed since we checked // Use CAS to ensure state hasn't changed since we checked
// This prevents race condition where another thread changes state // This prevents race condition where another thread changes state
@@ -267,6 +302,7 @@ func (sm *ConnStateMachine) notifyWaiters() {
} else { } else {
// State changed - re-add waiter to front of queue and retry // State changed - re-add waiter to front of queue and retry
sm.waiters.PushFront(w) sm.waiters.PushFront(w)
sm.waiterCount.Add(1)
// Continue to next iteration to re-read state // Continue to next iteration to re-read state
processed = true processed = true
break break

View File

@@ -20,5 +20,5 @@ func (p *ConnPool) CheckMinIdleConns() {
} }
func (p *ConnPool) QueueLen() int { func (p *ConnPool) QueueLen() int {
return len(p.queue) return int(p.semaphore.Len())
} }

View File

@@ -140,3 +140,16 @@ func (phm *PoolHookManager) GetHooks() []PoolHook {
copy(hooks, phm.hooks) copy(hooks, phm.hooks)
return hooks return hooks
} }
// Clone creates a copy of the hook manager with the same hooks.
// This is used for lock-free atomic updates of the hook manager.
func (phm *PoolHookManager) Clone() *PoolHookManager {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
newManager := &PoolHookManager{
hooks: make([]PoolHook, len(phm.hooks)),
}
copy(newManager.hooks, phm.hooks)
return newManager
}

View File

@@ -202,26 +202,29 @@ func TestPoolWithHooks(t *testing.T) {
pool.AddPoolHook(testHook) pool.AddPoolHook(testHook)
// Verify hooks are initialized // Verify hooks are initialized
if pool.hookManager == nil { manager := pool.hookManager.Load()
if manager == nil {
t.Error("Expected hookManager to be initialized") t.Error("Expected hookManager to be initialized")
} }
if pool.hookManager.GetHookCount() != 1 { if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount()) t.Errorf("Expected 1 hook in pool, got %d", manager.GetHookCount())
} }
// Test adding hook to pool // Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true} additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true}
pool.AddPoolHook(additionalHook) pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 { manager = pool.hookManager.Load()
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount()) if manager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
} }
// Test removing hook from pool // Test removing hook from pool
pool.RemovePoolHook(additionalHook) pool.RemovePoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 1 { manager = pool.hookManager.Load()
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount()) if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
} }
} }

View File

@@ -27,6 +27,12 @@ var (
// ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable. // ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable.
ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable") ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable")
// errHookRequestedRemoval is returned when a hook requests connection removal.
errHookRequestedRemoval = errors.New("hook requested removal")
// errConnNotPooled is returned when trying to return a non-pooled connection to the pool.
errConnNotPooled = errors.New("connection not pooled")
// popAttempts is the maximum number of attempts to find a usable connection // popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections // when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues). // are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
@@ -45,14 +51,6 @@ var (
noExpiration = maxTime noExpiration = maxTime
) )
var timers = sync.Pool{
New: func() interface{} {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
// Stats contains pool state information and accumulated stats. // Stats contains pool state information and accumulated stats.
type Stats struct { type Stats struct {
Hits uint32 // number of times free connection was found in the pool Hits uint32 // number of times free connection was found in the pool
@@ -132,7 +130,9 @@ type ConnPool struct {
dialErrorsNum uint32 // atomic dialErrorsNum uint32 // atomic
lastDialError atomic.Value lastDialError atomic.Value
queue chan struct{} // Fast atomic semaphore for connection limiting
// Replaces the old channel-based queue for better performance
semaphore *internal.FastSemaphore
connsMu sync.Mutex connsMu sync.Mutex
conns map[uint64]*Conn conns map[uint64]*Conn
@@ -148,8 +148,8 @@ type ConnPool struct {
_closed uint32 // atomic _closed uint32 // atomic
// Pool hooks manager for flexible connection processing // Pool hooks manager for flexible connection processing
hookManagerMu sync.RWMutex // Using atomic.Pointer for lock-free reads in hot paths (Get/Put)
hookManager *PoolHookManager hookManager atomic.Pointer[PoolHookManager]
} }
var _ Pooler = (*ConnPool)(nil) var _ Pooler = (*ConnPool)(nil)
@@ -158,7 +158,7 @@ func NewConnPool(opt *Options) *ConnPool {
p := &ConnPool{ p := &ConnPool{
cfg: opt, cfg: opt,
queue: make(chan struct{}, opt.PoolSize), semaphore: internal.NewFastSemaphore(opt.PoolSize),
conns: make(map[uint64]*Conn), conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize), idleConns: make([]*Conn, 0, opt.PoolSize),
} }
@@ -176,27 +176,37 @@ func NewConnPool(opt *Options) *ConnPool {
// initializeHooks sets up the pool hooks system. // initializeHooks sets up the pool hooks system.
func (p *ConnPool) initializeHooks() { func (p *ConnPool) initializeHooks() {
p.hookManager = NewPoolHookManager() manager := NewPoolHookManager()
p.hookManager.Store(manager)
} }
// AddPoolHook adds a pool hook to the pool. // AddPoolHook adds a pool hook to the pool.
func (p *ConnPool) AddPoolHook(hook PoolHook) { func (p *ConnPool) AddPoolHook(hook PoolHook) {
p.hookManagerMu.Lock() // Lock-free read of current manager
defer p.hookManagerMu.Unlock() manager := p.hookManager.Load()
if manager == nil {
if p.hookManager == nil {
p.initializeHooks() p.initializeHooks()
manager = p.hookManager.Load()
} }
p.hookManager.AddHook(hook)
// Create new manager with added hook
newManager := manager.Clone()
newManager.AddHook(hook)
// Atomically swap to new manager
p.hookManager.Store(newManager)
} }
// RemovePoolHook removes a pool hook from the pool. // RemovePoolHook removes a pool hook from the pool.
func (p *ConnPool) RemovePoolHook(hook PoolHook) { func (p *ConnPool) RemovePoolHook(hook PoolHook) {
p.hookManagerMu.Lock() manager := p.hookManager.Load()
defer p.hookManagerMu.Unlock() if manager != nil {
// Create new manager with removed hook
newManager := manager.Clone()
newManager.RemoveHook(hook)
if p.hookManager != nil { // Atomically swap to new manager
p.hookManager.RemoveHook(hook) p.hookManager.Store(newManager)
} }
} }
@@ -213,31 +223,32 @@ func (p *ConnPool) checkMinIdleConns() {
// Only create idle connections if we haven't reached the total pool size limit // Only create idle connections if we haven't reached the total pool size limit
// MinIdleConns should be a subset of PoolSize, not additional connections // MinIdleConns should be a subset of PoolSize, not additional connections
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns { for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
select { // Try to acquire a semaphore token
case p.queue <- struct{}{}: if !p.semaphore.TryAcquire() {
p.poolSize.Add(1) // Semaphore is full, can't create more connections
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
}
}()
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
default:
return return
} }
p.poolSize.Add(1)
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
}
}()
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
} }
} }
@@ -281,7 +292,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, ErrClosed return nil, ErrClosed
} }
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) { if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= p.cfg.MaxActiveConns {
return nil, ErrPoolExhausted return nil, ErrPoolExhausted
} }
@@ -296,7 +307,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
// when first used. Do NOT transition to IDLE here - that happens after initialization completes. // when first used. Do NOT transition to IDLE here - that happens after initialization completes.
// The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success) // The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success)
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > p.cfg.MaxActiveConns {
_ = cn.Close() _ = cn.Close()
return nil, ErrPoolExhausted return nil, ErrPoolExhausted
} }
@@ -441,14 +452,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
return nil, err return nil, err
} }
now := time.Now() // Use cached time for health checks (max 50ms staleness is acceptable)
now := time.Unix(0, getCachedTimeNs())
attempts := 0 attempts := 0
// Get hooks manager once for this getConn call for performance. // Lock-free atomic read - no mutex overhead!
// Note: Hooks added/removed during this call won't be reflected. hookManager := p.hookManager.Load()
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
for { for {
if attempts >= getAttempts { if attempts >= getAttempts {
@@ -476,19 +485,20 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
} }
// Process connection using the hooks system // Process connection using the hooks system
// Combine error and rejection checks to reduce branches
if hookManager != nil { if hookManager != nil {
acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false)
if err != nil { if err != nil || !acceptConn {
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) if err != nil {
_ = p.CloseConn(cn) internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
continue _ = p.CloseConn(cn)
} } else {
if !acceptConn { internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID())
internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) // Return connection to pool without freeing the turn that this Get() call holds.
// Return connection to pool without freeing the turn that this Get() call holds. // We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn.
// We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn. p.putConnWithoutTurn(ctx, cn)
p.putConnWithoutTurn(ctx, cn) cn = nil
cn = nil }
continue continue
} }
} }
@@ -521,44 +531,36 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
} }
func (p *ConnPool) waitTurn(ctx context.Context) error { func (p *ConnPool) waitTurn(ctx context.Context) error {
// Fast path: check context first
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
default: default:
} }
select { // Fast path: try to acquire without blocking
case p.queue <- struct{}{}: if p.semaphore.TryAcquire() {
return nil return nil
default:
} }
// Slow path: need to wait
start := time.Now() start := time.Now()
timer := timers.Get().(*time.Timer) err := p.semaphore.Acquire(ctx, p.cfg.PoolTimeout, ErrPoolTimeout)
defer timers.Put(timer)
timer.Reset(p.cfg.PoolTimeout)
select { switch err {
case <-ctx.Done(): case nil:
if !timer.Stop() { // Successfully acquired after waiting
<-timer.C
}
return ctx.Err()
case p.queue <- struct{}{}:
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
atomic.AddUint32(&p.stats.WaitCount, 1) atomic.AddUint32(&p.stats.WaitCount, 1)
if !timer.Stop() { case ErrPoolTimeout:
<-timer.C
}
return nil
case <-timer.C:
atomic.AddUint32(&p.stats.Timeouts, 1) atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
} }
return err
} }
func (p *ConnPool) freeTurn() { func (p *ConnPool) freeTurn() {
<-p.queue p.semaphore.Release()
} }
func (p *ConnPool) popIdle() (*Conn, error) { func (p *ConnPool) popIdle() (*Conn, error) {
@@ -592,15 +594,16 @@ func (p *ConnPool) popIdle() (*Conn, error) {
} }
attempts++ attempts++
// Try to atomically transition to IN_USE using state machine // Hot path optimization: try IDLE → IN_USE or CREATED → IN_USE transition
// Accept both CREATED (uninitialized) and IDLE (initialized) states // Using inline TryAcquire() method for better performance (avoids pointer dereference)
_, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse) if cn.TryAcquire() {
if err == nil {
// Successfully acquired the connection // Successfully acquired the connection
p.idleConnsLen.Add(-1) p.idleConnsLen.Add(-1)
break break
} }
// Connection is in UNUSABLE, INITIALIZING, or other state - skip it
// Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.) // Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.)
// Put it back in the pool and try the next one // Put it back in the pool and try the next one
if p.cfg.PoolFIFO { if p.cfg.PoolFIFO {
@@ -651,9 +654,8 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) {
// It's a push notification, allow pooling (client will handle it) // It's a push notification, allow pooling (client will handle it)
} }
p.hookManagerMu.RLock() // Lock-free atomic read - no mutex overhead!
hookManager := p.hookManager hookManager := p.hookManager.Load()
p.hookManagerMu.RUnlock()
if hookManager != nil { if hookManager != nil {
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
@@ -664,41 +666,35 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) {
} }
} }
// If hooks say to remove the connection, do so // Combine all removal checks into one - reduces branches
if shouldRemove { if shouldRemove || !shouldPool {
p.removeConnInternal(ctx, cn, errors.New("hook requested removal"), freeTurn) p.removeConnInternal(ctx, cn, errHookRequestedRemoval, freeTurn)
return
}
// If processor says not to pool the connection, remove it
if !shouldPool {
p.removeConnInternal(ctx, cn, errors.New("hook requested no pooling"), freeTurn)
return return
} }
if !cn.pooled { if !cn.pooled {
p.removeConnInternal(ctx, cn, errors.New("connection not pooled"), freeTurn) p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn)
return return
} }
var shouldCloseConn bool var shouldCloseConn bool
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// Try to transition to IDLE state BEFORE adding to pool // Hot path optimization: try fast IN_USE → IDLE transition
// Only transition if connection is still IN_USE (hooks might have changed state) // Using inline Release() method for better performance (avoids pointer dereference)
// This prevents: transitionedToIdle := cn.Release()
// 1. Race condition where another goroutine could acquire a connection that's still in IN_USE state
// 2. Overwriting state changes made by hooks (e.g., IN_USE → UNUSABLE for handoff) if !transitionedToIdle {
currentState, err := cn.GetStateMachine().TryTransition([]ConnState{StateInUse}, StateIdle) // Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff)
if err != nil {
// Hook changed the state (e.g., to UNUSABLE for handoff)
// Keep the state set by the hook and pool the connection anyway // Keep the state set by the hook and pool the connection anyway
currentState := cn.GetStateMachine().GetState()
internal.Logger.Printf(ctx, "Connection state changed by hook to %v, pooling as-is", currentState) internal.Logger.Printf(ctx, "Connection state changed by hook to %v, pooling as-is", currentState)
} }
// unusable conns are expected to become usable at some point (background process is reconnecting them) // unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue // put them at the opposite end of the queue
if !cn.IsUsable() { // Optimization: if we just transitioned to IDLE, we know it's usable - skip the check
if !transitionedToIdle && !cn.IsUsable() {
if p.cfg.PoolFIFO { if p.cfg.PoolFIFO {
p.connsMu.Lock() p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn) p.idleConns = append(p.idleConns, cn)
@@ -742,9 +738,8 @@ func (p *ConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error
// removeConnInternal is the internal implementation of Remove that optionally frees a turn. // removeConnInternal is the internal implementation of Remove that optionally frees a turn.
func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) { func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) {
p.hookManagerMu.RLock() // Lock-free atomic read - no mutex overhead!
hookManager := p.hookManager hookManager := p.hookManager.Load()
p.hookManagerMu.RUnlock()
if hookManager != nil { if hookManager != nil {
hookManager.ProcessOnRemove(ctx, cn, reason) hookManager.ProcessOnRemove(ctx, cn, reason)
@@ -877,36 +872,53 @@ func (p *ConnPool) Close() error {
} }
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
// slight optimization, check expiresAt first. // Performance optimization: check conditions from cheapest to most expensive,
if cn.expiresAt.Before(now) { // and from most likely to fail to least likely to fail.
return false
// Only fails if ConnMaxLifetime is set AND connection is old.
// Most pools don't set ConnMaxLifetime, so this rarely fails.
if p.cfg.ConnMaxLifetime > 0 {
if cn.expiresAt.Before(now) {
return false // Connection has exceeded max lifetime
}
} }
// Check if connection has exceeded idle timeout // Most pools set ConnMaxIdleTime, and idle connections are common.
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { // Checking this first allows us to fail fast without expensive syscalls.
return false if p.cfg.ConnMaxIdleTime > 0 {
if now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false // Connection has been idle too long
}
} }
cn.SetUsedAt(now) // Only run this if the cheap checks passed.
// Check basic connection health
// Use GetNetConn() to safely access netConn and avoid data races
if err := connCheck(cn.getNetConn()); err != nil { if err := connCheck(cn.getNetConn()); err != nil {
// If there's unexpected data, it might be push notifications (RESP3) // If there's unexpected data, it might be push notifications (RESP3)
// However, push notification processing is now handled by the client
// before WithReader to ensure proper context is available to handlers
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead { if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
// we know that there is something in the buffer, so peek at the next reply type without // Peek at the reply type to check if it's a push notification
// the potential to block
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
// For RESP3 connections with push notifications, we allow some buffered data // For RESP3 connections with push notifications, we allow some buffered data
// The client will process these notifications before using the connection // The client will process these notifications before using the connection
internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) internal.Logger.Printf(
return true // Connection is healthy, client will handle notifications context.Background(),
"push: conn[%d] has buffered data, likely push notifications - will be processed by client",
cn.GetID(),
)
// Update timestamp for healthy connection
cn.SetUsedAt(now)
// Connection is healthy, client will handle notifications
return true
} }
return false // Unexpected data, not push notifications, connection is unhealthy // Not a push notification - treat as unhealthy
} else {
return false return false
} }
// Connection failed health check
return false
} }
// Only update UsedAt if connection is healthy (avoids unnecessary atomic store)
cn.SetUsedAt(now)
return true return true
} }

View File

@@ -24,7 +24,7 @@ type PubSubPool struct {
stats PubSubStats stats PubSubStats
} }
// PubSubPool implements a pool for PubSub connections. // NewPubSubPool implements a pool for PubSub connections.
// It intentionally does not implement the Pooler interface // It intentionally does not implement the Pooler interface
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{ return &PubSubPool{

View File

@@ -370,9 +370,17 @@ func BenchmarkPeekPushNotificationName(b *testing.B) {
buf := createValidPushNotification(tc.notification, "data") buf := createValidPushNotification(tc.notification, "data")
data := buf.Bytes() data := buf.Bytes()
// Reuse both bytes.Reader and proto.Reader to avoid allocations
bytesReader := bytes.NewReader(data)
reader := NewReader(bytesReader)
b.ResetTimer() b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
reader := NewReader(bytes.NewReader(data)) // Reset the bytes.Reader to the beginning without allocating
bytesReader.Reset(data)
// Reset the proto.Reader to reuse the bufio buffer
reader.Reset(bytesReader)
_, err := reader.PeekPushNotificationName() _, err := reader.PeekPushNotificationName()
if err != nil { if err != nil {
b.Errorf("PeekPushNotificationName should not error: %v", err) b.Errorf("PeekPushNotificationName should not error: %v", err)

161
internal/semaphore.go Normal file
View File

@@ -0,0 +1,161 @@
package internal
import (
"context"
"sync"
"sync/atomic"
"time"
)
var semTimers = sync.Pool{
New: func() interface{} {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
// FastSemaphore is a counting semaphore implementation using atomic operations.
// It's optimized for the fast path (no blocking) while still supporting timeouts and context cancellation.
//
// Performance characteristics:
// - Fast path (no blocking): Single atomic CAS operation
// - Slow path (blocking): Falls back to channel-based waiting
// - Release: Single atomic decrement + optional channel notification
//
// This is significantly faster than a pure channel-based semaphore because:
// 1. The fast path avoids channel operations entirely (no scheduler involvement)
// 2. Atomic operations are much cheaper than channel send/receive
type FastSemaphore struct {
// Current number of acquired tokens (atomic)
count atomic.Int32
// Maximum number of tokens (capacity)
max int32
// Channel for blocking waiters
// Only used when fast path fails (semaphore is full)
waitCh chan struct{}
}
// NewFastSemaphore creates a new fast semaphore with the given capacity.
func NewFastSemaphore(capacity int32) *FastSemaphore {
return &FastSemaphore{
max: capacity,
waitCh: make(chan struct{}, capacity),
}
}
// TryAcquire attempts to acquire a token without blocking.
// Returns true if successful, false if the semaphore is full.
//
// This is the fast path - just a single CAS operation.
func (s *FastSemaphore) TryAcquire() bool {
for {
current := s.count.Load()
if current >= s.max {
return false // Semaphore is full
}
if s.count.CompareAndSwap(current, current+1) {
return true // Successfully acquired
}
// CAS failed due to concurrent modification, retry
}
}
// Acquire acquires a token, blocking if necessary until one is available or the context is cancelled.
// Returns an error if the context is cancelled or the timeout expires.
// Returns timeoutErr when the timeout expires.
//
// Performance optimization:
// 1. First try fast path (no blocking)
// 2. If that fails, fall back to channel-based waiting
func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
// Fast path: try to acquire without blocking
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Try fast acquire first
if s.TryAcquire() {
return nil
}
// Fast path failed, need to wait
// Use timer pool to avoid allocation
timer := semTimers.Get().(*time.Timer)
defer semTimers.Put(timer)
timer.Reset(timeout)
start := time.Now()
for {
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ctx.Err()
case <-s.waitCh:
// Someone released a token, try to acquire it
if s.TryAcquire() {
if !timer.Stop() {
<-timer.C
}
return nil
}
// Failed to acquire (race with another goroutine), continue waiting
case <-timer.C:
return timeoutErr
}
// Periodically check if we can acquire (handles race conditions)
if time.Since(start) > timeout {
return timeoutErr
}
}
}
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
// This is useful for cases where you don't need timeout or context cancellation.
// Returns immediately if a token is available (fast path).
func (s *FastSemaphore) AcquireBlocking() {
// Try fast path first
if s.TryAcquire() {
return
}
// Slow path: wait for a token
for {
<-s.waitCh
if s.TryAcquire() {
return
}
// Failed to acquire (race with another goroutine), continue waiting
}
}
// Release releases a token back to the semaphore.
// This wakes up one waiting goroutine if any are blocked.
func (s *FastSemaphore) Release() {
s.count.Add(-1)
// Try to wake up a waiter (non-blocking)
// If no one is waiting, this is a no-op
select {
case s.waitCh <- struct{}{}:
// Successfully notified a waiter
default:
// No waiters, that's fine
}
}
// Len returns the current number of acquired tokens.
// Used by tests to check semaphore state.
func (s *FastSemaphore) Len() int32 {
return s.count.Load()
}

View File

@@ -323,6 +323,7 @@ var _ = Describe("Client", func() {
cn, err = client.Pool().Get(context.Background()) cn, err = client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt().UnixNano()).To(BeNumerically(">", createdAt.UnixNano()))
Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) Expect(cn.UsedAt().After(createdAt)).To(BeTrue())
}) })