1
0
mirror of https://github.com/redis/go-redis.git synced 2025-09-02 22:01:16 +03:00

separate worker from pool hook

This commit is contained in:
Nedyalko Dyakov
2025-09-02 10:47:53 +03:00
parent 73ff2734d7
commit b34f8270c6
3 changed files with 442 additions and 376 deletions

394
hitless/handoff_worker.go Normal file
View File

@@ -0,0 +1,394 @@
package hitless
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// handoffWorkerManager manages background workers and queue for connection handoffs
type handoffWorkerManager struct {
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers atomic.Int32
workerTimeout time.Duration // How long workers wait for work before exiting
workersScaling atomic.Bool
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the hitless upgrade
config *Config
// Pool hook reference for handoff processing
poolHook *PoolHook
}
// newHandoffWorkerManager creates a new handoff worker manager
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
return &handoffWorkerManager{
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
}
}
// getCurrentWorkers returns the current number of active workers (for testing)
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
return int(hwm.activeWorkers.Load())
}
// getPendingMap returns the pending map for testing purposes
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
return &hwm.pending
}
// getMaxWorkers returns the max workers for testing purposes
func (hwm *handoffWorkerManager) getMaxWorkers() int {
return hwm.maxWorkers
}
// getHandoffQueue returns the handoff queue for testing purposes
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
return hwm.handoffQueue
}
// isHandoffPending returns true if the given connection has a pending handoff
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
_, pending := hwm.pending.Load(conn.GetID())
return pending
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
select {
case <-hwm.shutdown:
return
default:
if hwm.workersScaling.CompareAndSwap(false, true) {
defer hwm.workersScaling.Store(false)
// Check if we need a new worker
currentWorkers := hwm.activeWorkers.Load()
workersWas := currentWorkers
for currentWorkers <= int32(hwm.maxWorkers) {
hwm.workerWg.Add(1)
go hwm.onDemandWorker()
currentWorkers++
}
// workersWas is always <= currentWorkers
// currentWorkers will be maxWorkers, but if we have a worker that was closed
// while we were creating new workers, just add the difference between
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
hwm.activeWorkers.Add(currentWorkers - workersWas)
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (hwm *handoffWorkerManager) onDemandWorker() {
defer func() {
// Decrement active worker count when exiting
hwm.activeWorkers.Add(-1)
hwm.workerWg.Done()
}()
for {
select {
case <-hwm.shutdown:
return
case <-time.After(hwm.workerTimeout):
// Worker has been idle for too long, exit to save resources
if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() { // Debug level
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout)
}
return
case request := <-hwm.handoffQueue:
// Check for shutdown before processing
select {
case <-hwm.shutdown:
// Clean up the request before exiting
hwm.pending.Delete(request.ConnID)
return
default:
// Process the request
hwm.processHandoffRequest(request)
}
}
}
}
// processHandoffRequest processes a single handoff request
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer hwm.pending.Delete(request.Conn.GetID())
internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID())
// Create a context with handoff timeout from config
handoffTimeout := 15 * time.Second // Default timeout
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
handoffTimeout = hwm.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-hwm.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
minRetryBackoff := 500 * time.Millisecond
if err != nil {
if shouldRetry {
now := time.Now()
deadline, ok := shutdownCtx.Deadline()
thirdOfTimeout := handoffTimeout / 3
if !ok || deadline.Before(now) {
// wait half the timeout before retrying if no deadline or deadline has passed
deadline = now.Add(thirdOfTimeout)
}
afterTime := deadline.Sub(now)
if afterTime < minRetryBackoff {
afterTime = minRetryBackoff
}
internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err)
time.AfterFunc(afterTime, func() {
if err := hwm.queueHandoff(request.Conn); err != nil {
internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err)
hwm.closeConnFromRequest(context.Background(), request, err)
}
})
return
} else {
go hwm.closeConnFromRequest(ctx, request, err)
}
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if hwm.poolHook.hitlessManager != nil {
hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID)
}
}
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Create handoff request
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: conn.GetHandoffEndpoint(),
SeqID: conn.GetMovingSeqID(),
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-hwm.shutdown:
return ErrShutdown
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
case <-time.After(100 * time.Millisecond): // give workers a chance to process
// Queue is full - log and attempt scaling
queueLen := len(hwm.handoffQueue)
queueCap := cap(hwm.handoffQueue)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration",
queueLen, queueCap)
}
}
}
}
// Ensure we have workers available to handle the load
hwm.ensureWorkerAvailable()
return ErrHandoffQueueFull
}
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
hwm.shutdownOnce.Do(func() {
close(hwm.shutdown)
// workers will exit when they finish their current request
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
hwm.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// performConnectionHandoff performs the actual connection handoff
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
// Clear handoff state after successful handoff
connID := conn.GetID()
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
return false, ErrConnectionInvalidHandoffState
}
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", conn.GetID(), retries, newEndpoint, conn.RemoteAddr().String())
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
if retries > maxRetries {
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
}
// Create endpoint-specific dialer
endpointDialer := hwm.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", conn.GetID(), newEndpoint, err)
// hitless: will retry
// Maybe a network error - retry after a delay
return true, err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
// Setting this here (before initing the connection) ensures that the connection is going
// to use the relaxed timeout for the first operation (auth/ACL select)
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := hwm.config.RelaxedTimeout
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v",
connID, relaxedTimeout, deadline.Format("15:04:05.000"))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// hitless: won't retry
// Initialization failed - remove the connection
return false, err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint)
return false, nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
}
}
// closeConnFromRequest closes the connection and logs the reason
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
pooler := request.Pool
conn := request.Conn
if pooler != nil {
pooler.Remove(ctx, conn, err)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed conn[%d] from pool due to max handoff retries reached: %v",
conn.GetID(), err)
}
} else {
conn.Close()
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for conn[%d], cannot remove due to handoff initialization failure: %v",
conn.GetID(), err)
}
}
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
@@ -36,20 +35,8 @@ type PoolHook struct {
// Network type (e.g., "tcp", "unix")
network string
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers atomic.Int32
workerTimeout time.Duration // How long workers wait for work before exiting
workersScaling atomic.Bool
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Worker manager for background handoff processing
workerManager *handoffWorkerManager
// Configuration for the hitless upgrade
config *Config
@@ -77,18 +64,14 @@ func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (n
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
// handoffQueue is a buffered channel for queuing handoff requests
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
// shutdown is a channel for signaling shutdown
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
// NOTE: maybe we would like to make this configurable?
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
config: config,
// Hitless manager for operation completion tracking
hitlessManager: hitlessManager,
}
// Create worker manager
ph.workerManager = newHandoffWorkerManager(config, ph)
return ph
}
@@ -99,13 +82,27 @@ func (ph *PoolHook) SetPool(pooler pool.Pooler) {
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return int(ph.activeWorkers.Load())
return ph.workerManager.getCurrentWorkers()
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
_, pending := ph.pending.Load(conn.GetID())
return pending
return ph.workerManager.isHandoffPending(conn)
}
// GetPendingMap returns the pending map for testing purposes
func (ph *PoolHook) GetPendingMap() *sync.Map {
return ph.workerManager.getPendingMap()
}
// GetMaxWorkers returns the max workers for testing purposes
func (ph *PoolHook) GetMaxWorkers() int {
return ph.workerManager.getMaxWorkers()
}
// GetHandoffQueue returns the handoff queue for testing purposes
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
return ph.workerManager.getHandoffQueue()
}
// OnGet is called when a connection is retrieved from the pool
@@ -136,13 +133,12 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool
}
// check pending handoff to not queue the same connection twice
_, hasPendingHandoff := ph.pending.Load(conn.GetID())
if hasPendingHandoff {
if ph.workerManager.isHandoffPending(conn) {
// Default behavior (pending handoff): pool the connection
return true, false, nil
}
if err := ph.queueHandoff(conn); err != nil {
if err := ph.workerManager.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
// Don't pool, remove connection, no error to caller
@@ -167,331 +163,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool
return true, false, nil
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (ph *PoolHook) ensureWorkerAvailable() {
select {
case <-ph.shutdown:
return
default:
if ph.workersScaling.CompareAndSwap(false, true) {
defer ph.workersScaling.Store(false)
// Check if we need a new worker
currentWorkers := ph.activeWorkers.Load()
workersWas := currentWorkers
for currentWorkers <= int32(ph.maxWorkers) {
ph.workerWg.Add(1)
go ph.onDemandWorker()
currentWorkers++
}
// workersWas is always <= currentWorkers
// currentWorkers will be maxWorkers, but if we have a worker that was closed
// while we were creating new workers, just add the difference between
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
ph.activeWorkers.Add(currentWorkers - workersWas)
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (ph *PoolHook) onDemandWorker() {
defer func() {
// Decrement active worker count when exiting
ph.activeWorkers.Add(-1)
ph.workerWg.Done()
}()
for {
select {
case <-ph.shutdown:
return
case <-time.After(ph.workerTimeout):
// Worker has been idle for too long, exit to save resources
if ph.config != nil && ph.config.LogLevel.InfoOrAbove() { // Debug level
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout)
}
return
case request := <-ph.handoffQueue:
// Check for shutdown before processing
select {
case <-ph.shutdown:
// Clean up the request before exiting
ph.pending.Delete(request.ConnID)
return
default:
// Process the request
ph.processHandoffRequest(request)
}
}
}
}
// processHandoffRequest processes a single handoff request
func (ph *PoolHook) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer ph.pending.Delete(request.Conn.GetID())
internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID())
// Create a context with handoff timeout from config
handoffTimeout := 30 * time.Second // Default fallback
if ph.config != nil && ph.config.HandoffTimeout > 0 {
handoffTimeout = ph.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-ph.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
shouldRetry, err := ph.performConnectionHandoff(shutdownCtx, request.Conn)
minRetryBackoff := 500 * time.Millisecond
if err != nil {
if shouldRetry {
now := time.Now()
deadline, ok := shutdownCtx.Deadline()
thirdOfTimeout := handoffTimeout / 3
if !ok || deadline.Before(now) {
// wait half the timeout before retrying if no deadline or deadline has passed
deadline = now.Add(thirdOfTimeout)
}
afterTime := deadline.Sub(now)
if afterTime > thirdOfTimeout {
afterTime = thirdOfTimeout
}
if afterTime < minRetryBackoff {
afterTime = minRetryBackoff
}
internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err)
time.AfterFunc(afterTime, func() {
if err := ph.queueHandoff(request.Conn); err != nil {
internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err)
ph.closeConnFromRequest(context.Background(), request, err)
}
})
return
} else {
go ph.closeConnFromRequest(ctx, request, err)
}
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if ph.hitlessManager != nil {
ph.hitlessManager.UntrackOperationWithConnID(seqID, connID)
}
}
}
// closeConn closes the connection and logs the reason
func (ph *PoolHook) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
pooler := request.Pool
conn := request.Conn
if pooler != nil {
pooler.Remove(ctx, conn, err)
if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed conn[%d] from pool due to max handoff retries reached: %v",
conn.GetID(), err)
}
} else {
conn.Close()
if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for conn[%d], cannot remove due to handoff initialization failure: %v",
conn.GetID(), err)
}
}
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (ph *PoolHook) queueHandoff(conn *pool.Conn) error {
// Create handoff request
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: conn.GetHandoffEndpoint(),
SeqID: conn.GetMovingSeqID(),
Pool: ph.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-ph.shutdown:
return ErrShutdown
default:
select {
case <-ph.shutdown:
return ErrShutdown
case ph.handoffQueue <- request:
// Store in pending map
ph.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
ph.ensureWorkerAvailable()
return nil
default:
select {
case <-ph.shutdown:
return ErrShutdown
case ph.handoffQueue <- request:
// Store in pending map
ph.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
ph.ensureWorkerAvailable()
return nil
case <-time.After(100 * time.Millisecond): // give workers a chance to process
// Queue is full - log and attempt scaling
queueLen := len(ph.handoffQueue)
queueCap := cap(ph.handoffQueue)
if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), cant queue handoff request for conn[%d] seqID[%d]",
queueLen, queueCap, request.ConnID, request.SeqID)
if ph.config.LogLevel.DebugOrAbove() { // Debug level
ph.pending.Range(func(k, v interface{}) bool {
internal.Logger.Printf(context.Background(), "hitless: pending handoff for conn[%d] seqID[%d]", k, v)
return true
})
}
}
}
}
}
// Ensure we have workers available to handle the load
ph.ensureWorkerAvailable()
return ErrHandoffQueueFull
}
// performConnectionHandoff performs the actual connection handoff
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
func (ph *PoolHook) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
// Clear handoff state after successful handoff
connID := conn.GetID()
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
return false, ErrConnectionInvalidHandoffState
}
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", conn.GetID(), retries, newEndpoint, conn.RemoteAddr().String())
maxRetries := 3 // Default fallback
if ph.config != nil {
maxRetries = ph.config.MaxHandoffRetries
}
if retries > maxRetries {
if ph.config != nil && ph.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
}
// Create endpoint-specific dialer
endpointDialer := ph.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", conn.GetID(), newEndpoint, err)
// hitless: will retry
// Maybe a network error - retry after a delay
return true, err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
// Setting this here (before initing the connection) ensures that the connection is going
// to use the relaxed timeout for the first operation (auth/ACL select)
if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := ph.config.RelaxedTimeout
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(ph.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if ph.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v",
connID, relaxedTimeout, deadline.Format("15:04:05.000"))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// hitless: won't retry
// Initialization failed - remove the connection
return false, err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint)
return false, nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port))
}
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
ph.shutdownOnce.Do(func() {
close(ph.shutdown)
// workers will exit when they finish their current request
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
ph.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
return ph.workerManager.shutdownWorkers(ctx)
}

View File

@@ -169,7 +169,7 @@ func TestConnectionHook(t *testing.T) {
}
// Connection should be in pending map while initialization is blocked
if _, pending := processor.pending.Load(conn.GetID()); !pending {
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending handoffs map")
}
@@ -187,14 +187,14 @@ func TestConnectionHook(t *testing.T) {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify handoff completed (removed from pending map)
if _, pending := processor.pending.Load(conn); pending {
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map after handoff")
}
@@ -306,14 +306,14 @@ func TestConnectionHook(t *testing.T) {
case <-timeout:
t.Fatal("Timeout waiting for failed handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn.GetID()); !pending {
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
handoffCompleted = true
}
}
}
// Connection should be removed from pending map after failed handoff
if _, pending := processor.pending.Load(conn.GetID()); pending {
if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending {
t.Error("Connection should be removed from pending map after failed handoff")
}
@@ -380,8 +380,8 @@ func TestConnectionHook(t *testing.T) {
// Simulate a pending handoff by marking for handoff and queuing
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
@@ -390,7 +390,7 @@ func TestConnectionHook(t *testing.T) {
}
// Clean up
processor.pending.Delete(conn)
processor.GetPendingMap().Delete(conn)
})
t.Run("EventDrivenStateManagement", func(t *testing.T) {
@@ -400,16 +400,16 @@ func TestConnectionHook(t *testing.T) {
conn := createMockPoolConnection()
// Test initial state - no pending handoffs
if _, pending := processor.pending.Load(conn); pending {
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("New connection should not have pending handoffs")
}
// Test adding to pending map
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
if _, pending := processor.pending.Load(conn.GetID()); !pending {
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending map")
}
@@ -421,8 +421,8 @@ func TestConnectionHook(t *testing.T) {
}
// Test removing from pending map and clearing handoff state
processor.pending.Delete(conn)
if _, pending := processor.pending.Load(conn); pending {
processor.GetPendingMap().Delete(conn)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map")
}
@@ -510,14 +510,14 @@ func TestConnectionHook(t *testing.T) {
if processor.GetCurrentWorkers() != 0 {
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
}
if processor.maxWorkers != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers)
if processor.GetMaxWorkers() != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers())
}
// The on-demand worker behavior creates workers only when needed
// This test just verifies the basic configuration is correct
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
processor.maxWorkers, processor.GetCurrentWorkers())
processor.GetMaxWorkers(), processor.GetCurrentWorkers())
})
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
@@ -567,7 +567,7 @@ func TestConnectionHook(t *testing.T) {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
@@ -701,7 +701,7 @@ func TestConnectionHook(t *testing.T) {
defer processor.Shutdown(context.Background())
// Verify queue capacity matches configured size
queueCapacity := cap(processor.handoffQueue)
queueCapacity := cap(processor.GetHandoffQueue())
if queueCapacity != 50 {
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
}
@@ -734,7 +734,7 @@ func TestConnectionHook(t *testing.T) {
}
// Verify queue capacity remains static (the main purpose of this test)
finalCapacity := cap(processor.handoffQueue)
finalCapacity := cap(processor.GetHandoffQueue())
if finalCapacity != 50 {
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
@@ -851,7 +851,7 @@ func TestConnectionHook(t *testing.T) {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}