From 080a33c3a896b6f1c6d037b1ebde3afad60e6b5b Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Mon, 27 Oct 2025 15:06:30 +0200 Subject: [PATCH 01/20] fix(pool): pool performance (#3565) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf(pool): replace hookManager RWMutex with atomic.Pointer and add predefined state slices - Replace hookManager RWMutex with atomic.Pointer for lock-free reads in hot paths - Add predefined state slices to avoid allocations (validFromInUse, validFromCreatedOrIdle, etc.) - Add Clone() method to PoolHookManager for atomic updates - Update AddPoolHook/RemovePoolHook to use copy-on-write pattern - Update all hookManager access points to use atomic Load() Performance improvements: - Eliminates RWMutex contention in Get/Put/Remove hot paths - Reduces allocations by reusing predefined state slices - Lock-free reads allow better CPU cache utilization * perf(pool): eliminate mutex overhead in state machine hot path The state machine was calling notifyWaiters() on EVERY Get/Put operation, which acquired a mutex even when no waiters were present (the common case). Fix: Use atomic waiterCount to check for waiters BEFORE acquiring mutex. This eliminates mutex contention in the hot path (Get/Put operations). Implementation: - Added atomic.Int32 waiterCount field to ConnStateMachine - Increment when adding waiter, decrement when removing - Check waiterCount atomically before acquiring mutex in notifyWaiters() Performance impact: - Before: mutex lock/unlock on every Get/Put (even with no waiters) - After: lock-free atomic check, only acquire mutex if waiters exist - Expected improvement: ~30-50% for Get/Put operations * perf(pool): use predefined state slices to eliminate allocations in hot path The pool was creating new slice literals on EVERY Get/Put operation: - popIdle(): []ConnState{StateCreated, StateIdle} - putConn(): []ConnState{StateInUse} - CompareAndSwapUsed(): []ConnState{StateIdle} and []ConnState{StateInUse} - MarkUnusableForHandoff(): []ConnState{StateInUse, StateIdle, StateCreated} These allocations were happening millions of times per second in the hot path. Fix: Use predefined global slices defined in conn_state.go: - validFromInUse - validFromCreatedOrIdle - validFromCreatedInUseOrIdle Performance impact: - Before: 4 slice allocations per Get/Put cycle - After: 0 allocations (use predefined slices) - Expected improvement: ~30-40% reduction in allocations and GC pressure * perf(pool): optimize TryTransition to reduce atomic operations Further optimize the hot path by: 1. Remove redundant GetState() call in the loop 2. Only check waiterCount after successful CAS (not before loop) 3. Inline the waiterCount check to avoid notifyWaiters() call overhead This reduces atomic operations from 4-5 per Get/Put to 2-3: - Before: GetState() + CAS + waiterCount.Load() + notifyWaiters mutex check - After: CAS + waiterCount.Load() (only if CAS succeeds) Performance impact: - Eliminates 1-2 atomic operations per Get/Put - Expected improvement: ~10-15% for Get/Put operations * perf(pool): add fast path for Get/Put to match master performance Introduced TryTransitionFast() for the hot path (Get/Put operations): - Single CAS operation (same as master's atomic bool) - No waiter notification overhead - No loop through valid states - No error allocation Hot path flow: 1. popIdle(): Try IDLE → IN_USE (fast), fallback to CREATED → IN_USE 2. putConn(): Try IN_USE → IDLE (fast) This matches master's performance while preserving state machine for: - Background operations (handoff/reauth use UNUSABLE state) - State validation (TryTransition still available) - Waiter notification (AwaitAndTransition for blocking) Performance comparison per Get/Put cycle: - Master: 2 atomic CAS operations - State machine (before): 5 atomic operations (2.5x slower) - State machine (after): 2 atomic CAS operations (same as master!) Expected improvement: Restore to baseline ~11,373 ops/sec * combine cas * fix linter * try faster approach * fast semaphore * better inlining for hot path * fix linter issues * use new semaphore in auth as well * linter should be happy now * add comments * Update internal/pool/conn_state.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * address comment * slight reordering * try to cache time if for non-critical calculation * fix wrong benchmark * add concurrent test * fix benchmark report * add additional expect to check output * comment and variable rename --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- hset_benchmark_test.go | 115 +++++++- internal/auth/streaming/pool_hook.go | 19 +- internal/pool/conn.go | 88 +++++- internal/pool/conn_state.go | 54 +++- internal/pool/export_test.go | 2 +- internal/pool/hooks.go | 13 + internal/pool/hooks_test.go | 17 +- internal/pool/pool.go | 272 +++++++++--------- internal/pool/pubsub.go | 2 +- internal/proto/peek_push_notification_test.go | 10 +- internal/semaphore.go | 161 +++++++++++ redis_test.go | 1 + 12 files changed, 585 insertions(+), 169 deletions(-) create mode 100644 internal/semaphore.go diff --git a/hset_benchmark_test.go b/hset_benchmark_test.go index 8d141f41..649c9352 100644 --- a/hset_benchmark_test.go +++ b/hset_benchmark_test.go @@ -3,6 +3,7 @@ package redis_test import ( "context" "fmt" + "sync" "testing" "time" @@ -100,7 +101,82 @@ func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Contex avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) b.ReportMetric(float64(avgTimePerOp), "ns/op") // report average time in milliseconds from totalTimes - avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) + sumTime := time.Duration(0) + for _, t := range totalTimes { + sumTime += t + } + avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes)) + b.ReportMetric(float64(avgTimePerOpMs), "ms") +} + +// benchmarkHSETOperationsConcurrent performs the actual HSET benchmark for a given scale +func benchmarkHSETOperationsConcurrent(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { + hashKey := fmt.Sprintf("benchmark_hash_%d", operations) + + b.ResetTimer() + b.StartTimer() + totalTimes := []time.Duration{} + + for i := 0; i < b.N; i++ { + b.StopTimer() + // Clean up the hash before each iteration + rdb.Del(ctx, hashKey) + b.StartTimer() + + startTime := time.Now() + // Perform the specified number of HSET operations + + wg := sync.WaitGroup{} + timesCh := make(chan time.Duration, operations) + errCh := make(chan error, operations) + + for j := 0; j < operations; j++ { + wg.Add(1) + go func(j int) { + defer wg.Done() + field := fmt.Sprintf("field_%d", j) + value := fmt.Sprintf("value_%d", j) + + err := rdb.HSet(ctx, hashKey, field, value).Err() + if err != nil { + errCh <- err + return + } + timesCh <- time.Since(startTime) + }(j) + } + + wg.Wait() + close(timesCh) + close(errCh) + + // Check for errors + for err := range errCh { + b.Errorf("HSET operation failed: %v", err) + } + + for d := range timesCh { + totalTimes = append(totalTimes, d) + } + } + + // Stop the timer to calculate metrics + b.StopTimer() + + // Report operations per second + opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds() + b.ReportMetric(opsPerSec, "ops/sec") + + // Report average time per operation + avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) + b.ReportMetric(float64(avgTimePerOp), "ns/op") + // report average time in milliseconds from totalTimes + + sumTime := time.Duration(0) + for _, t := range totalTimes { + sumTime += t + } + avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes)) b.ReportMetric(float64(avgTimePerOpMs), "ms") } @@ -134,6 +210,37 @@ func BenchmarkHSETPipelined(b *testing.B) { } } +func BenchmarkHSET_Concurrent(b *testing.B) { + ctx := context.Background() + + // Setup Redis client + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + DB: 0, + PoolSize: 100, + }) + defer rdb.Close() + + // Test connection + if err := rdb.Ping(ctx).Err(); err != nil { + b.Skipf("Redis server not available: %v", err) + } + + // Clean up before and after tests + defer func() { + rdb.FlushDB(ctx) + }() + + // Reduced scales to avoid overwhelming the system with too many concurrent goroutines + scales := []int{1, 10, 100, 1000} + + for _, scale := range scales { + b.Run(fmt.Sprintf("HSET_%d_operations_concurrent", scale), func(b *testing.B) { + benchmarkHSETOperationsConcurrent(b, rdb, ctx, scale) + }) + } +} + // benchmarkHSETPipelined performs HSET benchmark using pipelining func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations) @@ -177,7 +284,11 @@ func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) b.ReportMetric(float64(avgTimePerOp), "ns/op") // report average time in milliseconds from totalTimes - avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) + sumTime := time.Duration(0) + for _, t := range totalTimes { + sumTime += t + } + avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes)) b.ReportMetric(float64(avgTimePerOpMs), "ms") } diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index a5647be0..f37fe557 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -34,9 +34,10 @@ type ReAuthPoolHook struct { shouldReAuth map[uint64]func(error) shouldReAuthLock sync.RWMutex - // workers is a semaphore channel limiting concurrent re-auth operations + // workers is a semaphore limiting concurrent re-auth operations // Initialized with poolSize tokens to prevent pool exhaustion - workers chan struct{} + // Uses FastSemaphore for consistency and better performance + workers *internal.FastSemaphore // reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth reAuthTimeout time.Duration @@ -59,16 +60,10 @@ type ReAuthPoolHook struct { // The poolSize parameter is used to initialize the worker semaphore, ensuring that // re-auth operations don't exhaust the connection pool. func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { - workers := make(chan struct{}, poolSize) - // Initialize the workers channel with tokens (semaphore pattern) - for i := 0; i < poolSize; i++ { - workers <- struct{}{} - } - return &ReAuthPoolHook{ shouldReAuth: make(map[uint64]func(error)), scheduledReAuth: make(map[uint64]bool), - workers: workers, + workers: internal.NewFastSemaphore(int32(poolSize)), reAuthTimeout: reAuthTimeout, } } @@ -162,10 +157,10 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, r.scheduledLock.Unlock() r.shouldReAuthLock.Unlock() go func() { - <-r.workers + r.workers.AcquireBlocking() // safety first if conn == nil || (conn != nil && conn.IsClosed()) { - r.workers <- struct{}{} + r.workers.Release() return } defer func() { @@ -176,7 +171,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, r.scheduledLock.Lock() delete(r.scheduledReAuth, connID) r.scheduledLock.Unlock() - r.workers <- struct{}{} + r.workers.Release() }() // Create timeout context for connection acquisition diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 4d38184a..56be7098 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -1,3 +1,4 @@ +// Package pool implements the pool management package pool import ( @@ -17,6 +18,35 @@ import ( var noDeadline = time.Time{} +// Global time cache updated every 50ms by background goroutine. +// This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout. +// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). +var globalTimeCache struct { + nowNs atomic.Int64 +} + +func init() { + // Initialize immediately + globalTimeCache.nowNs.Store(time.Now().UnixNano()) + + // Start background updater + go func() { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for range ticker.C { + globalTimeCache.nowNs.Store(time.Now().UnixNano()) + } + }() +} + +// getCachedTimeNs returns the current time in nanoseconds from the global cache. +// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 50ms. +func getCachedTimeNs() int64 { + return globalTimeCache.nowNs.Load() +} + // Global atomic counter for connection IDs var connIDCounter uint64 @@ -79,6 +109,7 @@ type Conn struct { expiresAt time.Time // maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers + // Using atomic operations for lock-free access to avoid mutex contention relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds @@ -260,11 +291,13 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool { if !old && new { // Acquiring: IDLE → IN_USE - _, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse) + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse) return err == nil } else { // Releasing: IN_USE → IDLE - _, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle) + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle) return err == nil } } @@ -454,7 +487,8 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati return time.Duration(readTimeoutNs) } - nowNs := time.Now().UnixNano() + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { // Deadline is in the future, use relaxed timeout @@ -487,7 +521,8 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat return time.Duration(writeTimeoutNs) } - nowNs := time.Now().UnixNano() + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { // Deadline is in the future, use relaxed timeout @@ -632,7 +667,8 @@ func (cn *Conn) MarkQueuedForHandoff() error { // The connection is typically in IN_USE state when OnPut is called (normal Put flow) // But in some edge cases or tests, it might be in IDLE or CREATED state // The pool will detect this state change and preserve it (not overwrite with IDLE) - finalState, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle, StateCreated}, StateUnusable) + // Use predefined slice to avoid allocation + finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable) if err != nil { // Check if already in UNUSABLE state (race condition or retry) // ShouldHandoff should be false now, but check just in case @@ -658,6 +694,42 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine { return cn.stateMachine } +// TryAcquire attempts to acquire the connection for use. +// This is an optimized inline method for the hot path (Get operation). +// +// It tries to transition from IDLE -> IN_USE or CREATED -> IN_USE. +// Returns true if the connection was successfully acquired, false otherwise. +// +// Performance: This is faster than calling GetStateMachine() + TryTransitionFast() +// +// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's +// methods. This breaks encapsulation but is necessary for performance. +// The IDLE->IN_USE and CREATED->IN_USE transitions don't need +// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever +// needs to notify waiters on these transitions, update this to use TryTransitionFast(). +func (cn *Conn) TryAcquire() bool { + // The || operator short-circuits, so only 1 CAS in the common case + return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) || + cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateInUse)) +} + +// Release releases the connection back to the pool. +// This is an optimized inline method for the hot path (Put operation). +// +// It tries to transition from IN_USE -> IDLE. +// Returns true if the connection was successfully released, false otherwise. +// +// Performance: This is faster than calling GetStateMachine() + TryTransitionFast(). +// +// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's +// methods. This breaks encapsulation but is necessary for performance. +// If the state machine ever needs to notify waiters +// on this transition, update this to use TryTransitionFast(). +func (cn *Conn) Release() bool { + // Inline the hot path - single CAS operation + return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle)) +} + // ClearHandoffState clears the handoff state after successful handoff. // Makes the connection usable again. func (cn *Conn) ClearHandoffState() { @@ -800,8 +872,12 @@ func (cn *Conn) MaybeHasData() bool { return false } +// deadline computes the effective deadline time based on context and timeout. +// It updates the usedAt timestamp to now. +// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { - tm := time.Now() + // Use cached time for deadline calculation (called 2x per command: read + write) + tm := time.Unix(0, getCachedTimeNs()) cn.SetUsedAt(tm) if timeout > 0 { diff --git a/internal/pool/conn_state.go b/internal/pool/conn_state.go index 93147d17..32fc5058 100644 --- a/internal/pool/conn_state.go +++ b/internal/pool/conn_state.go @@ -41,6 +41,13 @@ const ( StateClosed ) +// Predefined state slices to avoid allocations in hot paths +var ( + validFromInUse = []ConnState{StateInUse} + validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle} + validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle} +) + // String returns a human-readable string representation of the state. func (s ConnState) String() string { switch s { @@ -92,8 +99,9 @@ type ConnStateMachine struct { state atomic.Uint32 // FIFO queue for waiters - only locked during waiter add/remove/notify - mu sync.Mutex - waiters *list.List // List of *waiter + mu sync.Mutex + waiters *list.List // List of *waiter + waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path) } // NewConnStateMachine creates a new connection state machine. @@ -114,6 +122,23 @@ func (sm *ConnStateMachine) GetState() ConnState { return ConnState(sm.state.Load()) } +// TryTransitionFast is an optimized version for the hot path (Get/Put operations). +// It only handles simple state transitions without waiter notification. +// This is safe because: +// 1. Get/Put don't need to wait for state changes +// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match +// 3. If a background operation is in progress (state is UNUSABLE), this fails fast +// +// Returns true if transition succeeded, false otherwise. +// Use this for performance-critical paths where you don't need error details. +// +// Performance: Single CAS operation - as fast as the old atomic bool! +// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target) +// The || operator short-circuits, so only 1 CAS is executed in the common case. +func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool { + return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) +} + // TryTransition attempts an immediate state transition without waiting. // Returns the current state after the transition attempt and an error if the transition failed. // The returned state is the CURRENT state (after the attempt), not the previous state. @@ -126,17 +151,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta // Try each valid from state with CAS // This ensures only ONE goroutine can successfully transition at a time for _, fromState := range validFromStates { - // Fast path: check if we're already in target state - if fromState == targetState && sm.GetState() == targetState { - return targetState, nil - } - // Try to atomically swap from fromState to targetState // If successful, we won the race and can proceed if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { // Success! We transitioned atomically - // Notify any waiters - sm.notifyWaiters() + // Hot path optimization: only check for waiters if transition succeeded + // This avoids atomic load on every Get/Put when no waiters exist + if sm.waiterCount.Load() > 0 { + sm.notifyWaiters() + } return targetState, nil } } @@ -213,6 +236,7 @@ func (sm *ConnStateMachine) AwaitAndTransition( // Add to FIFO queue sm.mu.Lock() elem := sm.waiters.PushBack(w) + sm.waiterCount.Add(1) sm.mu.Unlock() // Wait for state change or timeout @@ -221,10 +245,13 @@ func (sm *ConnStateMachine) AwaitAndTransition( // Timeout or cancellation - remove from queue sm.mu.Lock() sm.waiters.Remove(elem) + sm.waiterCount.Add(-1) sm.mu.Unlock() return sm.GetState(), ctx.Err() case err := <-w.done: // Transition completed (or failed) + // Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed) + // or here (on timeout/cancellation). return sm.GetState(), err } } @@ -232,9 +259,16 @@ func (sm *ConnStateMachine) AwaitAndTransition( // notifyWaiters checks if any waiters can proceed and notifies them in FIFO order. // This is called after every state transition. func (sm *ConnStateMachine) notifyWaiters() { + // Fast path: check atomic counter without acquiring lock + // This eliminates mutex overhead in the common case (no waiters) + if sm.waiterCount.Load() == 0 { + return + } + sm.mu.Lock() defer sm.mu.Unlock() + // Double-check after acquiring lock (waiters might have been processed) if sm.waiters.Len() == 0 { return } @@ -255,6 +289,7 @@ func (sm *ConnStateMachine) notifyWaiters() { if _, valid := w.validStates[currentState]; valid { // Remove from queue first sm.waiters.Remove(elem) + sm.waiterCount.Add(-1) // Use CAS to ensure state hasn't changed since we checked // This prevents race condition where another thread changes state @@ -267,6 +302,7 @@ func (sm *ConnStateMachine) notifyWaiters() { } else { // State changed - re-add waiter to front of queue and retry sm.waiters.PushFront(w) + sm.waiterCount.Add(1) // Continue to next iteration to re-read state processed = true break diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index 20456b81..2d178038 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -20,5 +20,5 @@ func (p *ConnPool) CheckMinIdleConns() { } func (p *ConnPool) QueueLen() int { - return len(p.queue) + return int(p.semaphore.Len()) } diff --git a/internal/pool/hooks.go b/internal/pool/hooks.go index bfbd9e14..1c365dba 100644 --- a/internal/pool/hooks.go +++ b/internal/pool/hooks.go @@ -140,3 +140,16 @@ func (phm *PoolHookManager) GetHooks() []PoolHook { copy(hooks, phm.hooks) return hooks } + +// Clone creates a copy of the hook manager with the same hooks. +// This is used for lock-free atomic updates of the hook manager. +func (phm *PoolHookManager) Clone() *PoolHookManager { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + newManager := &PoolHookManager{ + hooks: make([]PoolHook, len(phm.hooks)), + } + copy(newManager.hooks, phm.hooks) + return newManager +} diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go index ec1d6da3..b8f504df 100644 --- a/internal/pool/hooks_test.go +++ b/internal/pool/hooks_test.go @@ -202,26 +202,29 @@ func TestPoolWithHooks(t *testing.T) { pool.AddPoolHook(testHook) // Verify hooks are initialized - if pool.hookManager == nil { + manager := pool.hookManager.Load() + if manager == nil { t.Error("Expected hookManager to be initialized") } - if pool.hookManager.GetHookCount() != 1 { - t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount()) + if manager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook in pool, got %d", manager.GetHookCount()) } // Test adding hook to pool additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true} pool.AddPoolHook(additionalHook) - if pool.hookManager.GetHookCount() != 2 { - t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount()) + manager = pool.hookManager.Load() + if manager.GetHookCount() != 2 { + t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount()) } // Test removing hook from pool pool.RemovePoolHook(additionalHook) - if pool.hookManager.GetHookCount() != 1 { - t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount()) + manager = pool.hookManager.Load() + if manager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount()) } } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 59b8e194..2dedca05 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -27,6 +27,12 @@ var ( // ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable. ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable") + // errHookRequestedRemoval is returned when a hook requests connection removal. + errHookRequestedRemoval = errors.New("hook requested removal") + + // errConnNotPooled is returned when trying to return a non-pooled connection to the pool. + errConnNotPooled = errors.New("connection not pooled") + // popAttempts is the maximum number of attempts to find a usable connection // when popping from the idle connection pool. This handles cases where connections // are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues). @@ -45,14 +51,6 @@ var ( noExpiration = maxTime ) -var timers = sync.Pool{ - New: func() interface{} { - t := time.NewTimer(time.Hour) - t.Stop() - return t - }, -} - // Stats contains pool state information and accumulated stats. type Stats struct { Hits uint32 // number of times free connection was found in the pool @@ -132,7 +130,9 @@ type ConnPool struct { dialErrorsNum uint32 // atomic lastDialError atomic.Value - queue chan struct{} + // Fast atomic semaphore for connection limiting + // Replaces the old channel-based queue for better performance + semaphore *internal.FastSemaphore connsMu sync.Mutex conns map[uint64]*Conn @@ -148,8 +148,8 @@ type ConnPool struct { _closed uint32 // atomic // Pool hooks manager for flexible connection processing - hookManagerMu sync.RWMutex - hookManager *PoolHookManager + // Using atomic.Pointer for lock-free reads in hot paths (Get/Put) + hookManager atomic.Pointer[PoolHookManager] } var _ Pooler = (*ConnPool)(nil) @@ -158,7 +158,7 @@ func NewConnPool(opt *Options) *ConnPool { p := &ConnPool{ cfg: opt, - queue: make(chan struct{}, opt.PoolSize), + semaphore: internal.NewFastSemaphore(opt.PoolSize), conns: make(map[uint64]*Conn), idleConns: make([]*Conn, 0, opt.PoolSize), } @@ -176,27 +176,37 @@ func NewConnPool(opt *Options) *ConnPool { // initializeHooks sets up the pool hooks system. func (p *ConnPool) initializeHooks() { - p.hookManager = NewPoolHookManager() + manager := NewPoolHookManager() + p.hookManager.Store(manager) } // AddPoolHook adds a pool hook to the pool. func (p *ConnPool) AddPoolHook(hook PoolHook) { - p.hookManagerMu.Lock() - defer p.hookManagerMu.Unlock() - - if p.hookManager == nil { + // Lock-free read of current manager + manager := p.hookManager.Load() + if manager == nil { p.initializeHooks() + manager = p.hookManager.Load() } - p.hookManager.AddHook(hook) + + // Create new manager with added hook + newManager := manager.Clone() + newManager.AddHook(hook) + + // Atomically swap to new manager + p.hookManager.Store(newManager) } // RemovePoolHook removes a pool hook from the pool. func (p *ConnPool) RemovePoolHook(hook PoolHook) { - p.hookManagerMu.Lock() - defer p.hookManagerMu.Unlock() + manager := p.hookManager.Load() + if manager != nil { + // Create new manager with removed hook + newManager := manager.Clone() + newManager.RemoveHook(hook) - if p.hookManager != nil { - p.hookManager.RemoveHook(hook) + // Atomically swap to new manager + p.hookManager.Store(newManager) } } @@ -213,31 +223,32 @@ func (p *ConnPool) checkMinIdleConns() { // Only create idle connections if we haven't reached the total pool size limit // MinIdleConns should be a subset of PoolSize, not additional connections for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns { - select { - case p.queue <- struct{}{}: - p.poolSize.Add(1) - p.idleConnsLen.Add(1) - go func() { - defer func() { - if err := recover(); err != nil { - p.poolSize.Add(-1) - p.idleConnsLen.Add(-1) - - p.freeTurn() - internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) - } - }() - - err := p.addIdleConn() - if err != nil && err != ErrClosed { - p.poolSize.Add(-1) - p.idleConnsLen.Add(-1) - } - p.freeTurn() - }() - default: + // Try to acquire a semaphore token + if !p.semaphore.TryAcquire() { + // Semaphore is full, can't create more connections return } + + p.poolSize.Add(1) + p.idleConnsLen.Add(1) + go func() { + defer func() { + if err := recover(); err != nil { + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) + + p.freeTurn() + internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) + } + }() + + err := p.addIdleConn() + if err != nil && err != ErrClosed { + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) + } + p.freeTurn() + }() } } @@ -281,7 +292,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, ErrClosed } - if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= p.cfg.MaxActiveConns { return nil, ErrPoolExhausted } @@ -296,7 +307,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { // when first used. Do NOT transition to IDLE here - that happens after initialization completes. // The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success) - if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > p.cfg.MaxActiveConns { _ = cn.Close() return nil, ErrPoolExhausted } @@ -441,14 +452,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { return nil, err } - now := time.Now() + // Use cached time for health checks (max 50ms staleness is acceptable) + now := time.Unix(0, getCachedTimeNs()) attempts := 0 - // Get hooks manager once for this getConn call for performance. - // Note: Hooks added/removed during this call won't be reflected. - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() for { if attempts >= getAttempts { @@ -476,19 +485,20 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system + // Combine error and rejection checks to reduce branches if hookManager != nil { acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) - if err != nil { - internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) - _ = p.CloseConn(cn) - continue - } - if !acceptConn { - internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) - // Return connection to pool without freeing the turn that this Get() call holds. - // We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn. - p.putConnWithoutTurn(ctx, cn) - cn = nil + if err != nil || !acceptConn { + if err != nil { + internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) + _ = p.CloseConn(cn) + } else { + internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + // Return connection to pool without freeing the turn that this Get() call holds. + // We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn. + p.putConnWithoutTurn(ctx, cn) + cn = nil + } continue } } @@ -521,44 +531,36 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } func (p *ConnPool) waitTurn(ctx context.Context) error { + // Fast path: check context first select { case <-ctx.Done(): return ctx.Err() default: } - select { - case p.queue <- struct{}{}: + // Fast path: try to acquire without blocking + if p.semaphore.TryAcquire() { return nil - default: } + // Slow path: need to wait start := time.Now() - timer := timers.Get().(*time.Timer) - defer timers.Put(timer) - timer.Reset(p.cfg.PoolTimeout) + err := p.semaphore.Acquire(ctx, p.cfg.PoolTimeout, ErrPoolTimeout) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return ctx.Err() - case p.queue <- struct{}{}: + switch err { + case nil: + // Successfully acquired after waiting p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) atomic.AddUint32(&p.stats.WaitCount, 1) - if !timer.Stop() { - <-timer.C - } - return nil - case <-timer.C: + case ErrPoolTimeout: atomic.AddUint32(&p.stats.Timeouts, 1) - return ErrPoolTimeout } + + return err } func (p *ConnPool) freeTurn() { - <-p.queue + p.semaphore.Release() } func (p *ConnPool) popIdle() (*Conn, error) { @@ -592,15 +594,16 @@ func (p *ConnPool) popIdle() (*Conn, error) { } attempts++ - // Try to atomically transition to IN_USE using state machine - // Accept both CREATED (uninitialized) and IDLE (initialized) states - _, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse) - if err == nil { + // Hot path optimization: try IDLE → IN_USE or CREATED → IN_USE transition + // Using inline TryAcquire() method for better performance (avoids pointer dereference) + if cn.TryAcquire() { // Successfully acquired the connection p.idleConnsLen.Add(-1) break } + // Connection is in UNUSABLE, INITIALIZING, or other state - skip it + // Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.) // Put it back in the pool and try the next one if p.cfg.PoolFIFO { @@ -651,9 +654,8 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { // It's a push notification, allow pooling (client will handle it) } - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() if hookManager != nil { shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) @@ -664,41 +666,35 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { } } - // If hooks say to remove the connection, do so - if shouldRemove { - p.removeConnInternal(ctx, cn, errors.New("hook requested removal"), freeTurn) - return - } - - // If processor says not to pool the connection, remove it - if !shouldPool { - p.removeConnInternal(ctx, cn, errors.New("hook requested no pooling"), freeTurn) + // Combine all removal checks into one - reduces branches + if shouldRemove || !shouldPool { + p.removeConnInternal(ctx, cn, errHookRequestedRemoval, freeTurn) return } if !cn.pooled { - p.removeConnInternal(ctx, cn, errors.New("connection not pooled"), freeTurn) + p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn) return } var shouldCloseConn bool if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { - // Try to transition to IDLE state BEFORE adding to pool - // Only transition if connection is still IN_USE (hooks might have changed state) - // This prevents: - // 1. Race condition where another goroutine could acquire a connection that's still in IN_USE state - // 2. Overwriting state changes made by hooks (e.g., IN_USE → UNUSABLE for handoff) - currentState, err := cn.GetStateMachine().TryTransition([]ConnState{StateInUse}, StateIdle) - if err != nil { - // Hook changed the state (e.g., to UNUSABLE for handoff) + // Hot path optimization: try fast IN_USE → IDLE transition + // Using inline Release() method for better performance (avoids pointer dereference) + transitionedToIdle := cn.Release() + + if !transitionedToIdle { + // Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff) // Keep the state set by the hook and pool the connection anyway + currentState := cn.GetStateMachine().GetState() internal.Logger.Printf(ctx, "Connection state changed by hook to %v, pooling as-is", currentState) } // unusable conns are expected to become usable at some point (background process is reconnecting them) // put them at the opposite end of the queue - if !cn.IsUsable() { + // Optimization: if we just transitioned to IDLE, we know it's usable - skip the check + if !transitionedToIdle && !cn.IsUsable() { if p.cfg.PoolFIFO { p.connsMu.Lock() p.idleConns = append(p.idleConns, cn) @@ -742,9 +738,8 @@ func (p *ConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error // removeConnInternal is the internal implementation of Remove that optionally frees a turn. func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) { - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() if hookManager != nil { hookManager.ProcessOnRemove(ctx, cn, reason) @@ -877,36 +872,53 @@ func (p *ConnPool) Close() error { } func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { - // slight optimization, check expiresAt first. - if cn.expiresAt.Before(now) { - return false + // Performance optimization: check conditions from cheapest to most expensive, + // and from most likely to fail to least likely to fail. + + // Only fails if ConnMaxLifetime is set AND connection is old. + // Most pools don't set ConnMaxLifetime, so this rarely fails. + if p.cfg.ConnMaxLifetime > 0 { + if cn.expiresAt.Before(now) { + return false // Connection has exceeded max lifetime + } } - // Check if connection has exceeded idle timeout - if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { - return false + // Most pools set ConnMaxIdleTime, and idle connections are common. + // Checking this first allows us to fail fast without expensive syscalls. + if p.cfg.ConnMaxIdleTime > 0 { + if now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { + return false // Connection has been idle too long + } } - cn.SetUsedAt(now) - // Check basic connection health - // Use GetNetConn() to safely access netConn and avoid data races + // Only run this if the cheap checks passed. if err := connCheck(cn.getNetConn()); err != nil { // If there's unexpected data, it might be push notifications (RESP3) - // However, push notification processing is now handled by the client - // before WithReader to ensure proper context is available to handlers if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead { - // we know that there is something in the buffer, so peek at the next reply type without - // the potential to block + // Peek at the reply type to check if it's a push notification if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { // For RESP3 connections with push notifications, we allow some buffered data // The client will process these notifications before using the connection - internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) - return true // Connection is healthy, client will handle notifications + internal.Logger.Printf( + context.Background(), + "push: conn[%d] has buffered data, likely push notifications - will be processed by client", + cn.GetID(), + ) + + // Update timestamp for healthy connection + cn.SetUsedAt(now) + + // Connection is healthy, client will handle notifications + return true } - return false // Unexpected data, not push notifications, connection is unhealthy - } else { + // Not a push notification - treat as unhealthy return false } + // Connection failed health check + return false } + + // Only update UsedAt if connection is healthy (avoids unnecessary atomic store) + cn.SetUsedAt(now) return true } diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go index ed87d1bb..5b29659e 100644 --- a/internal/pool/pubsub.go +++ b/internal/pool/pubsub.go @@ -24,7 +24,7 @@ type PubSubPool struct { stats PubSubStats } -// PubSubPool implements a pool for PubSub connections. +// NewPubSubPool implements a pool for PubSub connections. // It intentionally does not implement the Pooler interface func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { return &PubSubPool{ diff --git a/internal/proto/peek_push_notification_test.go b/internal/proto/peek_push_notification_test.go index 58a794b8..49186759 100644 --- a/internal/proto/peek_push_notification_test.go +++ b/internal/proto/peek_push_notification_test.go @@ -370,9 +370,17 @@ func BenchmarkPeekPushNotificationName(b *testing.B) { buf := createValidPushNotification(tc.notification, "data") data := buf.Bytes() + // Reuse both bytes.Reader and proto.Reader to avoid allocations + bytesReader := bytes.NewReader(data) + reader := NewReader(bytesReader) + b.ResetTimer() + b.ReportAllocs() for i := 0; i < b.N; i++ { - reader := NewReader(bytes.NewReader(data)) + // Reset the bytes.Reader to the beginning without allocating + bytesReader.Reset(data) + // Reset the proto.Reader to reuse the bufio buffer + reader.Reset(bytesReader) _, err := reader.PeekPushNotificationName() if err != nil { b.Errorf("PeekPushNotificationName should not error: %v", err) diff --git a/internal/semaphore.go b/internal/semaphore.go new file mode 100644 index 00000000..091b6635 --- /dev/null +++ b/internal/semaphore.go @@ -0,0 +1,161 @@ +package internal + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +var semTimers = sync.Pool{ + New: func() interface{} { + t := time.NewTimer(time.Hour) + t.Stop() + return t + }, +} + +// FastSemaphore is a counting semaphore implementation using atomic operations. +// It's optimized for the fast path (no blocking) while still supporting timeouts and context cancellation. +// +// Performance characteristics: +// - Fast path (no blocking): Single atomic CAS operation +// - Slow path (blocking): Falls back to channel-based waiting +// - Release: Single atomic decrement + optional channel notification +// +// This is significantly faster than a pure channel-based semaphore because: +// 1. The fast path avoids channel operations entirely (no scheduler involvement) +// 2. Atomic operations are much cheaper than channel send/receive +type FastSemaphore struct { + // Current number of acquired tokens (atomic) + count atomic.Int32 + + // Maximum number of tokens (capacity) + max int32 + + // Channel for blocking waiters + // Only used when fast path fails (semaphore is full) + waitCh chan struct{} +} + +// NewFastSemaphore creates a new fast semaphore with the given capacity. +func NewFastSemaphore(capacity int32) *FastSemaphore { + return &FastSemaphore{ + max: capacity, + waitCh: make(chan struct{}, capacity), + } +} + +// TryAcquire attempts to acquire a token without blocking. +// Returns true if successful, false if the semaphore is full. +// +// This is the fast path - just a single CAS operation. +func (s *FastSemaphore) TryAcquire() bool { + for { + current := s.count.Load() + if current >= s.max { + return false // Semaphore is full + } + if s.count.CompareAndSwap(current, current+1) { + return true // Successfully acquired + } + // CAS failed due to concurrent modification, retry + } +} + +// Acquire acquires a token, blocking if necessary until one is available or the context is cancelled. +// Returns an error if the context is cancelled or the timeout expires. +// Returns timeoutErr when the timeout expires. +// +// Performance optimization: +// 1. First try fast path (no blocking) +// 2. If that fails, fall back to channel-based waiting +func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error { + // Fast path: try to acquire without blocking + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Try fast acquire first + if s.TryAcquire() { + return nil + } + + // Fast path failed, need to wait + // Use timer pool to avoid allocation + timer := semTimers.Get().(*time.Timer) + defer semTimers.Put(timer) + timer.Reset(timeout) + + start := time.Now() + + for { + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + + case <-s.waitCh: + // Someone released a token, try to acquire it + if s.TryAcquire() { + if !timer.Stop() { + <-timer.C + } + return nil + } + // Failed to acquire (race with another goroutine), continue waiting + + case <-timer.C: + return timeoutErr + } + + // Periodically check if we can acquire (handles race conditions) + if time.Since(start) > timeout { + return timeoutErr + } + } +} + +// AcquireBlocking acquires a token, blocking indefinitely until one is available. +// This is useful for cases where you don't need timeout or context cancellation. +// Returns immediately if a token is available (fast path). +func (s *FastSemaphore) AcquireBlocking() { + // Try fast path first + if s.TryAcquire() { + return + } + + // Slow path: wait for a token + for { + <-s.waitCh + if s.TryAcquire() { + return + } + // Failed to acquire (race with another goroutine), continue waiting + } +} + +// Release releases a token back to the semaphore. +// This wakes up one waiting goroutine if any are blocked. +func (s *FastSemaphore) Release() { + s.count.Add(-1) + + // Try to wake up a waiter (non-blocking) + // If no one is waiting, this is a no-op + select { + case s.waitCh <- struct{}{}: + // Successfully notified a waiter + default: + // No waiters, that's fine + } +} + +// Len returns the current number of acquired tokens. +// Used by tests to check semaphore state. +func (s *FastSemaphore) Len() int32 { + return s.count.Load() +} diff --git a/redis_test.go b/redis_test.go index 0906d420..5cce3f25 100644 --- a/redis_test.go +++ b/redis_test.go @@ -323,6 +323,7 @@ var _ = Describe("Client", func() { cn, err = client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) + Expect(cn.UsedAt().UnixNano()).To(BeNumerically(">", createdAt.UnixNano())) Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) }) From 9448059c01adc988a840b4e0de80b42c620a46f6 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 28 Oct 2025 00:39:13 +0200 Subject: [PATCH 02/20] initConn sets IDLE state - Handle unexpected conn state changes --- internal/pool/conn.go | 2 ++ internal/pool/pool.go | 20 ++++++++++++++++---- maintnotifications/handoff_worker.go | 2 ++ redis.go | 6 ++++++ 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 56be7098..5ba3db41 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -603,6 +603,7 @@ func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) err // Execute initialization // NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success // or CLOSED on failure. We don't need to do it here. + // NOTE: Initconn returns conn in IDLE state initErr := cn.ExecuteInitConn(ctx) if initErr != nil { // ExecuteInitConn already transitioned to CLOSED, just return the error @@ -745,6 +746,7 @@ func (cn *Conn) ClearHandoffState() { // Mark connection as usable again // Use state machine directly instead of deprecated SetUsable + // probably done by initConn cn.stateMachine.Transition(StateIdle) } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 2dedca05..5df4962b 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -684,11 +684,22 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { // Using inline Release() method for better performance (avoids pointer dereference) transitionedToIdle := cn.Release() + // Handle unexpected state changes if !transitionedToIdle { // Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff) // Keep the state set by the hook and pool the connection anyway currentState := cn.GetStateMachine().GetState() - internal.Logger.Printf(ctx, "Connection state changed by hook to %v, pooling as-is", currentState) + switch currentState { + case StateUnusable: + // expected state, don't log it + case StateClosed: + internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, closing it", cn.GetID(), currentState) + shouldCloseConn = true + p.removeConnWithLock(cn) + default: + // Pool as-is + internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, pooling as-is", cn.GetID(), currentState) + } } // unusable conns are expected to become usable at some point (background process is reconnecting them) @@ -704,15 +715,16 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { p.idleConns = append([]*Conn{cn}, p.idleConns...) p.connsMu.Unlock() } - } else { + p.idleConnsLen.Add(1) + } else if !shouldCloseConn { p.connsMu.Lock() p.idleConns = append(p.idleConns, cn) p.connsMu.Unlock() + p.idleConnsLen.Add(1) } - p.idleConnsLen.Add(1) } else { - p.removeConnWithLock(cn) shouldCloseConn = true + p.removeConnWithLock(cn) } if freeTurn { diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 2fdeec16..5b60e39b 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -456,6 +456,8 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // - set the connection as usable again // - clear the handoff state (shouldHandoff, endpoint, seqID) // - reset the handoff retries to 0 + // Note: Theoretically there may be a short window where the connection is in the pool + // and IDLE (initConn completed) but still has handoff state set. conn.ClearHandoffState() internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) diff --git a/redis.go b/redis.go index a355531c..1f4b0224 100644 --- a/redis.go +++ b/redis.go @@ -298,6 +298,12 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return nil, err } + // initConn will transition to IDLE state, so we need to acquire it + // before returning it to the user. + if !cn.TryAcquire() { + return nil, fmt.Errorf("redis: connection is not usable") + } + return cn, nil } From d5db5340cbcc85997a861e5fb3d05e2750932b20 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 28 Oct 2025 12:34:09 +0200 Subject: [PATCH 03/20] fix precision of time cache and usedAt --- internal/pool/conn.go | 11 +- internal/pool/conn_used_at_test.go | 257 +++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+), 5 deletions(-) create mode 100644 internal/pool/conn_used_at_test.go diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 5ba3db41..0d18e274 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -131,8 +131,9 @@ func NewConn(netConn net.Conn) *Conn { } func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn { + now := time.Now() cn := &Conn{ - createdAt: time.Now(), + createdAt: now, id: generateConnID(), // Generate unique ID for this connection stateMachine: NewConnStateMachine(), } @@ -154,7 +155,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) cn.wr = proto.NewWriter(cn.bw) - cn.SetUsedAt(time.Now()) + cn.SetUsedAt(now) // Initialize handoff state atomically initialHandoffState := &HandoffState{ ShouldHandoff: false, @@ -166,12 +167,12 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con } func (cn *Conn) UsedAt() time.Time { - unix := atomic.LoadInt64(&cn.usedAt) - return time.Unix(unix, 0) + unixNano := atomic.LoadInt64(&cn.usedAt) + return time.Unix(0, unixNano) } func (cn *Conn) SetUsedAt(tm time.Time) { - atomic.StoreInt64(&cn.usedAt, tm.Unix()) + atomic.StoreInt64(&cn.usedAt, tm.UnixNano()) } // Backward-compatible wrapper methods for state machine diff --git a/internal/pool/conn_used_at_test.go b/internal/pool/conn_used_at_test.go new file mode 100644 index 00000000..74b447f2 --- /dev/null +++ b/internal/pool/conn_used_at_test.go @@ -0,0 +1,257 @@ +package pool + +import ( + "context" + "net" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/proto" +) + +// TestConn_UsedAtUpdatedOnRead verifies that usedAt is updated when reading from connection +func TestConn_UsedAtUpdatedOnRead(t *testing.T) { + // Create a mock connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + cn := NewConn(client) + defer cn.Close() + + // Get initial usedAt time + initialUsedAt := cn.UsedAt() + + // Wait at least 100ms to ensure time difference (usedAt has ~50ms precision from cached time) + time.Sleep(100 * time.Millisecond) + + // Simulate a read operation by calling WithReader + ctx := context.Background() + err := cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error { + // Don't actually read anything, just trigger the deadline update + return nil + }) + + if err != nil { + t.Fatalf("WithReader failed: %v", err) + } + + // Get updated usedAt time + updatedUsedAt := cn.UsedAt() + + // Verify that usedAt was updated + if !updatedUsedAt.After(initialUsedAt) { + t.Errorf("Expected usedAt to be updated after read. Initial: %v, Updated: %v", + initialUsedAt, updatedUsedAt) + } + + // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) + diff := updatedUsedAt.Sub(initialUsedAt) + if diff < 50*time.Millisecond || diff > 200*time.Millisecond { + t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache), got %v", diff) + } +} + +// TestConn_UsedAtUpdatedOnWrite verifies that usedAt is updated when writing to connection +func TestConn_UsedAtUpdatedOnWrite(t *testing.T) { + // Create a mock connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + cn := NewConn(client) + defer cn.Close() + + // Get initial usedAt time + initialUsedAt := cn.UsedAt() + + // Wait at least 100ms to ensure time difference (usedAt has ~50ms precision from cached time) + time.Sleep(100 * time.Millisecond) + + // Simulate a write operation by calling WithWriter + ctx := context.Background() + err := cn.WithWriter(ctx, time.Second, func(wr *proto.Writer) error { + // Don't actually write anything, just trigger the deadline update + return nil + }) + + if err != nil { + t.Fatalf("WithWriter failed: %v", err) + } + + // Get updated usedAt time + updatedUsedAt := cn.UsedAt() + + // Verify that usedAt was updated + if !updatedUsedAt.After(initialUsedAt) { + t.Errorf("Expected usedAt to be updated after write. Initial: %v, Updated: %v", + initialUsedAt, updatedUsedAt) + } + + // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) + diff := updatedUsedAt.Sub(initialUsedAt) + if diff < 50*time.Millisecond || diff > 200*time.Millisecond { + t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache), got %v", diff) + } +} + +// TestConn_UsedAtUpdatedOnMultipleOperations verifies that usedAt is updated on each operation +func TestConn_UsedAtUpdatedOnMultipleOperations(t *testing.T) { + // Create a mock connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + cn := NewConn(client) + defer cn.Close() + + ctx := context.Background() + var previousUsedAt time.Time + + // Perform multiple operations and verify usedAt is updated each time + // Note: usedAt has ~50ms precision from cached time + for i := 0; i < 5; i++ { + currentUsedAt := cn.UsedAt() + + if i > 0 { + // Verify usedAt was updated from previous iteration + if !currentUsedAt.After(previousUsedAt) { + t.Errorf("Iteration %d: Expected usedAt to be updated. Previous: %v, Current: %v", + i, previousUsedAt, currentUsedAt) + } + } + + previousUsedAt = currentUsedAt + + // Wait at least 100ms (accounting for ~50ms cache precision) + time.Sleep(100 * time.Millisecond) + + // Perform a read operation + err := cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error { + return nil + }) + if err != nil { + t.Fatalf("Iteration %d: WithReader failed: %v", i, err) + } + } + + // Verify final usedAt is significantly later than initial + finalUsedAt := cn.UsedAt() + if !finalUsedAt.After(previousUsedAt) { + t.Errorf("Expected final usedAt to be updated. Previous: %v, Final: %v", + previousUsedAt, finalUsedAt) + } +} + +// TestConn_UsedAtNotUpdatedWithoutOperation verifies that usedAt is NOT updated without operations +func TestConn_UsedAtNotUpdatedWithoutOperation(t *testing.T) { + // Create a mock connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + cn := NewConn(client) + defer cn.Close() + + // Get initial usedAt time + initialUsedAt := cn.UsedAt() + + // Wait without performing any operations + time.Sleep(100 * time.Millisecond) + + // Get usedAt time again + currentUsedAt := cn.UsedAt() + + // Verify that usedAt was NOT updated (should be the same) + if !currentUsedAt.Equal(initialUsedAt) { + t.Errorf("Expected usedAt to remain unchanged without operations. Initial: %v, Current: %v", + initialUsedAt, currentUsedAt) + } +} + +// TestConn_UsedAtConcurrentUpdates verifies that usedAt updates are thread-safe +func TestConn_UsedAtConcurrentUpdates(t *testing.T) { + // Create a mock connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + cn := NewConn(client) + defer cn.Close() + + ctx := context.Background() + const numGoroutines = 10 + const numIterations = 10 + + // Launch multiple goroutines that perform operations concurrently + done := make(chan bool, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + for j := 0; j < numIterations; j++ { + // Alternate between read and write operations + if j%2 == 0 { + _ = cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error { + return nil + }) + } else { + _ = cn.WithWriter(ctx, time.Second, func(wr *proto.Writer) error { + return nil + }) + } + time.Sleep(time.Millisecond) + } + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < numGoroutines; i++ { + <-done + } + + // Verify that usedAt was updated (should be recent) + usedAt := cn.UsedAt() + timeSinceUsed := time.Since(usedAt) + + // Should be very recent (within last second) + if timeSinceUsed > time.Second { + t.Errorf("Expected usedAt to be recent, but it was %v ago", timeSinceUsed) + } +} + +// TestConn_UsedAtPrecision verifies that usedAt has 50ms precision (not nanosecond) +func TestConn_UsedAtPrecision(t *testing.T) { + // Create a mock connection + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + cn := NewConn(client) + defer cn.Close() + + ctx := context.Background() + + // Perform an operation + err := cn.WithReader(ctx, time.Second, func(rd *proto.Reader) error { + return nil + }) + if err != nil { + t.Fatalf("WithReader failed: %v", err) + } + + // Get usedAt time + usedAt := cn.UsedAt() + + // Verify that usedAt has nanosecond precision (from the cached time which updates every 50ms) + // The value should be reasonable (not year 1970 or something) + if usedAt.Year() < 2020 { + t.Errorf("Expected usedAt to be a recent time, got %v", usedAt) + } + + // The nanoseconds might be non-zero depending on when the cache was updated + // We just verify the time is stored with full precision (not truncated to seconds) + initialNanos := usedAt.UnixNano() + if initialNanos == 0 { + t.Error("Expected usedAt to have nanosecond precision, got 0") + } +} From dcd8f9cf7f343f270c80117195577530ac1f58ab Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 28 Oct 2025 15:43:58 +0200 Subject: [PATCH 04/20] allow e2e tests to run longer --- maintnotifications/e2e/main_test.go | 3 +++ maintnotifications/e2e/scenario_endpoint_types_test.go | 2 +- maintnotifications/e2e/scenario_push_notifications_test.go | 5 +++-- maintnotifications/e2e/scenario_stress_test.go | 2 +- maintnotifications/e2e/scenario_tls_configs_test.go | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/maintnotifications/e2e/main_test.go b/maintnotifications/e2e/main_test.go index 5b1d6c94..ba24303d 100644 --- a/maintnotifications/e2e/main_test.go +++ b/maintnotifications/e2e/main_test.go @@ -4,6 +4,7 @@ import ( "log" "os" "testing" + "time" "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/logging" @@ -12,6 +13,8 @@ import ( // Global log collector var logCollector *TestLogCollector +const defaultTestTimeout = 30 * time.Minute + // Global fault injector client var faultInjector *FaultInjectorClient diff --git a/maintnotifications/e2e/scenario_endpoint_types_test.go b/maintnotifications/e2e/scenario_endpoint_types_test.go index 57bd9439..90115ecb 100644 --- a/maintnotifications/e2e/scenario_endpoint_types_test.go +++ b/maintnotifications/e2e/scenario_endpoint_types_test.go @@ -21,7 +21,7 @@ func TestEndpointTypesPushNotifications(t *testing.T) { t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true") } - ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() var dump = true diff --git a/maintnotifications/e2e/scenario_push_notifications_test.go b/maintnotifications/e2e/scenario_push_notifications_test.go index ffe74ace..ccc648b0 100644 --- a/maintnotifications/e2e/scenario_push_notifications_test.go +++ b/maintnotifications/e2e/scenario_push_notifications_test.go @@ -19,7 +19,7 @@ func TestPushNotifications(t *testing.T) { t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true") } - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() // Setup: Create fresh database and client factory for this test @@ -395,8 +395,9 @@ func TestPushNotifications(t *testing.T) { p("Executing commands and collecting logs for analysis... This will take 30 seconds...") go commandsRunner.FireCommandsUntilStop(ctx) - time.Sleep(30 * time.Second) + time.Sleep(time.Minute) commandsRunner.Stop() + time.Sleep(time.Minute) allLogsAnalysis := logCollector.GetAnalysis() trackerAnalysis := tracker.GetAnalysis() diff --git a/maintnotifications/e2e/scenario_stress_test.go b/maintnotifications/e2e/scenario_stress_test.go index 2eea1444..ec069d60 100644 --- a/maintnotifications/e2e/scenario_stress_test.go +++ b/maintnotifications/e2e/scenario_stress_test.go @@ -19,7 +19,7 @@ func TestStressPushNotifications(t *testing.T) { t.Skip("[STRESS][SKIP] Scenario tests require E2E_SCENARIO_TESTS=true") } - ctx, cancel := context.WithTimeout(context.Background(), 35*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Minute) defer cancel() // Setup: Create fresh database and client factory for this test diff --git a/maintnotifications/e2e/scenario_tls_configs_test.go b/maintnotifications/e2e/scenario_tls_configs_test.go index 243ea3b7..673fcacc 100644 --- a/maintnotifications/e2e/scenario_tls_configs_test.go +++ b/maintnotifications/e2e/scenario_tls_configs_test.go @@ -20,7 +20,7 @@ func ТestTLSConfigurationsPushNotifications(t *testing.T) { t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true") } - ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() var dump = true From 0752aecdfba6d4a707c256df8676adeb41e6cbe3 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 28 Oct 2025 19:45:23 +0200 Subject: [PATCH 05/20] Fix broken initialization of idle connections --- example/{pubsub => maintnotifiations-pubsub}/go.mod | 0 example/{pubsub => maintnotifiations-pubsub}/go.sum | 0 example/{pubsub => maintnotifiations-pubsub}/main.go | 0 internal/pool/conn.go | 8 +++++--- maintnotifications/e2e/config_parser_test.go | 8 +++++++- 5 files changed, 12 insertions(+), 4 deletions(-) rename example/{pubsub => maintnotifiations-pubsub}/go.mod (100%) rename example/{pubsub => maintnotifiations-pubsub}/go.sum (100%) rename example/{pubsub => maintnotifiations-pubsub}/main.go (100%) diff --git a/example/pubsub/go.mod b/example/maintnotifiations-pubsub/go.mod similarity index 100% rename from example/pubsub/go.mod rename to example/maintnotifiations-pubsub/go.mod diff --git a/example/pubsub/go.sum b/example/maintnotifiations-pubsub/go.sum similarity index 100% rename from example/pubsub/go.sum rename to example/maintnotifiations-pubsub/go.sum diff --git a/example/pubsub/main.go b/example/maintnotifiations-pubsub/main.go similarity index 100% rename from example/pubsub/main.go rename to example/maintnotifiations-pubsub/main.go diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 0d18e274..7c195cd7 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -699,20 +699,22 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine { // TryAcquire attempts to acquire the connection for use. // This is an optimized inline method for the hot path (Get operation). // -// It tries to transition from IDLE -> IN_USE or CREATED -> IN_USE. +// It tries to transition from IDLE -> IN_USE or CREATED -> CREATED. // Returns true if the connection was successfully acquired, false otherwise. +// The CREATED->CREATED is done so we can keep the state correct for later +// initialization of the connection in initConn. // // Performance: This is faster than calling GetStateMachine() + TryTransitionFast() // // NOTE: We directly access cn.stateMachine.state here instead of using the state machine's // methods. This breaks encapsulation but is necessary for performance. -// The IDLE->IN_USE and CREATED->IN_USE transitions don't need +// The IDLE->IN_USE and CREATED->CREATED transitions don't need // waiter notification, and benchmarks show 1-3% improvement. If the state machine ever // needs to notify waiters on these transitions, update this to use TryTransitionFast(). func (cn *Conn) TryAcquire() bool { // The || operator short-circuits, so only 1 CAS in the common case return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) || - cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateInUse)) + cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated)) } // Release releases the connection back to the pool. diff --git a/maintnotifications/e2e/config_parser_test.go b/maintnotifications/e2e/config_parser_test.go index 9c2d5373..735f6f05 100644 --- a/maintnotifications/e2e/config_parser_test.go +++ b/maintnotifications/e2e/config_parser_test.go @@ -319,6 +319,7 @@ func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis } var client redis.UniversalClient + var opts interface{} // Determine if this is a cluster configuration if len(cf.config.Endpoints) > 1 || cf.isClusterEndpoint() { @@ -349,6 +350,7 @@ func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis } } + opts = clusterOptions client = redis.NewClusterClient(clusterOptions) } else { // Create single client @@ -379,9 +381,14 @@ func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis } } + opts = clientOptions client = redis.NewClient(clientOptions) } + if err := client.Ping(context.Background()).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to Redis: %w\nOptions: %+v", err, opts) + } + // Store the client cf.clients[key] = client @@ -832,7 +839,6 @@ func (m *TestDatabaseManager) DeleteDatabase(ctx context.Context) error { return fmt.Errorf("failed to trigger database deletion: %w", err) } - // Wait for deletion to complete status, err := m.faultInjector.WaitForAction(ctx, resp.ActionID, WithMaxWaitTime(2*time.Minute), From 54281d687c320fdc0534c1754a1ca686bf1746d4 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 28 Oct 2025 23:32:27 +0200 Subject: [PATCH 06/20] optimize push notif --- internal/pool/conn.go | 27 +++++++++++++++------- internal/pool/conn_check.go | 3 +-- internal/pool/pool.go | 14 ++++++++--- redis.go | 34 +++++++++++++++++++++++---- redis_test.go | 46 +++++++++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 17 deletions(-) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 7c195cd7..e504dfbc 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -18,9 +18,9 @@ import ( var noDeadline = time.Time{} -// Global time cache updated every 50ms by background goroutine. +// Global time cache updated every 100ms by background goroutine. // This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout. -// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). +// Max staleness: 100ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). var globalTimeCache struct { nowNs atomic.Int64 } @@ -31,7 +31,7 @@ func init() { // Start background updater go func() { - ticker := time.NewTicker(50 * time.Millisecond) + ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() for range ticker.C { @@ -41,12 +41,20 @@ func init() { } // getCachedTimeNs returns the current time in nanoseconds from the global cache. -// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. -// Max staleness: 50ms. +// This is updated every 100ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 100ms. func getCachedTimeNs() int64 { return globalTimeCache.nowNs.Load() } +// GetCachedTimeNs returns the current time in nanoseconds from the global cache. +// This is updated every 100ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 100ms. +// Exported for use by other packages that need fast time access. +func GetCachedTimeNs() int64 { + return getCachedTimeNs() +} + // Global atomic counter for connection IDs var connIDCounter uint64 @@ -170,6 +178,9 @@ func (cn *Conn) UsedAt() time.Time { unixNano := atomic.LoadInt64(&cn.usedAt) return time.Unix(0, unixNano) } +func (cn *Conn) UsedAtNs() int64 { + return atomic.LoadInt64(&cn.usedAt) +} func (cn *Conn) SetUsedAt(tm time.Time) { atomic.StoreInt64(&cn.usedAt, tm.UnixNano()) @@ -488,7 +499,7 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati return time.Duration(readTimeoutNs) } - // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + // Use cached time to avoid expensive syscall (max 100ms staleness is acceptable for timeout checks) nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { @@ -522,7 +533,7 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat return time.Duration(writeTimeoutNs) } - // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + // Use cached time to avoid expensive syscall (max 100ms staleness is acceptable for timeout checks) nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { @@ -879,7 +890,7 @@ func (cn *Conn) MaybeHasData() bool { // deadline computes the effective deadline time based on context and timeout. // It updates the usedAt timestamp to now. -// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). +// Uses cached time to avoid expensive syscall (max 100ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { // Use cached time for deadline calculation (called 2x per command: read + write) tm := time.Unix(0, getCachedTimeNs()) diff --git a/internal/pool/conn_check.go b/internal/pool/conn_check.go index 9e83dd83..cfdf5e5d 100644 --- a/internal/pool/conn_check.go +++ b/internal/pool/conn_check.go @@ -30,7 +30,7 @@ func connCheck(conn net.Conn) error { var sysErr error - if err := rawConn.Read(func(fd uintptr) bool { + if err := rawConn.Control(func(fd uintptr) { var buf [1]byte // Use MSG_PEEK to peek at data without consuming it n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT) @@ -45,7 +45,6 @@ func connCheck(conn net.Conn) error { default: sysErr = err } - return true }); err != nil { return err } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 5df4962b..dcb6213d 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -155,10 +155,18 @@ type ConnPool struct { var _ Pooler = (*ConnPool)(nil) func NewConnPool(opt *Options) *ConnPool { - p := &ConnPool{ - cfg: opt, + semSize := opt.PoolSize + if opt.MaxActiveConns > 0 && opt.MaxActiveConns < opt.PoolSize { + if opt.MaxActiveConns < opt.PoolSize { + opt.MaxActiveConns = opt.PoolSize + } + semSize = opt.MaxActiveConns + } + //semSize = opt.PoolSize - semaphore: internal.NewFastSemaphore(opt.PoolSize), + p := &ConnPool{ + cfg: opt, + semaphore: internal.NewFastSemaphore(semSize), conns: make(map[uint64]*Conn), idleConns: make([]*Conn, 0, opt.PoolSize), } diff --git a/redis.go b/redis.go index 1f4b0224..8cd961e5 100644 --- a/redis.go +++ b/redis.go @@ -1351,13 +1351,39 @@ func (c *Conn) TxPipeline() Pipeliner { // processPushNotifications processes all pending push notifications on a connection // This ensures that cluster topology changes are handled immediately before the connection is used -// This method should be called by the client before using WithReader for command execution +// This method should be called by the client before using WithWriter for command execution +// +// Performance optimization: Skip the expensive MaybeHasData() syscall if a health check +// was performed recently (within 5 seconds). The health check already verified the connection +// is healthy and checked for unexpected data (push notifications). func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn) error { // Only process push notifications for RESP3 connections with a processor - // Also check if there is any data to read before processing - // Which is an optimization on UNIX systems where MaybeHasData is a syscall + if c.opt.Protocol != 3 || c.pushProcessor == nil { + return nil + } + + // Performance optimization: Skip MaybeHasData() syscall if health check was recent + // If the connection was health-checked within the last 5 seconds, we can skip the + // expensive syscall since the health check already verified no unexpected data. + // This is safe because: + // 1. Health check (connCheck) uses the same syscall (Recvfrom with MSG_PEEK) + // 2. If push notifications arrived, they would have been detected by health check + // 3. 5 seconds is short enough that connection state is still fresh + // 4. Push notifications will be processed by the next WithReader call + lastHealthCheckNs := cn.UsedAtNs() + if lastHealthCheckNs > 0 { + // Use pool's cached time to avoid expensive time.Now() syscall + nowNs := pool.GetCachedTimeNs() + if nowNs-lastHealthCheckNs < int64(5*time.Second) { + // Recent health check confirmed no unexpected data, skip the syscall + return nil + } + } + + // Check if there is any data to read before processing + // This is an optimization on UNIX systems where MaybeHasData is a syscall // On Windows, MaybeHasData always returns true, so this check is a no-op - if c.opt.Protocol != 3 || c.pushProcessor == nil || !cn.MaybeHasData() { + if !cn.MaybeHasData() { return nil } diff --git a/redis_test.go b/redis_test.go index 5cce3f25..9dd00f19 100644 --- a/redis_test.go +++ b/redis_test.go @@ -245,6 +245,52 @@ var _ = Describe("Client", func() { Expect(val).Should(HaveKeyWithValue("proto", int64(3))) }) + It("should initialize idle connections created by MinIdleConns", func() { + opt := redisOptions() + opt.MinIdleConns = 5 + opt.Password = "asdf" // Set password to require AUTH + opt.DB = 1 // Set DB to require SELECT + + db := redis.NewClient(opt) + defer func() { + Expect(db.Close()).NotTo(HaveOccurred()) + }() + + // Wait for minIdle connections to be created + time.Sleep(100 * time.Millisecond) + + // Verify that idle connections were created + stats := db.PoolStats() + Expect(stats.IdleConns).To(BeNumerically(">=", 5)) + + // Now use these connections - they should be properly initialized + // If they're not initialized, we'll get NOAUTH or WRONGDB errors + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + // Each goroutine performs multiple operations + for j := 0; j < 5; j++ { + key := fmt.Sprintf("test_key_%d_%d", id, j) + err := db.Set(ctx, key, "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err := db.Get(ctx, key).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("value")) + + err = db.Del(ctx, key).Err() + Expect(err).NotTo(HaveOccurred()) + } + }(i) + } + wg.Wait() + + // Verify no errors occurred + Expect(db.Ping(ctx).Err()).NotTo(HaveOccurred()) + }) + It("processes custom commands", func() { cmd := redis.NewCmd(ctx, "PING") _ = client.Process(ctx, cmd) From 600dfe258167df5bd3f3bad697bbf97c3f2da06c Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 15:29:33 +0200 Subject: [PATCH 07/20] 100ms -> 50ms --- internal/pool/conn.go | 29 ++++++++++++++--------------- internal/pool/conn_used_at_test.go | 6 +++--- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index e504dfbc..898a274d 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -18,9 +18,9 @@ import ( var noDeadline = time.Time{} -// Global time cache updated every 100ms by background goroutine. +// Global time cache updated every 50ms by background goroutine. // This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout. -// Max staleness: 100ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). +// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). var globalTimeCache struct { nowNs atomic.Int64 } @@ -31,7 +31,7 @@ func init() { // Start background updater go func() { - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(50 * time.Millisecond) defer ticker.Stop() for range ticker.C { @@ -41,15 +41,15 @@ func init() { } // getCachedTimeNs returns the current time in nanoseconds from the global cache. -// This is updated every 100ms by a background goroutine, avoiding expensive syscalls. -// Max staleness: 100ms. +// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 50ms. func getCachedTimeNs() int64 { return globalTimeCache.nowNs.Load() } // GetCachedTimeNs returns the current time in nanoseconds from the global cache. -// This is updated every 100ms by a background goroutine, avoiding expensive syscalls. -// Max staleness: 100ms. +// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 50ms. // Exported for use by other packages that need fast time access. func GetCachedTimeNs() int64 { return getCachedTimeNs() @@ -499,7 +499,7 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati return time.Duration(readTimeoutNs) } - // Use cached time to avoid expensive syscall (max 100ms staleness is acceptable for timeout checks) + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { @@ -533,7 +533,7 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat return time.Duration(writeTimeoutNs) } - // Use cached time to avoid expensive syscall (max 100ms staleness is acceptable for timeout checks) + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { @@ -725,7 +725,7 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine { func (cn *Conn) TryAcquire() bool { // The || operator short-circuits, so only 1 CAS in the common case return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) || - cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated)) + cn.stateMachine.state.Load() == uint32(StateCreated) } // Release releases the connection back to the pool. @@ -829,19 +829,18 @@ func (cn *Conn) WithWriter( // Use relaxed timeout if set, otherwise use provided timeout effectiveTimeout := cn.getEffectiveWriteTimeout(timeout) - // Always set write deadline, even if getNetConn() returns nil - // This prevents write operations from hanging indefinitely + // Set write deadline on the connection if netConn := cn.getNetConn(); netConn != nil { if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { return err } } else { - // If getNetConn() returns nil, we still need to respect the timeout - // Return an error to prevent indefinite blocking + // Connection is not available - return error return fmt.Errorf("redis: conn[%d] not available for write operation", cn.GetID()) } } + // Reset the buffered writer if needed, should not happen if cn.bw.Buffered() > 0 { if netConn := cn.getNetConn(); netConn != nil { cn.bw.Reset(netConn) @@ -890,7 +889,7 @@ func (cn *Conn) MaybeHasData() bool { // deadline computes the effective deadline time based on context and timeout. // It updates the usedAt timestamp to now. -// Uses cached time to avoid expensive syscall (max 100ms staleness is acceptable for deadline calculation). +// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { // Use cached time for deadline calculation (called 2x per command: read + write) tm := time.Unix(0, getCachedTimeNs()) diff --git a/internal/pool/conn_used_at_test.go b/internal/pool/conn_used_at_test.go index 74b447f2..67976f6d 100644 --- a/internal/pool/conn_used_at_test.go +++ b/internal/pool/conn_used_at_test.go @@ -22,7 +22,7 @@ func TestConn_UsedAtUpdatedOnRead(t *testing.T) { // Get initial usedAt time initialUsedAt := cn.UsedAt() - // Wait at least 100ms to ensure time difference (usedAt has ~50ms precision from cached time) + // Wait at least 50ms to ensure time difference (usedAt has ~50ms precision from cached time) time.Sleep(100 * time.Millisecond) // Simulate a read operation by calling WithReader @@ -45,10 +45,10 @@ func TestConn_UsedAtUpdatedOnRead(t *testing.T) { initialUsedAt, updatedUsedAt) } - // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) + // Verify the difference is reasonable (should be around 50ms, accounting for ~50ms cache precision) diff := updatedUsedAt.Sub(initialUsedAt) if diff < 50*time.Millisecond || diff > 200*time.Millisecond { - t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache), got %v", diff) + t.Errorf("Expected usedAt difference to be around 50ms (±50ms for cache), got %v", diff) } } From dccf01f396c9e512bcd5f938aade8ea8ecb7407b Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:00:14 +0200 Subject: [PATCH 08/20] use correct timer for last health check --- internal/pool/conn.go | 24 +++++++++++++++++------- internal/pool/pool.go | 16 +++++++++------- redis.go | 4 +++- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 898a274d..ad846651 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -81,7 +81,8 @@ type Conn struct { // Connection identifier for unique tracking id uint64 - usedAt int64 // atomic + usedAt atomic.Int64 + lastPutAt atomic.Int64 // Lock-free netConn access using atomic.Value // Contains *atomicNetConn wrapper, accessed atomically for better performance @@ -175,15 +176,24 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con } func (cn *Conn) UsedAt() time.Time { - unixNano := atomic.LoadInt64(&cn.usedAt) - return time.Unix(0, unixNano) + return time.Unix(0, cn.usedAt.Load()) } -func (cn *Conn) UsedAtNs() int64 { - return atomic.LoadInt64(&cn.usedAt) +func (cn *Conn) SetUsedAt(tm time.Time) { + cn.usedAt.Store(tm.UnixNano()) } -func (cn *Conn) SetUsedAt(tm time.Time) { - atomic.StoreInt64(&cn.usedAt, tm.UnixNano()) +func (cn *Conn) UsedAtNs() int64 { + return cn.usedAt.Load() +} +func (cn *Conn) SetUsedAtNs(ns int64) { + cn.usedAt.Store(ns) +} + +func (cn *Conn) LastPutAtNs() int64 { + return cn.lastPutAt.Load() +} +func (cn *Conn) SetLastPutAtNs(ns int64) { + cn.lastPutAt.Store(ns) } // Backward-compatible wrapper methods for state machine diff --git a/internal/pool/pool.go b/internal/pool/pool.go index dcb6213d..be847b1d 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -461,7 +461,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Use cached time for health checks (max 50ms staleness is acceptable) - now := time.Unix(0, getCachedTimeNs()) + nowNs := getCachedTimeNs() attempts := 0 // Lock-free atomic read - no mutex overhead! @@ -487,7 +487,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { break } - if !p.isHealthyConn(cn, now) { + if !p.isHealthyConn(cn, nowNs) { _ = p.CloseConn(cn) continue } @@ -742,6 +742,8 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { if shouldCloseConn { _ = p.closeConn(cn) } + + cn.SetLastPutAtNs(getCachedTimeNs()) } func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { @@ -891,14 +893,14 @@ func (p *ConnPool) Close() error { return firstErr } -func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { +func (p *ConnPool) isHealthyConn(cn *Conn, nowNs int64) bool { // Performance optimization: check conditions from cheapest to most expensive, // and from most likely to fail to least likely to fail. // Only fails if ConnMaxLifetime is set AND connection is old. // Most pools don't set ConnMaxLifetime, so this rarely fails. if p.cfg.ConnMaxLifetime > 0 { - if cn.expiresAt.Before(now) { + if cn.expiresAt.UnixNano() < nowNs { return false // Connection has exceeded max lifetime } } @@ -906,7 +908,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { // Most pools set ConnMaxIdleTime, and idle connections are common. // Checking this first allows us to fail fast without expensive syscalls. if p.cfg.ConnMaxIdleTime > 0 { - if now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { + if nowNs-cn.UsedAtNs() >= int64(p.cfg.ConnMaxIdleTime) { return false // Connection has been idle too long } } @@ -926,7 +928,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { ) // Update timestamp for healthy connection - cn.SetUsedAt(now) + cn.SetUsedAtNs(nowNs) // Connection is healthy, client will handle notifications return true @@ -939,6 +941,6 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { } // Only update UsedAt if connection is healthy (avoids unnecessary atomic store) - cn.SetUsedAt(now) + cn.SetUsedAtNs(nowNs) return true } diff --git a/redis.go b/redis.go index 8cd961e5..ac97c2ca 100644 --- a/redis.go +++ b/redis.go @@ -1366,11 +1366,13 @@ func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn // If the connection was health-checked within the last 5 seconds, we can skip the // expensive syscall since the health check already verified no unexpected data. // This is safe because: + // 0. lastHealthCheckNs is set in pool/conn.go:putConn() after a successful health check // 1. Health check (connCheck) uses the same syscall (Recvfrom with MSG_PEEK) // 2. If push notifications arrived, they would have been detected by health check // 3. 5 seconds is short enough that connection state is still fresh // 4. Push notifications will be processed by the next WithReader call - lastHealthCheckNs := cn.UsedAtNs() + // used it is set on getConn, so we should use another timer (lastPutAt?) + lastHealthCheckNs := cn.LastPutAtNs() if lastHealthCheckNs > 0 { // Use pool's cached time to avoid expensive time.Now() syscall nowNs := pool.GetCachedTimeNs() From 7201275eb523b7dadc82edc5715d06de6610d732 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:06:35 +0200 Subject: [PATCH 09/20] verify pass auth on conn creation --- redis_test.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/redis_test.go b/redis_test.go index 9dd00f19..5e062125 100644 --- a/redis_test.go +++ b/redis_test.go @@ -247,9 +247,19 @@ var _ = Describe("Client", func() { It("should initialize idle connections created by MinIdleConns", func() { opt := redisOptions() + passwrd := "asdf" + db0 := redis.NewClient(opt) + // set password + err := db0.Do(ctx, "CONFIG", "SET", "requirepass", passwrd).Err() + Expect(err).NotTo(HaveOccurred()) + defer func() { + err = db0.Do(ctx, "CONFIG", "SET", "requirepass", "").Err() + Expect(err).NotTo(HaveOccurred()) + Expect(db0.Close()).NotTo(HaveOccurred()) + }() opt.MinIdleConns = 5 - opt.Password = "asdf" // Set password to require AUTH - opt.DB = 1 // Set DB to require SELECT + opt.Password = passwrd + opt.DB = 1 // Set DB to require SELECT db := redis.NewClient(opt) defer func() { From 7f48276660e1df02cee1b8a038fcbcef2ef87ca1 Mon Sep 17 00:00:00 2001 From: pvragov Date: Wed, 29 Oct 2025 21:09:12 +0700 Subject: [PATCH 10/20] feat(otel): Add a 'error_type' metrics attribute to separate context errors (#3566) Co-authored-by: vragov_pf --- extra/redisotel/metrics.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/extra/redisotel/metrics.go b/extra/redisotel/metrics.go index 7fe55452..77aa5d14 100644 --- a/extra/redisotel/metrics.go +++ b/extra/redisotel/metrics.go @@ -2,6 +2,7 @@ package redisotel import ( "context" + "errors" "fmt" "net" "sync" @@ -271,9 +272,10 @@ func (mh *metricsHook) DialHook(hook redis.DialHook) redis.DialHook { dur := time.Since(start) - attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+1) + attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+2) attrs = append(attrs, mh.attrs...) attrs = append(attrs, statusAttr(err)) + attrs = append(attrs, errorTypeAttribute(err)) mh.createTime.Record(ctx, milliseconds(dur), metric.WithAttributeSet(attribute.NewSet(attrs...))) return conn, err @@ -288,10 +290,11 @@ func (mh *metricsHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { dur := time.Since(start) - attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+2) + attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+3) attrs = append(attrs, mh.attrs...) attrs = append(attrs, attribute.String("type", "command")) attrs = append(attrs, statusAttr(err)) + attrs = append(attrs, errorTypeAttribute(err)) mh.useTime.Record(ctx, milliseconds(dur), metric.WithAttributeSet(attribute.NewSet(attrs...))) @@ -309,10 +312,11 @@ func (mh *metricsHook) ProcessPipelineHook( dur := time.Since(start) - attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+2) + attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+3) attrs = append(attrs, mh.attrs...) attrs = append(attrs, attribute.String("type", "pipeline")) attrs = append(attrs, statusAttr(err)) + attrs = append(attrs, errorTypeAttribute(err)) mh.useTime.Record(ctx, milliseconds(dur), metric.WithAttributeSet(attribute.NewSet(attrs...))) @@ -330,3 +334,16 @@ func statusAttr(err error) attribute.KeyValue { } return attribute.String("status", "ok") } + +func errorTypeAttribute(err error) attribute.KeyValue { + switch { + case err == nil: + return attribute.String("error_type", "none") + case errors.Is(err, context.Canceled): + return attribute.String("error_type", "context_canceled") + case errors.Is(err, context.DeadlineExceeded): + return attribute.String("error_type", "context_timeout") + default: + return attribute.String("error_type", "other") + } +} From 62eecaa75ef13b1416cabd0374cbf685e2f59c7e Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:11:27 +0200 Subject: [PATCH 11/20] fix assertion --- internal/pool/conn_used_at_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/pool/conn_used_at_test.go b/internal/pool/conn_used_at_test.go index 67976f6d..8461d6d1 100644 --- a/internal/pool/conn_used_at_test.go +++ b/internal/pool/conn_used_at_test.go @@ -90,8 +90,9 @@ func TestConn_UsedAtUpdatedOnWrite(t *testing.T) { // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) diff := updatedUsedAt.Sub(initialUsedAt) - if diff < 50*time.Millisecond || diff > 200*time.Millisecond { - t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache), got %v", diff) + + if diff > 100*time.Millisecond { + t.Errorf("Expected usedAt difference to be no more than 100ms (±50ms for cache), got %v", diff) } } From 43eeae70abe6b10be6414a4de825fbeba78ee790 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:19:04 +0200 Subject: [PATCH 12/20] fix unsafe test --- internal/pool/buffer_size_test.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/internal/pool/buffer_size_test.go b/internal/pool/buffer_size_test.go index bffe495c..525e96db 100644 --- a/internal/pool/buffer_size_test.go +++ b/internal/pool/buffer_size_test.go @@ -3,6 +3,7 @@ package pool_test import ( "bufio" "context" + "sync/atomic" "unsafe" . "github.com/bsm/ginkgo/v2" @@ -129,9 +130,10 @@ var _ = Describe("Buffer Size Configuration", func() { // cause runtime panics or incorrect memory access due to invalid pointer dereferencing. func getWriterBufSizeUnsafe(cn *pool.Conn) int { cnPtr := (*struct { - id uint64 // First field in pool.Conn - usedAt int64 // Second field (atomic) - netConnAtomic interface{} // atomic.Value (interface{} has same size) + id uint64 // First field in pool.Conn + usedAt atomic.Int64 // Second field (atomic) + lastPutAt atomic.Int64 // Third field (atomic) + netConnAtomic interface{} // atomic.Value (interface{} has same size) rd *proto.Reader bw *bufio.Writer wr *proto.Writer @@ -155,9 +157,10 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int { func getReaderBufSizeUnsafe(cn *pool.Conn) int { cnPtr := (*struct { - id uint64 // First field in pool.Conn - usedAt int64 // Second field (atomic) - netConnAtomic interface{} // atomic.Value (interface{} has same size) + id uint64 // First field in pool.Conn + usedAt atomic.Int64 // Second field (atomic) + lastPutAt atomic.Int64 // Third field (atomic) + netConnAtomic interface{} // atomic.Value (interface{} has same size) rd *proto.Reader bw *bufio.Writer wr *proto.Writer From 2965e3d35c57a8e645101c6c27a0418d98c2c531 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:21:21 +0200 Subject: [PATCH 13/20] fix benchmark test --- internal/pool/bench_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index fc37b821..9b1fc57d 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -83,14 +83,14 @@ func BenchmarkPoolGetRemove(b *testing.B) { }) b.ResetTimer() - + rmvErr := errors.New("Bench test remove") b.RunParallel(func(pb *testing.PB) { for pb.Next() { cn, err := connPool.Get(ctx) if err != nil { b.Fatal(err) } - connPool.Remove(ctx, cn, errors.New("Bench test remove")) + connPool.Remove(ctx, cn, rmvErr) } }) }) From 59da35ba2d357056180c115a7b39880fb9a75638 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:23:21 +0200 Subject: [PATCH 14/20] improve remove conn --- internal/pool/pool.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index be847b1d..25a7de33 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -800,8 +800,7 @@ func (p *ConnPool) removeConn(cn *Conn) { p.poolSize.Add(-1) // this can be idle conn for idx, ic := range p.idleConns { - if ic.GetID() == cid { - internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) + if ic == cn { p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) p.idleConnsLen.Add(-1) break From 09a2f07ac3846995ba3cb54d9a229994eacea439 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 29 Oct 2025 16:34:01 +0200 Subject: [PATCH 15/20] re doesn't support requirepass --- redis_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_test.go b/redis_test.go index 5e062125..bc0db6ad 100644 --- a/redis_test.go +++ b/redis_test.go @@ -245,7 +245,7 @@ var _ = Describe("Client", func() { Expect(val).Should(HaveKeyWithValue("proto", int64(3))) }) - It("should initialize idle connections created by MinIdleConns", func() { + It("should initialize idle connections created by MinIdleConns", Label("NonRedisEnterprise"), func() { opt := redisOptions() passwrd := "asdf" db0 := redis.NewClient(opt) From fc2da240f8ebaed3a59df49ea7a75461babf4b6e Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 30 Oct 2025 16:35:44 +0200 Subject: [PATCH 16/20] wait more in e2e test --- maintnotifications/e2e/command_runner_test.go | 5 +++++ .../e2e/scenario_push_notifications_test.go | 18 ++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/maintnotifications/e2e/command_runner_test.go b/maintnotifications/e2e/command_runner_test.go index b80a434b..27c19c3a 100644 --- a/maintnotifications/e2e/command_runner_test.go +++ b/maintnotifications/e2e/command_runner_test.go @@ -20,6 +20,7 @@ type CommandRunnerStats struct { // CommandRunner provides utilities for running commands during tests type CommandRunner struct { + executing atomic.Bool client redis.UniversalClient stopCh chan struct{} operationCount atomic.Int64 @@ -56,6 +57,10 @@ func (cr *CommandRunner) Close() { // FireCommandsUntilStop runs commands continuously until stop signal func (cr *CommandRunner) FireCommandsUntilStop(ctx context.Context) { + if !cr.executing.CompareAndSwap(false, true) { + return + } + defer cr.executing.Store(false) fmt.Printf("[CR] Starting command runner...\n") defer fmt.Printf("[CR] Command runner stopped\n") // High frequency for timeout testing diff --git a/maintnotifications/e2e/scenario_push_notifications_test.go b/maintnotifications/e2e/scenario_push_notifications_test.go index ccc648b0..28677fa5 100644 --- a/maintnotifications/e2e/scenario_push_notifications_test.go +++ b/maintnotifications/e2e/scenario_push_notifications_test.go @@ -297,12 +297,6 @@ func TestPushNotifications(t *testing.T) { // once moving is received, start a second client commands runner p("Starting commands on second client") go commandsRunner2.FireCommandsUntilStop(ctx) - defer func() { - // stop the second runner - commandsRunner2.Stop() - // destroy the second client - factory.Destroy("push-notification-client-2") - }() p("Waiting for MOVING notification on second client") matchNotif, fnd := tracker2.FindOrWaitForNotification("MOVING", 3*time.Minute) @@ -393,11 +387,15 @@ func TestPushNotifications(t *testing.T) { p("MOVING notification test completed successfully") - p("Executing commands and collecting logs for analysis... This will take 30 seconds...") + p("Executing commands and collecting logs for analysis... ") go commandsRunner.FireCommandsUntilStop(ctx) - time.Sleep(time.Minute) + go commandsRunner2.FireCommandsUntilStop(ctx) + go commandsRunner3.FireCommandsUntilStop(ctx) + time.Sleep(30 * time.Second) commandsRunner.Stop() - time.Sleep(time.Minute) + commandsRunner2.Stop() + commandsRunner3.Stop() + time.Sleep(5 * time.Minute) allLogsAnalysis := logCollector.GetAnalysis() trackerAnalysis := tracker.GetAnalysis() @@ -473,7 +471,7 @@ func TestPushNotifications(t *testing.T) { // unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving { - e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount) + e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving) } if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount { e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount) From d207749af5bbaea41dc9309effb09765f7b14455 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 30 Oct 2025 18:48:33 +0200 Subject: [PATCH 17/20] flaky test --- internal/pool/conn.go | 5 +++-- internal/pool/conn_used_at_test.go | 5 +++-- maintnotifications/e2e/scenario_push_notifications_test.go | 4 ++-- maintnotifications/pool_hook_test.go | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index ad846651..71e71a8d 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -902,8 +902,9 @@ func (cn *Conn) MaybeHasData() bool { // Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { // Use cached time for deadline calculation (called 2x per command: read + write) - tm := time.Unix(0, getCachedTimeNs()) - cn.SetUsedAt(tm) + nowNs := getCachedTimeNs() + cn.SetUsedAtNs(nowNs) + tm := time.Unix(0, nowNs) if timeout > 0 { tm = tm.Add(timeout) diff --git a/internal/pool/conn_used_at_test.go b/internal/pool/conn_used_at_test.go index 8461d6d1..d6dd27a0 100644 --- a/internal/pool/conn_used_at_test.go +++ b/internal/pool/conn_used_at_test.go @@ -91,8 +91,9 @@ func TestConn_UsedAtUpdatedOnWrite(t *testing.T) { // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) diff := updatedUsedAt.Sub(initialUsedAt) - if diff > 100*time.Millisecond { - t.Errorf("Expected usedAt difference to be no more than 100ms (±50ms for cache), got %v", diff) + // 50 ms is the cache precision, so we allow up to 110ms difference + if diff < 45*time.Millisecond || diff > 155*time.Millisecond { + t.Errorf("Expected usedAt difference to be around 100 (±50ms for cache) (+-5ms for sleep precision), got %v", diff) } } diff --git a/maintnotifications/e2e/scenario_push_notifications_test.go b/maintnotifications/e2e/scenario_push_notifications_test.go index 28677fa5..99139860 100644 --- a/maintnotifications/e2e/scenario_push_notifications_test.go +++ b/maintnotifications/e2e/scenario_push_notifications_test.go @@ -391,11 +391,11 @@ func TestPushNotifications(t *testing.T) { go commandsRunner.FireCommandsUntilStop(ctx) go commandsRunner2.FireCommandsUntilStop(ctx) go commandsRunner3.FireCommandsUntilStop(ctx) - time.Sleep(30 * time.Second) + time.Sleep(2 * time.Minute) commandsRunner.Stop() commandsRunner2.Stop() commandsRunner3.Stop() - time.Sleep(5 * time.Minute) + time.Sleep(1 * time.Minute) allLogsAnalysis := logCollector.GetAnalysis() trackerAnalysis := tracker.GetAnalysis() diff --git a/maintnotifications/pool_hook_test.go b/maintnotifications/pool_hook_test.go index 41120af2..6ec61eed 100644 --- a/maintnotifications/pool_hook_test.go +++ b/maintnotifications/pool_hook_test.go @@ -174,7 +174,7 @@ func TestConnectionHook(t *testing.T) { select { case <-initConnCalled: // Good, initialization was called - case <-time.After(1 * time.Second): + case <-time.After(5 * time.Second): t.Fatal("Timeout waiting for initialization function to be called") } From ae5434ce660f0275440ab78a6a706b9691a98a25 Mon Sep 17 00:00:00 2001 From: cyningsun Date: Fri, 31 Oct 2025 01:21:12 +0800 Subject: [PATCH 18/20] feat(pool): Improve success rate of new connections (#3518) * async create conn * update default values and testcase * fix comments * fix data race * remove context.WithoutCancel, which is a function introduced in Go 1.21 * fix TestDialerRetryConfiguration/DefaultDialerRetries, because tryDial are likely done in async flow * change to share failed to delivery connection to other waiting * remove chinese comment * fix: optimize WantConnQueue benchmarks to prevent memory exhaustion - Fix BenchmarkWantConnQueue_Dequeue timeout issue by limiting pre-population - Use object pooling in BenchmarkWantConnQueue_Enqueue to reduce allocations - Optimize BenchmarkWantConnQueue_EnqueueDequeue with reusable wantConn pool - Prevent GitHub Actions benchmark failures due to excessive memory usage Before: BenchmarkWantConnQueue_Dequeue ran for 11+ minutes and was killed After: All benchmarks complete in ~8 seconds with consistent performance * format * fix turn leaks --------- Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Co-authored-by: Hristo Temelski --- async_handoff_integration_test.go | 20 +- internal/pool/bench_test.go | 22 +- internal/pool/buffer_size_test.go | 36 +- internal/pool/hooks_test.go | 5 +- internal/pool/pool.go | 110 +++++- internal/pool/pool_test.go | 602 ++++++++++++++++++++++++++++-- internal/pool/want_conn.go | 93 +++++ internal/pool/want_conn_test.go | 444 ++++++++++++++++++++++ options.go | 13 +- options_test.go | 74 ++++ pool_pubsub_bench_test.go | 39 +- 11 files changed, 1361 insertions(+), 97 deletions(-) create mode 100644 internal/pool/want_conn.go create mode 100644 internal/pool/want_conn_test.go diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go index 29960df5..b8925cad 100644 --- a/async_handoff_integration_test.go +++ b/async_handoff_integration_test.go @@ -53,8 +53,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { Dialer: func(ctx context.Context) (net.Conn, error) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(5), - PoolTimeout: time.Second, + PoolSize: int32(5), + MaxConcurrentDials: 5, + PoolTimeout: time.Second, }) // Add the hook to the pool after creation @@ -153,8 +154,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(10), - PoolTimeout: time.Second, + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Second, }) defer testPool.Close() @@ -225,8 +227,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(3), - PoolTimeout: time.Second, + PoolSize: int32(3), + MaxConcurrentDials: 3, + PoolTimeout: time.Second, }) defer testPool.Close() @@ -288,8 +291,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(2), - PoolTimeout: time.Second, + PoolSize: int32(2), + MaxConcurrentDials: 2, + PoolTimeout: time.Second, }) defer testPool.Close() diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index fc37b821..5bbd549d 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -31,11 +31,12 @@ func BenchmarkPoolGetPut(b *testing.B) { for _, bm := range benchmarks { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(bm.poolSize), - PoolTimeout: time.Second, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Hour, + Dialer: dummyDialer, + PoolSize: int32(bm.poolSize), + MaxConcurrentDials: bm.poolSize, + PoolTimeout: time.Second, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, }) b.ResetTimer() @@ -75,11 +76,12 @@ func BenchmarkPoolGetRemove(b *testing.B) { for _, bm := range benchmarks { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(bm.poolSize), - PoolTimeout: time.Second, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Hour, + Dialer: dummyDialer, + PoolSize: int32(bm.poolSize), + MaxConcurrentDials: bm.poolSize, + PoolTimeout: time.Second, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, }) b.ResetTimer() diff --git a/internal/pool/buffer_size_test.go b/internal/pool/buffer_size_test.go index bffe495c..b5de38f2 100644 --- a/internal/pool/buffer_size_test.go +++ b/internal/pool/buffer_size_test.go @@ -24,9 +24,10 @@ var _ = Describe("Buffer Size Configuration", func() { It("should use default buffer sizes when not specified", func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, }) cn, err := connPool.NewConn(ctx) @@ -46,11 +47,12 @@ var _ = Describe("Buffer Size Configuration", func() { customWriteSize := 64 * 1024 // 64KB connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, - ReadBufferSize: customReadSize, - WriteBufferSize: customWriteSize, + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, + ReadBufferSize: customReadSize, + WriteBufferSize: customWriteSize, }) cn, err := connPool.NewConn(ctx) @@ -67,11 +69,12 @@ var _ = Describe("Buffer Size Configuration", func() { It("should handle zero buffer sizes by using defaults", func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, - ReadBufferSize: 0, // Should use default - WriteBufferSize: 0, // Should use default + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, + ReadBufferSize: 0, // Should use default + WriteBufferSize: 0, // Should use default }) cn, err := connPool.NewConn(ctx) @@ -103,9 +106,10 @@ var _ = Describe("Buffer Size Configuration", func() { // Test the scenario where someone creates a pool directly (like in tests) // without setting ReadBufferSize and WriteBufferSize connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, // ReadBufferSize and WriteBufferSize are not set (will be 0) }) diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go index ec1d6da3..ad1a2db3 100644 --- a/internal/pool/hooks_test.go +++ b/internal/pool/hooks_test.go @@ -191,8 +191,9 @@ func TestPoolWithHooks(t *testing.T) { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil // Mock connection }, - PoolSize: 1, - DialTimeout: time.Second, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: time.Second, } pool := NewConnPool(opt) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 4915bf62..0a6453c7 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -98,6 +98,7 @@ type Options struct { PoolFIFO bool PoolSize int32 + MaxConcurrentDials int DialTimeout time.Duration PoolTimeout time.Duration MinIdleConns int32 @@ -126,7 +127,9 @@ type ConnPool struct { dialErrorsNum uint32 // atomic lastDialError atomic.Value - queue chan struct{} + queue chan struct{} + dialsInProgress chan struct{} + dialsQueue *wantConnQueue connsMu sync.Mutex conns map[uint64]*Conn @@ -152,9 +155,11 @@ func NewConnPool(opt *Options) *ConnPool { p := &ConnPool{ cfg: opt, - queue: make(chan struct{}, opt.PoolSize), - conns: make(map[uint64]*Conn), - idleConns: make([]*Conn, 0, opt.PoolSize), + queue: make(chan struct{}, opt.PoolSize), + conns: make(map[uint64]*Conn), + dialsInProgress: make(chan struct{}, opt.MaxConcurrentDials), + dialsQueue: newWantConnQueue(), + idleConns: make([]*Conn, 0, opt.PoolSize), } // Only create MinIdleConns if explicitly requested (> 0) @@ -233,6 +238,7 @@ func (p *ConnPool) checkMinIdleConns() { return } } + } func (p *ConnPool) addIdleConn() error { @@ -491,9 +497,8 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { atomic.AddUint32(&p.stats.Misses, 1) - newcn, err := p.newConn(ctx, true) + newcn, err := p.queuedNewConn(ctx) if err != nil { - p.freeTurn() return nil, err } @@ -512,6 +517,99 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { return newcn, nil } +func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { + select { + case p.dialsInProgress <- struct{}{}: + // Got permission, proceed to create connection + case <-ctx.Done(): + p.freeTurn() + return nil, ctx.Err() + } + + dialCtx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) + + w := &wantConn{ + ctx: dialCtx, + cancelCtx: cancel, + result: make(chan wantConnResult, 1), + } + var err error + defer func() { + if err != nil { + if cn := w.cancel(); cn != nil { + p.putIdleConn(ctx, cn) + p.freeTurn() + } + } + }() + + p.dialsQueue.enqueue(w) + + go func(w *wantConn) { + var freeTurnCalled bool + defer func() { + if err := recover(); err != nil { + if !freeTurnCalled { + p.freeTurn() + } + internal.Logger.Printf(context.Background(), "queuedNewConn panic: %+v", err) + } + }() + + defer w.cancelCtx() + defer func() { <-p.dialsInProgress }() // Release connection creation permission + + dialCtx := w.getCtxForDial() + cn, cnErr := p.newConn(dialCtx, true) + delivered := w.tryDeliver(cn, cnErr) + if cnErr == nil && delivered { + return + } else if cnErr == nil && !delivered { + p.putIdleConn(dialCtx, cn) + p.freeTurn() + freeTurnCalled = true + } else { + p.freeTurn() + freeTurnCalled = true + } + }(w) + + select { + case <-ctx.Done(): + err = ctx.Err() + return nil, err + case result := <-w.result: + err = result.err + return result.cn, err + } +} + +func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) { + for { + w, ok := p.dialsQueue.dequeue() + if !ok { + break + } + if w.tryDeliver(cn, nil) { + return + } + } + + cn.SetUsable(true) + + p.connsMu.Lock() + defer p.connsMu.Unlock() + + if p.closed() { + _ = cn.Close() + return + } + + // poolSize is increased in newConn + p.idleConns = append(p.idleConns, cn) + p.idleConnsLen.Add(1) +} + func (p *ConnPool) waitTurn(ctx context.Context) error { select { case <-ctx.Done(): diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 6aa6dc09..680370a7 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -3,6 +3,7 @@ package pool_test import ( "context" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -21,11 +22,12 @@ var _ = Describe("ConnPool", func() { BeforeEach(func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(10), - PoolTimeout: time.Hour, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Millisecond, + Dialer: dummyDialer, + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Hour, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Millisecond, }) }) @@ -47,17 +49,18 @@ var _ = Describe("ConnPool", func() { <-closedChan return &net.TCPConn{}, nil }, - PoolSize: int32(10), - PoolTimeout: time.Hour, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Millisecond, - MinIdleConns: int32(minIdleConns), + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Hour, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Millisecond, + MinIdleConns: int32(minIdleConns), }) wg.Wait() Expect(connPool.Close()).NotTo(HaveOccurred()) close(closedChan) - // We wait for 1 second and believe that checkMinIdleConns has been executed. + // We wait for 1 second and believe that checkIdleConns has been executed. time.Sleep(time.Second) Expect(connPool.Stats()).To(Equal(&pool.Stats{ @@ -131,12 +134,13 @@ var _ = Describe("MinIdleConns", func() { newConnPool := func() *pool.ConnPool { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(poolSize), - MinIdleConns: int32(minIdleConns), - PoolTimeout: 100 * time.Millisecond, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: -1, + Dialer: dummyDialer, + PoolSize: int32(poolSize), + MaxConcurrentDials: poolSize, + MinIdleConns: int32(minIdleConns), + PoolTimeout: 100 * time.Millisecond, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: -1, }) Eventually(func() int { return connPool.Len() @@ -310,11 +314,12 @@ var _ = Describe("race", func() { It("does not happen on Get, Put, and Remove", func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(10), - PoolTimeout: time.Minute, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Millisecond, + Dialer: dummyDialer, + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Minute, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Millisecond, }) perform(C, func(id int) { @@ -341,10 +346,11 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: int32(1000), - MinIdleConns: int32(50), - PoolTimeout: 3 * time.Second, - DialTimeout: 1 * time.Second, + PoolSize: int32(1000), + MaxConcurrentDials: 1000, + MinIdleConns: int32(50), + PoolTimeout: 3 * time.Second, + DialTimeout: 1 * time.Second, } p := pool.NewConnPool(opt) @@ -368,8 +374,9 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { panic("test panic") }, - PoolSize: int32(100), - MinIdleConns: int32(30), + PoolSize: int32(100), + MaxConcurrentDials: 100, + MinIdleConns: int32(30), } p := pool.NewConnPool(opt) @@ -386,8 +393,9 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: int32(1), - PoolTimeout: 3 * time.Second, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 3 * time.Second, } p := pool.NewConnPool(opt) @@ -417,8 +425,9 @@ var _ = Describe("race", func() { return &net.TCPConn{}, nil }, - PoolSize: int32(1), - PoolTimeout: testPoolTimeout, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: testPoolTimeout, } p := pool.NewConnPool(opt) @@ -452,6 +461,7 @@ func TestDialerRetryConfiguration(t *testing.T) { connPool := pool.NewConnPool(&pool.Options{ Dialer: failingDialer, PoolSize: 1, + MaxConcurrentDials: 1, PoolTimeout: time.Second, DialTimeout: time.Second, DialerRetries: 3, // Custom retry count @@ -483,10 +493,11 @@ func TestDialerRetryConfiguration(t *testing.T) { } connPool := pool.NewConnPool(&pool.Options{ - Dialer: failingDialer, - PoolSize: 1, - PoolTimeout: time.Second, - DialTimeout: time.Second, + Dialer: failingDialer, + PoolSize: 1, + MaxConcurrentDials: 1, + PoolTimeout: time.Second, + DialTimeout: time.Second, // DialerRetries and DialerRetryTimeout not set - should use defaults }) defer connPool.Close() @@ -509,6 +520,525 @@ func TestDialerRetryConfiguration(t *testing.T) { }) } +var _ = Describe("queuedNewConn", func() { + ctx := context.Background() + + It("should successfully create connection when pool is exhausted", func() { + testPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 1 * time.Second, + PoolTimeout: 2 * time.Second, + }) + defer testPool.Close() + + // Fill the pool + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(conn1).NotTo(BeNil()) + + // Get second connection in another goroutine + done := make(chan struct{}) + var conn2 *pool.Conn + var err2 error + + go func() { + defer GinkgoRecover() + conn2, err2 = testPool.Get(ctx) + close(done) + }() + + // Wait a bit to let the second Get start waiting + time.Sleep(100 * time.Millisecond) + + // Release first connection to let second Get acquire Turn + testPool.Put(ctx, conn1) + + // Wait for second Get to complete + <-done + Expect(err2).NotTo(HaveOccurred()) + Expect(conn2).NotTo(BeNil()) + + // Clean up second connection + testPool.Put(ctx, conn2) + }) + + It("should handle context cancellation before acquiring dialsInProgress", func() { + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Simulate slow dialing to let first connection creation occupy dialsInProgress + time.Sleep(200 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 2, + MaxConcurrentDials: 1, // Limit to 1 so second request cannot get dialsInProgress permission + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Start first connection creation, this will occupy dialsInProgress + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn1, err := testPool.Get(ctx) + if err == nil { + defer testPool.Put(ctx, conn1) + } + close(done1) + }() + + // Wait a bit to ensure first request starts and occupies dialsInProgress + time.Sleep(50 * time.Millisecond) + + // Create a context that will be cancelled quickly + cancelCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + // Second request should timeout while waiting for dialsInProgress + _, err := testPool.Get(cancelCtx) + Expect(err).To(Equal(context.DeadlineExceeded)) + + // Wait for first request to complete + <-done1 + + // Verify all turns are released after requests complete + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), "All turns should be released after requests complete") + }) + + It("should handle context cancellation while waiting for connection result", func() { + // This test focuses on proper error handling when context is cancelled + // during queuedNewConn execution (not testing connection reuse) + + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Simulate slow dialing + time.Sleep(500 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 2 * time.Second, + PoolTimeout: 2 * time.Second, + }) + defer testPool.Close() + + // Get first connection to fill the pool + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + + // Create a context that will be cancelled during connection creation + cancelCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer cancel() + + // This request should timeout while waiting for connection creation result + // Testing the error handling path in queuedNewConn select statement + done := make(chan struct{}) + var err2 error + go func() { + defer GinkgoRecover() + _, err2 = testPool.Get(cancelCtx) + close(done) + }() + + <-done + Expect(err2).To(Equal(context.DeadlineExceeded)) + + // Verify turn state - background goroutine may still hold turn + // Note: Background connection creation will complete and release turn + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(1), "Only conn1's turn should be held") + + // Clean up - release the first connection + testPool.Put(ctx, conn1) + + // Verify all turns are released after cleanup + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), "All turns should be released after cleanup") + }) + + It("should handle dial failures gracefully", func() { + alwaysFailDialer := func(ctx context.Context) (net.Conn, error) { + return nil, fmt.Errorf("dial failed") + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: alwaysFailDialer, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // This call should fail, testing error handling branch in goroutine + _, err := testPool.Get(ctx) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("dial failed")) + + // Verify turn is released after dial failure + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), "Turn should be released after dial failure") + }) + + It("should handle connection creation success with normal delivery", func() { + // This test verifies normal case where connection creation and delivery both succeed + testPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 1 * time.Second, + PoolTimeout: 2 * time.Second, + }) + defer testPool.Close() + + // Get first connection + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + + // Get second connection in another goroutine + done := make(chan struct{}) + var conn2 *pool.Conn + var err2 error + + go func() { + defer GinkgoRecover() + conn2, err2 = testPool.Get(ctx) + close(done) + }() + + // Wait a bit to let second Get start waiting + time.Sleep(100 * time.Millisecond) + + // Release first connection + testPool.Put(ctx, conn1) + + // Wait for second Get to complete + <-done + Expect(err2).NotTo(HaveOccurred()) + Expect(conn2).NotTo(BeNil()) + + // Clean up second connection + testPool.Put(ctx, conn2) + }) + + It("should handle MaxConcurrentDials limit", func() { + testPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 3, + MaxConcurrentDials: 1, // Only allow 1 concurrent dial + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Get all connections to fill the pool + var conns []*pool.Conn + for i := 0; i < 3; i++ { + conn, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + conns = append(conns, conn) + } + + // Now pool is full, next request needs to create new connection + // But due to MaxConcurrentDials=1, only one concurrent dial is allowed + done := make(chan struct{}) + var err4 error + go func() { + defer GinkgoRecover() + _, err4 = testPool.Get(ctx) + close(done) + }() + + // Release one connection to let the request complete + time.Sleep(100 * time.Millisecond) + testPool.Put(ctx, conns[0]) + + <-done + Expect(err4).NotTo(HaveOccurred()) + + // Clean up remaining connections + for i := 1; i < len(conns); i++ { + testPool.Put(ctx, conns[i]) + } + }) + + It("should reuse connections created in background after request timeout", func() { + // This test focuses on connection reuse mechanism: + // When a request times out but background connection creation succeeds, + // the created connection should be added to pool for future reuse + + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Simulate delay for connection creation + time.Sleep(100 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 1 * time.Second, + PoolTimeout: 150 * time.Millisecond, // Short timeout for waiting Turn + }) + defer testPool.Close() + + // Fill the pool with one connection + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + // Don't put it back yet, so pool is full + + // Start a goroutine that will create a new connection but take time + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + // This will trigger queuedNewConn since pool is full + conn, err := testPool.Get(ctx) + if err == nil { + // Put connection back to pool after creation + time.Sleep(50 * time.Millisecond) + testPool.Put(ctx, conn) + } + }() + + // Wait a bit to let the goroutine start and begin connection creation + time.Sleep(50 * time.Millisecond) + + // Now make a request that should timeout waiting for Turn + start := time.Now() + _, err = testPool.Get(ctx) + duration := time.Since(start) + + Expect(err).To(Equal(pool.ErrPoolTimeout)) + // Should timeout around PoolTimeout + Expect(duration).To(BeNumerically("~", 150*time.Millisecond, 50*time.Millisecond)) + + // Release the first connection to allow the background creation to complete + testPool.Put(ctx, conn1) + + // Wait for background connection creation to complete + <-done1 + time.Sleep(100 * time.Millisecond) + + // CORE TEST: Verify connection reuse mechanism + // The connection created in background should now be available in pool + start = time.Now() + conn3, err := testPool.Get(ctx) + duration = time.Since(start) + + Expect(err).NotTo(HaveOccurred()) + Expect(conn3).NotTo(BeNil()) + // Should be fast since connection is from pool (not newly created) + Expect(duration).To(BeNumerically("<", 50*time.Millisecond)) + + testPool.Put(ctx, conn3) + }) + + It("recover queuedNewConn panic", func() { + opt := &pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + panic("test panic in queuedNewConn") + }, + PoolSize: int32(10), + MaxConcurrentDials: 10, + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + } + testPool := pool.NewConnPool(opt) + defer testPool.Close() + + // Trigger queuedNewConn - calling Get() on empty pool will trigger it + // Since dialer will panic, it should be handled by recover + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Try to get connections multiple times, each will trigger panic but should be properly recovered + for i := 0; i < 3; i++ { + conn, err := testPool.Get(ctx) + // Connection should be nil, error should exist (panic converted to error) + Expect(conn).To(BeNil()) + Expect(err).To(HaveOccurred()) + } + + // Verify state after panic recovery: + // - turn should be properly released (QueueLen() == 0) + // - connection counts should be correct (TotalConns == 0, IdleConns == 0) + Eventually(func() bool { + stats := testPool.Stats() + queueLen := testPool.QueueLen() + return stats.TotalConns == 0 && stats.IdleConns == 0 && queueLen == 0 + }, "3s", "50ms").Should(BeTrue()) + }) + + It("should handle connection creation success but delivery failure (putIdleConn path)", func() { + // This test covers the most important untested branch in queuedNewConn: + // cnErr == nil && !delivered -> putIdleConn() + + // Use slow dialer to ensure request times out before connection is ready + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Delay long enough for client request to timeout first + time.Sleep(300 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 500 * time.Millisecond, // Long enough for dialer to complete + PoolTimeout: 100 * time.Millisecond, // Client requests will timeout quickly + }) + defer testPool.Close() + + // Record initial idle connection count + initialIdleConns := testPool.Stats().IdleConns + + // Make a request that will timeout + // This request will start queuedNewConn, create connection, but fail to deliver due to timeout + shortCtx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + conn, err := testPool.Get(shortCtx) + + // Request should fail due to timeout + Expect(err).To(HaveOccurred()) + Expect(conn).To(BeNil()) + + // However, background queuedNewConn should continue and complete connection creation + // Since it cannot deliver (request timed out), it should call putIdleConn to add connection to idle pool + Eventually(func() bool { + stats := testPool.Stats() + return stats.IdleConns > initialIdleConns + }, "1s", "50ms").Should(BeTrue()) + + // Verify the connection can indeed be used by subsequent requests + conn2, err2 := testPool.Get(context.Background()) + Expect(err2).NotTo(HaveOccurred()) + Expect(conn2).NotTo(BeNil()) + Expect(conn2.IsUsable()).To(BeTrue()) + + // Cleanup + testPool.Put(context.Background(), conn2) + + // Verify turn is released after putIdleConn path completes + // This is critical: ensures freeTurn() was called in the putIdleConn branch + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), + "Turn should be released after putIdleConn path completes") + }) + + It("should not leak turn when delivering connection via putIdleConn", func() { + // This test verifies that freeTurn() is called when putIdleConn successfully + // delivers a connection to another waiting request + // + // Scenario: + // 1. Request A: timeout 150ms, connection creation takes 200ms + // 2. Request B: timeout 500ms, connection creation takes 400ms + // 3. Both requests enter dialsQueue and start async connection creation + // 4. Request A times out at 150ms + // 5. Request A's connection completes at 200ms + // 6. putIdleConn delivers Request A's connection to Request B + // 7. queuedNewConn must call freeTurn() + // 8. Check: QueueLen should be 1 (only B holding turn), not 2 (A's turn leaked) + + callCount := int32(0) + + controlledDialer := func(ctx context.Context) (net.Conn, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + // Request A's connection: takes 200ms + time.Sleep(200 * time.Millisecond) + } else { + // Request B's connection: takes 400ms (longer, so A's connection is used) + time.Sleep(400 * time.Millisecond) + } + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: controlledDialer, + PoolSize: 2, // Allows both requests to get turns + MaxConcurrentDials: 2, // Allows both connections to be created simultaneously + DialTimeout: 500 * time.Millisecond, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Verify initial state + Expect(testPool.QueueLen()).To(Equal(0)) + + // Request A: Short timeout (150ms), connection takes 200ms + reqADone := make(chan error, 1) + go func() { + defer GinkgoRecover() + shortCtx, cancel := context.WithTimeout(ctx, 150*time.Millisecond) + defer cancel() + _, err := testPool.Get(shortCtx) + reqADone <- err + }() + + // Wait for Request A to acquire turn and enter dialsQueue + time.Sleep(50 * time.Millisecond) + Expect(testPool.QueueLen()).To(Equal(1), "Request A should occupy turn") + + // Request B: Long timeout (500ms), will receive Request A's connection + reqBDone := make(chan struct{}) + var reqBConn *pool.Conn + var reqBErr error + go func() { + defer GinkgoRecover() + longCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + reqBConn, reqBErr = testPool.Get(longCtx) + close(reqBDone) + }() + + // Wait for Request B to acquire turn and enter dialsQueue + time.Sleep(50 * time.Millisecond) + Expect(testPool.QueueLen()).To(Equal(2), "Both requests should occupy turns") + + // Request A times out at 150ms + reqAErr := <-reqADone + Expect(reqAErr).To(HaveOccurred(), "Request A should timeout") + + // Request A's connection completes at 200ms + // putIdleConn delivers it to Request B via tryDeliver + // queuedNewConn MUST call freeTurn() to release Request A's turn + <-reqBDone + Expect(reqBErr).NotTo(HaveOccurred(), "Request B should receive Request A's connection") + Expect(reqBConn).NotTo(BeNil()) + + // CRITICAL CHECK: Turn leak detection + // After Request B receives connection from putIdleConn: + // - Request A's turn SHOULD be released (via freeTurn) + // - Request B's turn is still held (will release on Put) + // Expected QueueLen: 1 (only Request B) + // If Bug exists (missing freeTurn): QueueLen: 2 (Request A's turn leaked) + time.Sleep(100 * time.Millisecond) // Allow time for turn release + currentQueueLen := testPool.QueueLen() + + Expect(currentQueueLen).To(Equal(1), + "QueueLen should be 1 (only Request B holding turn). "+ + "If it's 2, Request A's turn leaked due to missing freeTurn()") + + // Cleanup + testPool.Put(ctx, reqBConn) + Eventually(func() int { return testPool.QueueLen() }, "500ms").Should(Equal(0)) + }) +}) + func init() { logging.Disable() } diff --git a/internal/pool/want_conn.go b/internal/pool/want_conn.go new file mode 100644 index 00000000..6f9e4bfa --- /dev/null +++ b/internal/pool/want_conn.go @@ -0,0 +1,93 @@ +package pool + +import ( + "context" + "sync" +) + +type wantConn struct { + mu sync.Mutex // protects ctx, done and sending of the result + ctx context.Context // context for dial, cleared after delivered or canceled + cancelCtx context.CancelFunc + done bool // true after delivered or canceled + result chan wantConnResult // channel to deliver connection or error +} + +// getCtxForDial returns context for dial or nil if connection was delivered or canceled. +func (w *wantConn) getCtxForDial() context.Context { + w.mu.Lock() + defer w.mu.Unlock() + + return w.ctx +} + +func (w *wantConn) tryDeliver(cn *Conn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + if w.done { + return false + } + + w.done = true + w.ctx = nil + + w.result <- wantConnResult{cn: cn, err: err} + close(w.result) + + return true +} + +func (w *wantConn) cancel() *Conn { + w.mu.Lock() + var cn *Conn + if w.done { + select { + case result := <-w.result: + cn = result.cn + default: + } + } else { + close(w.result) + } + + w.done = true + w.ctx = nil + w.mu.Unlock() + + return cn +} + +type wantConnResult struct { + cn *Conn + err error +} + +type wantConnQueue struct { + mu sync.RWMutex + items []*wantConn +} + +func newWantConnQueue() *wantConnQueue { + return &wantConnQueue{ + items: make([]*wantConn, 0), + } +} + +func (q *wantConnQueue) enqueue(w *wantConn) { + q.mu.Lock() + defer q.mu.Unlock() + q.items = append(q.items, w) +} + +func (q *wantConnQueue) dequeue() (*wantConn, bool) { + q.mu.Lock() + defer q.mu.Unlock() + + if len(q.items) == 0 { + return nil, false + } + + item := q.items[0] + q.items = q.items[1:] + return item, true +} diff --git a/internal/pool/want_conn_test.go b/internal/pool/want_conn_test.go new file mode 100644 index 00000000..9526f70c --- /dev/null +++ b/internal/pool/want_conn_test.go @@ -0,0 +1,444 @@ +package pool + +import ( + "context" + "errors" + "sync" + "testing" + "time" +) + +func TestWantConn_getCtxForDial(t *testing.T) { + ctx := context.Background() + w := &wantConn{ + ctx: ctx, + result: make(chan wantConnResult, 1), + } + + // Test getting context when not done + gotCtx := w.getCtxForDial() + if gotCtx != ctx { + t.Errorf("getCtxForDial() = %v, want %v", gotCtx, ctx) + } + + // Test getting context when done + w.done = true + w.ctx = nil + gotCtx = w.getCtxForDial() + if gotCtx != nil { + t.Errorf("getCtxForDial() after done = %v, want nil", gotCtx) + } +} + +func TestWantConn_tryDeliver_Success(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + + // Create a mock connection + conn := &Conn{} + + // Test successful delivery + delivered := w.tryDeliver(conn, nil) + if !delivered { + t.Error("tryDeliver() = false, want true") + } + + // Check that wantConn is marked as done + if !w.done { + t.Error("wantConn.done = false, want true after delivery") + } + + // Check that context is cleared + if w.ctx != nil { + t.Error("wantConn.ctx should be nil after delivery") + } + + // Check that result is sent + select { + case result := <-w.result: + if result.cn != conn { + t.Errorf("result.cn = %v, want %v", result.cn, conn) + } + if result.err != nil { + t.Errorf("result.err = %v, want nil", result.err) + } + case <-time.After(time.Millisecond): + t.Error("Expected result to be sent to channel") + } +} + +func TestWantConn_tryDeliver_WithError(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + + testErr := errors.New("test error") + + // Test delivery with error + delivered := w.tryDeliver(nil, testErr) + if !delivered { + t.Error("tryDeliver() = false, want true") + } + + // Check result + select { + case result := <-w.result: + if result.cn != nil { + t.Errorf("result.cn = %v, want nil", result.cn) + } + if result.err != testErr { + t.Errorf("result.err = %v, want %v", result.err, testErr) + } + case <-time.After(time.Millisecond): + t.Error("Expected result to be sent to channel") + } +} + +func TestWantConn_tryDeliver_AlreadyDone(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + done: true, // Already done + result: make(chan wantConnResult, 1), + } + + // Test delivery when already done + delivered := w.tryDeliver(&Conn{}, nil) + if delivered { + t.Error("tryDeliver() = true, want false when already done") + } + + // Check that no result is sent + select { + case <-w.result: + t.Error("No result should be sent when already done") + case <-time.After(time.Millisecond): + // Expected + } +} + +func TestWantConn_cancel_NotDone(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + + // Test cancel when not done + cn := w.cancel() + + // Should return nil since no connection was not delivered + if cn != nil { + t.Errorf("cancel()= %v, want nil when no connection delivered", cn) + } + + // Check that wantConn is marked as done + if !w.done { + t.Error("wantConn.done = false, want true after cancel") + } + + // Check that context is cleared + if w.ctx != nil { + t.Error("wantConn.ctx should be nil after cancel") + } + + // Check that channel is closed + select { + case _, ok := <-w.result: + if ok { + t.Error("result channel should be closed after cancel") + } + case <-time.After(time.Millisecond): + t.Error("Expected channel to be closed") + } +} + +func TestWantConn_cancel_AlreadyDone(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + done: true, + result: make(chan wantConnResult, 1), + } + + // Put a result in the channel without connection (to avoid nil pointer issues) + testErr := errors.New("test error") + w.result <- wantConnResult{cn: nil, err: testErr} + + // Test cancel when already done + cn := w.cancel() + + // Should return nil since the result had no connection + if cn != nil { + t.Errorf("cancel()= %v, want nil when result had no connection", cn) + } + + // Check that wantConn remains done + if !w.done { + t.Error("wantConn.done = false, want true") + } + + // Check that context is cleared + if w.ctx != nil { + t.Error("wantConn.ctx should be nil after cancel") + } +} + +func TestWantConnQueue_newWantConnQueue(t *testing.T) { + q := newWantConnQueue() + if q == nil { + t.Fatal("newWantConnQueue() returned nil") + } + if q.items == nil { + t.Error("queue items should be initialized") + } + if len(q.items) != 0 { + t.Errorf("new queue length = %d, want 0", len(q.items)) + } +} + +func TestWantConnQueue_enqueue_dequeue(t *testing.T) { + q := newWantConnQueue() + + // Test dequeue from empty queue + item, ok := q.dequeue() + if ok { + t.Error("dequeue() from empty queue should return false") + } + if item != nil { + t.Error("dequeue() from empty queue should return nil") + } + + // Create test wantConn items + w1 := &wantConn{ctx: context.Background(), result: make(chan wantConnResult, 1)} + w2 := &wantConn{ctx: context.Background(), result: make(chan wantConnResult, 1)} + w3 := &wantConn{ctx: context.Background(), result: make(chan wantConnResult, 1)} + + // Test enqueue + q.enqueue(w1) + q.enqueue(w2) + q.enqueue(w3) + + // Test FIFO behavior + item, ok = q.dequeue() + if !ok { + t.Error("dequeue() should return true when queue has items") + } + if item != w1 { + t.Errorf("dequeue() = %v, want %v (FIFO order)", item, w1) + } + + item, ok = q.dequeue() + if !ok { + t.Error("dequeue() should return true when queue has items") + } + if item != w2 { + t.Errorf("dequeue() = %v, want %v (FIFO order)", item, w2) + } + + item, ok = q.dequeue() + if !ok { + t.Error("dequeue() should return true when queue has items") + } + if item != w3 { + t.Errorf("dequeue() = %v, want %v (FIFO order)", item, w3) + } + + // Test dequeue from empty queue again + item, ok = q.dequeue() + if ok { + t.Error("dequeue() from empty queue should return false") + } + if item != nil { + t.Error("dequeue() from empty queue should return nil") + } +} + +func TestWantConnQueue_ConcurrentAccess(t *testing.T) { + q := newWantConnQueue() + const numWorkers = 10 + const itemsPerWorker = 100 + + var wg sync.WaitGroup + + // Start enqueuers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + }() + } + + // Start dequeuers + dequeued := make(chan *wantConn, numWorkers*itemsPerWorker) + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + for { + if item, ok := q.dequeue(); ok { + dequeued <- item + break + } + // Small delay to avoid busy waiting + time.Sleep(time.Microsecond) + } + } + }() + } + + wg.Wait() + close(dequeued) + + // Count dequeued items + count := 0 + for range dequeued { + count++ + } + + expectedCount := numWorkers * itemsPerWorker + if count != expectedCount { + t.Errorf("dequeued %d items, want %d", count, expectedCount) + } + + // Queue should be empty + if item, ok := q.dequeue(); ok { + t.Errorf("queue should be empty but got item: %v", item) + } +} + +func TestWantConnQueue_ThreadSafety(t *testing.T) { + q := newWantConnQueue() + const numOperations = 1000 + + var wg sync.WaitGroup + errors := make(chan error, numOperations*2) + + // Concurrent enqueue operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numOperations; i++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + }() + + // Concurrent dequeue operations + wg.Add(1) + go func() { + defer wg.Done() + dequeued := 0 + for dequeued < numOperations { + if _, ok := q.dequeue(); ok { + dequeued++ + } else { + // Small delay when queue is empty + time.Sleep(time.Microsecond) + } + } + }() + + // Wait for completion + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + t.Error(err) + } + + // Final queue should be empty + if item, ok := q.dequeue(); ok { + t.Errorf("queue should be empty but got item: %v", item) + } +} + +// Benchmark tests +func BenchmarkWantConnQueue_Enqueue(b *testing.B) { + q := newWantConnQueue() + + // Pre-allocate a pool of wantConn to reuse + const poolSize = 1000 + wantConnPool := make([]*wantConn, poolSize) + for i := 0; i < poolSize; i++ { + wantConnPool[i] = &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + w := wantConnPool[i%poolSize] + q.enqueue(w) + } +} + +func BenchmarkWantConnQueue_Dequeue(b *testing.B) { + q := newWantConnQueue() + + // Use a reasonable fixed size for pre-population to avoid memory issues + const queueSize = 10000 + + // Pre-populate queue with a fixed reasonable size + for i := 0; i < queueSize; i++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + + b.ResetTimer() + + // Benchmark dequeue operations, refilling as needed + for i := 0; i < b.N; i++ { + if _, ok := q.dequeue(); !ok { + // Queue is empty, refill a batch + for j := 0; j < 1000; j++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + // Dequeue again + q.dequeue() + } + } +} + +func BenchmarkWantConnQueue_EnqueueDequeue(b *testing.B) { + q := newWantConnQueue() + + // Pre-allocate a pool of wantConn to reuse + const poolSize = 1000 + wantConnPool := make([]*wantConn, poolSize) + for i := 0; i < poolSize; i++ { + wantConnPool[i] = &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + w := wantConnPool[i%poolSize] + q.enqueue(w) + q.dequeue() + } +} diff --git a/options.go b/options.go index 79e4b6df..e0dcb5eb 100644 --- a/options.go +++ b/options.go @@ -34,7 +34,6 @@ type Limiter interface { // Options keeps the settings to set up redis connection. type Options struct { - // Network type, either tcp or unix. // // default: is tcp. @@ -176,6 +175,10 @@ type Options struct { // default: 10 * runtime.GOMAXPROCS(0) PoolSize int + // MaxConcurrentDials is the maximum number of concurrent connection creation goroutines. + // If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize. + MaxConcurrentDials int + // PoolTimeout is the amount of time client waits for connection if all connections // are busy before returning an error. // @@ -295,6 +298,11 @@ func (opt *Options) init() { if opt.PoolSize == 0 { opt.PoolSize = 10 * runtime.GOMAXPROCS(0) } + if opt.MaxConcurrentDials <= 0 { + opt.MaxConcurrentDials = opt.PoolSize + } else if opt.MaxConcurrentDials > opt.PoolSize { + opt.MaxConcurrentDials = opt.PoolSize + } if opt.ReadBufferSize == 0 { opt.ReadBufferSize = proto.DefaultBufferSize } @@ -622,6 +630,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) { o.MinIdleConns = q.int("min_idle_conns") o.MaxIdleConns = q.int("max_idle_conns") o.MaxActiveConns = q.int("max_active_conns") + o.MaxConcurrentDials = q.int("max_concurrent_dials") if q.has("conn_max_idle_time") { o.ConnMaxIdleTime = q.duration("conn_max_idle_time") } else { @@ -688,6 +697,7 @@ func newConnPool( }, PoolFIFO: opt.PoolFIFO, PoolSize: poolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, DialerRetries: opt.DialerRetries, @@ -728,6 +738,7 @@ func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr return pool.NewPubSubPool(&pool.Options{ PoolFIFO: opt.PoolFIFO, PoolSize: poolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, DialerRetries: opt.DialerRetries, diff --git a/options_test.go b/options_test.go index 8de4986b..32d75e25 100644 --- a/options_test.go +++ b/options_test.go @@ -67,6 +67,12 @@ func TestParseURL(t *testing.T) { }, { url: "redis://localhost:123/?db=2&protocol=2", // RESP Protocol o: &Options{Addr: "localhost:123", DB: 2, Protocol: 2}, + }, { + url: "redis://localhost:123/?max_concurrent_dials=5", // MaxConcurrentDials parameter + o: &Options{Addr: "localhost:123", MaxConcurrentDials: 5}, + }, { + url: "redis://localhost:123/?max_concurrent_dials=0", // MaxConcurrentDials zero value + o: &Options{Addr: "localhost:123", MaxConcurrentDials: 0}, }, { url: "unix:///tmp/redis.sock", o: &Options{Addr: "/tmp/redis.sock"}, @@ -197,6 +203,9 @@ func comprareOptions(t *testing.T, actual, expected *Options) { if actual.ConnMaxLifetime != expected.ConnMaxLifetime { t.Errorf("ConnMaxLifetime: got %v, expected %v", actual.ConnMaxLifetime, expected.ConnMaxLifetime) } + if actual.MaxConcurrentDials != expected.MaxConcurrentDials { + t.Errorf("MaxConcurrentDials: got %v, expected %v", actual.MaxConcurrentDials, expected.MaxConcurrentDials) + } } // Test ReadTimeout option initialization, including special values -1 and 0. @@ -245,3 +254,68 @@ func TestProtocolOptions(t *testing.T) { } } } + +func TestMaxConcurrentDialsOptions(t *testing.T) { + // Test cases for MaxConcurrentDials initialization logic + testCases := []struct { + name string + poolSize int + maxConcurrentDials int + expectedConcurrentDials int + }{ + // Edge cases and invalid values - negative/zero values set to PoolSize + { + name: "negative value gets set to pool size", + poolSize: 10, + maxConcurrentDials: -1, + expectedConcurrentDials: 10, // negative values are set to PoolSize + }, + // Zero value tests - MaxConcurrentDials should be set to PoolSize + { + name: "zero value with positive pool size", + poolSize: 1, + maxConcurrentDials: 0, + expectedConcurrentDials: 1, // MaxConcurrentDials = PoolSize when 0 + }, + // Explicit positive value tests + { + name: "explicit value within limit", + poolSize: 10, + maxConcurrentDials: 3, + expectedConcurrentDials: 3, // should remain unchanged when < PoolSize + }, + // Capping tests - values exceeding PoolSize should be capped + { + name: "value exceeding pool size", + poolSize: 5, + maxConcurrentDials: 10, + expectedConcurrentDials: 5, // should be capped at PoolSize + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := &Options{ + PoolSize: tc.poolSize, + MaxConcurrentDials: tc.maxConcurrentDials, + } + opts.init() + + if opts.MaxConcurrentDials != tc.expectedConcurrentDials { + t.Errorf("MaxConcurrentDials: got %v, expected %v (PoolSize=%v)", + opts.MaxConcurrentDials, tc.expectedConcurrentDials, opts.PoolSize) + } + + // Ensure MaxConcurrentDials never exceeds PoolSize (for all inputs) + if opts.MaxConcurrentDials > opts.PoolSize { + t.Errorf("MaxConcurrentDials (%v) should not exceed PoolSize (%v)", + opts.MaxConcurrentDials, opts.PoolSize) + } + + // Ensure MaxConcurrentDials is always positive (for all inputs) + if opts.MaxConcurrentDials <= 0 { + t.Errorf("MaxConcurrentDials should be positive, got %v", opts.MaxConcurrentDials) + } + }) + } +} diff --git a/pool_pubsub_bench_test.go b/pool_pubsub_bench_test.go index 0db8ec55..d7f0f185 100644 --- a/pool_pubsub_bench_test.go +++ b/pool_pubsub_bench_test.go @@ -70,12 +70,13 @@ func BenchmarkPoolGetPut(b *testing.B) { for _, poolSize := range poolSizes { b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(poolSize), - PoolTimeout: time.Second, - DialTimeout: time.Second, - ConnMaxIdleTime: time.Hour, - MinIdleConns: int32(0), // Start with no idle connections + Dialer: dummyDialer, + PoolSize: int32(poolSize), + MaxConcurrentDials: poolSize, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), // Start with no idle connections }) defer connPool.Close() @@ -112,12 +113,13 @@ func BenchmarkPoolGetPutWithMinIdle(b *testing.B) { for _, config := range configs { b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(config.poolSize), - MinIdleConns: int32(config.minIdleConns), - PoolTimeout: time.Second, - DialTimeout: time.Second, - ConnMaxIdleTime: time.Hour, + Dialer: dummyDialer, + PoolSize: int32(config.poolSize), + MaxConcurrentDials: config.poolSize, + MinIdleConns: int32(config.minIdleConns), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, }) defer connPool.Close() @@ -142,12 +144,13 @@ func BenchmarkPoolConcurrentGetPut(b *testing.B) { ctx := context.Background() connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(32), - PoolTimeout: time.Second, - DialTimeout: time.Second, - ConnMaxIdleTime: time.Hour, - MinIdleConns: int32(0), + Dialer: dummyDialer, + PoolSize: int32(32), + MaxConcurrentDials: 32, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), }) defer connPool.Close() From 5fa97c826cd840058cb52b51bce381200b3e88b5 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 30 Oct 2025 21:10:22 +0200 Subject: [PATCH 19/20] add missed method in interface --- internal/pool/pool.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index fc4c7228..d1676499 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -86,6 +86,12 @@ type Pooler interface { AddPoolHook(hook PoolHook) RemovePoolHook(hook PoolHook) + // RemoveWithoutTurn removes a connection from the pool without freeing a turn. + // This should be used when removing a connection from a context that didn't acquire + // a turn via Get() (e.g., background workers, cleanup tasks). + // For normal removal after Get(), use Remove() instead. + RemoveWithoutTurn(context.Context, *Conn, error) + Close() error } @@ -163,10 +169,10 @@ func NewConnPool(opt *Options) *ConnPool { //semSize = opt.PoolSize p := &ConnPool{ - cfg: opt, - semaphore: internal.NewFastSemaphore(semSize), + cfg: opt, + semaphore: internal.NewFastSemaphore(semSize), queue: make(chan struct{}, opt.PoolSize), - conns: make(map[uint64]*Conn), + conns: make(map[uint64]*Conn), dialsInProgress: make(chan struct{}, opt.MaxConcurrentDials), dialsQueue: newWantConnQueue(), idleConns: make([]*Conn, 0, opt.PoolSize), From d91800d640c63c42d77ea8c6985cf5532960a987 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 30 Oct 2025 21:35:16 +0200 Subject: [PATCH 20/20] fix test assertions --- .../e2e/scenario_push_notifications_test.go | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/maintnotifications/e2e/scenario_push_notifications_test.go b/maintnotifications/e2e/scenario_push_notifications_test.go index 99139860..80511494 100644 --- a/maintnotifications/e2e/scenario_push_notifications_test.go +++ b/maintnotifications/e2e/scenario_push_notifications_test.go @@ -436,33 +436,35 @@ func TestPushNotifications(t *testing.T) { e("No logs found for connection %d", connID) } } + // checks are tracker >= logs since the tracker only tracks client1 + // logs include all clients (and some of them start logging even before all hooks are setup) + // for example for idle connections if they receive a notification before the hook is setup + // the action (i.e. relaxing timeouts) will be logged, but the notification will not be tracked and maybe wont be logged // validate number of notifications in tracker matches number of notifications in logs // allow for more moving in the logs since we started a second client if trackerAnalysis.TotalNotifications > allLogsAnalysis.TotalNotifications { - e("Expected %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications) + e("Expected at least %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications) } - // and per type - // allow for more moving in the logs since we started a second client if trackerAnalysis.MovingCount > allLogsAnalysis.MovingCount { - e("Expected %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount) + e("Expected at least %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount) } - if trackerAnalysis.MigratingCount != allLogsAnalysis.MigratingCount { - e("Expected %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount) + if trackerAnalysis.MigratingCount > allLogsAnalysis.MigratingCount { + e("Expected at least %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount) } - if trackerAnalysis.MigratedCount != allLogsAnalysis.MigratedCount { - e("Expected %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount) + if trackerAnalysis.MigratedCount > allLogsAnalysis.MigratedCount { + e("Expected at least %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount) } - if trackerAnalysis.FailingOverCount != allLogsAnalysis.FailingOverCount { - e("Expected %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount) + if trackerAnalysis.FailingOverCount > allLogsAnalysis.FailingOverCount { + e("Expected at least %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount) } - if trackerAnalysis.FailedOverCount != allLogsAnalysis.FailedOverCount { - e("Expected %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount) + if trackerAnalysis.FailedOverCount > allLogsAnalysis.FailedOverCount { + e("Expected at least %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount) } if trackerAnalysis.UnexpectedNotificationCount != allLogsAnalysis.UnexpectedCount { @@ -470,11 +472,11 @@ func TestPushNotifications(t *testing.T) { } // unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it - if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving { - e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving) + if trackerAnalysis.UnrelaxedTimeoutCount > allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving { + e("Expected at least %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving) } - if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount { - e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount) + if trackerAnalysis.RelaxedTimeoutCount > allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount { + e("Expected at least %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount) } // validate all handoffs succeeded