diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go index 3bb9819a..673a6224 100644 --- a/async_handoff_integration_test.go +++ b/async_handoff_integration_test.go @@ -57,8 +57,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { Dialer: func(ctx context.Context) (net.Conn, error) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(5), - PoolTimeout: time.Second, + PoolSize: int32(5), + MaxConcurrentDials: 5, + PoolTimeout: time.Second, }) // Add the hook to the pool after creation @@ -190,8 +191,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(10), - PoolTimeout: time.Second, + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Second, }) defer testPool.Close() @@ -262,8 +264,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(3), - PoolTimeout: time.Second, + PoolSize: int32(3), + MaxConcurrentDials: 3, + PoolTimeout: time.Second, }) defer testPool.Close() @@ -333,8 +336,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(2), - PoolTimeout: time.Second, + PoolSize: int32(2), + MaxConcurrentDials: 2, + PoolTimeout: time.Second, }) defer testPool.Close() diff --git a/extra/redisotel/metrics.go b/extra/redisotel/metrics.go index 7fe55452..77aa5d14 100644 --- a/extra/redisotel/metrics.go +++ b/extra/redisotel/metrics.go @@ -2,6 +2,7 @@ package redisotel import ( "context" + "errors" "fmt" "net" "sync" @@ -271,9 +272,10 @@ func (mh *metricsHook) DialHook(hook redis.DialHook) redis.DialHook { dur := time.Since(start) - attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+1) + attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+2) attrs = append(attrs, mh.attrs...) attrs = append(attrs, statusAttr(err)) + attrs = append(attrs, errorTypeAttribute(err)) mh.createTime.Record(ctx, milliseconds(dur), metric.WithAttributeSet(attribute.NewSet(attrs...))) return conn, err @@ -288,10 +290,11 @@ func (mh *metricsHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { dur := time.Since(start) - attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+2) + attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+3) attrs = append(attrs, mh.attrs...) attrs = append(attrs, attribute.String("type", "command")) attrs = append(attrs, statusAttr(err)) + attrs = append(attrs, errorTypeAttribute(err)) mh.useTime.Record(ctx, milliseconds(dur), metric.WithAttributeSet(attribute.NewSet(attrs...))) @@ -309,10 +312,11 @@ func (mh *metricsHook) ProcessPipelineHook( dur := time.Since(start) - attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+2) + attrs := make([]attribute.KeyValue, 0, len(mh.attrs)+3) attrs = append(attrs, mh.attrs...) attrs = append(attrs, attribute.String("type", "pipeline")) attrs = append(attrs, statusAttr(err)) + attrs = append(attrs, errorTypeAttribute(err)) mh.useTime.Record(ctx, milliseconds(dur), metric.WithAttributeSet(attribute.NewSet(attrs...))) @@ -330,3 +334,16 @@ func statusAttr(err error) attribute.KeyValue { } return attribute.String("status", "ok") } + +func errorTypeAttribute(err error) attribute.KeyValue { + switch { + case err == nil: + return attribute.String("error_type", "none") + case errors.Is(err, context.Canceled): + return attribute.String("error_type", "context_canceled") + case errors.Is(err, context.DeadlineExceeded): + return attribute.String("error_type", "context_timeout") + default: + return attribute.String("error_type", "other") + } +} diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index fc37b821..b0ef72db 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -31,11 +31,12 @@ func BenchmarkPoolGetPut(b *testing.B) { for _, bm := range benchmarks { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(bm.poolSize), - PoolTimeout: time.Second, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Hour, + Dialer: dummyDialer, + PoolSize: int32(bm.poolSize), + MaxConcurrentDials: bm.poolSize, + PoolTimeout: time.Second, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, }) b.ResetTimer() @@ -75,22 +76,23 @@ func BenchmarkPoolGetRemove(b *testing.B) { for _, bm := range benchmarks { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(bm.poolSize), - PoolTimeout: time.Second, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Hour, + Dialer: dummyDialer, + PoolSize: int32(bm.poolSize), + MaxConcurrentDials: bm.poolSize, + PoolTimeout: time.Second, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Hour, }) b.ResetTimer() - + rmvErr := errors.New("Bench test remove") b.RunParallel(func(pb *testing.PB) { for pb.Next() { cn, err := connPool.Get(ctx) if err != nil { b.Fatal(err) } - connPool.Remove(ctx, cn, errors.New("Bench test remove")) + connPool.Remove(ctx, cn, rmvErr) } }) }) diff --git a/internal/pool/buffer_size_test.go b/internal/pool/buffer_size_test.go index 58db9574..278836ec 100644 --- a/internal/pool/buffer_size_test.go +++ b/internal/pool/buffer_size_test.go @@ -3,6 +3,7 @@ package pool_test import ( "bufio" "context" + "sync/atomic" "unsafe" . "github.com/bsm/ginkgo/v2" @@ -24,9 +25,10 @@ var _ = Describe("Buffer Size Configuration", func() { It("should use default buffer sizes when not specified", func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, }) cn, err := connPool.NewConn(ctx) @@ -46,11 +48,12 @@ var _ = Describe("Buffer Size Configuration", func() { customWriteSize := 64 * 1024 // 64KB connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, - ReadBufferSize: customReadSize, - WriteBufferSize: customWriteSize, + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, + ReadBufferSize: customReadSize, + WriteBufferSize: customWriteSize, }) cn, err := connPool.NewConn(ctx) @@ -67,11 +70,12 @@ var _ = Describe("Buffer Size Configuration", func() { It("should handle zero buffer sizes by using defaults", func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, - ReadBufferSize: 0, // Should use default - WriteBufferSize: 0, // Should use default + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, + ReadBufferSize: 0, // Should use default + WriteBufferSize: 0, // Should use default }) cn, err := connPool.NewConn(ctx) @@ -103,9 +107,10 @@ var _ = Describe("Buffer Size Configuration", func() { // Test the scenario where someone creates a pool directly (like in tests) // without setting ReadBufferSize and WriteBufferSize connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(1), - PoolTimeout: 1000, + Dialer: dummyDialer, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 1000, // ReadBufferSize and WriteBufferSize are not set (will be 0) }) @@ -129,9 +134,10 @@ var _ = Describe("Buffer Size Configuration", func() { // cause runtime panics or incorrect memory access due to invalid pointer dereferencing. func getWriterBufSizeUnsafe(cn *pool.Conn) int { cnPtr := (*struct { - id uint64 // First field in pool.Conn - usedAt int64 // Second field (atomic) - netConnAtomic interface{} // atomic.Value (interface{} has same size) + id uint64 // First field in pool.Conn + usedAt atomic.Int64 // Second field (atomic) + lastPutAt atomic.Int64 // Third field (atomic) + netConnAtomic interface{} // atomic.Value (interface{} has same size) rd *proto.Reader bw *bufio.Writer wr *proto.Writer @@ -155,9 +161,10 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int { func getReaderBufSizeUnsafe(cn *pool.Conn) int { cnPtr := (*struct { - id uint64 // First field in pool.Conn - usedAt int64 // Second field (atomic) - netConnAtomic interface{} // atomic.Value (interface{} has same size) + id uint64 // First field in pool.Conn + usedAt atomic.Int64 // Second field (atomic) + lastPutAt atomic.Int64 // Third field (atomic) + netConnAtomic interface{} // atomic.Value (interface{} has same size) rd *proto.Reader bw *bufio.Writer wr *proto.Writer diff --git a/internal/pool/conn.go b/internal/pool/conn.go index e504dfbc..57fbfe17 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -55,6 +55,43 @@ func GetCachedTimeNs() int64 { return getCachedTimeNs() } +// Global time cache updated every 50ms by background goroutine. +// This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout. +// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). +var globalTimeCache struct { + nowNs atomic.Int64 +} + +func init() { + // Initialize immediately + globalTimeCache.nowNs.Store(time.Now().UnixNano()) + + // Start background updater + go func() { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for range ticker.C { + globalTimeCache.nowNs.Store(time.Now().UnixNano()) + } + }() +} + +// getCachedTimeNs returns the current time in nanoseconds from the global cache. +// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 50ms. +func getCachedTimeNs() int64 { + return globalTimeCache.nowNs.Load() +} + +// GetCachedTimeNs returns the current time in nanoseconds from the global cache. +// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 50ms. +// Exported for use by other packages that need fast time access. +func GetCachedTimeNs() int64 { + return getCachedTimeNs() +} + // Global atomic counter for connection IDs var connIDCounter uint64 @@ -81,7 +118,8 @@ type Conn struct { // Connection identifier for unique tracking id uint64 - usedAt int64 // atomic + usedAt atomic.Int64 + lastPutAt atomic.Int64 // Lock-free netConn access using atomic.Value // Contains *atomicNetConn wrapper, accessed atomically for better performance @@ -175,15 +213,24 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con } func (cn *Conn) UsedAt() time.Time { - unixNano := atomic.LoadInt64(&cn.usedAt) - return time.Unix(0, unixNano) + return time.Unix(0, cn.usedAt.Load()) } -func (cn *Conn) UsedAtNs() int64 { - return atomic.LoadInt64(&cn.usedAt) +func (cn *Conn) SetUsedAt(tm time.Time) { + cn.usedAt.Store(tm.UnixNano()) } -func (cn *Conn) SetUsedAt(tm time.Time) { - atomic.StoreInt64(&cn.usedAt, tm.UnixNano()) +func (cn *Conn) UsedAtNs() int64 { + return cn.usedAt.Load() +} +func (cn *Conn) SetUsedAtNs(ns int64) { + cn.usedAt.Store(ns) +} + +func (cn *Conn) LastPutAtNs() int64 { + return cn.lastPutAt.Load() +} +func (cn *Conn) SetLastPutAtNs(ns int64) { + cn.lastPutAt.Store(ns) } // Backward-compatible wrapper methods for state machine @@ -499,7 +546,7 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati return time.Duration(readTimeoutNs) } - // Use cached time to avoid expensive syscall (max 100ms staleness is acceptable for timeout checks) + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { @@ -533,7 +580,7 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat return time.Duration(writeTimeoutNs) } - // Use cached time to avoid expensive syscall (max 100ms staleness is acceptable for timeout checks) + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { @@ -725,7 +772,7 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine { func (cn *Conn) TryAcquire() bool { // The || operator short-circuits, so only 1 CAS in the common case return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) || - cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated)) + cn.stateMachine.state.Load() == uint32(StateCreated) } // Release releases the connection back to the pool. @@ -829,19 +876,18 @@ func (cn *Conn) WithWriter( // Use relaxed timeout if set, otherwise use provided timeout effectiveTimeout := cn.getEffectiveWriteTimeout(timeout) - // Always set write deadline, even if getNetConn() returns nil - // This prevents write operations from hanging indefinitely + // Set write deadline on the connection if netConn := cn.getNetConn(); netConn != nil { if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { return err } } else { - // If getNetConn() returns nil, we still need to respect the timeout - // Return an error to prevent indefinite blocking + // Connection is not available - return error return fmt.Errorf("redis: conn[%d] not available for write operation", cn.GetID()) } } + // Reset the buffered writer if needed, should not happen if cn.bw.Buffered() > 0 { if netConn := cn.getNetConn(); netConn != nil { cn.bw.Reset(netConn) @@ -890,11 +936,12 @@ func (cn *Conn) MaybeHasData() bool { // deadline computes the effective deadline time based on context and timeout. // It updates the usedAt timestamp to now. -// Uses cached time to avoid expensive syscall (max 100ms staleness is acceptable for deadline calculation). +// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { // Use cached time for deadline calculation (called 2x per command: read + write) - tm := time.Unix(0, getCachedTimeNs()) - cn.SetUsedAt(tm) + nowNs := getCachedTimeNs() + cn.SetUsedAtNs(nowNs) + tm := time.Unix(0, nowNs) if timeout > 0 { tm = tm.Add(timeout) diff --git a/internal/pool/conn_used_at_test.go b/internal/pool/conn_used_at_test.go index 74b447f2..d6dd27a0 100644 --- a/internal/pool/conn_used_at_test.go +++ b/internal/pool/conn_used_at_test.go @@ -22,7 +22,7 @@ func TestConn_UsedAtUpdatedOnRead(t *testing.T) { // Get initial usedAt time initialUsedAt := cn.UsedAt() - // Wait at least 100ms to ensure time difference (usedAt has ~50ms precision from cached time) + // Wait at least 50ms to ensure time difference (usedAt has ~50ms precision from cached time) time.Sleep(100 * time.Millisecond) // Simulate a read operation by calling WithReader @@ -45,10 +45,10 @@ func TestConn_UsedAtUpdatedOnRead(t *testing.T) { initialUsedAt, updatedUsedAt) } - // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) + // Verify the difference is reasonable (should be around 50ms, accounting for ~50ms cache precision) diff := updatedUsedAt.Sub(initialUsedAt) if diff < 50*time.Millisecond || diff > 200*time.Millisecond { - t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache), got %v", diff) + t.Errorf("Expected usedAt difference to be around 50ms (±50ms for cache), got %v", diff) } } @@ -90,8 +90,10 @@ func TestConn_UsedAtUpdatedOnWrite(t *testing.T) { // Verify the difference is reasonable (should be around 100ms, accounting for ~50ms cache precision) diff := updatedUsedAt.Sub(initialUsedAt) - if diff < 50*time.Millisecond || diff > 200*time.Millisecond { - t.Errorf("Expected usedAt difference to be around 100ms (±50ms for cache), got %v", diff) + + // 50 ms is the cache precision, so we allow up to 110ms difference + if diff < 45*time.Millisecond || diff > 155*time.Millisecond { + t.Errorf("Expected usedAt difference to be around 100 (±50ms for cache) (+-5ms for sleep precision), got %v", diff) } } diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go index b8f504df..f4be12a3 100644 --- a/internal/pool/hooks_test.go +++ b/internal/pool/hooks_test.go @@ -191,8 +191,9 @@ func TestPoolWithHooks(t *testing.T) { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil // Mock connection }, - PoolSize: 1, - DialTimeout: time.Second, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: time.Second, } pool := NewConnPool(opt) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index dcb6213d..d1676499 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -75,12 +75,6 @@ type Pooler interface { Put(context.Context, *Conn) Remove(context.Context, *Conn, error) - // RemoveWithoutTurn removes a connection from the pool without freeing a turn. - // This should be used when removing a connection from a context that didn't acquire - // a turn via Get() (e.g., background workers, cleanup tasks). - // For normal removal after Get(), use Remove() instead. - RemoveWithoutTurn(context.Context, *Conn, error) - Len() int IdleLen() int Stats() *Stats @@ -92,6 +86,12 @@ type Pooler interface { AddPoolHook(hook PoolHook) RemovePoolHook(hook PoolHook) + // RemoveWithoutTurn removes a connection from the pool without freeing a turn. + // This should be used when removing a connection from a context that didn't acquire + // a turn via Get() (e.g., background workers, cleanup tasks). + // For normal removal after Get(), use Remove() instead. + RemoveWithoutTurn(context.Context, *Conn, error) + Close() error } @@ -102,6 +102,7 @@ type Options struct { PoolFIFO bool PoolSize int32 + MaxConcurrentDials int DialTimeout time.Duration PoolTimeout time.Duration MinIdleConns int32 @@ -130,6 +131,9 @@ type ConnPool struct { dialErrorsNum uint32 // atomic lastDialError atomic.Value + queue chan struct{} + dialsInProgress chan struct{} + dialsQueue *wantConnQueue // Fast atomic semaphore for connection limiting // Replaces the old channel-based queue for better performance semaphore *internal.FastSemaphore @@ -165,10 +169,13 @@ func NewConnPool(opt *Options) *ConnPool { //semSize = opt.PoolSize p := &ConnPool{ - cfg: opt, - semaphore: internal.NewFastSemaphore(semSize), - conns: make(map[uint64]*Conn), - idleConns: make([]*Conn, 0, opt.PoolSize), + cfg: opt, + semaphore: internal.NewFastSemaphore(semSize), + queue: make(chan struct{}, opt.PoolSize), + conns: make(map[uint64]*Conn), + dialsInProgress: make(chan struct{}, opt.MaxConcurrentDials), + dialsQueue: newWantConnQueue(), + idleConns: make([]*Conn, 0, opt.PoolSize), } // Only create MinIdleConns if explicitly requested (> 0) @@ -461,7 +468,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Use cached time for health checks (max 50ms staleness is acceptable) - now := time.Unix(0, getCachedTimeNs()) + nowNs := getCachedTimeNs() attempts := 0 // Lock-free atomic read - no mutex overhead! @@ -487,7 +494,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { break } - if !p.isHealthyConn(cn, now) { + if !p.isHealthyConn(cn, nowNs) { _ = p.CloseConn(cn) continue } @@ -517,9 +524,8 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { atomic.AddUint32(&p.stats.Misses, 1) - newcn, err := p.newConn(ctx, true) + newcn, err := p.queuedNewConn(ctx) if err != nil { - p.freeTurn() return nil, err } @@ -538,6 +544,97 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { return newcn, nil } +func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { + select { + case p.dialsInProgress <- struct{}{}: + // Got permission, proceed to create connection + case <-ctx.Done(): + p.freeTurn() + return nil, ctx.Err() + } + + dialCtx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) + + w := &wantConn{ + ctx: dialCtx, + cancelCtx: cancel, + result: make(chan wantConnResult, 1), + } + var err error + defer func() { + if err != nil { + if cn := w.cancel(); cn != nil { + p.putIdleConn(ctx, cn) + p.freeTurn() + } + } + }() + + p.dialsQueue.enqueue(w) + + go func(w *wantConn) { + var freeTurnCalled bool + defer func() { + if err := recover(); err != nil { + if !freeTurnCalled { + p.freeTurn() + } + internal.Logger.Printf(context.Background(), "queuedNewConn panic: %+v", err) + } + }() + + defer w.cancelCtx() + defer func() { <-p.dialsInProgress }() // Release connection creation permission + + dialCtx := w.getCtxForDial() + cn, cnErr := p.newConn(dialCtx, true) + delivered := w.tryDeliver(cn, cnErr) + if cnErr == nil && delivered { + return + } else if cnErr == nil && !delivered { + p.putIdleConn(dialCtx, cn) + p.freeTurn() + freeTurnCalled = true + } else { + p.freeTurn() + freeTurnCalled = true + } + }(w) + + select { + case <-ctx.Done(): + err = ctx.Err() + return nil, err + case result := <-w.result: + err = result.err + return result.cn, err + } +} + +func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) { + for { + w, ok := p.dialsQueue.dequeue() + if !ok { + break + } + if w.tryDeliver(cn, nil) { + return + } + } + + p.connsMu.Lock() + defer p.connsMu.Unlock() + + if p.closed() { + _ = cn.Close() + return + } + + // poolSize is increased in newConn + p.idleConns = append(p.idleConns, cn) + p.idleConnsLen.Add(1) +} + func (p *ConnPool) waitTurn(ctx context.Context) error { // Fast path: check context first select { @@ -742,6 +839,8 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { if shouldCloseConn { _ = p.closeConn(cn) } + + cn.SetLastPutAtNs(getCachedTimeNs()) } func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { @@ -798,8 +897,7 @@ func (p *ConnPool) removeConn(cn *Conn) { p.poolSize.Add(-1) // this can be idle conn for idx, ic := range p.idleConns { - if ic.GetID() == cid { - internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) + if ic == cn { p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) p.idleConnsLen.Add(-1) break @@ -891,14 +989,14 @@ func (p *ConnPool) Close() error { return firstErr } -func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { +func (p *ConnPool) isHealthyConn(cn *Conn, nowNs int64) bool { // Performance optimization: check conditions from cheapest to most expensive, // and from most likely to fail to least likely to fail. // Only fails if ConnMaxLifetime is set AND connection is old. // Most pools don't set ConnMaxLifetime, so this rarely fails. if p.cfg.ConnMaxLifetime > 0 { - if cn.expiresAt.Before(now) { + if cn.expiresAt.UnixNano() < nowNs { return false // Connection has exceeded max lifetime } } @@ -906,7 +1004,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { // Most pools set ConnMaxIdleTime, and idle connections are common. // Checking this first allows us to fail fast without expensive syscalls. if p.cfg.ConnMaxIdleTime > 0 { - if now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { + if nowNs-cn.UsedAtNs() >= int64(p.cfg.ConnMaxIdleTime) { return false // Connection has been idle too long } } @@ -926,7 +1024,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { ) // Update timestamp for healthy connection - cn.SetUsedAt(now) + cn.SetUsedAtNs(nowNs) // Connection is healthy, client will handle notifications return true @@ -939,6 +1037,6 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { } // Only update UsedAt if connection is healthy (avoids unnecessary atomic store) - cn.SetUsedAt(now) + cn.SetUsedAtNs(nowNs) return true } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 6aa6dc09..680370a7 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -3,6 +3,7 @@ package pool_test import ( "context" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -21,11 +22,12 @@ var _ = Describe("ConnPool", func() { BeforeEach(func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(10), - PoolTimeout: time.Hour, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Millisecond, + Dialer: dummyDialer, + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Hour, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Millisecond, }) }) @@ -47,17 +49,18 @@ var _ = Describe("ConnPool", func() { <-closedChan return &net.TCPConn{}, nil }, - PoolSize: int32(10), - PoolTimeout: time.Hour, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Millisecond, - MinIdleConns: int32(minIdleConns), + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Hour, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Millisecond, + MinIdleConns: int32(minIdleConns), }) wg.Wait() Expect(connPool.Close()).NotTo(HaveOccurred()) close(closedChan) - // We wait for 1 second and believe that checkMinIdleConns has been executed. + // We wait for 1 second and believe that checkIdleConns has been executed. time.Sleep(time.Second) Expect(connPool.Stats()).To(Equal(&pool.Stats{ @@ -131,12 +134,13 @@ var _ = Describe("MinIdleConns", func() { newConnPool := func() *pool.ConnPool { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(poolSize), - MinIdleConns: int32(minIdleConns), - PoolTimeout: 100 * time.Millisecond, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: -1, + Dialer: dummyDialer, + PoolSize: int32(poolSize), + MaxConcurrentDials: poolSize, + MinIdleConns: int32(minIdleConns), + PoolTimeout: 100 * time.Millisecond, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: -1, }) Eventually(func() int { return connPool.Len() @@ -310,11 +314,12 @@ var _ = Describe("race", func() { It("does not happen on Get, Put, and Remove", func() { connPool = pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(10), - PoolTimeout: time.Minute, - DialTimeout: 1 * time.Second, - ConnMaxIdleTime: time.Millisecond, + Dialer: dummyDialer, + PoolSize: int32(10), + MaxConcurrentDials: 10, + PoolTimeout: time.Minute, + DialTimeout: 1 * time.Second, + ConnMaxIdleTime: time.Millisecond, }) perform(C, func(id int) { @@ -341,10 +346,11 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: int32(1000), - MinIdleConns: int32(50), - PoolTimeout: 3 * time.Second, - DialTimeout: 1 * time.Second, + PoolSize: int32(1000), + MaxConcurrentDials: 1000, + MinIdleConns: int32(50), + PoolTimeout: 3 * time.Second, + DialTimeout: 1 * time.Second, } p := pool.NewConnPool(opt) @@ -368,8 +374,9 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { panic("test panic") }, - PoolSize: int32(100), - MinIdleConns: int32(30), + PoolSize: int32(100), + MaxConcurrentDials: 100, + MinIdleConns: int32(30), } p := pool.NewConnPool(opt) @@ -386,8 +393,9 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: int32(1), - PoolTimeout: 3 * time.Second, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: 3 * time.Second, } p := pool.NewConnPool(opt) @@ -417,8 +425,9 @@ var _ = Describe("race", func() { return &net.TCPConn{}, nil }, - PoolSize: int32(1), - PoolTimeout: testPoolTimeout, + PoolSize: int32(1), + MaxConcurrentDials: 1, + PoolTimeout: testPoolTimeout, } p := pool.NewConnPool(opt) @@ -452,6 +461,7 @@ func TestDialerRetryConfiguration(t *testing.T) { connPool := pool.NewConnPool(&pool.Options{ Dialer: failingDialer, PoolSize: 1, + MaxConcurrentDials: 1, PoolTimeout: time.Second, DialTimeout: time.Second, DialerRetries: 3, // Custom retry count @@ -483,10 +493,11 @@ func TestDialerRetryConfiguration(t *testing.T) { } connPool := pool.NewConnPool(&pool.Options{ - Dialer: failingDialer, - PoolSize: 1, - PoolTimeout: time.Second, - DialTimeout: time.Second, + Dialer: failingDialer, + PoolSize: 1, + MaxConcurrentDials: 1, + PoolTimeout: time.Second, + DialTimeout: time.Second, // DialerRetries and DialerRetryTimeout not set - should use defaults }) defer connPool.Close() @@ -509,6 +520,525 @@ func TestDialerRetryConfiguration(t *testing.T) { }) } +var _ = Describe("queuedNewConn", func() { + ctx := context.Background() + + It("should successfully create connection when pool is exhausted", func() { + testPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 1 * time.Second, + PoolTimeout: 2 * time.Second, + }) + defer testPool.Close() + + // Fill the pool + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(conn1).NotTo(BeNil()) + + // Get second connection in another goroutine + done := make(chan struct{}) + var conn2 *pool.Conn + var err2 error + + go func() { + defer GinkgoRecover() + conn2, err2 = testPool.Get(ctx) + close(done) + }() + + // Wait a bit to let the second Get start waiting + time.Sleep(100 * time.Millisecond) + + // Release first connection to let second Get acquire Turn + testPool.Put(ctx, conn1) + + // Wait for second Get to complete + <-done + Expect(err2).NotTo(HaveOccurred()) + Expect(conn2).NotTo(BeNil()) + + // Clean up second connection + testPool.Put(ctx, conn2) + }) + + It("should handle context cancellation before acquiring dialsInProgress", func() { + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Simulate slow dialing to let first connection creation occupy dialsInProgress + time.Sleep(200 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 2, + MaxConcurrentDials: 1, // Limit to 1 so second request cannot get dialsInProgress permission + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Start first connection creation, this will occupy dialsInProgress + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn1, err := testPool.Get(ctx) + if err == nil { + defer testPool.Put(ctx, conn1) + } + close(done1) + }() + + // Wait a bit to ensure first request starts and occupies dialsInProgress + time.Sleep(50 * time.Millisecond) + + // Create a context that will be cancelled quickly + cancelCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + // Second request should timeout while waiting for dialsInProgress + _, err := testPool.Get(cancelCtx) + Expect(err).To(Equal(context.DeadlineExceeded)) + + // Wait for first request to complete + <-done1 + + // Verify all turns are released after requests complete + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), "All turns should be released after requests complete") + }) + + It("should handle context cancellation while waiting for connection result", func() { + // This test focuses on proper error handling when context is cancelled + // during queuedNewConn execution (not testing connection reuse) + + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Simulate slow dialing + time.Sleep(500 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 2 * time.Second, + PoolTimeout: 2 * time.Second, + }) + defer testPool.Close() + + // Get first connection to fill the pool + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + + // Create a context that will be cancelled during connection creation + cancelCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer cancel() + + // This request should timeout while waiting for connection creation result + // Testing the error handling path in queuedNewConn select statement + done := make(chan struct{}) + var err2 error + go func() { + defer GinkgoRecover() + _, err2 = testPool.Get(cancelCtx) + close(done) + }() + + <-done + Expect(err2).To(Equal(context.DeadlineExceeded)) + + // Verify turn state - background goroutine may still hold turn + // Note: Background connection creation will complete and release turn + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(1), "Only conn1's turn should be held") + + // Clean up - release the first connection + testPool.Put(ctx, conn1) + + // Verify all turns are released after cleanup + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), "All turns should be released after cleanup") + }) + + It("should handle dial failures gracefully", func() { + alwaysFailDialer := func(ctx context.Context) (net.Conn, error) { + return nil, fmt.Errorf("dial failed") + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: alwaysFailDialer, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // This call should fail, testing error handling branch in goroutine + _, err := testPool.Get(ctx) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("dial failed")) + + // Verify turn is released after dial failure + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), "Turn should be released after dial failure") + }) + + It("should handle connection creation success with normal delivery", func() { + // This test verifies normal case where connection creation and delivery both succeed + testPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 1 * time.Second, + PoolTimeout: 2 * time.Second, + }) + defer testPool.Close() + + // Get first connection + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + + // Get second connection in another goroutine + done := make(chan struct{}) + var conn2 *pool.Conn + var err2 error + + go func() { + defer GinkgoRecover() + conn2, err2 = testPool.Get(ctx) + close(done) + }() + + // Wait a bit to let second Get start waiting + time.Sleep(100 * time.Millisecond) + + // Release first connection + testPool.Put(ctx, conn1) + + // Wait for second Get to complete + <-done + Expect(err2).NotTo(HaveOccurred()) + Expect(conn2).NotTo(BeNil()) + + // Clean up second connection + testPool.Put(ctx, conn2) + }) + + It("should handle MaxConcurrentDials limit", func() { + testPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: 3, + MaxConcurrentDials: 1, // Only allow 1 concurrent dial + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Get all connections to fill the pool + var conns []*pool.Conn + for i := 0; i < 3; i++ { + conn, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + conns = append(conns, conn) + } + + // Now pool is full, next request needs to create new connection + // But due to MaxConcurrentDials=1, only one concurrent dial is allowed + done := make(chan struct{}) + var err4 error + go func() { + defer GinkgoRecover() + _, err4 = testPool.Get(ctx) + close(done) + }() + + // Release one connection to let the request complete + time.Sleep(100 * time.Millisecond) + testPool.Put(ctx, conns[0]) + + <-done + Expect(err4).NotTo(HaveOccurred()) + + // Clean up remaining connections + for i := 1; i < len(conns); i++ { + testPool.Put(ctx, conns[i]) + } + }) + + It("should reuse connections created in background after request timeout", func() { + // This test focuses on connection reuse mechanism: + // When a request times out but background connection creation succeeds, + // the created connection should be added to pool for future reuse + + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Simulate delay for connection creation + time.Sleep(100 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 1, + MaxConcurrentDials: 1, + DialTimeout: 1 * time.Second, + PoolTimeout: 150 * time.Millisecond, // Short timeout for waiting Turn + }) + defer testPool.Close() + + // Fill the pool with one connection + conn1, err := testPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + // Don't put it back yet, so pool is full + + // Start a goroutine that will create a new connection but take time + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + // This will trigger queuedNewConn since pool is full + conn, err := testPool.Get(ctx) + if err == nil { + // Put connection back to pool after creation + time.Sleep(50 * time.Millisecond) + testPool.Put(ctx, conn) + } + }() + + // Wait a bit to let the goroutine start and begin connection creation + time.Sleep(50 * time.Millisecond) + + // Now make a request that should timeout waiting for Turn + start := time.Now() + _, err = testPool.Get(ctx) + duration := time.Since(start) + + Expect(err).To(Equal(pool.ErrPoolTimeout)) + // Should timeout around PoolTimeout + Expect(duration).To(BeNumerically("~", 150*time.Millisecond, 50*time.Millisecond)) + + // Release the first connection to allow the background creation to complete + testPool.Put(ctx, conn1) + + // Wait for background connection creation to complete + <-done1 + time.Sleep(100 * time.Millisecond) + + // CORE TEST: Verify connection reuse mechanism + // The connection created in background should now be available in pool + start = time.Now() + conn3, err := testPool.Get(ctx) + duration = time.Since(start) + + Expect(err).NotTo(HaveOccurred()) + Expect(conn3).NotTo(BeNil()) + // Should be fast since connection is from pool (not newly created) + Expect(duration).To(BeNumerically("<", 50*time.Millisecond)) + + testPool.Put(ctx, conn3) + }) + + It("recover queuedNewConn panic", func() { + opt := &pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + panic("test panic in queuedNewConn") + }, + PoolSize: int32(10), + MaxConcurrentDials: 10, + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + } + testPool := pool.NewConnPool(opt) + defer testPool.Close() + + // Trigger queuedNewConn - calling Get() on empty pool will trigger it + // Since dialer will panic, it should be handled by recover + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Try to get connections multiple times, each will trigger panic but should be properly recovered + for i := 0; i < 3; i++ { + conn, err := testPool.Get(ctx) + // Connection should be nil, error should exist (panic converted to error) + Expect(conn).To(BeNil()) + Expect(err).To(HaveOccurred()) + } + + // Verify state after panic recovery: + // - turn should be properly released (QueueLen() == 0) + // - connection counts should be correct (TotalConns == 0, IdleConns == 0) + Eventually(func() bool { + stats := testPool.Stats() + queueLen := testPool.QueueLen() + return stats.TotalConns == 0 && stats.IdleConns == 0 && queueLen == 0 + }, "3s", "50ms").Should(BeTrue()) + }) + + It("should handle connection creation success but delivery failure (putIdleConn path)", func() { + // This test covers the most important untested branch in queuedNewConn: + // cnErr == nil && !delivered -> putIdleConn() + + // Use slow dialer to ensure request times out before connection is ready + slowDialer := func(ctx context.Context) (net.Conn, error) { + // Delay long enough for client request to timeout first + time.Sleep(300 * time.Millisecond) + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: slowDialer, + PoolSize: 1, + MaxConcurrentDials: 2, + DialTimeout: 500 * time.Millisecond, // Long enough for dialer to complete + PoolTimeout: 100 * time.Millisecond, // Client requests will timeout quickly + }) + defer testPool.Close() + + // Record initial idle connection count + initialIdleConns := testPool.Stats().IdleConns + + // Make a request that will timeout + // This request will start queuedNewConn, create connection, but fail to deliver due to timeout + shortCtx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + conn, err := testPool.Get(shortCtx) + + // Request should fail due to timeout + Expect(err).To(HaveOccurred()) + Expect(conn).To(BeNil()) + + // However, background queuedNewConn should continue and complete connection creation + // Since it cannot deliver (request timed out), it should call putIdleConn to add connection to idle pool + Eventually(func() bool { + stats := testPool.Stats() + return stats.IdleConns > initialIdleConns + }, "1s", "50ms").Should(BeTrue()) + + // Verify the connection can indeed be used by subsequent requests + conn2, err2 := testPool.Get(context.Background()) + Expect(err2).NotTo(HaveOccurred()) + Expect(conn2).NotTo(BeNil()) + Expect(conn2.IsUsable()).To(BeTrue()) + + // Cleanup + testPool.Put(context.Background(), conn2) + + // Verify turn is released after putIdleConn path completes + // This is critical: ensures freeTurn() was called in the putIdleConn branch + Eventually(func() int { + return testPool.QueueLen() + }, "1s", "50ms").Should(Equal(0), + "Turn should be released after putIdleConn path completes") + }) + + It("should not leak turn when delivering connection via putIdleConn", func() { + // This test verifies that freeTurn() is called when putIdleConn successfully + // delivers a connection to another waiting request + // + // Scenario: + // 1. Request A: timeout 150ms, connection creation takes 200ms + // 2. Request B: timeout 500ms, connection creation takes 400ms + // 3. Both requests enter dialsQueue and start async connection creation + // 4. Request A times out at 150ms + // 5. Request A's connection completes at 200ms + // 6. putIdleConn delivers Request A's connection to Request B + // 7. queuedNewConn must call freeTurn() + // 8. Check: QueueLen should be 1 (only B holding turn), not 2 (A's turn leaked) + + callCount := int32(0) + + controlledDialer := func(ctx context.Context) (net.Conn, error) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + // Request A's connection: takes 200ms + time.Sleep(200 * time.Millisecond) + } else { + // Request B's connection: takes 400ms (longer, so A's connection is used) + time.Sleep(400 * time.Millisecond) + } + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: controlledDialer, + PoolSize: 2, // Allows both requests to get turns + MaxConcurrentDials: 2, // Allows both connections to be created simultaneously + DialTimeout: 500 * time.Millisecond, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Verify initial state + Expect(testPool.QueueLen()).To(Equal(0)) + + // Request A: Short timeout (150ms), connection takes 200ms + reqADone := make(chan error, 1) + go func() { + defer GinkgoRecover() + shortCtx, cancel := context.WithTimeout(ctx, 150*time.Millisecond) + defer cancel() + _, err := testPool.Get(shortCtx) + reqADone <- err + }() + + // Wait for Request A to acquire turn and enter dialsQueue + time.Sleep(50 * time.Millisecond) + Expect(testPool.QueueLen()).To(Equal(1), "Request A should occupy turn") + + // Request B: Long timeout (500ms), will receive Request A's connection + reqBDone := make(chan struct{}) + var reqBConn *pool.Conn + var reqBErr error + go func() { + defer GinkgoRecover() + longCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + reqBConn, reqBErr = testPool.Get(longCtx) + close(reqBDone) + }() + + // Wait for Request B to acquire turn and enter dialsQueue + time.Sleep(50 * time.Millisecond) + Expect(testPool.QueueLen()).To(Equal(2), "Both requests should occupy turns") + + // Request A times out at 150ms + reqAErr := <-reqADone + Expect(reqAErr).To(HaveOccurred(), "Request A should timeout") + + // Request A's connection completes at 200ms + // putIdleConn delivers it to Request B via tryDeliver + // queuedNewConn MUST call freeTurn() to release Request A's turn + <-reqBDone + Expect(reqBErr).NotTo(HaveOccurred(), "Request B should receive Request A's connection") + Expect(reqBConn).NotTo(BeNil()) + + // CRITICAL CHECK: Turn leak detection + // After Request B receives connection from putIdleConn: + // - Request A's turn SHOULD be released (via freeTurn) + // - Request B's turn is still held (will release on Put) + // Expected QueueLen: 1 (only Request B) + // If Bug exists (missing freeTurn): QueueLen: 2 (Request A's turn leaked) + time.Sleep(100 * time.Millisecond) // Allow time for turn release + currentQueueLen := testPool.QueueLen() + + Expect(currentQueueLen).To(Equal(1), + "QueueLen should be 1 (only Request B holding turn). "+ + "If it's 2, Request A's turn leaked due to missing freeTurn()") + + // Cleanup + testPool.Put(ctx, reqBConn) + Eventually(func() int { return testPool.QueueLen() }, "500ms").Should(Equal(0)) + }) +}) + func init() { logging.Disable() } diff --git a/internal/pool/want_conn.go b/internal/pool/want_conn.go new file mode 100644 index 00000000..6f9e4bfa --- /dev/null +++ b/internal/pool/want_conn.go @@ -0,0 +1,93 @@ +package pool + +import ( + "context" + "sync" +) + +type wantConn struct { + mu sync.Mutex // protects ctx, done and sending of the result + ctx context.Context // context for dial, cleared after delivered or canceled + cancelCtx context.CancelFunc + done bool // true after delivered or canceled + result chan wantConnResult // channel to deliver connection or error +} + +// getCtxForDial returns context for dial or nil if connection was delivered or canceled. +func (w *wantConn) getCtxForDial() context.Context { + w.mu.Lock() + defer w.mu.Unlock() + + return w.ctx +} + +func (w *wantConn) tryDeliver(cn *Conn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + if w.done { + return false + } + + w.done = true + w.ctx = nil + + w.result <- wantConnResult{cn: cn, err: err} + close(w.result) + + return true +} + +func (w *wantConn) cancel() *Conn { + w.mu.Lock() + var cn *Conn + if w.done { + select { + case result := <-w.result: + cn = result.cn + default: + } + } else { + close(w.result) + } + + w.done = true + w.ctx = nil + w.mu.Unlock() + + return cn +} + +type wantConnResult struct { + cn *Conn + err error +} + +type wantConnQueue struct { + mu sync.RWMutex + items []*wantConn +} + +func newWantConnQueue() *wantConnQueue { + return &wantConnQueue{ + items: make([]*wantConn, 0), + } +} + +func (q *wantConnQueue) enqueue(w *wantConn) { + q.mu.Lock() + defer q.mu.Unlock() + q.items = append(q.items, w) +} + +func (q *wantConnQueue) dequeue() (*wantConn, bool) { + q.mu.Lock() + defer q.mu.Unlock() + + if len(q.items) == 0 { + return nil, false + } + + item := q.items[0] + q.items = q.items[1:] + return item, true +} diff --git a/internal/pool/want_conn_test.go b/internal/pool/want_conn_test.go new file mode 100644 index 00000000..9526f70c --- /dev/null +++ b/internal/pool/want_conn_test.go @@ -0,0 +1,444 @@ +package pool + +import ( + "context" + "errors" + "sync" + "testing" + "time" +) + +func TestWantConn_getCtxForDial(t *testing.T) { + ctx := context.Background() + w := &wantConn{ + ctx: ctx, + result: make(chan wantConnResult, 1), + } + + // Test getting context when not done + gotCtx := w.getCtxForDial() + if gotCtx != ctx { + t.Errorf("getCtxForDial() = %v, want %v", gotCtx, ctx) + } + + // Test getting context when done + w.done = true + w.ctx = nil + gotCtx = w.getCtxForDial() + if gotCtx != nil { + t.Errorf("getCtxForDial() after done = %v, want nil", gotCtx) + } +} + +func TestWantConn_tryDeliver_Success(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + + // Create a mock connection + conn := &Conn{} + + // Test successful delivery + delivered := w.tryDeliver(conn, nil) + if !delivered { + t.Error("tryDeliver() = false, want true") + } + + // Check that wantConn is marked as done + if !w.done { + t.Error("wantConn.done = false, want true after delivery") + } + + // Check that context is cleared + if w.ctx != nil { + t.Error("wantConn.ctx should be nil after delivery") + } + + // Check that result is sent + select { + case result := <-w.result: + if result.cn != conn { + t.Errorf("result.cn = %v, want %v", result.cn, conn) + } + if result.err != nil { + t.Errorf("result.err = %v, want nil", result.err) + } + case <-time.After(time.Millisecond): + t.Error("Expected result to be sent to channel") + } +} + +func TestWantConn_tryDeliver_WithError(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + + testErr := errors.New("test error") + + // Test delivery with error + delivered := w.tryDeliver(nil, testErr) + if !delivered { + t.Error("tryDeliver() = false, want true") + } + + // Check result + select { + case result := <-w.result: + if result.cn != nil { + t.Errorf("result.cn = %v, want nil", result.cn) + } + if result.err != testErr { + t.Errorf("result.err = %v, want %v", result.err, testErr) + } + case <-time.After(time.Millisecond): + t.Error("Expected result to be sent to channel") + } +} + +func TestWantConn_tryDeliver_AlreadyDone(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + done: true, // Already done + result: make(chan wantConnResult, 1), + } + + // Test delivery when already done + delivered := w.tryDeliver(&Conn{}, nil) + if delivered { + t.Error("tryDeliver() = true, want false when already done") + } + + // Check that no result is sent + select { + case <-w.result: + t.Error("No result should be sent when already done") + case <-time.After(time.Millisecond): + // Expected + } +} + +func TestWantConn_cancel_NotDone(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + + // Test cancel when not done + cn := w.cancel() + + // Should return nil since no connection was not delivered + if cn != nil { + t.Errorf("cancel()= %v, want nil when no connection delivered", cn) + } + + // Check that wantConn is marked as done + if !w.done { + t.Error("wantConn.done = false, want true after cancel") + } + + // Check that context is cleared + if w.ctx != nil { + t.Error("wantConn.ctx should be nil after cancel") + } + + // Check that channel is closed + select { + case _, ok := <-w.result: + if ok { + t.Error("result channel should be closed after cancel") + } + case <-time.After(time.Millisecond): + t.Error("Expected channel to be closed") + } +} + +func TestWantConn_cancel_AlreadyDone(t *testing.T) { + w := &wantConn{ + ctx: context.Background(), + done: true, + result: make(chan wantConnResult, 1), + } + + // Put a result in the channel without connection (to avoid nil pointer issues) + testErr := errors.New("test error") + w.result <- wantConnResult{cn: nil, err: testErr} + + // Test cancel when already done + cn := w.cancel() + + // Should return nil since the result had no connection + if cn != nil { + t.Errorf("cancel()= %v, want nil when result had no connection", cn) + } + + // Check that wantConn remains done + if !w.done { + t.Error("wantConn.done = false, want true") + } + + // Check that context is cleared + if w.ctx != nil { + t.Error("wantConn.ctx should be nil after cancel") + } +} + +func TestWantConnQueue_newWantConnQueue(t *testing.T) { + q := newWantConnQueue() + if q == nil { + t.Fatal("newWantConnQueue() returned nil") + } + if q.items == nil { + t.Error("queue items should be initialized") + } + if len(q.items) != 0 { + t.Errorf("new queue length = %d, want 0", len(q.items)) + } +} + +func TestWantConnQueue_enqueue_dequeue(t *testing.T) { + q := newWantConnQueue() + + // Test dequeue from empty queue + item, ok := q.dequeue() + if ok { + t.Error("dequeue() from empty queue should return false") + } + if item != nil { + t.Error("dequeue() from empty queue should return nil") + } + + // Create test wantConn items + w1 := &wantConn{ctx: context.Background(), result: make(chan wantConnResult, 1)} + w2 := &wantConn{ctx: context.Background(), result: make(chan wantConnResult, 1)} + w3 := &wantConn{ctx: context.Background(), result: make(chan wantConnResult, 1)} + + // Test enqueue + q.enqueue(w1) + q.enqueue(w2) + q.enqueue(w3) + + // Test FIFO behavior + item, ok = q.dequeue() + if !ok { + t.Error("dequeue() should return true when queue has items") + } + if item != w1 { + t.Errorf("dequeue() = %v, want %v (FIFO order)", item, w1) + } + + item, ok = q.dequeue() + if !ok { + t.Error("dequeue() should return true when queue has items") + } + if item != w2 { + t.Errorf("dequeue() = %v, want %v (FIFO order)", item, w2) + } + + item, ok = q.dequeue() + if !ok { + t.Error("dequeue() should return true when queue has items") + } + if item != w3 { + t.Errorf("dequeue() = %v, want %v (FIFO order)", item, w3) + } + + // Test dequeue from empty queue again + item, ok = q.dequeue() + if ok { + t.Error("dequeue() from empty queue should return false") + } + if item != nil { + t.Error("dequeue() from empty queue should return nil") + } +} + +func TestWantConnQueue_ConcurrentAccess(t *testing.T) { + q := newWantConnQueue() + const numWorkers = 10 + const itemsPerWorker = 100 + + var wg sync.WaitGroup + + // Start enqueuers + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + }() + } + + // Start dequeuers + dequeued := make(chan *wantConn, numWorkers*itemsPerWorker) + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + for { + if item, ok := q.dequeue(); ok { + dequeued <- item + break + } + // Small delay to avoid busy waiting + time.Sleep(time.Microsecond) + } + } + }() + } + + wg.Wait() + close(dequeued) + + // Count dequeued items + count := 0 + for range dequeued { + count++ + } + + expectedCount := numWorkers * itemsPerWorker + if count != expectedCount { + t.Errorf("dequeued %d items, want %d", count, expectedCount) + } + + // Queue should be empty + if item, ok := q.dequeue(); ok { + t.Errorf("queue should be empty but got item: %v", item) + } +} + +func TestWantConnQueue_ThreadSafety(t *testing.T) { + q := newWantConnQueue() + const numOperations = 1000 + + var wg sync.WaitGroup + errors := make(chan error, numOperations*2) + + // Concurrent enqueue operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numOperations; i++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + }() + + // Concurrent dequeue operations + wg.Add(1) + go func() { + defer wg.Done() + dequeued := 0 + for dequeued < numOperations { + if _, ok := q.dequeue(); ok { + dequeued++ + } else { + // Small delay when queue is empty + time.Sleep(time.Microsecond) + } + } + }() + + // Wait for completion + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + t.Error(err) + } + + // Final queue should be empty + if item, ok := q.dequeue(); ok { + t.Errorf("queue should be empty but got item: %v", item) + } +} + +// Benchmark tests +func BenchmarkWantConnQueue_Enqueue(b *testing.B) { + q := newWantConnQueue() + + // Pre-allocate a pool of wantConn to reuse + const poolSize = 1000 + wantConnPool := make([]*wantConn, poolSize) + for i := 0; i < poolSize; i++ { + wantConnPool[i] = &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + w := wantConnPool[i%poolSize] + q.enqueue(w) + } +} + +func BenchmarkWantConnQueue_Dequeue(b *testing.B) { + q := newWantConnQueue() + + // Use a reasonable fixed size for pre-population to avoid memory issues + const queueSize = 10000 + + // Pre-populate queue with a fixed reasonable size + for i := 0; i < queueSize; i++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + + b.ResetTimer() + + // Benchmark dequeue operations, refilling as needed + for i := 0; i < b.N; i++ { + if _, ok := q.dequeue(); !ok { + // Queue is empty, refill a batch + for j := 0; j < 1000; j++ { + w := &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + q.enqueue(w) + } + // Dequeue again + q.dequeue() + } + } +} + +func BenchmarkWantConnQueue_EnqueueDequeue(b *testing.B) { + q := newWantConnQueue() + + // Pre-allocate a pool of wantConn to reuse + const poolSize = 1000 + wantConnPool := make([]*wantConn, poolSize) + for i := 0; i < poolSize; i++ { + wantConnPool[i] = &wantConn{ + ctx: context.Background(), + result: make(chan wantConnResult, 1), + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + w := wantConnPool[i%poolSize] + q.enqueue(w) + q.dequeue() + } +} diff --git a/maintnotifications/e2e/command_runner_test.go b/maintnotifications/e2e/command_runner_test.go index b80a434b..27c19c3a 100644 --- a/maintnotifications/e2e/command_runner_test.go +++ b/maintnotifications/e2e/command_runner_test.go @@ -20,6 +20,7 @@ type CommandRunnerStats struct { // CommandRunner provides utilities for running commands during tests type CommandRunner struct { + executing atomic.Bool client redis.UniversalClient stopCh chan struct{} operationCount atomic.Int64 @@ -56,6 +57,10 @@ func (cr *CommandRunner) Close() { // FireCommandsUntilStop runs commands continuously until stop signal func (cr *CommandRunner) FireCommandsUntilStop(ctx context.Context) { + if !cr.executing.CompareAndSwap(false, true) { + return + } + defer cr.executing.Store(false) fmt.Printf("[CR] Starting command runner...\n") defer fmt.Printf("[CR] Command runner stopped\n") // High frequency for timeout testing diff --git a/maintnotifications/e2e/scenario_push_notifications_test.go b/maintnotifications/e2e/scenario_push_notifications_test.go index ccc648b0..80511494 100644 --- a/maintnotifications/e2e/scenario_push_notifications_test.go +++ b/maintnotifications/e2e/scenario_push_notifications_test.go @@ -297,12 +297,6 @@ func TestPushNotifications(t *testing.T) { // once moving is received, start a second client commands runner p("Starting commands on second client") go commandsRunner2.FireCommandsUntilStop(ctx) - defer func() { - // stop the second runner - commandsRunner2.Stop() - // destroy the second client - factory.Destroy("push-notification-client-2") - }() p("Waiting for MOVING notification on second client") matchNotif, fnd := tracker2.FindOrWaitForNotification("MOVING", 3*time.Minute) @@ -393,11 +387,15 @@ func TestPushNotifications(t *testing.T) { p("MOVING notification test completed successfully") - p("Executing commands and collecting logs for analysis... This will take 30 seconds...") + p("Executing commands and collecting logs for analysis... ") go commandsRunner.FireCommandsUntilStop(ctx) - time.Sleep(time.Minute) + go commandsRunner2.FireCommandsUntilStop(ctx) + go commandsRunner3.FireCommandsUntilStop(ctx) + time.Sleep(2 * time.Minute) commandsRunner.Stop() - time.Sleep(time.Minute) + commandsRunner2.Stop() + commandsRunner3.Stop() + time.Sleep(1 * time.Minute) allLogsAnalysis := logCollector.GetAnalysis() trackerAnalysis := tracker.GetAnalysis() @@ -438,33 +436,35 @@ func TestPushNotifications(t *testing.T) { e("No logs found for connection %d", connID) } } + // checks are tracker >= logs since the tracker only tracks client1 + // logs include all clients (and some of them start logging even before all hooks are setup) + // for example for idle connections if they receive a notification before the hook is setup + // the action (i.e. relaxing timeouts) will be logged, but the notification will not be tracked and maybe wont be logged // validate number of notifications in tracker matches number of notifications in logs // allow for more moving in the logs since we started a second client if trackerAnalysis.TotalNotifications > allLogsAnalysis.TotalNotifications { - e("Expected %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications) + e("Expected at least %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications) } - // and per type - // allow for more moving in the logs since we started a second client if trackerAnalysis.MovingCount > allLogsAnalysis.MovingCount { - e("Expected %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount) + e("Expected at least %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount) } - if trackerAnalysis.MigratingCount != allLogsAnalysis.MigratingCount { - e("Expected %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount) + if trackerAnalysis.MigratingCount > allLogsAnalysis.MigratingCount { + e("Expected at least %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount) } - if trackerAnalysis.MigratedCount != allLogsAnalysis.MigratedCount { - e("Expected %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount) + if trackerAnalysis.MigratedCount > allLogsAnalysis.MigratedCount { + e("Expected at least %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount) } - if trackerAnalysis.FailingOverCount != allLogsAnalysis.FailingOverCount { - e("Expected %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount) + if trackerAnalysis.FailingOverCount > allLogsAnalysis.FailingOverCount { + e("Expected at least %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount) } - if trackerAnalysis.FailedOverCount != allLogsAnalysis.FailedOverCount { - e("Expected %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount) + if trackerAnalysis.FailedOverCount > allLogsAnalysis.FailedOverCount { + e("Expected at least %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount) } if trackerAnalysis.UnexpectedNotificationCount != allLogsAnalysis.UnexpectedCount { @@ -472,11 +472,11 @@ func TestPushNotifications(t *testing.T) { } // unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it - if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving { - e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount) + if trackerAnalysis.UnrelaxedTimeoutCount > allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving { + e("Expected at least %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving) } - if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount { - e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount) + if trackerAnalysis.RelaxedTimeoutCount > allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount { + e("Expected at least %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount) } // validate all handoffs succeeded diff --git a/maintnotifications/pool_hook_test.go b/maintnotifications/pool_hook_test.go index 41120af2..6ec61eed 100644 --- a/maintnotifications/pool_hook_test.go +++ b/maintnotifications/pool_hook_test.go @@ -174,7 +174,7 @@ func TestConnectionHook(t *testing.T) { select { case <-initConnCalled: // Good, initialization was called - case <-time.After(1 * time.Second): + case <-time.After(5 * time.Second): t.Fatal("Timeout waiting for initialization function to be called") } diff --git a/options.go b/options.go index a55beed2..ea5f4fa5 100644 --- a/options.go +++ b/options.go @@ -34,7 +34,6 @@ type Limiter interface { // Options keeps the settings to set up redis connection. type Options struct { - // Network type, either tcp or unix. // // default: is tcp. @@ -184,6 +183,10 @@ type Options struct { // default: 10 * runtime.GOMAXPROCS(0) PoolSize int + // MaxConcurrentDials is the maximum number of concurrent connection creation goroutines. + // If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize. + MaxConcurrentDials int + // PoolTimeout is the amount of time client waits for connection if all connections // are busy before returning an error. // @@ -309,6 +312,11 @@ func (opt *Options) init() { if opt.PoolSize == 0 { opt.PoolSize = 10 * runtime.GOMAXPROCS(0) } + if opt.MaxConcurrentDials <= 0 { + opt.MaxConcurrentDials = opt.PoolSize + } else if opt.MaxConcurrentDials > opt.PoolSize { + opt.MaxConcurrentDials = opt.PoolSize + } if opt.ReadBufferSize == 0 { opt.ReadBufferSize = proto.DefaultBufferSize } @@ -636,6 +644,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) { o.MinIdleConns = q.int("min_idle_conns") o.MaxIdleConns = q.int("max_idle_conns") o.MaxActiveConns = q.int("max_active_conns") + o.MaxConcurrentDials = q.int("max_concurrent_dials") if q.has("conn_max_idle_time") { o.ConnMaxIdleTime = q.duration("conn_max_idle_time") } else { @@ -702,6 +711,7 @@ func newConnPool( }, PoolFIFO: opt.PoolFIFO, PoolSize: poolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, DialerRetries: opt.DialerRetries, @@ -742,6 +752,7 @@ func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr return pool.NewPubSubPool(&pool.Options{ PoolFIFO: opt.PoolFIFO, PoolSize: poolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, DialerRetries: opt.DialerRetries, diff --git a/options_test.go b/options_test.go index 8de4986b..32d75e25 100644 --- a/options_test.go +++ b/options_test.go @@ -67,6 +67,12 @@ func TestParseURL(t *testing.T) { }, { url: "redis://localhost:123/?db=2&protocol=2", // RESP Protocol o: &Options{Addr: "localhost:123", DB: 2, Protocol: 2}, + }, { + url: "redis://localhost:123/?max_concurrent_dials=5", // MaxConcurrentDials parameter + o: &Options{Addr: "localhost:123", MaxConcurrentDials: 5}, + }, { + url: "redis://localhost:123/?max_concurrent_dials=0", // MaxConcurrentDials zero value + o: &Options{Addr: "localhost:123", MaxConcurrentDials: 0}, }, { url: "unix:///tmp/redis.sock", o: &Options{Addr: "/tmp/redis.sock"}, @@ -197,6 +203,9 @@ func comprareOptions(t *testing.T, actual, expected *Options) { if actual.ConnMaxLifetime != expected.ConnMaxLifetime { t.Errorf("ConnMaxLifetime: got %v, expected %v", actual.ConnMaxLifetime, expected.ConnMaxLifetime) } + if actual.MaxConcurrentDials != expected.MaxConcurrentDials { + t.Errorf("MaxConcurrentDials: got %v, expected %v", actual.MaxConcurrentDials, expected.MaxConcurrentDials) + } } // Test ReadTimeout option initialization, including special values -1 and 0. @@ -245,3 +254,68 @@ func TestProtocolOptions(t *testing.T) { } } } + +func TestMaxConcurrentDialsOptions(t *testing.T) { + // Test cases for MaxConcurrentDials initialization logic + testCases := []struct { + name string + poolSize int + maxConcurrentDials int + expectedConcurrentDials int + }{ + // Edge cases and invalid values - negative/zero values set to PoolSize + { + name: "negative value gets set to pool size", + poolSize: 10, + maxConcurrentDials: -1, + expectedConcurrentDials: 10, // negative values are set to PoolSize + }, + // Zero value tests - MaxConcurrentDials should be set to PoolSize + { + name: "zero value with positive pool size", + poolSize: 1, + maxConcurrentDials: 0, + expectedConcurrentDials: 1, // MaxConcurrentDials = PoolSize when 0 + }, + // Explicit positive value tests + { + name: "explicit value within limit", + poolSize: 10, + maxConcurrentDials: 3, + expectedConcurrentDials: 3, // should remain unchanged when < PoolSize + }, + // Capping tests - values exceeding PoolSize should be capped + { + name: "value exceeding pool size", + poolSize: 5, + maxConcurrentDials: 10, + expectedConcurrentDials: 5, // should be capped at PoolSize + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := &Options{ + PoolSize: tc.poolSize, + MaxConcurrentDials: tc.maxConcurrentDials, + } + opts.init() + + if opts.MaxConcurrentDials != tc.expectedConcurrentDials { + t.Errorf("MaxConcurrentDials: got %v, expected %v (PoolSize=%v)", + opts.MaxConcurrentDials, tc.expectedConcurrentDials, opts.PoolSize) + } + + // Ensure MaxConcurrentDials never exceeds PoolSize (for all inputs) + if opts.MaxConcurrentDials > opts.PoolSize { + t.Errorf("MaxConcurrentDials (%v) should not exceed PoolSize (%v)", + opts.MaxConcurrentDials, opts.PoolSize) + } + + // Ensure MaxConcurrentDials is always positive (for all inputs) + if opts.MaxConcurrentDials <= 0 { + t.Errorf("MaxConcurrentDials should be positive, got %v", opts.MaxConcurrentDials) + } + }) + } +} diff --git a/pool_pubsub_bench_test.go b/pool_pubsub_bench_test.go index 0db8ec55..d7f0f185 100644 --- a/pool_pubsub_bench_test.go +++ b/pool_pubsub_bench_test.go @@ -70,12 +70,13 @@ func BenchmarkPoolGetPut(b *testing.B) { for _, poolSize := range poolSizes { b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(poolSize), - PoolTimeout: time.Second, - DialTimeout: time.Second, - ConnMaxIdleTime: time.Hour, - MinIdleConns: int32(0), // Start with no idle connections + Dialer: dummyDialer, + PoolSize: int32(poolSize), + MaxConcurrentDials: poolSize, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), // Start with no idle connections }) defer connPool.Close() @@ -112,12 +113,13 @@ func BenchmarkPoolGetPutWithMinIdle(b *testing.B) { for _, config := range configs { b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(config.poolSize), - MinIdleConns: int32(config.minIdleConns), - PoolTimeout: time.Second, - DialTimeout: time.Second, - ConnMaxIdleTime: time.Hour, + Dialer: dummyDialer, + PoolSize: int32(config.poolSize), + MaxConcurrentDials: config.poolSize, + MinIdleConns: int32(config.minIdleConns), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, }) defer connPool.Close() @@ -142,12 +144,13 @@ func BenchmarkPoolConcurrentGetPut(b *testing.B) { ctx := context.Background() connPool := pool.NewConnPool(&pool.Options{ - Dialer: dummyDialer, - PoolSize: int32(32), - PoolTimeout: time.Second, - DialTimeout: time.Second, - ConnMaxIdleTime: time.Hour, - MinIdleConns: int32(0), + Dialer: dummyDialer, + PoolSize: int32(32), + MaxConcurrentDials: 32, + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), }) defer connPool.Close() diff --git a/redis.go b/redis.go index fdd3027b..ac97c2ca 100644 --- a/redis.go +++ b/redis.go @@ -1181,28 +1181,6 @@ func (c *Client) TxPipeline() Pipeliner { return &pipe } -// AutoPipeline creates a new autopipeliner that automatically batches commands. -// Commands are automatically flushed based on batch size and time interval. -// The autopipeliner must be closed when done to flush pending commands. -// -// Example: -// -// ap := client.AutoPipeline() -// defer ap.Close() -// -// for i := 0; i < 1000; i++ { -// ap.Do(ctx, "SET", fmt.Sprintf("key%d", i), i) -// } -// -// Note: AutoPipeline requires AutoPipelineConfig to be set in Options. -// If not set, this will panic. -func (c *Client) AutoPipeline() *AutoPipeliner { - if c.opt.AutoPipelineConfig == nil { - c.opt.AutoPipelineConfig = DefaultAutoPipelineConfig() - } - return NewAutoPipeliner(c, c.opt.AutoPipelineConfig) -} - func (c *Client) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, @@ -1388,11 +1366,13 @@ func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn // If the connection was health-checked within the last 5 seconds, we can skip the // expensive syscall since the health check already verified no unexpected data. // This is safe because: + // 0. lastHealthCheckNs is set in pool/conn.go:putConn() after a successful health check // 1. Health check (connCheck) uses the same syscall (Recvfrom with MSG_PEEK) // 2. If push notifications arrived, they would have been detected by health check // 3. 5 seconds is short enough that connection state is still fresh // 4. Push notifications will be processed by the next WithReader call - lastHealthCheckNs := cn.UsedAtNs() + // used it is set on getConn, so we should use another timer (lastPutAt?) + lastHealthCheckNs := cn.LastPutAtNs() if lastHealthCheckNs > 0 { // Use pool's cached time to avoid expensive time.Now() syscall nowNs := pool.GetCachedTimeNs() diff --git a/redis_test.go b/redis_test.go index 9dd00f19..bc0db6ad 100644 --- a/redis_test.go +++ b/redis_test.go @@ -245,11 +245,21 @@ var _ = Describe("Client", func() { Expect(val).Should(HaveKeyWithValue("proto", int64(3))) }) - It("should initialize idle connections created by MinIdleConns", func() { + It("should initialize idle connections created by MinIdleConns", Label("NonRedisEnterprise"), func() { opt := redisOptions() + passwrd := "asdf" + db0 := redis.NewClient(opt) + // set password + err := db0.Do(ctx, "CONFIG", "SET", "requirepass", passwrd).Err() + Expect(err).NotTo(HaveOccurred()) + defer func() { + err = db0.Do(ctx, "CONFIG", "SET", "requirepass", "").Err() + Expect(err).NotTo(HaveOccurred()) + Expect(db0.Close()).NotTo(HaveOccurred()) + }() opt.MinIdleConns = 5 - opt.Password = "asdf" // Set password to require AUTH - opt.DB = 1 // Set DB to require SELECT + opt.Password = passwrd + opt.DB = 1 // Set DB to require SELECT db := redis.NewClient(opt) defer func() {