From b34f8270c66fc3f95ac507e935a4b997e6ecb675 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 2 Sep 2025 10:47:53 +0300 Subject: [PATCH] separate worker from pool hook --- hitless/handoff_worker.go | 394 ++++++++++++++++++++++++++++++++++++++ hitless/pool_hook.go | 382 +++--------------------------------- hitless/pool_hook_test.go | 42 ++-- 3 files changed, 442 insertions(+), 376 deletions(-) create mode 100644 hitless/handoff_worker.go diff --git a/hitless/handoff_worker.go b/hitless/handoff_worker.go new file mode 100644 index 00000000..744766f3 --- /dev/null +++ b/hitless/handoff_worker.go @@ -0,0 +1,394 @@ +package hitless + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// handoffWorkerManager manages background workers and queue for connection handoffs +type handoffWorkerManager struct { + // Event-driven handoff support + handoffQueue chan HandoffRequest // Queue for handoff requests + shutdown chan struct{} // Shutdown signal + shutdownOnce sync.Once // Ensure clean shutdown + workerWg sync.WaitGroup // Track worker goroutines + + // On-demand worker management + maxWorkers int + activeWorkers atomic.Int32 + workerTimeout time.Duration // How long workers wait for work before exiting + workersScaling atomic.Bool + + // Simple state tracking + pending sync.Map // map[uint64]int64 (connID -> seqID) + + // Configuration for the hitless upgrade + config *Config + + // Pool hook reference for handoff processing + poolHook *PoolHook +} + +// newHandoffWorkerManager creates a new handoff worker manager +func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager { + return &handoffWorkerManager{ + handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize), + shutdown: make(chan struct{}), + maxWorkers: config.MaxWorkers, + activeWorkers: atomic.Int32{}, // Start with no workers - create on demand + workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity + config: config, + poolHook: poolHook, + } +} + +// getCurrentWorkers returns the current number of active workers (for testing) +func (hwm *handoffWorkerManager) getCurrentWorkers() int { + return int(hwm.activeWorkers.Load()) +} + +// getPendingMap returns the pending map for testing purposes +func (hwm *handoffWorkerManager) getPendingMap() *sync.Map { + return &hwm.pending +} + +// getMaxWorkers returns the max workers for testing purposes +func (hwm *handoffWorkerManager) getMaxWorkers() int { + return hwm.maxWorkers +} + +// getHandoffQueue returns the handoff queue for testing purposes +func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest { + return hwm.handoffQueue +} + +// isHandoffPending returns true if the given connection has a pending handoff +func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool { + _, pending := hwm.pending.Load(conn.GetID()) + return pending +} + +// ensureWorkerAvailable ensures at least one worker is available to process requests +// Creates a new worker if needed and under the max limit +func (hwm *handoffWorkerManager) ensureWorkerAvailable() { + select { + case <-hwm.shutdown: + return + default: + if hwm.workersScaling.CompareAndSwap(false, true) { + defer hwm.workersScaling.Store(false) + // Check if we need a new worker + currentWorkers := hwm.activeWorkers.Load() + workersWas := currentWorkers + for currentWorkers <= int32(hwm.maxWorkers) { + hwm.workerWg.Add(1) + go hwm.onDemandWorker() + currentWorkers++ + } + // workersWas is always <= currentWorkers + // currentWorkers will be maxWorkers, but if we have a worker that was closed + // while we were creating new workers, just add the difference between + // the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created) + hwm.activeWorkers.Add(currentWorkers - workersWas) + } + } +} + +// onDemandWorker processes handoff requests and exits when idle +func (hwm *handoffWorkerManager) onDemandWorker() { + defer func() { + // Decrement active worker count when exiting + hwm.activeWorkers.Add(-1) + hwm.workerWg.Done() + }() + + for { + select { + case <-hwm.shutdown: + return + case <-time.After(hwm.workerTimeout): + // Worker has been idle for too long, exit to save resources + if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), + "hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout) + } + return + case request := <-hwm.handoffQueue: + // Check for shutdown before processing + select { + case <-hwm.shutdown: + // Clean up the request before exiting + hwm.pending.Delete(request.ConnID) + return + default: + // Process the request + hwm.processHandoffRequest(request) + } + } + } +} + +// processHandoffRequest processes a single handoff request +func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { + // Remove from pending map + defer hwm.pending.Delete(request.Conn.GetID()) + internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID()) + + // Create a context with handoff timeout from config + handoffTimeout := 15 * time.Second // Default timeout + if hwm.config != nil && hwm.config.HandoffTimeout > 0 { + handoffTimeout = hwm.config.HandoffTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout) + defer cancel() + + // Create a context that also respects the shutdown signal + shutdownCtx, shutdownCancel := context.WithCancel(ctx) + defer shutdownCancel() + + // Monitor shutdown signal in a separate goroutine + go func() { + select { + case <-hwm.shutdown: + shutdownCancel() + case <-shutdownCtx.Done(): + } + }() + + // Perform the handoff with cancellable context + shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn) + minRetryBackoff := 500 * time.Millisecond + if err != nil { + if shouldRetry { + now := time.Now() + deadline, ok := shutdownCtx.Deadline() + thirdOfTimeout := handoffTimeout / 3 + if !ok || deadline.Before(now) { + // wait half the timeout before retrying if no deadline or deadline has passed + deadline = now.Add(thirdOfTimeout) + } + afterTime := deadline.Sub(now) + if afterTime < minRetryBackoff { + afterTime = minRetryBackoff + } + + internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err) + time.AfterFunc(afterTime, func() { + if err := hwm.queueHandoff(request.Conn); err != nil { + internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err) + hwm.closeConnFromRequest(context.Background(), request, err) + } + }) + return + } else { + go hwm.closeConnFromRequest(ctx, request, err) + } + + // Clear handoff state if not returned for retry + seqID := request.Conn.GetMovingSeqID() + connID := request.Conn.GetID() + if hwm.poolHook.hitlessManager != nil { + hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID) + } + } +} + +// queueHandoff queues a handoff request for processing +// if err is returned, connection will be removed from pool +func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { + // Create handoff request + request := HandoffRequest{ + Conn: conn, + ConnID: conn.GetID(), + Endpoint: conn.GetHandoffEndpoint(), + SeqID: conn.GetMovingSeqID(), + Pool: hwm.poolHook.pool, // Include pool for connection removal on failure + } + + select { + // priority to shutdown + case <-hwm.shutdown: + return ErrShutdown + default: + select { + case <-hwm.shutdown: + return ErrShutdown + case hwm.handoffQueue <- request: + // Store in pending map + hwm.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + hwm.ensureWorkerAvailable() + return nil + default: + select { + case <-hwm.shutdown: + return ErrShutdown + case hwm.handoffQueue <- request: + // Store in pending map + hwm.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + hwm.ensureWorkerAvailable() + return nil + case <-time.After(100 * time.Millisecond): // give workers a chance to process + // Queue is full - log and attempt scaling + queueLen := len(hwm.handoffQueue) + queueCap := cap(hwm.handoffQueue) + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(context.Background(), + "hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", + queueLen, queueCap) + } + } + } + } + + // Ensure we have workers available to handle the load + hwm.ensureWorkerAvailable() + return ErrHandoffQueueFull +} + +// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete +func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error { + hwm.shutdownOnce.Do(func() { + close(hwm.shutdown) + // workers will exit when they finish their current request + }) + + // Wait for workers to complete + done := make(chan struct{}) + go func() { + hwm.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// performConnectionHandoff performs the actual connection handoff +// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached +func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) { + // Clear handoff state after successful handoff + connID := conn.GetID() + + newEndpoint := conn.GetHandoffEndpoint() + if newEndpoint == "" { + return false, ErrConnectionInvalidHandoffState + } + + retries := conn.IncrementAndGetHandoffRetries(1) + internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", conn.GetID(), retries, newEndpoint, conn.RemoteAddr().String()) + maxRetries := 3 // Default fallback + if hwm.config != nil { + maxRetries = hwm.config.MaxHandoffRetries + } + + if retries > maxRetries { + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, + "hitless: reached max retries (%d) for handoff of conn[%d] to %s", + maxRetries, conn.GetID(), conn.GetHandoffEndpoint()) + } + // won't retry on ErrMaxHandoffRetriesReached + return false, ErrMaxHandoffRetriesReached + } + + // Create endpoint-specific dialer + endpointDialer := hwm.createEndpointDialer(newEndpoint) + + // Create new connection to the new endpoint + newNetConn, err := endpointDialer(ctx) + if err != nil { + internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", conn.GetID(), newEndpoint, err) + // hitless: will retry + // Maybe a network error - retry after a delay + return true, err + } + + // Get the old connection + oldConn := conn.GetNetConn() + + // Apply relaxed timeout to the new connection for the configured post-handoff duration + // This gives the new connection more time to handle operations during cluster transition + // Setting this here (before initing the connection) ensures that the connection is going + // to use the relaxed timeout for the first operation (auth/ACL select) + if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 { + relaxedTimeout := hwm.config.RelaxedTimeout + // Set relaxed timeout with deadline - no background goroutine needed + deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration) + conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) + + if hwm.config.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v", + connID, relaxedTimeout, deadline.Format("15:04:05.000")) + } + } + + // Replace the connection and execute initialization + err = conn.SetNetConnAndInitConn(ctx, newNetConn) + if err != nil { + // hitless: won't retry + // Initialization failed - remove the connection + return false, err + } + defer func() { + if oldConn != nil { + oldConn.Close() + } + }() + + conn.ClearHandoffState() + internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint) + + return false, nil +} + +// createEndpointDialer creates a dialer function that connects to a specific endpoint +func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + // Parse endpoint to extract host and port + host, port, err := net.SplitHostPort(endpoint) + if err != nil { + // If no port specified, assume default Redis port + host = endpoint + if port == "" { + port = "6379" + } + } + + // Use the base dialer to connect to the new endpoint + return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port)) + } +} + +// closeConnFromRequest closes the connection and logs the reason +func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) { + pooler := request.Pool + conn := request.Conn + if pooler != nil { + pooler.Remove(ctx, conn, err) + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed conn[%d] from pool due to max handoff retries reached: %v", + conn.GetID(), err) + } + } else { + conn.Close() + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, + "hitless: no pool provided for conn[%d], cannot remove due to handoff initialization failure: %v", + conn.GetID(), err) + } + } +} diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index 4966d9da..4acf92e8 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -4,7 +4,6 @@ import ( "context" "net" "sync" - "sync/atomic" "time" "github.com/redis/go-redis/v9/internal" @@ -36,20 +35,8 @@ type PoolHook struct { // Network type (e.g., "tcp", "unix") network string - // Event-driven handoff support - handoffQueue chan HandoffRequest // Queue for handoff requests - shutdown chan struct{} // Shutdown signal - shutdownOnce sync.Once // Ensure clean shutdown - workerWg sync.WaitGroup // Track worker goroutines - - // On-demand worker management - maxWorkers int - activeWorkers atomic.Int32 - workerTimeout time.Duration // How long workers wait for work before exiting - workersScaling atomic.Bool - - // Simple state tracking - pending sync.Map // map[uint64]int64 (connID -> seqID) + // Worker manager for background handoff processing + workerManager *handoffWorkerManager // Configuration for the hitless upgrade config *Config @@ -77,18 +64,14 @@ func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (n // baseDialer is used to create connections to new endpoints during handoffs baseDialer: baseDialer, network: network, - // handoffQueue is a buffered channel for queuing handoff requests - handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize), - // shutdown is a channel for signaling shutdown - shutdown: make(chan struct{}), - maxWorkers: config.MaxWorkers, - activeWorkers: atomic.Int32{}, // Start with no workers - create on demand - // NOTE: maybe we would like to make this configurable? - workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity - config: config, + config: config, // Hitless manager for operation completion tracking hitlessManager: hitlessManager, } + + // Create worker manager + ph.workerManager = newHandoffWorkerManager(config, ph) + return ph } @@ -99,13 +82,27 @@ func (ph *PoolHook) SetPool(pooler pool.Pooler) { // GetCurrentWorkers returns the current number of active workers (for testing) func (ph *PoolHook) GetCurrentWorkers() int { - return int(ph.activeWorkers.Load()) + return ph.workerManager.getCurrentWorkers() } // IsHandoffPending returns true if the given connection has a pending handoff func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool { - _, pending := ph.pending.Load(conn.GetID()) - return pending + return ph.workerManager.isHandoffPending(conn) +} + +// GetPendingMap returns the pending map for testing purposes +func (ph *PoolHook) GetPendingMap() *sync.Map { + return ph.workerManager.getPendingMap() +} + +// GetMaxWorkers returns the max workers for testing purposes +func (ph *PoolHook) GetMaxWorkers() int { + return ph.workerManager.getMaxWorkers() +} + +// GetHandoffQueue returns the handoff queue for testing purposes +func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest { + return ph.workerManager.getHandoffQueue() } // OnGet is called when a connection is retrieved from the pool @@ -136,13 +133,12 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool } // check pending handoff to not queue the same connection twice - _, hasPendingHandoff := ph.pending.Load(conn.GetID()) - if hasPendingHandoff { + if ph.workerManager.isHandoffPending(conn) { // Default behavior (pending handoff): pool the connection return true, false, nil } - if err := ph.queueHandoff(conn); err != nil { + if err := ph.workerManager.queueHandoff(conn); err != nil { // Failed to queue handoff, remove the connection internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err) // Don't pool, remove connection, no error to caller @@ -167,331 +163,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool return true, false, nil } -// ensureWorkerAvailable ensures at least one worker is available to process requests -// Creates a new worker if needed and under the max limit -func (ph *PoolHook) ensureWorkerAvailable() { - select { - case <-ph.shutdown: - return - default: - if ph.workersScaling.CompareAndSwap(false, true) { - defer ph.workersScaling.Store(false) - // Check if we need a new worker - currentWorkers := ph.activeWorkers.Load() - workersWas := currentWorkers - for currentWorkers <= int32(ph.maxWorkers) { - ph.workerWg.Add(1) - go ph.onDemandWorker() - currentWorkers++ - } - // workersWas is always <= currentWorkers - // currentWorkers will be maxWorkers, but if we have a worker that was closed - // while we were creating new workers, just add the difference between - // the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created) - ph.activeWorkers.Add(currentWorkers - workersWas) - } - } -} - -// onDemandWorker processes handoff requests and exits when idle -func (ph *PoolHook) onDemandWorker() { - defer func() { - // Decrement active worker count when exiting - ph.activeWorkers.Add(-1) - ph.workerWg.Done() - }() - - for { - select { - case <-ph.shutdown: - return - case <-time.After(ph.workerTimeout): - // Worker has been idle for too long, exit to save resources - if ph.config != nil && ph.config.LogLevel.InfoOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), - "hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout) - } - return - case request := <-ph.handoffQueue: - // Check for shutdown before processing - select { - case <-ph.shutdown: - // Clean up the request before exiting - ph.pending.Delete(request.ConnID) - return - default: - // Process the request - ph.processHandoffRequest(request) - } - } - } -} - -// processHandoffRequest processes a single handoff request -func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { - // Remove from pending map - defer ph.pending.Delete(request.Conn.GetID()) - internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID()) - - // Create a context with handoff timeout from config - handoffTimeout := 30 * time.Second // Default fallback - if ph.config != nil && ph.config.HandoffTimeout > 0 { - handoffTimeout = ph.config.HandoffTimeout - } - ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout) - defer cancel() - - // Create a context that also respects the shutdown signal - shutdownCtx, shutdownCancel := context.WithCancel(ctx) - defer shutdownCancel() - - // Monitor shutdown signal in a separate goroutine - go func() { - select { - case <-ph.shutdown: - shutdownCancel() - case <-shutdownCtx.Done(): - } - }() - - // Perform the handoff with cancellable context - shouldRetry, err := ph.performConnectionHandoff(shutdownCtx, request.Conn) - minRetryBackoff := 500 * time.Millisecond - if err != nil { - if shouldRetry { - now := time.Now() - deadline, ok := shutdownCtx.Deadline() - thirdOfTimeout := handoffTimeout / 3 - if !ok || deadline.Before(now) { - // wait half the timeout before retrying if no deadline or deadline has passed - deadline = now.Add(thirdOfTimeout) - } - - afterTime := deadline.Sub(now) - if afterTime > thirdOfTimeout { - afterTime = thirdOfTimeout - } - if afterTime < minRetryBackoff { - afterTime = minRetryBackoff - } - - internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err) - time.AfterFunc(afterTime, func() { - if err := ph.queueHandoff(request.Conn); err != nil { - internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err) - ph.closeConnFromRequest(context.Background(), request, err) - } - }) - return - } else { - go ph.closeConnFromRequest(ctx, request, err) - } - - // Clear handoff state if not returned for retry - seqID := request.Conn.GetMovingSeqID() - connID := request.Conn.GetID() - if ph.hitlessManager != nil { - ph.hitlessManager.UntrackOperationWithConnID(seqID, connID) - } - } -} - -// closeConn closes the connection and logs the reason -func (ph *PoolHook) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) { - pooler := request.Pool - conn := request.Conn - if pooler != nil { - pooler.Remove(ctx, conn, err) - if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level - internal.Logger.Printf(ctx, - "hitless: removed conn[%d] from pool due to max handoff retries reached: %v", - conn.GetID(), err) - } - } else { - conn.Close() - if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level - internal.Logger.Printf(ctx, - "hitless: no pool provided for conn[%d], cannot remove due to handoff initialization failure: %v", - conn.GetID(), err) - } - } -} - -// queueHandoff queues a handoff request for processing -// if err is returned, connection will be removed from pool -func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { - // Create handoff request - request := HandoffRequest{ - Conn: conn, - ConnID: conn.GetID(), - Endpoint: conn.GetHandoffEndpoint(), - SeqID: conn.GetMovingSeqID(), - Pool: ph.pool, // Include pool for connection removal on failure - } - - select { - // priority to shutdown - case <-ph.shutdown: - return ErrShutdown - default: - select { - case <-ph.shutdown: - return ErrShutdown - case ph.handoffQueue <- request: - // Store in pending map - ph.pending.Store(request.ConnID, request.SeqID) - // Ensure we have a worker to process this request - ph.ensureWorkerAvailable() - return nil - default: - select { - case <-ph.shutdown: - return ErrShutdown - case ph.handoffQueue <- request: - // Store in pending map - ph.pending.Store(request.ConnID, request.SeqID) - // Ensure we have a worker to process this request - ph.ensureWorkerAvailable() - return nil - case <-time.After(100 * time.Millisecond): // give workers a chance to process - // Queue is full - log and attempt scaling - queueLen := len(ph.handoffQueue) - queueCap := cap(ph.handoffQueue) - if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level - internal.Logger.Printf(context.Background(), - "hitless: handoff queue is full (%d/%d), cant queue handoff request for conn[%d] seqID[%d]", - queueLen, queueCap, request.ConnID, request.SeqID) - if ph.config.LogLevel.DebugOrAbove() { // Debug level - ph.pending.Range(func(k, v interface{}) bool { - internal.Logger.Printf(context.Background(), "hitless: pending handoff for conn[%d] seqID[%d]", k, v) - return true - }) - } - } - } - } - } - - // Ensure we have workers available to handle the load - ph.ensureWorkerAvailable() - return ErrHandoffQueueFull -} - -// performConnectionHandoff performs the actual connection handoff -// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached -func (ph *PoolHook) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) { - // Clear handoff state after successful handoff - connID := conn.GetID() - - newEndpoint := conn.GetHandoffEndpoint() - if newEndpoint == "" { - return false, ErrConnectionInvalidHandoffState - } - - retries := conn.IncrementAndGetHandoffRetries(1) - internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", conn.GetID(), retries, newEndpoint, conn.RemoteAddr().String()) - maxRetries := 3 // Default fallback - if ph.config != nil { - maxRetries = ph.config.MaxHandoffRetries - } - - if retries > maxRetries { - if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level - internal.Logger.Printf(ctx, - "hitless: reached max retries (%d) for handoff of conn[%d] to %s", - maxRetries, conn.GetID(), conn.GetHandoffEndpoint()) - } - // won't retry on ErrMaxHandoffRetriesReached - return false, ErrMaxHandoffRetriesReached - } - - // Create endpoint-specific dialer - endpointDialer := ph.createEndpointDialer(newEndpoint) - - // Create new connection to the new endpoint - newNetConn, err := endpointDialer(ctx) - if err != nil { - internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", conn.GetID(), newEndpoint, err) - // hitless: will retry - // Maybe a network error - retry after a delay - return true, err - } - - // Get the old connection - oldConn := conn.GetNetConn() - - // Apply relaxed timeout to the new connection for the configured post-handoff duration - // This gives the new connection more time to handle operations during cluster transition - // Setting this here (before initing the connection) ensures that the connection is going - // to use the relaxed timeout for the first operation (auth/ACL select) - if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 { - relaxedTimeout := ph.config.RelaxedTimeout - // Set relaxed timeout with deadline - no background goroutine needed - deadline := time.Now().Add(ph.config.PostHandoffRelaxedDuration) - conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) - - if ph.config.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), - "hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v", - connID, relaxedTimeout, deadline.Format("15:04:05.000")) - } - } - - // Replace the connection and execute initialization - err = conn.SetNetConnAndInitConn(ctx, newNetConn) - if err != nil { - // hitless: won't retry - // Initialization failed - remove the connection - return false, err - } - defer func() { - if oldConn != nil { - oldConn.Close() - } - }() - - conn.ClearHandoffState() - internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint) - - return false, nil -} - -// createEndpointDialer creates a dialer function that connects to a specific endpoint -func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) { - return func(ctx context.Context) (net.Conn, error) { - // Parse endpoint to extract host and port - host, port, err := net.SplitHostPort(endpoint) - if err != nil { - // If no port specified, assume default Redis port - host = endpoint - if port == "" { - port = "6379" - } - } - - // Use the base dialer to connect to the new endpoint - return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port)) - } -} - // Shutdown gracefully shuts down the processor, waiting for workers to complete func (ph *PoolHook) Shutdown(ctx context.Context) error { - ph.shutdownOnce.Do(func() { - close(ph.shutdown) - // workers will exit when they finish their current request - }) - - // Wait for workers to complete - done := make(chan struct{}) - go func() { - ph.workerWg.Wait() - close(done) - }() - - select { - case <-done: - return nil - case <-ctx.Done(): - return ctx.Err() - } + return ph.workerManager.shutdownWorkers(ctx) } diff --git a/hitless/pool_hook_test.go b/hitless/pool_hook_test.go index 2a3a504a..6f84002e 100644 --- a/hitless/pool_hook_test.go +++ b/hitless/pool_hook_test.go @@ -169,7 +169,7 @@ func TestConnectionHook(t *testing.T) { } // Connection should be in pending map while initialization is blocked - if _, pending := processor.pending.Load(conn.GetID()); !pending { + if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { t.Error("Connection should be in pending handoffs map") } @@ -187,14 +187,14 @@ func TestConnectionHook(t *testing.T) { case <-timeout: t.Fatal("Timeout waiting for handoff to complete") case <-ticker.C: - if _, pending := processor.pending.Load(conn); !pending { + if _, pending := processor.GetPendingMap().Load(conn); !pending { handoffCompleted = true } } } // Verify handoff completed (removed from pending map) - if _, pending := processor.pending.Load(conn); pending { + if _, pending := processor.GetPendingMap().Load(conn); pending { t.Error("Connection should be removed from pending map after handoff") } @@ -306,14 +306,14 @@ func TestConnectionHook(t *testing.T) { case <-timeout: t.Fatal("Timeout waiting for failed handoff to complete") case <-ticker.C: - if _, pending := processor.pending.Load(conn.GetID()); !pending { + if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { handoffCompleted = true } } } // Connection should be removed from pending map after failed handoff - if _, pending := processor.pending.Load(conn.GetID()); pending { + if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending { t.Error("Connection should be removed from pending map after failed handoff") } @@ -380,8 +380,8 @@ func TestConnectionHook(t *testing.T) { // Simulate a pending handoff by marking for handoff and queuing conn.MarkForHandoff("new-endpoint:6379", 12345) - processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID - conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) ctx := context.Background() err := processor.OnGet(ctx, conn, false) @@ -390,7 +390,7 @@ func TestConnectionHook(t *testing.T) { } // Clean up - processor.pending.Delete(conn) + processor.GetPendingMap().Delete(conn) }) t.Run("EventDrivenStateManagement", func(t *testing.T) { @@ -400,16 +400,16 @@ func TestConnectionHook(t *testing.T) { conn := createMockPoolConnection() // Test initial state - no pending handoffs - if _, pending := processor.pending.Load(conn); pending { + if _, pending := processor.GetPendingMap().Load(conn); pending { t.Error("New connection should not have pending handoffs") } // Test adding to pending map conn.MarkForHandoff("new-endpoint:6379", 12345) - processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID - conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) - if _, pending := processor.pending.Load(conn.GetID()); !pending { + if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { t.Error("Connection should be in pending map") } @@ -421,8 +421,8 @@ func TestConnectionHook(t *testing.T) { } // Test removing from pending map and clearing handoff state - processor.pending.Delete(conn) - if _, pending := processor.pending.Load(conn); pending { + processor.GetPendingMap().Delete(conn) + if _, pending := processor.GetPendingMap().Load(conn); pending { t.Error("Connection should be removed from pending map") } @@ -510,14 +510,14 @@ func TestConnectionHook(t *testing.T) { if processor.GetCurrentWorkers() != 0 { t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers()) } - if processor.maxWorkers != 15 { - t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers) + if processor.GetMaxWorkers() != 15 { + t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers()) } // The on-demand worker behavior creates workers only when needed // This test just verifies the basic configuration is correct t.Logf("On-demand worker configuration verified - Max: %d, Current: %d", - processor.maxWorkers, processor.GetCurrentWorkers()) + processor.GetMaxWorkers(), processor.GetCurrentWorkers()) }) t.Run("PassiveTimeoutRestoration", func(t *testing.T) { @@ -567,7 +567,7 @@ func TestConnectionHook(t *testing.T) { case <-timeout: t.Fatal("Timeout waiting for handoff to complete") case <-ticker.C: - if _, pending := processor.pending.Load(conn); !pending { + if _, pending := processor.GetPendingMap().Load(conn); !pending { handoffCompleted = true } } @@ -701,7 +701,7 @@ func TestConnectionHook(t *testing.T) { defer processor.Shutdown(context.Background()) // Verify queue capacity matches configured size - queueCapacity := cap(processor.handoffQueue) + queueCapacity := cap(processor.GetHandoffQueue()) if queueCapacity != 50 { t.Errorf("Expected queue capacity 50, got %d", queueCapacity) } @@ -734,7 +734,7 @@ func TestConnectionHook(t *testing.T) { } // Verify queue capacity remains static (the main purpose of this test) - finalCapacity := cap(processor.handoffQueue) + finalCapacity := cap(processor.GetHandoffQueue()) if finalCapacity != 50 { t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity) @@ -851,7 +851,7 @@ func TestConnectionHook(t *testing.T) { case <-timeout: t.Fatal("Timeout waiting for handoff to complete") case <-ticker.C: - if _, pending := processor.pending.Load(conn); !pending { + if _, pending := processor.GetPendingMap().Load(conn); !pending { handoffCompleted = true } }