mirror of
				https://github.com/redis/go-redis.git
				synced 2025-10-30 16:45:34 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			490 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			490 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package maintnotifications
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/redis/go-redis/v9/internal"
 | |
| 	"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
 | |
| 	"github.com/redis/go-redis/v9/internal/pool"
 | |
| )
 | |
| 
 | |
| // handoffWorkerManager manages background workers and queue for connection handoffs
 | |
| type handoffWorkerManager struct {
 | |
| 	// Event-driven handoff support
 | |
| 	handoffQueue chan HandoffRequest // Queue for handoff requests
 | |
| 	shutdown     chan struct{}       // Shutdown signal
 | |
| 	shutdownOnce sync.Once           // Ensure clean shutdown
 | |
| 	workerWg     sync.WaitGroup      // Track worker goroutines
 | |
| 
 | |
| 	// On-demand worker management
 | |
| 	maxWorkers     int
 | |
| 	activeWorkers  atomic.Int32
 | |
| 	workerTimeout  time.Duration // How long workers wait for work before exiting
 | |
| 	workersScaling atomic.Bool
 | |
| 
 | |
| 	// Simple state tracking
 | |
| 	pending sync.Map // map[uint64]int64 (connID -> seqID)
 | |
| 
 | |
| 	// Configuration for the maintenance notifications
 | |
| 	config *Config
 | |
| 
 | |
| 	// Pool hook reference for handoff processing
 | |
| 	poolHook *PoolHook
 | |
| 
 | |
| 	// Circuit breaker manager for endpoint failure handling
 | |
| 	circuitBreakerManager *CircuitBreakerManager
 | |
| }
 | |
| 
 | |
| // newHandoffWorkerManager creates a new handoff worker manager
 | |
| func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
 | |
| 	return &handoffWorkerManager{
 | |
| 		handoffQueue:          make(chan HandoffRequest, config.HandoffQueueSize),
 | |
| 		shutdown:              make(chan struct{}),
 | |
| 		maxWorkers:            config.MaxWorkers,
 | |
| 		activeWorkers:         atomic.Int32{},   // Start with no workers - create on demand
 | |
| 		workerTimeout:         15 * time.Second, // Workers exit after 15s of inactivity
 | |
| 		config:                config,
 | |
| 		poolHook:              poolHook,
 | |
| 		circuitBreakerManager: newCircuitBreakerManager(config),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // getCurrentWorkers returns the current number of active workers (for testing)
 | |
| func (hwm *handoffWorkerManager) getCurrentWorkers() int {
 | |
| 	return int(hwm.activeWorkers.Load())
 | |
| }
 | |
| 
 | |
| // getPendingMap returns the pending map for testing purposes
 | |
| func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
 | |
| 	return &hwm.pending
 | |
| }
 | |
| 
 | |
| // getMaxWorkers returns the max workers for testing purposes
 | |
| func (hwm *handoffWorkerManager) getMaxWorkers() int {
 | |
| 	return hwm.maxWorkers
 | |
| }
 | |
| 
 | |
| // getHandoffQueue returns the handoff queue for testing purposes
 | |
| func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
 | |
| 	return hwm.handoffQueue
 | |
| }
 | |
| 
 | |
| // getCircuitBreakerStats returns circuit breaker statistics for monitoring
 | |
| func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
 | |
| 	return hwm.circuitBreakerManager.GetAllStats()
 | |
| }
 | |
| 
 | |
| // resetCircuitBreakers resets all circuit breakers (useful for testing)
 | |
| func (hwm *handoffWorkerManager) resetCircuitBreakers() {
 | |
| 	hwm.circuitBreakerManager.Reset()
 | |
| }
 | |
| 
 | |
| // isHandoffPending returns true if the given connection has a pending handoff
 | |
| func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
 | |
| 	_, pending := hwm.pending.Load(conn.GetID())
 | |
| 	return pending
 | |
| }
 | |
| 
 | |
| // ensureWorkerAvailable ensures at least one worker is available to process requests
 | |
| // Creates a new worker if needed and under the max limit
 | |
| func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
 | |
| 	select {
 | |
| 	case <-hwm.shutdown:
 | |
| 		return
 | |
| 	default:
 | |
| 		if hwm.workersScaling.CompareAndSwap(false, true) {
 | |
| 			defer hwm.workersScaling.Store(false)
 | |
| 			// Check if we need a new worker
 | |
| 			currentWorkers := hwm.activeWorkers.Load()
 | |
| 			workersWas := currentWorkers
 | |
| 			for currentWorkers < int32(hwm.maxWorkers) {
 | |
| 				hwm.workerWg.Add(1)
 | |
| 				go hwm.onDemandWorker()
 | |
| 				currentWorkers++
 | |
| 			}
 | |
| 			// workersWas is always <= currentWorkers
 | |
| 			// currentWorkers will be maxWorkers, but if we have a worker that was closed
 | |
| 			// while we were creating new workers, just add the difference between
 | |
| 			// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
 | |
| 			hwm.activeWorkers.Add(currentWorkers - workersWas)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // onDemandWorker processes handoff requests and exits when idle
 | |
| func (hwm *handoffWorkerManager) onDemandWorker() {
 | |
| 	defer func() {
 | |
| 		// Handle panics to ensure proper cleanup
 | |
| 		if r := recover(); r != nil {
 | |
| 			internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
 | |
| 		}
 | |
| 
 | |
| 		// Decrement active worker count when exiting
 | |
| 		hwm.activeWorkers.Add(-1)
 | |
| 		hwm.workerWg.Done()
 | |
| 	}()
 | |
| 
 | |
| 	// Create reusable timer to prevent timer leaks
 | |
| 	timer := time.NewTimer(hwm.workerTimeout)
 | |
| 	defer timer.Stop()
 | |
| 
 | |
| 	for {
 | |
| 		// Reset timer for next iteration
 | |
| 		if !timer.Stop() {
 | |
| 			select {
 | |
| 			case <-timer.C:
 | |
| 			default:
 | |
| 			}
 | |
| 		}
 | |
| 		timer.Reset(hwm.workerTimeout)
 | |
| 
 | |
| 		select {
 | |
| 		case <-hwm.shutdown:
 | |
| 			if internal.LogLevel.InfoOrAbove() {
 | |
| 				internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
 | |
| 			}
 | |
| 			return
 | |
| 		case <-timer.C:
 | |
| 			// Worker has been idle for too long, exit to save resources
 | |
| 			if internal.LogLevel.InfoOrAbove() {
 | |
| 				internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
 | |
| 			}
 | |
| 			return
 | |
| 		case request := <-hwm.handoffQueue:
 | |
| 			// Check for shutdown before processing
 | |
| 			select {
 | |
| 			case <-hwm.shutdown:
 | |
| 				if internal.LogLevel.InfoOrAbove() {
 | |
| 					internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
 | |
| 				}
 | |
| 				// Clean up the request before exiting
 | |
| 				hwm.pending.Delete(request.ConnID)
 | |
| 				return
 | |
| 			default:
 | |
| 				// Process the request
 | |
| 				hwm.processHandoffRequest(request)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // processHandoffRequest processes a single handoff request
 | |
| func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
 | |
| 	// Remove from pending map
 | |
| 	defer hwm.pending.Delete(request.Conn.GetID())
 | |
| 	if internal.LogLevel.InfoOrAbove() {
 | |
| 		internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
 | |
| 	}
 | |
| 
 | |
| 	// Create a context with handoff timeout from config
 | |
| 	handoffTimeout := 15 * time.Second // Default timeout
 | |
| 	if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
 | |
| 		handoffTimeout = hwm.config.HandoffTimeout
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	// Create a context that also respects the shutdown signal
 | |
| 	shutdownCtx, shutdownCancel := context.WithCancel(ctx)
 | |
| 	defer shutdownCancel()
 | |
| 
 | |
| 	// Monitor shutdown signal in a separate goroutine
 | |
| 	go func() {
 | |
| 		select {
 | |
| 		case <-hwm.shutdown:
 | |
| 			shutdownCancel()
 | |
| 		case <-shutdownCtx.Done():
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// Perform the handoff with cancellable context
 | |
| 	shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
 | |
| 	minRetryBackoff := 500 * time.Millisecond
 | |
| 	if err != nil {
 | |
| 		if shouldRetry {
 | |
| 			now := time.Now()
 | |
| 			deadline, ok := shutdownCtx.Deadline()
 | |
| 			thirdOfTimeout := handoffTimeout / 3
 | |
| 			if !ok || deadline.Before(now) {
 | |
| 				// wait half the timeout before retrying if no deadline or deadline has passed
 | |
| 				deadline = now.Add(thirdOfTimeout)
 | |
| 			}
 | |
| 			afterTime := deadline.Sub(now)
 | |
| 			if afterTime < minRetryBackoff {
 | |
| 				afterTime = minRetryBackoff
 | |
| 			}
 | |
| 
 | |
| 			if internal.LogLevel.InfoOrAbove() {
 | |
| 				// Get current retry count for better logging
 | |
| 				currentRetries := request.Conn.HandoffRetries()
 | |
| 				maxRetries := 3 // Default fallback
 | |
| 				if hwm.config != nil {
 | |
| 					maxRetries = hwm.config.MaxHandoffRetries
 | |
| 				}
 | |
| 				internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
 | |
| 			}
 | |
| 			time.AfterFunc(afterTime, func() {
 | |
| 				if err := hwm.queueHandoff(request.Conn); err != nil {
 | |
| 					if internal.LogLevel.WarnOrAbove() {
 | |
| 						internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
 | |
| 					}
 | |
| 					hwm.closeConnFromRequest(context.Background(), request, err)
 | |
| 				}
 | |
| 			})
 | |
| 			return
 | |
| 		} else {
 | |
| 			go hwm.closeConnFromRequest(ctx, request, err)
 | |
| 		}
 | |
| 
 | |
| 		// Clear handoff state if not returned for retry
 | |
| 		seqID := request.Conn.GetMovingSeqID()
 | |
| 		connID := request.Conn.GetID()
 | |
| 		if hwm.poolHook.operationsManager != nil {
 | |
| 			hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // queueHandoff queues a handoff request for processing
 | |
| // if err is returned, connection will be removed from pool
 | |
| func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
 | |
| 	// Get handoff info atomically to prevent race conditions
 | |
| 	shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
 | |
| 	// on retries the connection will not be marked for handoff, but it will have retries > 0
 | |
| 	// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
 | |
| 	if !shouldHandoff && conn.HandoffRetries() == 0 {
 | |
| 		if internal.LogLevel.InfoOrAbove() {
 | |
| 			internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
 | |
| 		}
 | |
| 		return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
 | |
| 	}
 | |
| 
 | |
| 	// Create handoff request with atomically retrieved data
 | |
| 	request := HandoffRequest{
 | |
| 		Conn:     conn,
 | |
| 		ConnID:   conn.GetID(),
 | |
| 		Endpoint: endpoint,
 | |
| 		SeqID:    seqID,
 | |
| 		Pool:     hwm.poolHook.pool, // Include pool for connection removal on failure
 | |
| 	}
 | |
| 
 | |
| 	select {
 | |
| 	// priority to shutdown
 | |
| 	case <-hwm.shutdown:
 | |
| 		return ErrShutdown
 | |
| 	default:
 | |
| 		select {
 | |
| 		case <-hwm.shutdown:
 | |
| 			return ErrShutdown
 | |
| 		case hwm.handoffQueue <- request:
 | |
| 			// Store in pending map
 | |
| 			hwm.pending.Store(request.ConnID, request.SeqID)
 | |
| 			// Ensure we have a worker to process this request
 | |
| 			hwm.ensureWorkerAvailable()
 | |
| 			return nil
 | |
| 		default:
 | |
| 			select {
 | |
| 			case <-hwm.shutdown:
 | |
| 				return ErrShutdown
 | |
| 			case hwm.handoffQueue <- request:
 | |
| 				// Store in pending map
 | |
| 				hwm.pending.Store(request.ConnID, request.SeqID)
 | |
| 				// Ensure we have a worker to process this request
 | |
| 				hwm.ensureWorkerAvailable()
 | |
| 				return nil
 | |
| 			case <-time.After(100 * time.Millisecond): // give workers a chance to process
 | |
| 				// Queue is full - log and attempt scaling
 | |
| 				queueLen := len(hwm.handoffQueue)
 | |
| 				queueCap := cap(hwm.handoffQueue)
 | |
| 				if internal.LogLevel.WarnOrAbove() {
 | |
| 					internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Ensure we have workers available to handle the load
 | |
| 	hwm.ensureWorkerAvailable()
 | |
| 	return ErrHandoffQueueFull
 | |
| }
 | |
| 
 | |
| // shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
 | |
| func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
 | |
| 	hwm.shutdownOnce.Do(func() {
 | |
| 		close(hwm.shutdown)
 | |
| 		// workers will exit when they finish their current request
 | |
| 
 | |
| 		// Shutdown circuit breaker manager cleanup goroutine
 | |
| 		if hwm.circuitBreakerManager != nil {
 | |
| 			hwm.circuitBreakerManager.Shutdown()
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	// Wait for workers to complete
 | |
| 	done := make(chan struct{})
 | |
| 	go func() {
 | |
| 		hwm.workerWg.Wait()
 | |
| 		close(done)
 | |
| 	}()
 | |
| 
 | |
| 	select {
 | |
| 	case <-done:
 | |
| 		return nil
 | |
| 	case <-ctx.Done():
 | |
| 		return ctx.Err()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // performConnectionHandoff performs the actual connection handoff
 | |
| // When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
 | |
| func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
 | |
| 	// Clear handoff state after successful handoff
 | |
| 	connID := conn.GetID()
 | |
| 
 | |
| 	newEndpoint := conn.GetHandoffEndpoint()
 | |
| 	if newEndpoint == "" {
 | |
| 		return false, ErrConnectionInvalidHandoffState
 | |
| 	}
 | |
| 
 | |
| 	// Use circuit breaker to protect against failing endpoints
 | |
| 	circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
 | |
| 
 | |
| 	// Check if circuit breaker is open before attempting handoff
 | |
| 	if circuitBreaker.IsOpen() {
 | |
| 		internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
 | |
| 		return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
 | |
| 	}
 | |
| 
 | |
| 	// Perform the handoff
 | |
| 	shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
 | |
| 
 | |
| 	// Update circuit breaker based on result
 | |
| 	if err != nil {
 | |
| 		// Only track dial/network errors in circuit breaker, not initialization errors
 | |
| 		if shouldRetry {
 | |
| 			circuitBreaker.recordFailure()
 | |
| 		}
 | |
| 		return shouldRetry, err
 | |
| 	}
 | |
| 
 | |
| 	// Success - record in circuit breaker
 | |
| 	circuitBreaker.recordSuccess()
 | |
| 	return false, nil
 | |
| }
 | |
| 
 | |
| // performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
 | |
| func (hwm *handoffWorkerManager) performHandoffInternal(
 | |
| 	ctx context.Context,
 | |
| 	conn *pool.Conn,
 | |
| 	newEndpoint string,
 | |
| 	connID uint64,
 | |
| ) (shouldRetry bool, err error) {
 | |
| 	retries := conn.IncrementAndGetHandoffRetries(1)
 | |
| 	internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
 | |
| 	maxRetries := 3 // Default fallback
 | |
| 	if hwm.config != nil {
 | |
| 		maxRetries = hwm.config.MaxHandoffRetries
 | |
| 	}
 | |
| 
 | |
| 	if retries > maxRetries {
 | |
| 		if internal.LogLevel.WarnOrAbove() {
 | |
| 			internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
 | |
| 		}
 | |
| 		// won't retry on ErrMaxHandoffRetriesReached
 | |
| 		return false, ErrMaxHandoffRetriesReached
 | |
| 	}
 | |
| 
 | |
| 	// Create endpoint-specific dialer
 | |
| 	endpointDialer := hwm.createEndpointDialer(newEndpoint)
 | |
| 
 | |
| 	// Create new connection to the new endpoint
 | |
| 	newNetConn, err := endpointDialer(ctx)
 | |
| 	if err != nil {
 | |
| 		internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
 | |
| 		// will retry
 | |
| 		// Maybe a network error - retry after a delay
 | |
| 		return true, err
 | |
| 	}
 | |
| 
 | |
| 	// Get the old connection
 | |
| 	oldConn := conn.GetNetConn()
 | |
| 
 | |
| 	// Apply relaxed timeout to the new connection for the configured post-handoff duration
 | |
| 	// This gives the new connection more time to handle operations during cluster transition
 | |
| 	// Setting this here (before initing the connection) ensures that the connection is going
 | |
| 	// to use the relaxed timeout for the first operation (auth/ACL select)
 | |
| 	if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
 | |
| 		relaxedTimeout := hwm.config.RelaxedTimeout
 | |
| 		// Set relaxed timeout with deadline - no background goroutine needed
 | |
| 		deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
 | |
| 		conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
 | |
| 
 | |
| 		if internal.LogLevel.InfoOrAbove() {
 | |
| 			internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Replace the connection and execute initialization
 | |
| 	err = conn.SetNetConnAndInitConn(ctx, newNetConn)
 | |
| 	if err != nil {
 | |
| 		// won't retry
 | |
| 		// Initialization failed - remove the connection
 | |
| 		return false, err
 | |
| 	}
 | |
| 	defer func() {
 | |
| 		if oldConn != nil {
 | |
| 			oldConn.Close()
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// Clear handoff state will:
 | |
| 	// - set the connection as usable again
 | |
| 	// - clear the handoff state (shouldHandoff, endpoint, seqID)
 | |
| 	// - reset the handoff retries to 0
 | |
| 	conn.ClearHandoffState()
 | |
| 	internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
 | |
| 
 | |
| 	// successfully completed the handoff, no retry needed and no error
 | |
| 	return false, nil
 | |
| }
 | |
| 
 | |
| // createEndpointDialer creates a dialer function that connects to a specific endpoint
 | |
| func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
 | |
| 	return func(ctx context.Context) (net.Conn, error) {
 | |
| 		// Parse endpoint to extract host and port
 | |
| 		host, port, err := net.SplitHostPort(endpoint)
 | |
| 		if err != nil {
 | |
| 			// If no port specified, assume default Redis port
 | |
| 			host = endpoint
 | |
| 			if port == "" {
 | |
| 				port = "6379"
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Use the base dialer to connect to the new endpoint
 | |
| 		return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // closeConnFromRequest closes the connection and logs the reason
 | |
| func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
 | |
| 	pooler := request.Pool
 | |
| 	conn := request.Conn
 | |
| 	if pooler != nil {
 | |
| 		pooler.Remove(ctx, conn, err)
 | |
| 		if internal.LogLevel.WarnOrAbove() {
 | |
| 			internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
 | |
| 		}
 | |
| 	} else {
 | |
| 		conn.Close()
 | |
| 		if internal.LogLevel.WarnOrAbove() {
 | |
| 			internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
 | |
| 		}
 | |
| 	}
 | |
| }
 |