1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-16 13:21:51 +03:00
This commit is contained in:
Nedyalko Dyakov
2025-07-07 18:18:37 +03:00
parent 225c0bf5b2
commit e697fcc76b
14 changed files with 2173 additions and 29 deletions

416
hitless.go Normal file
View File

@ -0,0 +1,416 @@
package redis
import (
"context"
"fmt"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hitless"
"github.com/redis/go-redis/v9/internal/pool"
)
// HitlessUpgradeConfig provides configuration for hitless upgrades
type HitlessUpgradeConfig struct {
// Enabled controls whether hitless upgrades are active
Enabled bool
// TransitionTimeout is the increased timeout for connections during transitions
// (MIGRATING/FAILING_OVER). This should be longer than normal operation timeouts
// to account for the time needed to complete the transition.
// Default: 60 seconds
TransitionTimeout time.Duration
// CleanupInterval controls how often expired states are cleaned up
// Default: 30 seconds
CleanupInterval time.Duration
}
// DefaultHitlessUpgradeConfig returns the default configuration for hitless upgrades
func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig {
return &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections
CleanupInterval: 30 * time.Second, // How often to clean up expired states
}
}
// HitlessUpgradeStatistics provides statistics about ongoing upgrade operations
type HitlessUpgradeStatistics struct {
ActiveConnections int // Total connections in transition
IsMoving bool // Whether pool is currently moving
MigratingConnections int // Connections in MIGRATING state
FailingOverConnections int // Connections in FAILING_OVER state
Timestamp time.Time // When these statistics were collected
}
// HitlessUpgradeStatus provides detailed status of all ongoing upgrades
type HitlessUpgradeStatus struct {
ConnectionStates map[interface{}]interface{}
IsMoving bool
NewEndpoint string
Timestamp time.Time
}
// HitlessIntegration provides the interface for hitless upgrade functionality
type HitlessIntegration interface {
// IsEnabled returns whether hitless upgrades are currently enabled
IsEnabled() bool
// EnableHitlessUpgrades enables hitless upgrade functionality
EnableHitlessUpgrades()
// DisableHitlessUpgrades disables hitless upgrade functionality
DisableHitlessUpgrades()
// GetConnectionTimeout returns the appropriate timeout for a connection
// If the connection is transitioning, returns the longer TransitionTimeout
GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration
// GetConnectionTimeouts returns both read and write timeouts for a connection
// If the connection is transitioning, returns increased timeouts
GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration)
// MarkConnectionAsBlocking marks a connection as having blocking commands
MarkConnectionAsBlocking(conn interface{}, isBlocking bool)
// IsConnectionMarkedForClosing checks if a connection should be closed
IsConnectionMarkedForClosing(conn interface{}) bool
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
ShouldRedirectBlockingConnection(conn interface{}) (bool, string)
// GetUpgradeStatistics returns current upgrade statistics
GetUpgradeStatistics() *HitlessUpgradeStatistics
// GetUpgradeStatus returns detailed upgrade status
GetUpgradeStatus() *HitlessUpgradeStatus
// UpdateConfig updates the hitless upgrade configuration
UpdateConfig(config *HitlessUpgradeConfig) error
// GetConfig returns the current configuration
GetConfig() *HitlessUpgradeConfig
// Close shuts down the hitless integration
Close() error
}
// hitlessIntegrationImpl implements the HitlessIntegration interface
type hitlessIntegrationImpl struct {
integration *hitless.RedisClientIntegration
mu sync.RWMutex
}
// newHitlessIntegration creates a new hitless integration instance
func newHitlessIntegration(config *HitlessUpgradeConfig) *hitlessIntegrationImpl {
if config == nil {
config = DefaultHitlessUpgradeConfig()
}
// Convert to internal config format
internalConfig := &hitless.HitlessUpgradeConfig{
Enabled: config.Enabled,
TransitionTimeout: config.TransitionTimeout,
CleanupInterval: config.CleanupInterval,
}
integration := hitless.NewRedisClientIntegration(internalConfig, 3*time.Second, 3*time.Second)
return &hitlessIntegrationImpl{
integration: integration,
}
}
// newHitlessIntegrationWithTimeouts creates a new hitless integration instance with timeout configuration
func newHitlessIntegrationWithTimeouts(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *hitlessIntegrationImpl {
// Start with defaults
defaults := DefaultHitlessUpgradeConfig()
// If config is nil, use all defaults
if config == nil {
config = defaults
}
// Ensure all fields are set with defaults if they are zero values
enabled := config.Enabled
transitionTimeout := config.TransitionTimeout
cleanupInterval := config.CleanupInterval
// Apply defaults for zero values
if transitionTimeout == 0 {
transitionTimeout = defaults.TransitionTimeout
}
if cleanupInterval == 0 {
cleanupInterval = defaults.CleanupInterval
}
// Convert to internal config format with all fields properly set
internalConfig := &hitless.HitlessUpgradeConfig{
Enabled: enabled,
TransitionTimeout: transitionTimeout,
CleanupInterval: cleanupInterval,
}
integration := hitless.NewRedisClientIntegration(internalConfig, defaultReadTimeout, defaultWriteTimeout)
return &hitlessIntegrationImpl{
integration: integration,
}
}
// IsEnabled returns whether hitless upgrades are currently enabled
func (h *hitlessIntegrationImpl) IsEnabled() bool {
h.mu.RLock()
defer h.mu.RUnlock()
return h.integration.IsEnabled()
}
// EnableHitlessUpgrades enables hitless upgrade functionality
func (h *hitlessIntegrationImpl) EnableHitlessUpgrades() {
h.mu.Lock()
defer h.mu.Unlock()
h.integration.EnableHitlessUpgrades()
}
// DisableHitlessUpgrades disables hitless upgrade functionality
func (h *hitlessIntegrationImpl) DisableHitlessUpgrades() {
h.mu.Lock()
defer h.mu.Unlock()
h.integration.DisableHitlessUpgrades()
}
// GetConnectionTimeout returns the appropriate timeout for a connection
func (h *hitlessIntegrationImpl) GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration {
h.mu.RLock()
defer h.mu.RUnlock()
// Convert interface{} to *pool.Conn
if poolConn, ok := conn.(*pool.Conn); ok {
return h.integration.GetConnectionTimeout(poolConn, defaultTimeout)
}
// If not a pool connection, return default timeout
return defaultTimeout
}
// GetConnectionTimeouts returns both read and write timeouts for a connection
func (h *hitlessIntegrationImpl) GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) {
h.mu.RLock()
defer h.mu.RUnlock()
// Convert interface{} to *pool.Conn
if poolConn, ok := conn.(*pool.Conn); ok {
return h.integration.GetConnectionTimeouts(poolConn, defaultReadTimeout, defaultWriteTimeout)
}
// If not a pool connection, return default timeouts
return defaultReadTimeout, defaultWriteTimeout
}
// MarkConnectionAsBlocking marks a connection as having blocking commands
func (h *hitlessIntegrationImpl) MarkConnectionAsBlocking(conn interface{}, isBlocking bool) {
h.mu.Lock()
defer h.mu.Unlock()
// Convert interface{} to *pool.Conn
if poolConn, ok := conn.(*pool.Conn); ok {
h.integration.MarkConnectionAsBlocking(poolConn, isBlocking)
}
}
// IsConnectionMarkedForClosing checks if a connection should be closed
func (h *hitlessIntegrationImpl) IsConnectionMarkedForClosing(conn interface{}) bool {
h.mu.RLock()
defer h.mu.RUnlock()
// Convert interface{} to *pool.Conn
if poolConn, ok := conn.(*pool.Conn); ok {
return h.integration.IsConnectionMarkedForClosing(poolConn)
}
return false
}
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
func (h *hitlessIntegrationImpl) ShouldRedirectBlockingConnection(conn interface{}) (bool, string) {
h.mu.RLock()
defer h.mu.RUnlock()
// Convert interface{} to *pool.Conn (can be nil for checking pool state)
var poolConn *pool.Conn
if conn != nil {
if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
}
}
return h.integration.ShouldRedirectBlockingConnection(poolConn)
}
// GetUpgradeStatistics returns current upgrade statistics
func (h *hitlessIntegrationImpl) GetUpgradeStatistics() *HitlessUpgradeStatistics {
h.mu.RLock()
defer h.mu.RUnlock()
stats := h.integration.GetUpgradeStatistics()
if stats == nil {
return &HitlessUpgradeStatistics{Timestamp: time.Now()}
}
return &HitlessUpgradeStatistics{
ActiveConnections: stats.ActiveConnections,
IsMoving: stats.IsMoving,
MigratingConnections: stats.MigratingConnections,
FailingOverConnections: stats.FailingOverConnections,
Timestamp: stats.Timestamp,
}
}
// GetUpgradeStatus returns detailed upgrade status
func (h *hitlessIntegrationImpl) GetUpgradeStatus() *HitlessUpgradeStatus {
h.mu.RLock()
defer h.mu.RUnlock()
status := h.integration.GetUpgradeStatus()
if status == nil {
return &HitlessUpgradeStatus{
ConnectionStates: make(map[interface{}]interface{}),
IsMoving: false,
NewEndpoint: "",
Timestamp: time.Now(),
}
}
return &HitlessUpgradeStatus{
ConnectionStates: convertToInterfaceMap(status.ConnectionStates),
IsMoving: status.IsMoving,
NewEndpoint: status.NewEndpoint,
Timestamp: status.Timestamp,
}
}
// UpdateConfig updates the hitless upgrade configuration
func (h *hitlessIntegrationImpl) UpdateConfig(config *HitlessUpgradeConfig) error {
if config == nil {
return fmt.Errorf("config cannot be nil")
}
h.mu.Lock()
defer h.mu.Unlock()
// Start with defaults for any zero values
defaults := DefaultHitlessUpgradeConfig()
// Ensure all fields are set with defaults if they are zero values
enabled := config.Enabled
transitionTimeout := config.TransitionTimeout
cleanupInterval := config.CleanupInterval
// Apply defaults for zero values
if transitionTimeout == 0 {
transitionTimeout = defaults.TransitionTimeout
}
if cleanupInterval == 0 {
cleanupInterval = defaults.CleanupInterval
}
// Convert to internal config format with all fields properly set
internalConfig := &hitless.HitlessUpgradeConfig{
Enabled: enabled,
TransitionTimeout: transitionTimeout,
CleanupInterval: cleanupInterval,
}
return h.integration.UpdateConfig(internalConfig)
}
// GetConfig returns the current configuration
func (h *hitlessIntegrationImpl) GetConfig() *HitlessUpgradeConfig {
h.mu.RLock()
defer h.mu.RUnlock()
internalConfig := h.integration.GetConfig()
if internalConfig == nil {
return DefaultHitlessUpgradeConfig()
}
return &HitlessUpgradeConfig{
Enabled: internalConfig.Enabled,
TransitionTimeout: internalConfig.TransitionTimeout,
CleanupInterval: internalConfig.CleanupInterval,
}
}
// Close shuts down the hitless integration
func (h *hitlessIntegrationImpl) Close() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.integration.Close()
}
// getInternalIntegration returns the internal integration for use by Redis clients
func (h *hitlessIntegrationImpl) getInternalIntegration() *hitless.RedisClientIntegration {
h.mu.RLock()
defer h.mu.RUnlock()
return h.integration
}
// ClientTimeoutProvider interface for extracting timeout configuration from client options
type ClientTimeoutProvider interface {
GetReadTimeout() time.Duration
GetWriteTimeout() time.Duration
}
// optionsTimeoutProvider implements ClientTimeoutProvider for Options struct
type optionsTimeoutProvider struct {
readTimeout time.Duration
writeTimeout time.Duration
}
func (p *optionsTimeoutProvider) GetReadTimeout() time.Duration {
return p.readTimeout
}
func (p *optionsTimeoutProvider) GetWriteTimeout() time.Duration {
return p.writeTimeout
}
// newOptionsTimeoutProvider creates a timeout provider from Options
func newOptionsTimeoutProvider(readTimeout, writeTimeout time.Duration) ClientTimeoutProvider {
return &optionsTimeoutProvider{
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
}
// initializeHitlessIntegration initializes hitless integration for a client
func initializeHitlessIntegration(client interface{}, config *HitlessUpgradeConfig, timeoutProvider ClientTimeoutProvider) (*hitlessIntegrationImpl, error) {
if config == nil || !config.Enabled {
return nil, nil
}
// Extract timeout configuration from client options
defaultReadTimeout := timeoutProvider.GetReadTimeout()
defaultWriteTimeout := timeoutProvider.GetWriteTimeout()
// Create hitless integration - each client gets its own instance
integration := newHitlessIntegrationWithTimeouts(config, defaultReadTimeout, defaultWriteTimeout)
// Push notification handlers are registered directly by the client
// No separate registration needed in simplified implementation
internal.Logger.Printf(context.Background(), "hitless: initialized hitless upgrades for client")
return integration, nil
}
// convertToInterfaceMap converts a typed map to interface{} map for public API
func convertToInterfaceMap(input map[*pool.Conn]*hitless.ConnectionState) map[interface{}]interface{} {
result := make(map[interface{}]interface{})
for k, v := range input {
result[k] = v
}
return result
}

View File

@ -0,0 +1,197 @@
package redis
import (
"testing"
"time"
)
func TestHitlessUpgradeConfig_DefaultValues(t *testing.T) {
tests := []struct {
name string
inputConfig *HitlessUpgradeConfig
expectedConfig *HitlessUpgradeConfig
}{
{
name: "nil config should use all defaults",
inputConfig: nil,
expectedConfig: &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 60 * time.Second,
CleanupInterval: 30 * time.Second,
},
},
{
name: "zero TransitionTimeout should use default",
inputConfig: &HitlessUpgradeConfig{
Enabled: false,
TransitionTimeout: 0, // Zero value
CleanupInterval: 45 * time.Second,
},
expectedConfig: &HitlessUpgradeConfig{
Enabled: false,
TransitionTimeout: 60 * time.Second, // Should use default
CleanupInterval: 45 * time.Second,
},
},
{
name: "zero CleanupInterval should use default",
inputConfig: &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 90 * time.Second,
CleanupInterval: 0, // Zero value
},
expectedConfig: &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 90 * time.Second,
CleanupInterval: 30 * time.Second, // Should use default
},
},
{
name: "both timeouts zero should use defaults",
inputConfig: &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 0, // Zero value
CleanupInterval: 0, // Zero value
},
expectedConfig: &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 60 * time.Second, // Should use default
CleanupInterval: 30 * time.Second, // Should use default
},
},
{
name: "all values set should be preserved",
inputConfig: &HitlessUpgradeConfig{
Enabled: false,
TransitionTimeout: 120 * time.Second,
CleanupInterval: 60 * time.Second,
},
expectedConfig: &HitlessUpgradeConfig{
Enabled: false,
TransitionTimeout: 120 * time.Second,
CleanupInterval: 60 * time.Second,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test with a mock client that has hitless upgrades enabled
opt := &Options{
Addr: "127.0.0.1:6379",
Protocol: 3,
HitlessUpgrades: true,
HitlessUpgradeConfig: tt.inputConfig,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
// Test the integration creation using the internal method directly
// since initializeHitlessIntegration requires a push processor
integration := newHitlessIntegrationWithTimeouts(tt.inputConfig, opt.ReadTimeout, opt.WriteTimeout)
if integration == nil {
t.Fatal("Integration should not be nil")
}
// Get the config from the integration
actualConfig := integration.GetConfig()
if actualConfig == nil {
t.Fatal("Config should not be nil")
}
// Verify all fields match expected values
if actualConfig.Enabled != tt.expectedConfig.Enabled {
t.Errorf("Enabled: expected %v, got %v", tt.expectedConfig.Enabled, actualConfig.Enabled)
}
if actualConfig.TransitionTimeout != tt.expectedConfig.TransitionTimeout {
t.Errorf("TransitionTimeout: expected %v, got %v", tt.expectedConfig.TransitionTimeout, actualConfig.TransitionTimeout)
}
if actualConfig.CleanupInterval != tt.expectedConfig.CleanupInterval {
t.Errorf("CleanupInterval: expected %v, got %v", tt.expectedConfig.CleanupInterval, actualConfig.CleanupInterval)
}
// Test UpdateConfig as well
newConfig := &HitlessUpgradeConfig{
Enabled: !tt.expectedConfig.Enabled,
TransitionTimeout: 0, // Zero value should use default
CleanupInterval: 0, // Zero value should use default
}
err := integration.UpdateConfig(newConfig)
if err != nil {
t.Fatalf("Failed to update config: %v", err)
}
// Verify updated config has defaults applied
updatedConfig := integration.GetConfig()
if updatedConfig.Enabled == tt.expectedConfig.Enabled {
t.Error("Enabled should have been toggled")
}
if updatedConfig.TransitionTimeout != 60*time.Second {
t.Errorf("TransitionTimeout should use default (60s), got %v", updatedConfig.TransitionTimeout)
}
if updatedConfig.CleanupInterval != 30*time.Second {
t.Errorf("CleanupInterval should use default (30s), got %v", updatedConfig.CleanupInterval)
}
})
}
}
func TestDefaultHitlessUpgradeConfig(t *testing.T) {
config := DefaultHitlessUpgradeConfig()
if config == nil {
t.Fatal("Default config should not be nil")
}
if !config.Enabled {
t.Error("Default config should have Enabled=true")
}
if config.TransitionTimeout != 60*time.Second {
t.Errorf("Default TransitionTimeout should be 60s, got %v", config.TransitionTimeout)
}
if config.CleanupInterval != 30*time.Second {
t.Errorf("Default CleanupInterval should be 30s, got %v", config.CleanupInterval)
}
}
func TestHitlessUpgradeConfig_ZeroValueHandling(t *testing.T) {
// Test that zero values are properly handled in various scenarios
// Test 1: Partial config with some zero values
partialConfig := &HitlessUpgradeConfig{
Enabled: true,
// TransitionTimeout and CleanupInterval are zero values
}
integration := newHitlessIntegrationWithTimeouts(partialConfig, 3*time.Second, 3*time.Second)
if integration == nil {
t.Fatal("Integration should not be nil")
}
config := integration.GetConfig()
if config.TransitionTimeout == 0 {
t.Error("Zero TransitionTimeout should have been replaced with default")
}
if config.CleanupInterval == 0 {
t.Error("Zero CleanupInterval should have been replaced with default")
}
// Test 2: Empty struct
emptyConfig := &HitlessUpgradeConfig{}
integration2 := newHitlessIntegrationWithTimeouts(emptyConfig, 3*time.Second, 3*time.Second)
if integration2 == nil {
t.Fatal("Integration should not be nil")
}
config2 := integration2.GetConfig()
if config2.TransitionTimeout == 0 {
t.Error("Zero TransitionTimeout in empty config should have been replaced with default")
}
if config2.CleanupInterval == 0 {
t.Error("Zero CleanupInterval in empty config should have been replaced with default")
}
}

View File

@ -0,0 +1,23 @@
# Hitless Upgrade Package
This package implements hitless upgrade functionality for Redis cluster clients using the push notification architecture. It provides handlers for managing connection and pool state during Redis cluster upgrades.
## Quick Start
To enable hitless upgrades in your Redis client, simply set the configuration option:
```go
import "github.com/redis/go-redis/v9"
// Enable hitless upgrades with a simple configuration option
client := redis.NewClusterClient(&redis.ClusterOptions{
Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"},
Protocol: 3, // RESP3 required for push notifications
HitlessUpgrades: true, // Enable hitless upgrades
})
defer client.Close()
// That's it! Use your client normally - hitless upgrades work automatically
ctx := context.Background()
client.Set(ctx, "key", "value", 0)
```

View File

@ -0,0 +1,200 @@
package hitless
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// ClientIntegrator provides integration between hitless upgrade handlers and Redis clients
type ClientIntegrator struct {
upgradeHandler *UpgradeHandler
mu sync.RWMutex
// Simple atomic state for pool redirection
isMoving int32 // atomic: 0 = not moving, 1 = moving
newEndpoint string // only written during MOVING, read-only after
}
// NewClientIntegrator creates a new client integrator with client timeout configuration
func NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout time.Duration) *ClientIntegrator {
return &ClientIntegrator{
upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout),
}
}
// GetUpgradeHandler returns the upgrade handler for direct access
func (ci *ClientIntegrator) GetUpgradeHandler() *UpgradeHandler {
return ci.upgradeHandler
}
// HandlePushNotification is the main entry point for processing upgrade notifications
func (ci *ClientIntegrator) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// Handle MOVING notifications for pool redirection
if len(notification) > 0 {
if notificationType, ok := notification[0].(string); ok && notificationType == "MOVING" {
if len(notification) >= 3 {
if newEndpoint, ok := notification[2].(string); ok {
// Simple atomic state update - no locks needed
ci.newEndpoint = newEndpoint
atomic.StoreInt32(&ci.isMoving, 1)
}
}
}
}
return ci.upgradeHandler.HandlePushNotification(ctx, handlerCtx, notification)
}
// Close shuts down the client integrator
func (ci *ClientIntegrator) Close() error {
ci.mu.Lock()
defer ci.mu.Unlock()
// Reset atomic state
atomic.StoreInt32(&ci.isMoving, 0)
ci.newEndpoint = ""
return nil
}
// IsMoving returns true if the pool is currently moving to a new endpoint
// Uses atomic read - no locks needed
func (ci *ClientIntegrator) IsMoving() bool {
return atomic.LoadInt32(&ci.isMoving) == 1
}
// GetNewEndpoint returns the new endpoint if moving, empty string otherwise
// Safe to read without locks since it's only written during MOVING
func (ci *ClientIntegrator) GetNewEndpoint() string {
if ci.IsMoving() {
return ci.newEndpoint
}
return ""
}
// PushNotificationHandlerInterface defines the interface for push notification handlers
// This implements the interface expected by the push notification system
type PushNotificationHandlerInterface interface {
HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error
}
// Ensure ClientIntegrator implements the interface
var _ PushNotificationHandlerInterface = (*ClientIntegrator)(nil)
// PoolRedirector provides pool redirection functionality for hitless upgrades
type PoolRedirector struct {
poolManager *PoolEndpointManager
mu sync.RWMutex
}
// NewPoolRedirector creates a new pool redirector
func NewPoolRedirector() *PoolRedirector {
return &PoolRedirector{
poolManager: NewPoolEndpointManager(),
}
}
// RedirectPool redirects a connection pool to a new endpoint
func (pr *PoolRedirector) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error {
pr.mu.Lock()
defer pr.mu.Unlock()
return pr.poolManager.RedirectPool(ctx, pooler, newEndpoint, timeout)
}
// IsPoolRedirected checks if a pool is currently redirected
func (pr *PoolRedirector) IsPoolRedirected(pooler pool.Pooler) bool {
pr.mu.RLock()
defer pr.mu.RUnlock()
return pr.poolManager.IsPoolRedirected(pooler)
}
// GetRedirection returns redirection information for a pool
func (pr *PoolRedirector) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) {
pr.mu.RLock()
defer pr.mu.RUnlock()
return pr.poolManager.GetRedirection(pooler)
}
// Close shuts down the pool redirector
func (pr *PoolRedirector) Close() error {
pr.mu.Lock()
defer pr.mu.Unlock()
// Clean up all redirections
ctx := context.Background()
pr.poolManager.CleanupExpiredRedirections(ctx)
return nil
}
// ConnectionStateTracker tracks connection states during upgrades
type ConnectionStateTracker struct {
upgradeHandler *UpgradeHandler
mu sync.RWMutex
}
// NewConnectionStateTracker creates a new connection state tracker with timeout configuration
func NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout time.Duration) *ConnectionStateTracker {
return &ConnectionStateTracker{
upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout),
}
}
// IsConnectionTransitioning checks if a connection is currently transitioning
func (cst *ConnectionStateTracker) IsConnectionTransitioning(conn *pool.Conn) bool {
cst.mu.RLock()
defer cst.mu.RUnlock()
return cst.upgradeHandler.IsConnectionTransitioning(conn)
}
// GetConnectionState returns the current state of a connection
func (cst *ConnectionStateTracker) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
cst.mu.RLock()
defer cst.mu.RUnlock()
return cst.upgradeHandler.GetConnectionState(conn)
}
// CleanupConnection removes tracking for a connection
func (cst *ConnectionStateTracker) CleanupConnection(conn *pool.Conn) {
cst.mu.Lock()
defer cst.mu.Unlock()
cst.upgradeHandler.CleanupConnection(conn)
}
// Close shuts down the connection state tracker
func (cst *ConnectionStateTracker) Close() error {
cst.mu.Lock()
defer cst.mu.Unlock()
// Clean up all expired states
cst.upgradeHandler.CleanupExpiredStates()
return nil
}
// HitlessUpgradeConfig provides configuration for hitless upgrades
type HitlessUpgradeConfig struct {
Enabled bool
TransitionTimeout time.Duration
CleanupInterval time.Duration
}
// DefaultHitlessUpgradeConfig returns default configuration for hitless upgrades
func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig {
return &HitlessUpgradeConfig{
Enabled: true,
TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections
CleanupInterval: 30 * time.Second, // How often to clean up expired states
}
}

View File

@ -0,0 +1,205 @@
package hitless
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// PoolEndpointManager manages endpoint transitions for connection pools during hitless upgrades.
// It provides functionality to redirect new connections to new endpoints while maintaining
// existing connections until they can be gracefully transitioned.
type PoolEndpointManager struct {
mu sync.RWMutex
// Map of pools to their endpoint redirections
redirections map[interface{}]*EndpointRedirection
// Original dialers for pools (to restore after transition)
originalDialers map[interface{}]func(context.Context) (net.Conn, error)
}
// EndpointRedirection represents an active endpoint redirection
type EndpointRedirection struct {
OriginalEndpoint string
NewEndpoint string
StartTime time.Time
Timeout time.Duration
// Statistics
NewConnections int64
FailedConnections int64
}
// NewPoolEndpointManager creates a new pool endpoint manager
func NewPoolEndpointManager() *PoolEndpointManager {
return &PoolEndpointManager{
redirections: make(map[interface{}]*EndpointRedirection),
originalDialers: make(map[interface{}]func(context.Context) (net.Conn, error)),
}
}
// RedirectPool redirects new connections from a pool to a new endpoint
func (m *PoolEndpointManager) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
// Check if pool is already being redirected
if _, exists := m.redirections[pooler]; exists {
return fmt.Errorf("pool is already being redirected")
}
// Get the current dialer from the pool
connPool, ok := pooler.(*pool.ConnPool)
if !ok {
return fmt.Errorf("unsupported pool type: %T", pooler)
}
// Store original dialer
originalDialer := m.getPoolDialer(connPool)
if originalDialer == nil {
return fmt.Errorf("could not get original dialer from pool")
}
m.originalDialers[pooler] = originalDialer
// Create new dialer that connects to the new endpoint
newDialer := m.createRedirectDialer(ctx, newEndpoint, originalDialer)
// Replace the pool's dialer
if err := m.setPoolDialer(connPool, newDialer); err != nil {
delete(m.originalDialers, pooler)
return fmt.Errorf("failed to set new dialer: %w", err)
}
// Record the redirection
m.redirections[pooler] = &EndpointRedirection{
OriginalEndpoint: m.extractEndpointFromDialer(originalDialer),
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Timeout: timeout,
}
internal.Logger.Printf(ctx, "hitless: redirected pool to new endpoint %s", newEndpoint)
return nil
}
// IsPoolRedirected checks if a pool is currently being redirected
func (m *PoolEndpointManager) IsPoolRedirected(pooler pool.Pooler) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.redirections[pooler]
return exists
}
// GetRedirection returns redirection information for a pool
func (m *PoolEndpointManager) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
redirection, exists := m.redirections[pooler]
if !exists {
return nil, false
}
// Return a copy to avoid race conditions
redirectionCopy := *redirection
return &redirectionCopy, true
}
// CleanupExpiredRedirections removes expired redirections
func (m *PoolEndpointManager) CleanupExpiredRedirections(ctx context.Context) {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
for pooler, redirection := range m.redirections {
if now.Sub(redirection.StartTime) > redirection.Timeout {
// TODO: Here we should decide if we need to failback to the original dialer,
// i.e. if the new endpoint did not produce any active connections.
delete(m.redirections, pooler)
delete(m.originalDialers, pooler)
internal.Logger.Printf(ctx, "hitless: cleaned up expired redirection for pool")
}
}
}
// createRedirectDialer creates a dialer that connects to the new endpoint
func (m *PoolEndpointManager) createRedirectDialer(ctx context.Context, newEndpoint string, originalDialer func(context.Context) (net.Conn, error)) func(context.Context) (net.Conn, error) {
return func(dialCtx context.Context) (net.Conn, error) {
// Try to connect to the new endpoint
conn, err := net.DialTimeout("tcp", newEndpoint, 10*time.Second)
if err != nil {
internal.Logger.Printf(ctx, "hitless: failed to connect to new endpoint %s: %v", newEndpoint, err)
// Fallback to original dialer
return originalDialer(dialCtx)
}
internal.Logger.Printf(ctx, "hitless: successfully connected to new endpoint %s", newEndpoint)
return conn, nil
}
}
// getPoolDialer extracts the dialer from a connection pool
func (m *PoolEndpointManager) getPoolDialer(connPool *pool.ConnPool) func(context.Context) (net.Conn, error) {
return connPool.GetDialer()
}
// setPoolDialer sets a new dialer for a connection pool
func (m *PoolEndpointManager) setPoolDialer(connPool *pool.ConnPool, dialer func(context.Context) (net.Conn, error)) error {
return connPool.SetDialer(dialer)
}
// extractEndpointFromDialer extracts the endpoint address from a dialer
func (m *PoolEndpointManager) extractEndpointFromDialer(dialer func(context.Context) (net.Conn, error)) string {
// Try to extract endpoint by making a test connection
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn, err := dialer(ctx)
if err != nil {
return "unknown"
}
defer conn.Close()
if conn.RemoteAddr() != nil {
return conn.RemoteAddr().String()
}
return "unknown"
}
// GetActiveRedirections returns all active redirections
func (m *PoolEndpointManager) GetActiveRedirections() map[interface{}]*EndpointRedirection {
m.mu.RLock()
defer m.mu.RUnlock()
// Create copies to avoid race conditions
redirections := make(map[interface{}]*EndpointRedirection)
for pooler, redirection := range m.redirections {
redirectionCopy := *redirection
redirections[pooler] = &redirectionCopy
}
return redirections
}
// UpdateRedirectionStats updates statistics for a redirection
func (m *PoolEndpointManager) UpdateRedirectionStats(pooler pool.Pooler, newConnections, failedConnections int64) {
m.mu.Lock()
defer m.mu.Unlock()
if redirection, exists := m.redirections[pooler]; exists {
redirection.NewConnections += newConnections
redirection.FailedConnections += failedConnections
}
}

View File

@ -0,0 +1,309 @@
package hitless
import (
"context"
"fmt"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// UpgradeStatus represents the current status of all upgrade operations
type UpgradeStatus struct {
ConnectionStates map[*pool.Conn]*ConnectionState
IsMoving bool
NewEndpoint string
Timestamp time.Time
}
// UpgradeStatistics provides statistics about upgrade operations
type UpgradeStatistics struct {
ActiveConnections int
IsMoving bool
MigratingConnections int
FailingOverConnections int
Timestamp time.Time
}
// RedisClientIntegration provides complete hitless upgrade integration for Redis clients
type RedisClientIntegration struct {
clientIntegrator *ClientIntegrator
connectionStateTracker *ConnectionStateTracker
config *HitlessUpgradeConfig
mu sync.RWMutex
enabled bool
}
// NewRedisClientIntegration creates a new Redis client integration for hitless upgrades
// This is used internally by the main hitless.go package
func NewRedisClientIntegration(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *RedisClientIntegration {
// Start with defaults
defaults := DefaultHitlessUpgradeConfig()
// If config is nil, use all defaults
if config == nil {
config = defaults
} else {
// Ensure all fields are set with defaults if they are zero values
if config.TransitionTimeout == 0 {
config = &HitlessUpgradeConfig{
Enabled: config.Enabled,
TransitionTimeout: defaults.TransitionTimeout,
CleanupInterval: config.CleanupInterval,
}
}
if config.CleanupInterval == 0 {
config = &HitlessUpgradeConfig{
Enabled: config.Enabled,
TransitionTimeout: config.TransitionTimeout,
CleanupInterval: defaults.CleanupInterval,
}
}
}
return &RedisClientIntegration{
clientIntegrator: NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout),
connectionStateTracker: NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout),
config: config,
enabled: config.Enabled,
}
}
// EnableHitlessUpgrades enables hitless upgrade functionality
func (rci *RedisClientIntegration) EnableHitlessUpgrades() {
rci.mu.Lock()
defer rci.mu.Unlock()
rci.enabled = true
}
// DisableHitlessUpgrades disables hitless upgrade functionality
func (rci *RedisClientIntegration) DisableHitlessUpgrades() {
rci.mu.Lock()
defer rci.mu.Unlock()
rci.enabled = false
}
// IsEnabled returns whether hitless upgrades are enabled
func (rci *RedisClientIntegration) IsEnabled() bool {
rci.mu.RLock()
defer rci.mu.RUnlock()
return rci.enabled
}
// No client registration needed - each client has its own hitless integration instance
// HandlePushNotification processes push notifications for hitless upgrades
func (rci *RedisClientIntegration) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if !rci.IsEnabled() {
// If disabled, just log and return without processing
internal.Logger.Printf(ctx, "hitless: received notification but hitless upgrades are disabled")
return nil
}
return rci.clientIntegrator.HandlePushNotification(ctx, handlerCtx, notification)
}
// IsConnectionTransitioning checks if a connection is currently transitioning
func (rci *RedisClientIntegration) IsConnectionTransitioning(conn *pool.Conn) bool {
if !rci.IsEnabled() {
return false
}
return rci.connectionStateTracker.IsConnectionTransitioning(conn)
}
// GetConnectionState returns the current state of a connection
func (rci *RedisClientIntegration) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
if !rci.IsEnabled() {
return nil, false
}
return rci.connectionStateTracker.GetConnectionState(conn)
}
// GetUpgradeStatus returns comprehensive status of all ongoing upgrades
func (rci *RedisClientIntegration) GetUpgradeStatus() *UpgradeStatus {
connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions()
return &UpgradeStatus{
ConnectionStates: connStates,
IsMoving: rci.clientIntegrator.IsMoving(),
NewEndpoint: rci.clientIntegrator.GetNewEndpoint(),
Timestamp: time.Now(),
}
}
// GetUpgradeStatistics returns statistics about upgrade operations
func (rci *RedisClientIntegration) GetUpgradeStatistics() *UpgradeStatistics {
connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions()
stats := &UpgradeStatistics{
ActiveConnections: len(connStates),
IsMoving: rci.clientIntegrator.IsMoving(),
Timestamp: time.Now(),
}
// Count by type
stats.MigratingConnections = 0
stats.FailingOverConnections = 0
for _, state := range connStates {
switch state.TransitionType {
case "MIGRATING":
stats.MigratingConnections++
case "FAILING_OVER":
stats.FailingOverConnections++
}
}
return stats
}
// GetConnectionTimeout returns the appropriate timeout for a connection
// If the connection is transitioning (MIGRATING/FAILING_OVER), returns the longer TransitionTimeout
// Otherwise returns the provided defaultTimeout
func (rci *RedisClientIntegration) GetConnectionTimeout(conn *pool.Conn, defaultTimeout time.Duration) time.Duration {
if !rci.IsEnabled() {
return defaultTimeout
}
// Check if connection is transitioning
if rci.connectionStateTracker.IsConnectionTransitioning(conn) {
// Use longer timeout for transitioning connections
return rci.config.TransitionTimeout
}
return defaultTimeout
}
// GetConnectionTimeouts returns both read and write timeouts for a connection
func (rci *RedisClientIntegration) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) {
if !rci.IsEnabled() {
return defaultReadTimeout, defaultWriteTimeout
}
// Use the upgrade handler to get appropriate timeouts
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
return upgradeHandler.GetConnectionTimeouts(conn, defaultReadTimeout, defaultWriteTimeout, rci.config.TransitionTimeout)
}
// MarkConnectionAsBlocking marks a connection as having blocking commands
func (rci *RedisClientIntegration) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
if !rci.IsEnabled() {
return
}
// Use the upgrade handler to mark connection as blocking
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
upgradeHandler.MarkConnectionAsBlocking(conn, isBlocking)
}
// IsConnectionMarkedForClosing checks if a connection should be closed
func (rci *RedisClientIntegration) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
if !rci.IsEnabled() {
return false
}
// Use the upgrade handler to check if connection is marked for closing
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
return upgradeHandler.IsConnectionMarkedForClosing(conn)
}
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
func (rci *RedisClientIntegration) ShouldRedirectBlockingConnection(conn *pool.Conn) (bool, string) {
if !rci.IsEnabled() {
return false, ""
}
// Check client integrator's atomic state for pool-level redirection
if rci.clientIntegrator.IsMoving() {
return true, rci.clientIntegrator.GetNewEndpoint()
}
// Check specific connection state
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
return upgradeHandler.ShouldRedirectBlockingConnection(conn, rci.clientIntegrator)
}
// CleanupConnection removes tracking for a connection (called when connection is closed)
func (rci *RedisClientIntegration) CleanupConnection(conn *pool.Conn) {
if rci.IsEnabled() {
rci.connectionStateTracker.CleanupConnection(conn)
}
}
// CleanupPool removed - no pool state to clean up
// Close shuts down the Redis client integration
func (rci *RedisClientIntegration) Close() error {
rci.mu.Lock()
defer rci.mu.Unlock()
var firstErr error
// Close all components
if err := rci.clientIntegrator.Close(); err != nil && firstErr == nil {
firstErr = err
}
// poolRedirector removed in simplified implementation
if err := rci.connectionStateTracker.Close(); err != nil && firstErr == nil {
firstErr = err
}
rci.enabled = false
return firstErr
}
// GetConfig returns the current configuration
func (rci *RedisClientIntegration) GetConfig() *HitlessUpgradeConfig {
rci.mu.RLock()
defer rci.mu.RUnlock()
// Return a copy to prevent modification
configCopy := *rci.config
return &configCopy
}
// UpdateConfig updates the configuration
func (rci *RedisClientIntegration) UpdateConfig(config *HitlessUpgradeConfig) error {
if config == nil {
return fmt.Errorf("config cannot be nil")
}
rci.mu.Lock()
defer rci.mu.Unlock()
// Start with defaults for any zero values
defaults := DefaultHitlessUpgradeConfig()
// Ensure all fields are set with defaults if they are zero values
enabled := config.Enabled
transitionTimeout := config.TransitionTimeout
cleanupInterval := config.CleanupInterval
// Apply defaults for zero values
if transitionTimeout == 0 {
transitionTimeout = defaults.TransitionTimeout
}
if cleanupInterval == 0 {
cleanupInterval = defaults.CleanupInterval
}
// Create properly configured config
finalConfig := &HitlessUpgradeConfig{
Enabled: enabled,
TransitionTimeout: transitionTimeout,
CleanupInterval: cleanupInterval,
}
rci.config = finalConfig
rci.enabled = finalConfig.Enabled
return nil
}

View File

@ -0,0 +1,500 @@
package hitless
import (
"context"
"fmt"
"strconv"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// UpgradeHandler handles hitless upgrade push notifications for Redis cluster upgrades.
// It implements different strategies based on notification type:
// - MOVING: Changes pool state to use new endpoint for future connections
// - mark existing connections for closing, handle pubsub to change underlying connection, change pool dialer with the new endpoint
//
// - MIGRATING/FAILING_OVER: Marks specific connection as in transition
// - relaxing timeouts
//
// - MIGRATED/FAILED_OVER: Clears transition state for specific connection
// - return to original timeouts
type UpgradeHandler struct {
mu sync.RWMutex
// Connection-specific state for MIGRATING/FAILING_OVER notifications
connectionStates map[*pool.Conn]*ConnectionState
// Pool-level state removed - using atomic state in ClusterUpgradeManager instead
// Client configuration for getting default timeouts
defaultReadTimeout time.Duration
defaultWriteTimeout time.Duration
}
// ConnectionState tracks the state of a specific connection during upgrades
type ConnectionState struct {
IsTransitioning bool
TransitionType string
StartTime time.Time
ShardID string
TimeoutSeconds int
// Timeout management
OriginalReadTimeout time.Duration // Original read timeout
OriginalWriteTimeout time.Duration // Original write timeout
// MOVING state specific
MarkedForClosing bool // Connection should be closed after current commands
IsBlocking bool // Connection has blocking commands
NewEndpoint string // New endpoint for blocking commands
LastCommandTime time.Time // When the last command was sent
}
// PoolState removed - using atomic state in ClusterUpgradeManager instead
// NewUpgradeHandler creates a new hitless upgrade handler with client timeout configuration
func NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout time.Duration) *UpgradeHandler {
return &UpgradeHandler{
connectionStates: make(map[*pool.Conn]*ConnectionState),
defaultReadTimeout: defaultReadTimeout,
defaultWriteTimeout: defaultWriteTimeout,
}
}
// GetConnectionTimeouts returns the appropriate read and write timeouts for a connection
// If the connection is transitioning, returns increased timeouts
func (h *UpgradeHandler) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout, transitionTimeout time.Duration) (time.Duration, time.Duration) {
h.mu.RLock()
defer h.mu.RUnlock()
state, exists := h.connectionStates[conn]
if !exists || !state.IsTransitioning {
return defaultReadTimeout, defaultWriteTimeout
}
// For transitioning connections (MIGRATING/FAILING_OVER), use longer timeouts
switch state.TransitionType {
case "MIGRATING", "FAILING_OVER":
return transitionTimeout, transitionTimeout
case "MOVING":
// For MOVING connections, use default timeouts but mark for special handling
return defaultReadTimeout, defaultWriteTimeout
default:
return defaultReadTimeout, defaultWriteTimeout
}
}
// MarkConnectionForClosing marks a connection to be closed after current commands complete
func (h *UpgradeHandler) MarkConnectionForClosing(conn *pool.Conn, newEndpoint string) {
h.mu.Lock()
defer h.mu.Unlock()
state, exists := h.connectionStates[conn]
if !exists {
state = &ConnectionState{
IsTransitioning: true,
TransitionType: "MOVING",
StartTime: time.Now(),
}
h.connectionStates[conn] = state
}
state.MarkedForClosing = true
state.NewEndpoint = newEndpoint
state.LastCommandTime = time.Now()
}
// IsConnectionMarkedForClosing checks if a connection should be closed
func (h *UpgradeHandler) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
h.mu.RLock()
defer h.mu.RUnlock()
state, exists := h.connectionStates[conn]
return exists && state.MarkedForClosing
}
// MarkConnectionAsBlocking marks a connection as having blocking commands
func (h *UpgradeHandler) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
h.mu.Lock()
defer h.mu.Unlock()
state, exists := h.connectionStates[conn]
if !exists {
state = &ConnectionState{
IsTransitioning: false,
}
h.connectionStates[conn] = state
}
state.IsBlocking = isBlocking
}
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
// Uses client integrator's atomic state for pool-level checks - minimal locking
func (h *UpgradeHandler) ShouldRedirectBlockingConnection(conn *pool.Conn, clientIntegrator interface{}) (bool, string) {
if conn != nil {
// Check specific connection - need lock only for connection state
h.mu.RLock()
state, exists := h.connectionStates[conn]
h.mu.RUnlock()
if exists && state.IsBlocking && state.TransitionType == "MOVING" && state.NewEndpoint != "" {
return true, state.NewEndpoint
}
}
// Check client integrator's atomic state - no locks needed
if ci, ok := clientIntegrator.(*ClientIntegrator); ok && ci != nil && ci.IsMoving() {
return true, ci.GetNewEndpoint()
}
return false, ""
}
// ShouldRedirectNewBlockingConnection removed - functionality merged into ShouldRedirectBlockingConnection
// HandlePushNotification processes hitless upgrade push notifications
func (h *UpgradeHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
return fmt.Errorf("hitless: empty notification received")
}
notificationType, ok := notification[0].(string)
if !ok {
return fmt.Errorf("hitless: notification type is not a string: %T", notification[0])
}
internal.Logger.Printf(ctx, "hitless: processing %s notification with %d elements", notificationType, len(notification))
switch notificationType {
case "MOVING":
return h.handleMovingNotification(ctx, handlerCtx, notification)
case "MIGRATING":
return h.handleMigratingNotification(ctx, handlerCtx, notification)
case "MIGRATED":
return h.handleMigratedNotification(ctx, handlerCtx, notification)
case "FAILING_OVER":
return h.handleFailingOverNotification(ctx, handlerCtx, notification)
case "FAILED_OVER":
return h.handleFailedOverNotification(ctx, handlerCtx, notification)
default:
internal.Logger.Printf(ctx, "hitless: unknown notification type: %s", notificationType)
return nil
}
}
// handleMovingNotification processes MOVING notifications that affect the entire pool
// Format: ["MOVING", time_seconds, "new_endpoint"]
func (h *UpgradeHandler) handleMovingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
return fmt.Errorf("hitless: MOVING notification requires at least 3 elements, got %d", len(notification))
}
// Parse timeout
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
if err != nil {
return fmt.Errorf("hitless: failed to parse timeout for MOVING notification: %w", err)
}
// Parse new endpoint
newEndpoint, ok := notification[2].(string)
if !ok {
return fmt.Errorf("hitless: new endpoint is not a string: %T", notification[2])
}
internal.Logger.Printf(ctx, "hitless: MOVING notification - endpoint will move to %s in %d seconds", newEndpoint, timeoutSeconds)
h.mu.Lock()
// Mark all existing connections for closing after current commands complete
for _, state := range h.connectionStates {
state.MarkedForClosing = true
state.NewEndpoint = newEndpoint
state.TransitionType = "MOVING"
state.LastCommandTime = time.Now()
}
h.mu.Unlock()
internal.Logger.Printf(ctx, "hitless: marked existing connections for closing, new blocking commands will use %s", newEndpoint)
return nil
}
// Removed complex helper methods - simplified to direct inline logic
// handleMigratingNotification processes MIGRATING notifications for specific connections
// Format: ["MIGRATING", time_seconds, shard_id]
func (h *UpgradeHandler) handleMigratingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
return fmt.Errorf("hitless: MIGRATING notification requires at least 3 elements, got %d", len(notification))
}
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
if err != nil {
return fmt.Errorf("hitless: failed to parse timeout for MIGRATING notification: %w", err)
}
shardID, err := h.parseShardID(notification[2])
if err != nil {
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATING notification: %w", err)
}
conn := handlerCtx.Conn
if conn == nil {
return fmt.Errorf("hitless: no connection available in handler context")
}
internal.Logger.Printf(ctx, "hitless: MIGRATING notification - shard %s will migrate in %d seconds on connection %p", shardID, timeoutSeconds, conn)
h.mu.Lock()
defer h.mu.Unlock()
// Store original timeouts if not already stored
var originalReadTimeout, originalWriteTimeout time.Duration
if existingState, exists := h.connectionStates[conn]; exists {
originalReadTimeout = existingState.OriginalReadTimeout
originalWriteTimeout = existingState.OriginalWriteTimeout
} else {
// Get default timeouts from client configuration
originalReadTimeout = h.defaultReadTimeout
originalWriteTimeout = h.defaultWriteTimeout
}
h.connectionStates[conn] = &ConnectionState{
IsTransitioning: true,
TransitionType: "MIGRATING",
StartTime: time.Now(),
ShardID: shardID,
TimeoutSeconds: timeoutSeconds,
OriginalReadTimeout: originalReadTimeout,
OriginalWriteTimeout: originalWriteTimeout,
LastCommandTime: time.Now(),
}
internal.Logger.Printf(ctx, "hitless: connection %p marked as MIGRATING with increased timeouts", conn)
return nil
}
// handleMigratedNotification processes MIGRATED notifications for specific connections
// Format: ["MIGRATED", shard_id]
func (h *UpgradeHandler) handleMigratedNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 2 {
return fmt.Errorf("hitless: MIGRATED notification requires at least 2 elements, got %d", len(notification))
}
shardID, err := h.parseShardID(notification[1])
if err != nil {
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATED notification: %w", err)
}
conn := handlerCtx.Conn
if conn == nil {
return fmt.Errorf("hitless: no connection available in handler context")
}
internal.Logger.Printf(ctx, "hitless: MIGRATED notification - shard %s migration completed on connection %p", shardID, conn)
h.mu.Lock()
defer h.mu.Unlock()
// Clear the transitioning state for this connection and restore original timeouts
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "MIGRATING" && state.ShardID == shardID {
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
// In a real implementation, this would restore the connection's original timeouts
// For now, we'll just log and delete the state
delete(h.connectionStates, conn)
internal.Logger.Printf(ctx, "hitless: cleared MIGRATING state and restored timeouts for connection %p", conn)
}
return nil
}
// handleFailingOverNotification processes FAILING_OVER notifications for specific connections
// Format: ["FAILING_OVER", time_seconds, shard_id]
func (h *UpgradeHandler) handleFailingOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
return fmt.Errorf("hitless: FAILING_OVER notification requires at least 3 elements, got %d", len(notification))
}
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
if err != nil {
return fmt.Errorf("hitless: failed to parse timeout for FAILING_OVER notification: %w", err)
}
shardID, err := h.parseShardID(notification[2])
if err != nil {
return fmt.Errorf("hitless: failed to parse shard ID for FAILING_OVER notification: %w", err)
}
conn := handlerCtx.Conn
if conn == nil {
return fmt.Errorf("hitless: no connection available in handler context")
}
internal.Logger.Printf(ctx, "hitless: FAILING_OVER notification - shard %s will failover in %d seconds on connection %p", shardID, timeoutSeconds, conn)
h.mu.Lock()
defer h.mu.Unlock()
// Store original timeouts if not already stored
var originalReadTimeout, originalWriteTimeout time.Duration
if existingState, exists := h.connectionStates[conn]; exists {
originalReadTimeout = existingState.OriginalReadTimeout
originalWriteTimeout = existingState.OriginalWriteTimeout
} else {
// Get default timeouts from client configuration
originalReadTimeout = h.defaultReadTimeout
originalWriteTimeout = h.defaultWriteTimeout
}
h.connectionStates[conn] = &ConnectionState{
IsTransitioning: true,
TransitionType: "FAILING_OVER",
StartTime: time.Now(),
ShardID: shardID,
TimeoutSeconds: timeoutSeconds,
OriginalReadTimeout: originalReadTimeout,
OriginalWriteTimeout: originalWriteTimeout,
LastCommandTime: time.Now(),
}
internal.Logger.Printf(ctx, "hitless: connection %p marked as FAILING_OVER with increased timeouts", conn)
return nil
}
// handleFailedOverNotification processes FAILED_OVER notifications for specific connections
// Format: ["FAILED_OVER", shard_id]
func (h *UpgradeHandler) handleFailedOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 2 {
return fmt.Errorf("hitless: FAILED_OVER notification requires at least 2 elements, got %d", len(notification))
}
shardID, err := h.parseShardID(notification[1])
if err != nil {
return fmt.Errorf("hitless: failed to parse shard ID for FAILED_OVER notification: %w", err)
}
conn := handlerCtx.Conn
if conn == nil {
return fmt.Errorf("hitless: no connection available in handler context")
}
internal.Logger.Printf(ctx, "hitless: FAILED_OVER notification - shard %s failover completed on connection %p", shardID, conn)
h.mu.Lock()
defer h.mu.Unlock()
// Clear the transitioning state for this connection and restore original timeouts
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "FAILING_OVER" && state.ShardID == shardID {
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
// In a real implementation, this would restore the connection's original timeouts
// For now, we'll just log and delete the state
delete(h.connectionStates, conn)
internal.Logger.Printf(ctx, "hitless: cleared FAILING_OVER state and restored timeouts for connection %p", conn)
}
return nil
}
// parseTimeoutSeconds parses timeout value from notification
func (h *UpgradeHandler) parseTimeoutSeconds(value interface{}) (int, error) {
switch v := value.(type) {
case int64:
return int(v), nil
case int:
return v, nil
case string:
return strconv.Atoi(v)
default:
return 0, fmt.Errorf("unsupported timeout type: %T", value)
}
}
// parseShardID parses shard ID from notification
func (h *UpgradeHandler) parseShardID(value interface{}) (string, error) {
switch v := value.(type) {
case string:
return v, nil
case int64:
return strconv.FormatInt(v, 10), nil
case int:
return strconv.Itoa(v), nil
default:
return "", fmt.Errorf("unsupported shard ID type: %T", value)
}
}
// GetConnectionState returns the current state of a connection
func (h *UpgradeHandler) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
state, exists := h.connectionStates[conn]
if !exists {
return nil, false
}
// Return a copy to avoid race conditions
stateCopy := *state
return &stateCopy, true
}
// IsConnectionTransitioning checks if a connection is currently transitioning
func (h *UpgradeHandler) IsConnectionTransitioning(conn *pool.Conn) bool {
h.mu.RLock()
defer h.mu.RUnlock()
state, exists := h.connectionStates[conn]
return exists && state.IsTransitioning
}
// IsPoolMoving and GetNewEndpoint removed - using atomic state in ClusterUpgradeManager instead
// CleanupExpiredStates removes expired connection and pool states
func (h *UpgradeHandler) CleanupExpiredStates() {
h.mu.Lock()
defer h.mu.Unlock()
now := time.Now()
// Cleanup expired connection states
for conn, state := range h.connectionStates {
timeout := time.Duration(state.TimeoutSeconds) * time.Second
if now.Sub(state.StartTime) > timeout {
delete(h.connectionStates, conn)
}
}
// Pool state cleanup removed - using atomic state in ClusterUpgradeManager instead
}
// CleanupConnection removes state for a specific connection (called when connection is closed)
func (h *UpgradeHandler) CleanupConnection(conn *pool.Conn) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.connectionStates, conn)
}
// GetActiveTransitions returns information about all active connection transitions
func (h *UpgradeHandler) GetActiveTransitions() map[*pool.Conn]*ConnectionState {
h.mu.RLock()
defer h.mu.RUnlock()
// Create copies to avoid race conditions
connStates := make(map[*pool.Conn]*ConnectionState)
for conn, state := range h.connectionStates {
stateCopy := *state
connStates[conn] = &stateCopy
}
return connStates
}

View File

@ -565,3 +565,24 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
cn.SetUsedAt(now) cn.SetUsedAt(now)
return true return true
} }
// GetDialer returns the current dialer function for the pool
func (p *ConnPool) GetDialer() func(context.Context) (net.Conn, error) {
p.connsMu.Lock()
defer p.connsMu.Unlock()
return p.cfg.Dialer
}
// SetDialer sets a new dialer function for the pool
// This is used for hitless upgrades to redirect new connections to new endpoints
func (p *ConnPool) SetDialer(dialer func(context.Context) (net.Conn, error)) error {
if p.closed() {
return ErrClosed
}
p.connsMu.Lock()
defer p.connsMu.Unlock()
p.cfg.Dialer = dialer
return nil
}

View File

@ -3,6 +3,7 @@ package proto
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math/rand"
"strings" "strings"
"testing" "testing"
) )
@ -215,9 +216,9 @@ func TestPeekPushNotificationName(t *testing.T) {
// This is acceptable behavior for malformed input // This is acceptable behavior for malformed input
name, err := reader.PeekPushNotificationName() name, err := reader.PeekPushNotificationName()
if err != nil { if err != nil {
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v", tc.name, err) t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data)
} else { } else {
t.Logf("PeekPushNotificationName returned '%s' for corrupted data %s", name, tc.name) t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data)
} }
}) })
} }
@ -293,15 +294,27 @@ func TestPeekPushNotificationName(t *testing.T) {
func createValidPushNotification(notificationName, data string) *bytes.Buffer { func createValidPushNotification(notificationName, data string) *bytes.Buffer {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
simpleOrString := rand.Intn(2) == 0
if data == "" { if data == "" {
// Single element notification // Single element notification
buf.WriteString(">1\r\n") buf.WriteString(">1\r\n")
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) if simpleOrString {
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
} else {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
}
} else { } else {
// Two element notification // Two element notification
buf.WriteString(">2\r\n") buf.WriteString(">2\r\n")
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) if simpleOrString {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(data), data)) buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
buf.WriteString(fmt.Sprintf("+%s\r\n", data))
} else {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
}
} }
return buf return buf

View File

@ -116,26 +116,55 @@ func (r *Reader) PeekPushNotificationName() (string, error) {
if buf[0] != RespPush { if buf[0] != RespPush {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf) return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
} }
// remove push notification type and length
buf = buf[2:] if len(buf) < 3 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// remove push notification type
buf = buf[1:]
// remove first line - e.g. >2\r\n
for i := 0; i < len(buf)-1; i++ { for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' { if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:] buf = buf[i+2:]
break break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
} }
} }
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// next line should be $<length><string>\r\n or +<length><string>\r\n
// should have the type of the push notification name and it's length // should have the type of the push notification name and it's length
if buf[0] != RespString { if buf[0] != RespString && buf[0] != RespStatus {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf) return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
} }
// skip the length of the string typeOfName := buf[0]
for i := 0; i < len(buf)-1; i++ { // remove the type of the push notification name
if buf[i] == '\r' && buf[i+1] == '\n' { buf = buf[1:]
buf = buf[i+2:] if typeOfName == RespString {
break // remove the length of the string
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
}
} }
} }
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
// keep only the notification name // keep only the notification name
for i := 0; i < len(buf)-1; i++ { for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' { if buf[i] == '\r' && buf[i+1] == '\n' {
@ -143,6 +172,7 @@ func (r *Reader) PeekPushNotificationName() (string, error) {
break break
} }
} }
return util.BytesToString(buf), nil return util.BytesToString(buf), nil
} }

View File

@ -224,6 +224,18 @@ type Options struct {
// PushNotificationProcessor is the processor for handling push notifications. // PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections. // If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor PushNotificationProcessor push.NotificationProcessor
// HitlessUpgrades enables hitless upgrade functionality for cluster upgrades.
// Requires Protocol: 3 (RESP3) for push notifications.
// When enabled, the client will automatically handle cluster upgrade notifications
// and manage connection/pool state transitions seamlessly.
//
// default: false
HitlessUpgrades bool
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// If nil, default configuration will be used when HitlessUpgrades is true.
HitlessUpgradeConfig *HitlessUpgradeConfig
} }
func (opt *Options) init() { func (opt *Options) init() {

View File

@ -110,6 +110,18 @@ type ClusterOptions struct {
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3. // UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
UnstableResp3 bool UnstableResp3 bool
// HitlessUpgrades enables hitless upgrade functionality for cluster upgrades.
// Requires Protocol: 3 (RESP3) for push notifications.
// When enabled, the client will automatically handle cluster upgrade notifications
// and manage connection/pool state transitions seamlessly.
//
// default: false
HitlessUpgrades bool
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// If nil, default configuration will be used when HitlessUpgrades is true.
HitlessUpgradeConfig *HitlessUpgradeConfig
} }
func (opt *ClusterOptions) init() { func (opt *ClusterOptions) init() {
@ -327,8 +339,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
// much use for ClusterSlots config). This means we cannot execute the // much use for ClusterSlots config). This means we cannot execute the
// READONLY command against that node -- setting readOnly to false in such // READONLY command against that node -- setting readOnly to false in such
// situations in the options below will prevent that from happening. // situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil, readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3, UnstableResp3: opt.UnstableResp3,
HitlessUpgrades: opt.HitlessUpgrades,
HitlessUpgradeConfig: opt.HitlessUpgradeConfig,
} }
} }
@ -943,6 +957,9 @@ type ClusterClient struct {
cmdsInfoCache *cmdsInfoCache cmdsInfoCache *cmdsInfoCache
cmdable cmdable
hooksMixin hooksMixin
// hitlessIntegration provides hitless upgrade functionality
hitlessIntegration HitlessIntegration
} }
// NewClusterClient returns a Redis Cluster client as described in // NewClusterClient returns a Redis Cluster client as described in
@ -969,6 +986,22 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
txPipeline: c.processTxPipeline, txPipeline: c.processTxPipeline,
}) })
// Initialize hitless upgrades if enabled
if opt.HitlessUpgrades {
if opt.Protocol != 3 {
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
} else {
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
integration, err := initializeHitlessIntegration(c, opt.HitlessUpgradeConfig, timeoutProvider)
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
} else {
c.hitlessIntegration = integration
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for cluster client")
}
}
}
return c return c
} }
@ -977,6 +1010,14 @@ func (c *ClusterClient) Options() *ClusterOptions {
return c.opt return c.opt
} }
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *ClusterClient) GetHitlessIntegration() HitlessIntegration {
return c.hitlessIntegration
}
// getPushNotificationProcessor removed - not needed in simplified implementation
// ReloadState reloads cluster state. If available it calls ClusterSlots func // ReloadState reloads cluster state. If available it calls ClusterSlots func
// to get cluster slots information. // to get cluster slots information.
func (c *ClusterClient) ReloadState(ctx context.Context) { func (c *ClusterClient) ReloadState(ctx context.Context) {
@ -988,6 +1029,13 @@ func (c *ClusterClient) ReloadState(ctx context.Context) {
// It is rare to Close a ClusterClient, as the ClusterClient is meant // It is rare to Close a ClusterClient, as the ClusterClient is meant
// to be long-lived and shared between many goroutines. // to be long-lived and shared between many goroutines.
func (c *ClusterClient) Close() error { func (c *ClusterClient) Close() error {
// Close hitless integration first
if c.hitlessIntegration != nil {
if err := c.hitlessIntegration.Close(); err != nil {
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
}
}
return c.nodes.Close() return c.nodes.Close()
} }

View File

@ -13,9 +13,6 @@ type NotificationHandlerContext struct {
// circular dependencies. The developer is responsible for type assertion. // circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types: // It can be one of the following types:
// - *redis.baseClient // - *redis.baseClient
// - *redis.Client
// - *redis.ClusterClient
// - *redis.Conn
Client interface{} Client interface{}
// ConnPool is the connection pool from which the connection was obtained. // ConnPool is the connection pool from which the connection was obtained.
@ -25,7 +22,7 @@ type NotificationHandlerContext struct {
// - *pool.ConnPool // - *pool.ConnPool
// - *pool.SingleConnPool // - *pool.SingleConnPool
// - *pool.StickyConnPool // - *pool.StickyConnPool
ConnPool interface{} ConnPool pool.Pooler
// PubSub is the PubSub instance that received the notification. // PubSub is the PubSub instance that received the notification.
// It is interface to both allow for future expansion and to avoid // It is interface to both allow for future expansion and to avoid

193
redis.go
View File

@ -2,6 +2,7 @@ package redis
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -211,6 +212,9 @@ type baseClient struct {
// Push notification processing // Push notification processing
pushProcessor push.NotificationProcessor pushProcessor push.NotificationProcessor
// hitlessIntegration provides hitless upgrade functionality
hitlessIntegration HitlessIntegration
} }
func (c *baseClient) clone() *baseClient { func (c *baseClient) clone() *baseClient {
@ -466,7 +470,16 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
// Push notification processing errors shouldn't break normal Redis operations // Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err) internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
} }
c.connPool.Put(ctx, cn)
// Check if connection is marked for closing due to hitless upgrades
if c.hitlessIntegration != nil && c.hitlessIntegration.IsConnectionMarkedForClosing(cn) {
// Connection is marked for closing (e.g., during MOVING state)
// Remove it instead of putting it back in the pool
internal.Logger.Printf(ctx, "hitless: closing connection marked for closure during upgrade")
c.connPool.Remove(ctx, cn, nil)
} else {
c.connPool.Put(ctx, cn)
}
} }
} }
@ -528,7 +541,33 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
} }
retryTimeout := uint32(0) retryTimeout := uint32(0)
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
// Check if this is a blocking command that needs redirection
isBlockingCommand := cmd.readTimeout() != nil
var redirectedConn *pool.Conn
var shouldRedirect bool
var newEndpoint string
if c.hitlessIntegration != nil && isBlockingCommand {
// For blocking commands, check if we need to redirect to a new endpoint
// This happens during MOVING state when the endpoint is changing
shouldRedirect, newEndpoint = c.hitlessIntegration.ShouldRedirectBlockingConnection(nil)
if shouldRedirect {
// Create a new connection to the new endpoint
var err error
redirectedConn, err = c.createConnectionToEndpoint(ctx, newEndpoint)
if err != nil {
internal.Logger.Printf(ctx, "hitless: failed to create redirected connection to %s: %v", newEndpoint, err)
// Fall back to normal connection if redirection fails
shouldRedirect = false
} else {
internal.Logger.Printf(ctx, "hitless: redirecting blocking command %s to new endpoint %s", cmd.Name(), newEndpoint)
}
}
}
// Use redirected connection if available, otherwise use normal connection
connFunc := func(ctx context.Context, cn *pool.Conn) error {
// Process any pending push notifications before executing the command // Process any pending push notifications before executing the command
if err := c.processPushNotifications(ctx, cn); err != nil { if err := c.processPushNotifications(ctx, cn); err != nil {
// Log the error but don't fail the command execution // Log the error but don't fail the command execution
@ -536,7 +575,19 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err) internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
} }
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { // Mark connection as blocking if this is a blocking command
if c.hitlessIntegration != nil && isBlockingCommand {
c.hitlessIntegration.MarkConnectionAsBlocking(cn, true)
internal.Logger.Printf(ctx, "hitless: marked connection as blocking for command %s", cmd.Name())
}
// Get appropriate write timeout for this connection
writeTimeout := c.opt.WriteTimeout
if c.hitlessIntegration != nil {
_, writeTimeout = c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
}
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd) return writeCmd(wr, cmd)
}); err != nil { }); err != nil {
atomic.StoreUint32(&retryTimeout, 1) atomic.StoreUint32(&retryTimeout, 1)
@ -547,7 +598,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) { if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
readReplyFunc = cmd.readRawReply readReplyFunc = cmd.readRawReply
} }
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { if err := cn.WithReader(c.context(ctx), c.cmdTimeoutForConnection(cmd, cn), func(rd *proto.Reader) error {
// To be sure there are no buffered push notifications, we process them before reading the reply // To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution // Log the error but don't fail the command execution
@ -564,8 +615,31 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
return err return err
} }
// Unmark connection as blocking after command completes
if c.hitlessIntegration != nil && isBlockingCommand {
c.hitlessIntegration.MarkConnectionAsBlocking(cn, false)
internal.Logger.Printf(ctx, "hitless: unmarked connection as blocking after command %s completed", cmd.Name())
}
return nil return nil
}); err != nil { }
// Execute the command with either redirected or normal connection
var err error
if shouldRedirect && redirectedConn != nil {
// Use the redirected connection for blocking command
err = connFunc(ctx, redirectedConn)
// Close the redirected connection after use
defer func() {
redirectedConn.Close()
internal.Logger.Printf(ctx, "hitless: closed redirected connection to %s after blocking command completed", newEndpoint)
}()
} else {
// Use normal connection pool
err = c.withConn(ctx, connFunc)
}
if err != nil {
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
return retry, err return retry, err
} }
@ -588,6 +662,70 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
return c.opt.ReadTimeout return c.opt.ReadTimeout
} }
// cmdTimeoutForConnection returns the appropriate read timeout for a specific connection
// taking into account hitless upgrade state
func (c *baseClient) cmdTimeoutForConnection(cmd Cmder, cn *pool.Conn) time.Duration {
baseTimeout := c.cmdTimeout(cmd)
// If hitless upgrades are enabled, get dynamic timeout based on connection state
if c.hitlessIntegration != nil {
// For blocking commands, use the command's timeout but check if connection needs increased timeout
if cmd.readTimeout() != nil {
// For blocking commands, use the base timeout but apply hitless upgrade adjustments
adjustedTimeout := c.hitlessIntegration.GetConnectionTimeout(cn, baseTimeout)
return adjustedTimeout
} else {
// For regular commands, get both read and write timeouts (use read timeout for command)
readTimeout, _ := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
return readTimeout
}
}
return baseTimeout
}
// createConnectionToEndpoint creates a new connection to a specific endpoint
// This is used for redirecting blocking commands during MOVING state
func (c *baseClient) createConnectionToEndpoint(ctx context.Context, endpoint string) (*pool.Conn, error) {
// Parse the endpoint to get host and port
addr := endpoint
if addr == "" {
return nil, fmt.Errorf("empty endpoint provided")
}
// Create a temporary dialer for the new endpoint
dialer := func(ctx context.Context) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: c.opt.DialTimeout,
KeepAlive: 30 * time.Second,
}
if c.opt.TLSConfig == nil {
return netDialer.DialContext(ctx, c.opt.Network, addr)
}
return tls.DialWithDialer(netDialer, c.opt.Network, addr, c.opt.TLSConfig)
}
// Create a new connection using the dialer
netConn, err := dialer(ctx)
if err != nil {
return nil, fmt.Errorf("failed to dial new endpoint %s: %w", endpoint, err)
}
// Wrap in pool.Conn
cn := pool.NewConn(netConn)
// Initialize the connection (auth, select db, etc.)
if err := c.initConn(ctx, cn); err != nil {
cn.Close()
return nil, fmt.Errorf("failed to initialize connection to %s: %w", endpoint, err)
}
internal.Logger.Printf(ctx, "hitless: created new connection to endpoint %s for blocking command redirection", endpoint)
return cn, nil
}
// context returns the context for the current connection. // context returns the context for the current connection.
// If the context timeout is enabled, it returns the original context. // If the context timeout is enabled, it returns the original context.
// Otherwise, it returns a new background context. // Otherwise, it returns a new background context.
@ -604,8 +742,19 @@ func (c *baseClient) context(ctx context.Context) context.Context {
// long-lived and shared between many goroutines. // long-lived and shared between many goroutines.
func (c *baseClient) Close() error { func (c *baseClient) Close() error {
var firstErr error var firstErr error
// Close hitless integration first
if c.hitlessIntegration != nil {
if err := c.hitlessIntegration.Close(); err != nil {
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
if firstErr == nil {
firstErr = err
}
}
}
if c.onClose != nil { if c.onClose != nil {
if err := c.onClose(); err != nil { if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
} }
@ -678,14 +827,15 @@ func (c *baseClient) pipelineProcessCmds(
internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err) internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err)
} }
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds) return writeCmds(wr, cmds)
}); err != nil { }); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return true, err return true, err
} }
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
// read all replies // read all replies
return c.pipelineReadCmds(ctx, cn, rd, cmds) return c.pipelineReadCmds(ctx, cn, rd, cmds)
}); err != nil { }); err != nil {
@ -725,14 +875,15 @@ func (c *baseClient) txPipelineProcessCmds(
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err) internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
} }
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds) return writeCmds(wr, cmds)
}); err != nil { }); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return true, err return true, err
} }
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
statusCmd := cmds[0].(*StatusCmd) statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec. // Trim multi and exec.
trimmedCmds := cmds[1 : len(cmds)-1] trimmedCmds := cmds[1 : len(cmds)-1]
@ -837,6 +988,22 @@ func NewClient(opt *Options) *Client {
c.connPool = newConnPool(opt, c.dialHook) c.connPool = newConnPool(opt, c.dialHook)
// Initialize hitless upgrades if enabled
if opt.HitlessUpgrades {
if opt.Protocol != 3 {
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
} else {
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
integration, err := initializeHitlessIntegration(&c, opt.HitlessUpgradeConfig, timeoutProvider)
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
} else {
c.hitlessIntegration = integration
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for client")
}
}
}
return &c return &c
} }
@ -857,6 +1024,12 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
return &clone return &clone
} }
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessIntegration() HitlessIntegration {
return c.hitlessIntegration
}
func (c *Client) Conn() *Conn { func (c *Client) Conn() *Conn {
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin) return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
} }