diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go index 9cf553dc..30afa71d 100644 --- a/async_handoff_integration_test.go +++ b/async_handoff_integration_test.go @@ -17,12 +17,12 @@ type mockNetConn struct { } func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } -func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } -func (m *mockNetConn) Close() error { return nil } -func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } -func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } -func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } -func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } type mockAddr struct { @@ -152,8 +152,8 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(10), - PoolTimeout: time.Second, + PoolSize: int32(10), + PoolTimeout: time.Second, }) defer testPool.Close() @@ -175,7 +175,7 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { // Get connection conn, err := testPool.Get(ctx) if err != nil { - t.Errorf("Failed to get connection %d: %v", id, err) + t.Errorf("Failed to get conn[%d]: %v", id, err) return } @@ -224,8 +224,8 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(3), - PoolTimeout: time.Second, + PoolSize: int32(3), + PoolTimeout: time.Second, }) defer testPool.Close() @@ -287,8 +287,8 @@ func TestEventDrivenHandoffIntegration(t *testing.T) { return &mockNetConn{addr: "original:6379"}, nil }, - PoolSize: int32(2), - PoolTimeout: time.Second, + PoolSize: int32(2), + PoolTimeout: time.Second, }) defer testPool.Close() diff --git a/hitless/README.md b/hitless/README.md index 8d9a021f..6fc805f2 100644 --- a/hitless/README.md +++ b/hitless/README.md @@ -39,27 +39,31 @@ Config: &hitless.Config{ PostHandoffRelaxedDuration: 20 * time.Second, // Keep relaxed timeout after handoff LogLevel: logging.LogLevelWarn, // LogLevelError, LogLevelWarn, LogLevelInfo, LogLevelDebug MaxWorkers: 15, // Concurrent handoff workers - HandoffQueueSize: 50, // Handoff request queue size + HandoffQueueSize: 300, // Handoff request queue size } ``` ### Worker Scaling -- **Auto-calculated**: `min(10, PoolSize/3)` - scales with pool size, capped at 10 -- **Explicit values**: `max(10, set_value)` - enforces minimum 10 workers +- **Auto-calculated**: `min(PoolSize/2, max(10, PoolSize/3))` - balanced scaling approach +- **Explicit values**: `max(PoolSize/2, set_value)` - enforces minimum PoolSize/2 workers - **On-demand**: Workers created when needed, cleaned up when idle ### Queue Sizing -- **Auto-calculated**: `max(8 × MaxWorkers, max(50, PoolSize/2))` - hybrid scaling - - Worker-based: 8 handoffs per worker for burst processing - - Pool-based: Scales with pool size (minimum 50, up to PoolSize/2) +- **Auto-calculated**: `max(20 × MaxWorkers, PoolSize)` - hybrid scaling + - Worker-based: 20 handoffs per worker for burst processing + - Pool-based: Scales directly with pool size - Takes the larger of the two for optimal performance -- **Explicit values**: `max(50, set_value)` - enforces minimum 50 when set -- **Always capped**: Queue size never exceeds `2 × PoolSize` for memory efficiency +- **Explicit values**: `max(200, set_value)` - enforces minimum 200 when set +- **Capping**: Queue size capped by `MaxActiveConns+1` (if set) or `5 × PoolSize` for memory efficiency -**Examples:** -- Pool 10: Queue 50 (max(8×3, max(50, 5)) = max(24, 50) = 50) -- Pool 100: Queue 80 (max(8×10, max(50, 50)) = max(80, 50) = 80) -- Pool 200: Queue 100 (max(8×10, max(50, 100)) = max(80, 100) = 100) +**Examples (without MaxActiveConns):** +- Pool 10: Workers 5, Queue 100 (max(20×5, 10) = 100, capped at 5×10 = 50) +- Pool 100: Workers 33, Queue 660 (max(20×33, 100) = 660, capped at 5×100 = 500) +- Pool 200: Workers 66, Queue 1320 (max(20×66, 200) = 1320, capped at 5×200 = 1000) + +**Examples (with MaxActiveConns=150):** +- Pool 100: Workers 33, Queue 151 (max(20×33, 100) = 660, capped at MaxActiveConns+1 = 151) +- Pool 200: Workers 66, Queue 151 (max(20×66, 200) = 1320, capped at MaxActiveConns+1 = 151) ## Notification Hooks @@ -94,7 +98,7 @@ type CustomHook struct{} func (h *CustomHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { // Log notification with connection details if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { - log.Printf("Processing %s on connection %d", notificationType, conn.GetID()) + log.Printf("Processing %s on conn[%d]", notificationType, conn.GetID()) } return notification, true // Continue processing } diff --git a/hitless/config.go b/hitless/config.go index a5a5e162..212bc77d 100644 --- a/hitless/config.go +++ b/hitless/config.go @@ -103,18 +103,18 @@ type Config struct { // MaxWorkers is the maximum number of worker goroutines for processing handoff requests. // Workers are created on-demand and automatically cleaned up when idle. - // If zero, defaults to min(10, PoolSize/3) to handle bursts effectively. - // If explicitly set, enforces minimum of 10 workers. + // If zero, defaults to min(10, PoolSize/2) to handle bursts effectively. + // If explicitly set, enforces minimum of PoolSize/2 // - // Default: min(10, PoolSize/3), Minimum when set: 10 + // Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2 MaxWorkers int // HandoffQueueSize is the size of the buffered channel used to queue handoff requests. // If the queue is full, new handoff requests will be rejected. // Scales with both worker count and pool size for better burst handling. // - // Default: max(8x workers, max(50, PoolSize/2)), capped by 2x pool size - // When set: min 50, capped by 2x pool size + // Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize + // When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize HandoffQueueSize int // PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection @@ -122,11 +122,6 @@ type Config struct { // Default: 2 * RelaxedTimeout PostHandoffRelaxedDuration time.Duration - // ScaleDownDelay is the delay before checking if workers should be scaled down. - // This prevents expensive checks on every handoff completion and avoids rapid scaling cycles. - // Default: 2 seconds - ScaleDownDelay time.Duration - // LogLevel controls the verbosity of hitless upgrade logging. // LogLevelError (0) = errors only, LogLevelWarn (1) = warnings, LogLevelInfo (2) = info, LogLevelDebug (3) = debug // Default: LogLevelWarn (warnings) @@ -152,7 +147,6 @@ func DefaultConfig() *Config { MaxWorkers: 0, // Auto-calculated based on pool size HandoffQueueSize: 0, // Auto-calculated based on max workers PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout - ScaleDownDelay: 2 * time.Second, LogLevel: LogLevelWarn, // Connection Handoff Configuration @@ -212,6 +206,13 @@ func (c *Config) ApplyDefaults() *Config { // using the provided pool size to calculate worker defaults. // This ensures that partially configured structs get sensible defaults for missing fields. func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { + return c.ApplyDefaultsWithPoolConfig(poolSize, 0) +} + +// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration, +// using the provided pool size and max active connections to calculate worker and queue defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config { if c == nil { return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize) } @@ -251,19 +252,25 @@ func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { // Apply worker defaults based on pool size result.applyWorkerDefaults(poolSize) - // Apply queue size defaults with hybrid scaling approach + // Apply queue size defaults with new scaling approach if c.HandoffQueueSize <= 0 { - // Default: max(8x workers, max(50, PoolSize/2)), capped by 2x pool size - workerBasedSize := result.MaxWorkers * 8 - poolBasedSize := util.Max(50, poolSize/2) + // Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size + workerBasedSize := result.MaxWorkers * 20 + poolBasedSize := poolSize result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize) } else { - // When explicitly set: enforce minimum of 50 - result.HandoffQueueSize = util.Max(50, c.HandoffQueueSize) + // When explicitly set: enforce minimum of 200 + result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize) } - // Always cap queue size by 2x pool size - balances burst capacity with memory efficiency - result.HandoffQueueSize = util.Min(result.HandoffQueueSize, poolSize*2) + // Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size + var queueCap int + if maxActiveConns > 0 { + queueCap = maxActiveConns + 1 + } else { + queueCap = poolSize * 5 + } + result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap) // Ensure minimum queue size of 2 (fallback for very small pools) if result.HandoffQueueSize < 2 { @@ -276,12 +283,6 @@ func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration } - if c.ScaleDownDelay <= 0 { - result.ScaleDownDelay = defaults.ScaleDownDelay - } else { - result.ScaleDownDelay = c.ScaleDownDelay - } - // LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set // We'll use the provided value as-is, since 0 is valid result.LogLevel = c.LogLevel @@ -314,7 +315,6 @@ func (c *Config) Clone() *Config { MaxWorkers: c.MaxWorkers, HandoffQueueSize: c.HandoffQueueSize, PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration, - ScaleDownDelay: c.ScaleDownDelay, LogLevel: c.LogLevel, // Configuration fields @@ -330,11 +330,11 @@ func (c *Config) applyWorkerDefaults(poolSize int) { } if c.MaxWorkers == 0 { - // When not set: min(10, poolSize/3) - don't exceed 10 workers for small pools - c.MaxWorkers = util.Min(10, poolSize/3) + // When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach + c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3)) } else { - // When explicitly set: max(10, set_value) - ensure at least 10 workers - c.MaxWorkers = util.Max(10, c.MaxWorkers) + // When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers + c.MaxWorkers = util.Max(poolSize/2, c.MaxWorkers) } // Ensure minimum of 1 worker (fallback for very small pools) diff --git a/hitless/errors.go b/hitless/errors.go index c0ae353b..38e275c4 100644 --- a/hitless/errors.go +++ b/hitless/errors.go @@ -43,3 +43,8 @@ var ( // ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff") ) + +// general errors +var ( + ErrShutdown = errors.New("hitless: shutdown") +) diff --git a/hitless/example_hooks.go b/hitless/example_hooks.go index 0b65f1f5..92bdddcf 100644 --- a/hitless/example_hooks.go +++ b/hitless/example_hooks.go @@ -38,7 +38,7 @@ func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.Notific // Log connection information if available if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { - internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on connection %d", notificationType, conn.GetID()) + internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on conn[%d]", notificationType, conn.GetID()) } // Store start time in context for duration calculation @@ -62,7 +62,7 @@ func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.Notifi // Log error details with connection information if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { - internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on connection %d: %v", notificationType, conn.GetID(), result) + internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on conn[%d]: %v", notificationType, conn.GetID(), result) } } } diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index cf924a4c..2d58fb0d 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -2,7 +2,6 @@ package hitless import ( "context" - "errors" "net" "sync" "sync/atomic" @@ -44,9 +43,10 @@ type PoolHook struct { workerWg sync.WaitGroup // Track worker goroutines // On-demand worker management - maxWorkers int - activeWorkers int32 // Atomic counter for active workers - workerTimeout time.Duration // How long workers wait for work before exiting + maxWorkers int + activeWorkers atomic.Int32 + workerTimeout time.Duration // How long workers wait for work before exiting + workersScaling atomic.Bool // Simple state tracking pending sync.Map // map[uint64]int64 (connID -> seqID) @@ -82,8 +82,9 @@ func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (n // shutdown is a channel for signaling shutdown shutdown: make(chan struct{}), maxWorkers: config.MaxWorkers, - activeWorkers: 0, // Start with no workers - create on demand - workerTimeout: 30 * time.Second, // Workers exit after 30s of inactivity + activeWorkers: atomic.Int32{}, // Start with no workers - create on demand + // NOTE: maybe we would like to make this configurable? + workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity config: config, // Hitless manager for operation completion tracking hitlessManager: hitlessManager, @@ -98,15 +99,7 @@ func (ph *PoolHook) SetPool(pooler pool.Pooler) { // GetCurrentWorkers returns the current number of active workers (for testing) func (ph *PoolHook) GetCurrentWorkers() int { - return int(atomic.LoadInt32(&ph.activeWorkers)) -} - -// GetScaleLevel returns 1 if workers are active, 0 if none (for testing compatibility) -func (ph *PoolHook) GetScaleLevel() int { - if atomic.LoadInt32(&ph.activeWorkers) > 0 { - return 1 - } - return 0 + return int(ph.activeWorkers.Load()) } // IsHandoffPending returns true if the given connection has a pending handoff @@ -181,14 +174,21 @@ func (ph *PoolHook) ensureWorkerAvailable() { case <-ph.shutdown: return default: - // Check if we need a new worker - currentWorkers := atomic.LoadInt32(&ph.activeWorkers) - if currentWorkers < int32(ph.maxWorkers) { - // Try to create a new worker (atomic increment to prevent race) - if atomic.CompareAndSwapInt32(&ph.activeWorkers, currentWorkers, currentWorkers+1) { + if ph.workersScaling.CompareAndSwap(false, true) { + defer ph.workersScaling.Store(false) + // Check if we need a new worker + currentWorkers := ph.activeWorkers.Load() + workersWas := currentWorkers + for currentWorkers <= int32(ph.maxWorkers) { ph.workerWg.Add(1) go ph.onDemandWorker() + currentWorkers++ } + // workersWas is always <= currentWorkers + // currentWorkers will be maxWorkers, but if we have a worker that was closed + // while we were creating new workers, just add the difference between + // the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created) + ph.activeWorkers.Add(currentWorkers - workersWas) } } } @@ -197,12 +197,21 @@ func (ph *PoolHook) ensureWorkerAvailable() { func (ph *PoolHook) onDemandWorker() { defer func() { // Decrement active worker count when exiting - atomic.AddInt32(&ph.activeWorkers, -1) + ph.activeWorkers.Add(-1) ph.workerWg.Done() }() for { select { + case <-ph.shutdown: + return + case <-time.After(ph.workerTimeout): + // Worker has been idle for too long, exit to save resources + if ph.config != nil && ph.config.LogLevel >= LogLevelDebug { // Debug level + internal.Logger.Printf(context.Background(), + "hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout) + } + return case request := <-ph.handoffQueue: // Check for shutdown before processing select { @@ -214,17 +223,6 @@ func (ph *PoolHook) onDemandWorker() { // Process the request ph.processHandoffRequest(request) } - - case <-time.After(ph.workerTimeout): - // Worker has been idle for too long, exit to save resources - if ph.config != nil && ph.config.LogLevel >= LogLevelDebug { // Debug level - internal.Logger.Printf(context.Background(), - "hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout) - } - return - - case <-ph.shutdown: - return } } } @@ -257,20 +255,21 @@ func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { }() // Perform the handoff with cancellable context - shouldRetry, err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool) + shouldRetry, err := ph.performConnectionHandoff(shutdownCtx, request.Conn) minRetryBackoff := 500 * time.Millisecond if err != nil { if shouldRetry { now := time.Now() deadline, ok := shutdownCtx.Deadline() + thirdOfTimeout := handoffTimeout / 3 if !ok || deadline.Before(now) { // wait half the timeout before retrying if no deadline or deadline has passed - deadline = now.Add(handoffTimeout / 2) + deadline = now.Add(thirdOfTimeout) } afterTime := deadline.Sub(now) - if afterTime > handoffTimeout/2 { - afterTime = handoffTimeout / 2 + if afterTime > thirdOfTimeout { + afterTime = thirdOfTimeout } if afterTime < minRetryBackoff { afterTime = minRetryBackoff @@ -280,12 +279,12 @@ func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { time.AfterFunc(afterTime, func() { if err := ph.queueHandoff(request.Conn); err != nil { internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err) - ph.removeConn(context.Background(), request, err) + ph.closeConnFromRequest(context.Background(), request, err) } }) return } else { - go ph.removeConn(ctx, request, err) + go ph.closeConnFromRequest(ctx, request, err) } // Clear handoff state if not returned for retry @@ -297,21 +296,22 @@ func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { } } -func (ph *PoolHook) removeConn(ctx context.Context, request HandoffRequest, err error) { +// closeConn closes the connection and logs the reason +func (ph *PoolHook) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) { pooler := request.Pool conn := request.Conn if pooler != nil { - pooler.Remove(ctx, conn, err) + pooler.CloseConn(conn) if ph.config != nil && ph.config.LogLevel >= LogLevelWarn { // Warning level internal.Logger.Printf(ctx, - "hitless: removed connection %d from pool due to max handoff retries reached", - conn.GetID()) + "hitless: removed conn[%d] from pool due to max handoff retries reached: %v", + conn.GetID(), err) } } else { conn.Close() if ph.config != nil && ph.config.LogLevel >= LogLevelWarn { // Warning level internal.Logger.Printf(ctx, - "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + "hitless: no pool provided for conn[%d], cannot remove due to handoff initialization failure: %v", conn.GetID(), err) } } @@ -332,11 +332,11 @@ func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { select { // priority to shutdown case <-ph.shutdown: - return errors.New("shutdown") + return ErrShutdown default: select { case <-ph.shutdown: - return errors.New("shutdown") + return ErrShutdown case ph.handoffQueue <- request: // Store in pending map ph.pending.Store(request.ConnID, request.SeqID) @@ -344,13 +344,30 @@ func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { ph.ensureWorkerAvailable() return nil default: - // Queue is full - log and attempt scaling - queueLen := len(ph.handoffQueue) - queueCap := cap(ph.handoffQueue) - if ph.config != nil && ph.config.LogLevel >= LogLevelWarn { // Warning level - internal.Logger.Printf(context.Background(), - "hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers", - queueLen, queueCap) + select { + case <-ph.shutdown: + return ErrShutdown + case ph.handoffQueue <- request: + // Store in pending map + ph.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + ph.ensureWorkerAvailable() + return nil + case <-time.After(100 * time.Millisecond): // give workers a chance to process + // Queue is full - log and attempt scaling + queueLen := len(ph.handoffQueue) + queueCap := cap(ph.handoffQueue) + if ph.config != nil && ph.config.LogLevel >= LogLevelWarn { // Warning level + internal.Logger.Printf(context.Background(), + "hitless: handoff queue is full (%d/%d), cant queue handoff request for conn[%d] seqID[%d]", + queueLen, queueCap, request.ConnID, request.SeqID) + if ph.config.LogLevel >= LogLevelDebug { // Debug level + ph.pending.Range(func(k, v interface{}) bool { + internal.Logger.Printf(context.Background(), "hitless: pending handoff for conn[%d] seqID[%d]", k, v) + return true + }) + } + } } } } @@ -360,9 +377,9 @@ func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { return ErrHandoffQueueFull } -// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure +// performConnectionHandoff performs the actual connection handoff // When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached -func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) (shouldRetry bool, err error) { +func (ph *PoolHook) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) { // Clear handoff state after successful handoff connID := conn.GetID() @@ -381,7 +398,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * if retries > maxRetries { if ph.config != nil && ph.config.LogLevel >= LogLevelWarn { // Warning level internal.Logger.Printf(ctx, - "hitless: reached max retries (%d) for handoff of connection %d to %s", + "hitless: reached max retries (%d) for handoff of conn[%d] to %s", maxRetries, conn.GetID(), conn.GetHandoffEndpoint()) } // won't retry on ErrMaxHandoffRetriesReached @@ -403,6 +420,23 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * // Get the old connection oldConn := conn.GetNetConn() + // Apply relaxed timeout to the new connection for the configured post-handoff duration + // This gives the new connection more time to handle operations during cluster transition + // Setting this here (before initing the connection) ensures that the connection is going + // to use the relaxed timeout for the first operation (auth/ACL select) + if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 { + relaxedTimeout := ph.config.RelaxedTimeout + // Set relaxed timeout with deadline - no background goroutine needed + deadline := time.Now().Add(ph.config.PostHandoffRelaxedDuration) + conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) + + if ph.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v", + connID, relaxedTimeout, deadline.Format("15:04:05.000")) + } + } + // Replace the connection and execute initialization err = conn.SetNetConnAndInitConn(ctx, newNetConn) if err != nil { @@ -419,23 +453,6 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * conn.ClearHandoffState() internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint) - // Apply relaxed timeout to the new connection for the configured post-handoff duration - // This gives the new connection more time to handle operations during cluster transition - if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 { - relaxedTimeout := ph.config.RelaxedTimeout - postHandoffDuration := ph.config.PostHandoffRelaxedDuration - - // Set relaxed timeout with deadline - no background goroutine needed - deadline := time.Now().Add(postHandoffDuration) - conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) - - if ph.config.LogLevel >= 2 { // Info level - internal.Logger.Printf(context.Background(), - "hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v", - connID, relaxedTimeout, deadline.Format("15:04:05.000")) - } - } - return false, nil } diff --git a/hitless/pool_hook_test.go b/hitless/pool_hook_test.go index 1c976ad4..2a3a504a 100644 --- a/hitless/pool_hook_test.go +++ b/hitless/pool_hook_test.go @@ -267,7 +267,7 @@ func TestConnectionHook(t *testing.T) { EndpointType: EndpointTypeAuto, MaxWorkers: 2, HandoffQueueSize: 10, - MaxHandoffRetries: 2, // Reduced retries for faster test + MaxHandoffRetries: 2, // Reduced retries for faster test HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test LogLevel: 2, } @@ -460,7 +460,7 @@ func TestConnectionHook(t *testing.T) { for i := 0; i < 5; i++ { connections[i] = createMockPoolConnection() if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil { - t.Fatalf("Failed to mark connection %d for handoff: %v", i, err) + t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err) } // Set a mock initialization function connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { @@ -510,9 +510,6 @@ func TestConnectionHook(t *testing.T) { if processor.GetCurrentWorkers() != 0 { t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers()) } - if processor.GetScaleLevel() != 0 { - t.Errorf("Processor should be at scale level 0 initially, got %d", processor.GetScaleLevel()) - } if processor.maxWorkers != 15 { t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers) } @@ -528,7 +525,7 @@ func TestConnectionHook(t *testing.T) { config := &Config{ MaxWorkers: 2, HandoffQueueSize: 10, - MaxHandoffRetries: 3, // Allow retries for successful handoff + MaxHandoffRetries: 3, // Allow retries for successful handoff PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing RelaxedTimeout: 5 * time.Second, LogLevel: 2, @@ -718,7 +715,7 @@ func TestConnectionHook(t *testing.T) { for i := 0; i < 10; i++ { conn := createMockPoolConnection() if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil { - t.Fatalf("Failed to mark connection %d for handoff: %v", i, err) + t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err) } // Set a mock initialization function conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { @@ -731,7 +728,7 @@ func TestConnectionHook(t *testing.T) { } if !shouldPool || shouldRemove { - t.Errorf("Connection %d should be pooled after handoff (shouldPool=%v, shouldRemove=%v)", + t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)", i, shouldPool, shouldRemove) } } @@ -795,7 +792,7 @@ func TestConnectionHook(t *testing.T) { // Verify that the connection was removed from the pool if !mockPool.WasRemoved(conn.GetID()) { - t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID()) + t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID()) } t.Logf("Connection removal on handoff failure test completed successfully") diff --git a/hitless/notification_handler.go b/hitless/push_notification_handler.go similarity index 95% rename from hitless/notification_handler.go rename to hitless/push_notification_handler.go index 27cd8d90..9e583ba4 100644 --- a/hitless/notification_handler.go +++ b/hitless/push_notification_handler.go @@ -103,6 +103,15 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus return ErrInvalidNotification } + // If the connection is closed or not pooled, we can ignore the notification + // this connection won't be remembered by the pool and will be garbage collected + // Keep pubsub connections around since they are not pooled but are long-lived + // and should be allowed to handoff (the pubsub instance will reconnect and change + // the underlying *pool.Conn) + if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() { + return nil + } + deadline := time.Now().Add(time.Duration(timeS) * time.Second) // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds if newEndpoint == "" || newEndpoint == internal.RedisNull { @@ -133,6 +142,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err) // Connection is already marked for handoff, which is acceptable // This can happen if multiple MOVING notifications are received for the same connection return nil diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 60875b04..d69915d8 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -45,6 +45,7 @@ type Conn struct { Inited atomic.Bool pooled bool + pubsub bool closed atomic.Bool createdAt time.Time expiresAt time.Time @@ -201,6 +202,16 @@ func (cn *Conn) IsUsable() bool { return cn.isUsable() } +// IsPooled returns true if the connection is managed by a pool and will be pooled on Put. +func (cn *Conn) IsPooled() bool { + return cn.pooled +} + +// IsPubSub returns true if the connection is used for PubSub. +func (cn *Conn) IsPubSub() bool { + return cn.pubsub +} + func (cn *Conn) IsInited() bool { return cn.Inited.Load() } @@ -223,9 +234,7 @@ func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { // After the deadline, timeouts automatically revert to normal values. // Uses atomic operations for lock-free access. func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { - cn.relaxedCounter.Add(1) - cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) - cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) + cn.SetRelaxedTimeout(readTimeout, writeTimeout) cn.relaxedDeadlineNs.Store(deadline.UnixNano()) } @@ -354,15 +363,12 @@ func (cn *Conn) ExecuteInitConn(ctx context.Context) error { if cn.initConnFunc != nil { return cn.initConnFunc(ctx, cn) } - return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID()) + return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID()) } func (cn *Conn) SetNetConn(netConn net.Conn) { // Store the new connection atomically first (lock-free) cn.setNetConn(netConn) - // Clear relaxed timeouts when connection is replaced - cn.clearRelaxedTimeout() - // Protect reader reset operations to avoid data races // Use write lock since we're modifying the reader state cn.readerMu.Lock() diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 10745ed1..e2109591 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -28,13 +28,14 @@ var ( // when popping from the idle connection pool. This handles cases where connections // are temporarily marked as unusable (e.g., during hitless upgrades or network issues). // Value of 50 provides sufficient resilience without excessive overhead. + // This is capped by the idle connection count, so we won't loop excessively. popAttempts = 50 // getAttempts is the maximum number of attempts to get a connection that passes // hook validation (e.g., hitless upgrade hooks). This protects against race conditions // where hooks might temporarily reject connections during cluster transitions. - // Value of 2 balances resilience with performance - most hook rejections resolve quickly. - getAttempts = 2 + // Value of 3 balances resilience with performance - most hook rejections resolve quickly. + getAttempts = 3 minTime = time.Unix(-2208988800, 0) // Jan 1, 1900 maxTime = minTime.Add(1<<63 - 1) @@ -262,7 +263,9 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, ErrPoolExhausted } - cn, err := p.dialConn(ctx, pooled) + dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout) + defer cancel() + cn, err := p.dialConn(dialCtx, pooled) if err != nil { return nil, err } @@ -309,24 +312,55 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, p.getLastDialError() } - netConn, err := p.cfg.Dialer(ctx) - if err != nil { - p.setLastDialError(err) - if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { - go p.tryDial() + // Retry dialing with backoff + // the context timeout is already handled by the context passed in + // so we may never reach the max retries, higher values don't hurt + const maxRetries = 10 + const backoffDuration = 100 * time.Millisecond + + var lastErr error + for attempt := 0; attempt < maxRetries; attempt++ { + // Add backoff delay for retry attempts + // (not for the first attempt, do at least one) + if attempt > 0 { + select { + case <-ctx.Done(): + // we should have lastErr set, but just in case + if lastErr == nil { + lastErr = ctx.Err() + } + break + case <-time.After(backoffDuration): + // Continue with retry + } } - return nil, err + + netConn, err := p.cfg.Dialer(ctx) + if err != nil { + lastErr = err + // Continue to next retry attempt + continue + } + + // Success - create connection + cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) + cn.pooled = pooled + if p.cfg.ConnMaxLifetime > 0 { + cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime) + } else { + cn.expiresAt = noExpiration + } + + return cn, nil } - cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) - cn.pooled = pooled - if p.cfg.ConnMaxLifetime > 0 { - cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime) - } else { - cn.expiresAt = noExpiration + internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr) + // All retries failed - handle error tracking + p.setLastDialError(lastErr) + if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { + go p.tryDial() } - - return cn, nil + return nil, lastErr } func (p *ConnPool) tryDial() { @@ -386,7 +420,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { attempts := 0 for { if attempts >= getAttempts { - internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a connection accepted by hook after %d attempts", attempts) + internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) break } attempts++ @@ -448,7 +482,6 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { return nil, err } } - return newcn, nil } @@ -497,6 +530,7 @@ func (p *ConnPool) popIdle() (*Conn, error) { if p.closed() { return nil, ErrClosed } + defer p.checkMinIdleConns() n := len(p.idleConns) if n == 0 { @@ -540,12 +574,11 @@ func (p *ConnPool) popIdle() (*Conn, error) { } // If we exhausted all attempts without finding a usable connection, return nil - if int32(attempts) >= p.poolSize.Load() { + if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() { internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) return nil, nil } - p.checkMinIdleConns() return cn, nil } @@ -631,11 +664,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) { p.removeConnWithLock(cn) - // Only free a turn if the connection was actually pooled - // This prevents queue imbalance when removing connections that weren't obtained through Get() - if cn.pooled { - p.freeTurn() - } + p.freeTurn() _ = p.closeConn(cn) @@ -655,12 +684,22 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) { } func (p *ConnPool) removeConn(cn *Conn) { - delete(p.conns, cn.GetID()) + cid := cn.GetID() + delete(p.conns, cid) atomic.AddUint32(&p.stats.StaleConns, 1) // Decrement pool size counter when removing a connection if cn.pooled { p.poolSize.Add(-1) + // this can be idle conn + for idx, ic := range p.idleConns { + if ic.GetID() == cid { + internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) + p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) + p.idleConnsLen.Add(-1) + break + } + } } } @@ -745,6 +784,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { return false } + // Check if connection has exceeded idle timeout if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { return false } diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go index c616300f..8b12f74c 100644 --- a/internal/pool/pubsub.go +++ b/internal/pool/pubsub.go @@ -43,6 +43,7 @@ func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, c return nil, err } cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize) + cn.pubsub = true atomic.AddUint32(&p.stats.Created, 1) return cn, nil @@ -55,7 +56,7 @@ func (p *PubSubPool) TrackConn(cn *Conn) { func (p *PubSubPool) UntrackConn(cn *Conn) { if !cn.IsUsable() || cn.ShouldHandoff() { - internal.Logger.Printf(context.Background(), "pubsub: untracking connection %d [usable, handoff] = [%v, %v]", cn.GetID(), cn.IsUsable(), cn.ShouldHandoff()) + internal.Logger.Printf(context.Background(), "pubsub: untracking conn[%d] [usable, handoff] = [%v, %v]", cn.GetID(), cn.IsUsable(), cn.ShouldHandoff()) } atomic.AddUint32(&p.stats.Active, ^uint32(0)) atomic.AddUint32(&p.stats.Untracked, 1) diff --git a/options.go b/options.go index c89f4605..bb0a7dbf 100644 --- a/options.go +++ b/options.go @@ -109,7 +109,7 @@ type Options struct { // DialTimeout for establishing new connections. // - // default: 5 seconds + // default: 10 seconds DialTimeout time.Duration // ReadTimeout for socket reads. If reached, commands will fail @@ -275,7 +275,7 @@ func (opt *Options) init() { opt.Protocol = 3 } if opt.DialTimeout == 0 { - opt.DialTimeout = 5 * time.Second + opt.DialTimeout = 10 * time.Second } if opt.Dialer == nil { opt.Dialer = NewDialer(opt) @@ -335,7 +335,7 @@ func (opt *Options) init() { opt.MaxRetryBackoff = 512 * time.Millisecond } - opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize) + opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns) // auto-detect endpoint type if not specified endpointType := opt.HitlessUpgradeConfig.EndpointType