mirror of
https://github.com/redis/go-redis.git
synced 2025-07-16 13:21:51 +03:00
wip.
This commit is contained in:
416
hitless.go
Normal file
416
hitless.go
Normal file
@ -0,0 +1,416 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/hitless"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// HitlessUpgradeConfig provides configuration for hitless upgrades
|
||||
type HitlessUpgradeConfig struct {
|
||||
// Enabled controls whether hitless upgrades are active
|
||||
Enabled bool
|
||||
|
||||
// TransitionTimeout is the increased timeout for connections during transitions
|
||||
// (MIGRATING/FAILING_OVER). This should be longer than normal operation timeouts
|
||||
// to account for the time needed to complete the transition.
|
||||
// Default: 60 seconds
|
||||
TransitionTimeout time.Duration
|
||||
|
||||
// CleanupInterval controls how often expired states are cleaned up
|
||||
// Default: 30 seconds
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// DefaultHitlessUpgradeConfig returns the default configuration for hitless upgrades
|
||||
func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig {
|
||||
return &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections
|
||||
CleanupInterval: 30 * time.Second, // How often to clean up expired states
|
||||
}
|
||||
}
|
||||
|
||||
// HitlessUpgradeStatistics provides statistics about ongoing upgrade operations
|
||||
type HitlessUpgradeStatistics struct {
|
||||
ActiveConnections int // Total connections in transition
|
||||
IsMoving bool // Whether pool is currently moving
|
||||
MigratingConnections int // Connections in MIGRATING state
|
||||
FailingOverConnections int // Connections in FAILING_OVER state
|
||||
Timestamp time.Time // When these statistics were collected
|
||||
}
|
||||
|
||||
// HitlessUpgradeStatus provides detailed status of all ongoing upgrades
|
||||
type HitlessUpgradeStatus struct {
|
||||
ConnectionStates map[interface{}]interface{}
|
||||
IsMoving bool
|
||||
NewEndpoint string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// HitlessIntegration provides the interface for hitless upgrade functionality
|
||||
type HitlessIntegration interface {
|
||||
// IsEnabled returns whether hitless upgrades are currently enabled
|
||||
IsEnabled() bool
|
||||
|
||||
// EnableHitlessUpgrades enables hitless upgrade functionality
|
||||
EnableHitlessUpgrades()
|
||||
|
||||
// DisableHitlessUpgrades disables hitless upgrade functionality
|
||||
DisableHitlessUpgrades()
|
||||
|
||||
// GetConnectionTimeout returns the appropriate timeout for a connection
|
||||
// If the connection is transitioning, returns the longer TransitionTimeout
|
||||
GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration
|
||||
|
||||
// GetConnectionTimeouts returns both read and write timeouts for a connection
|
||||
// If the connection is transitioning, returns increased timeouts
|
||||
GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration)
|
||||
|
||||
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||
MarkConnectionAsBlocking(conn interface{}, isBlocking bool)
|
||||
|
||||
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||
IsConnectionMarkedForClosing(conn interface{}) bool
|
||||
|
||||
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||
ShouldRedirectBlockingConnection(conn interface{}) (bool, string)
|
||||
|
||||
// GetUpgradeStatistics returns current upgrade statistics
|
||||
GetUpgradeStatistics() *HitlessUpgradeStatistics
|
||||
|
||||
// GetUpgradeStatus returns detailed upgrade status
|
||||
GetUpgradeStatus() *HitlessUpgradeStatus
|
||||
|
||||
// UpdateConfig updates the hitless upgrade configuration
|
||||
UpdateConfig(config *HitlessUpgradeConfig) error
|
||||
|
||||
// GetConfig returns the current configuration
|
||||
GetConfig() *HitlessUpgradeConfig
|
||||
|
||||
// Close shuts down the hitless integration
|
||||
Close() error
|
||||
}
|
||||
|
||||
// hitlessIntegrationImpl implements the HitlessIntegration interface
|
||||
type hitlessIntegrationImpl struct {
|
||||
integration *hitless.RedisClientIntegration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// newHitlessIntegration creates a new hitless integration instance
|
||||
func newHitlessIntegration(config *HitlessUpgradeConfig) *hitlessIntegrationImpl {
|
||||
if config == nil {
|
||||
config = DefaultHitlessUpgradeConfig()
|
||||
}
|
||||
|
||||
// Convert to internal config format
|
||||
internalConfig := &hitless.HitlessUpgradeConfig{
|
||||
Enabled: config.Enabled,
|
||||
TransitionTimeout: config.TransitionTimeout,
|
||||
CleanupInterval: config.CleanupInterval,
|
||||
}
|
||||
|
||||
integration := hitless.NewRedisClientIntegration(internalConfig, 3*time.Second, 3*time.Second)
|
||||
|
||||
return &hitlessIntegrationImpl{
|
||||
integration: integration,
|
||||
}
|
||||
}
|
||||
|
||||
// newHitlessIntegrationWithTimeouts creates a new hitless integration instance with timeout configuration
|
||||
func newHitlessIntegrationWithTimeouts(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *hitlessIntegrationImpl {
|
||||
// Start with defaults
|
||||
defaults := DefaultHitlessUpgradeConfig()
|
||||
|
||||
// If config is nil, use all defaults
|
||||
if config == nil {
|
||||
config = defaults
|
||||
}
|
||||
|
||||
// Ensure all fields are set with defaults if they are zero values
|
||||
enabled := config.Enabled
|
||||
transitionTimeout := config.TransitionTimeout
|
||||
cleanupInterval := config.CleanupInterval
|
||||
|
||||
// Apply defaults for zero values
|
||||
if transitionTimeout == 0 {
|
||||
transitionTimeout = defaults.TransitionTimeout
|
||||
}
|
||||
if cleanupInterval == 0 {
|
||||
cleanupInterval = defaults.CleanupInterval
|
||||
}
|
||||
|
||||
// Convert to internal config format with all fields properly set
|
||||
internalConfig := &hitless.HitlessUpgradeConfig{
|
||||
Enabled: enabled,
|
||||
TransitionTimeout: transitionTimeout,
|
||||
CleanupInterval: cleanupInterval,
|
||||
}
|
||||
|
||||
integration := hitless.NewRedisClientIntegration(internalConfig, defaultReadTimeout, defaultWriteTimeout)
|
||||
|
||||
return &hitlessIntegrationImpl{
|
||||
integration: integration,
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled returns whether hitless upgrades are currently enabled
|
||||
func (h *hitlessIntegrationImpl) IsEnabled() bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.integration.IsEnabled()
|
||||
}
|
||||
|
||||
// EnableHitlessUpgrades enables hitless upgrade functionality
|
||||
func (h *hitlessIntegrationImpl) EnableHitlessUpgrades() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.integration.EnableHitlessUpgrades()
|
||||
}
|
||||
|
||||
// DisableHitlessUpgrades disables hitless upgrade functionality
|
||||
func (h *hitlessIntegrationImpl) DisableHitlessUpgrades() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.integration.DisableHitlessUpgrades()
|
||||
}
|
||||
|
||||
// GetConnectionTimeout returns the appropriate timeout for a connection
|
||||
func (h *hitlessIntegrationImpl) GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Convert interface{} to *pool.Conn
|
||||
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||
return h.integration.GetConnectionTimeout(poolConn, defaultTimeout)
|
||||
}
|
||||
|
||||
// If not a pool connection, return default timeout
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
// GetConnectionTimeouts returns both read and write timeouts for a connection
|
||||
func (h *hitlessIntegrationImpl) GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Convert interface{} to *pool.Conn
|
||||
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||
return h.integration.GetConnectionTimeouts(poolConn, defaultReadTimeout, defaultWriteTimeout)
|
||||
}
|
||||
|
||||
// If not a pool connection, return default timeouts
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
}
|
||||
|
||||
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||
func (h *hitlessIntegrationImpl) MarkConnectionAsBlocking(conn interface{}, isBlocking bool) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Convert interface{} to *pool.Conn
|
||||
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||
h.integration.MarkConnectionAsBlocking(poolConn, isBlocking)
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||
func (h *hitlessIntegrationImpl) IsConnectionMarkedForClosing(conn interface{}) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Convert interface{} to *pool.Conn
|
||||
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||
return h.integration.IsConnectionMarkedForClosing(poolConn)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||
func (h *hitlessIntegrationImpl) ShouldRedirectBlockingConnection(conn interface{}) (bool, string) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Convert interface{} to *pool.Conn (can be nil for checking pool state)
|
||||
var poolConn *pool.Conn
|
||||
if conn != nil {
|
||||
if pc, ok := conn.(*pool.Conn); ok {
|
||||
poolConn = pc
|
||||
}
|
||||
}
|
||||
|
||||
return h.integration.ShouldRedirectBlockingConnection(poolConn)
|
||||
}
|
||||
|
||||
// GetUpgradeStatistics returns current upgrade statistics
|
||||
func (h *hitlessIntegrationImpl) GetUpgradeStatistics() *HitlessUpgradeStatistics {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
stats := h.integration.GetUpgradeStatistics()
|
||||
if stats == nil {
|
||||
return &HitlessUpgradeStatistics{Timestamp: time.Now()}
|
||||
}
|
||||
|
||||
return &HitlessUpgradeStatistics{
|
||||
ActiveConnections: stats.ActiveConnections,
|
||||
IsMoving: stats.IsMoving,
|
||||
MigratingConnections: stats.MigratingConnections,
|
||||
FailingOverConnections: stats.FailingOverConnections,
|
||||
Timestamp: stats.Timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpgradeStatus returns detailed upgrade status
|
||||
func (h *hitlessIntegrationImpl) GetUpgradeStatus() *HitlessUpgradeStatus {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
status := h.integration.GetUpgradeStatus()
|
||||
if status == nil {
|
||||
return &HitlessUpgradeStatus{
|
||||
ConnectionStates: make(map[interface{}]interface{}),
|
||||
IsMoving: false,
|
||||
NewEndpoint: "",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
return &HitlessUpgradeStatus{
|
||||
ConnectionStates: convertToInterfaceMap(status.ConnectionStates),
|
||||
IsMoving: status.IsMoving,
|
||||
NewEndpoint: status.NewEndpoint,
|
||||
Timestamp: status.Timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig updates the hitless upgrade configuration
|
||||
func (h *hitlessIntegrationImpl) UpdateConfig(config *HitlessUpgradeConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("config cannot be nil")
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Start with defaults for any zero values
|
||||
defaults := DefaultHitlessUpgradeConfig()
|
||||
|
||||
// Ensure all fields are set with defaults if they are zero values
|
||||
enabled := config.Enabled
|
||||
transitionTimeout := config.TransitionTimeout
|
||||
cleanupInterval := config.CleanupInterval
|
||||
|
||||
// Apply defaults for zero values
|
||||
if transitionTimeout == 0 {
|
||||
transitionTimeout = defaults.TransitionTimeout
|
||||
}
|
||||
if cleanupInterval == 0 {
|
||||
cleanupInterval = defaults.CleanupInterval
|
||||
}
|
||||
|
||||
// Convert to internal config format with all fields properly set
|
||||
internalConfig := &hitless.HitlessUpgradeConfig{
|
||||
Enabled: enabled,
|
||||
TransitionTimeout: transitionTimeout,
|
||||
CleanupInterval: cleanupInterval,
|
||||
}
|
||||
|
||||
return h.integration.UpdateConfig(internalConfig)
|
||||
}
|
||||
|
||||
// GetConfig returns the current configuration
|
||||
func (h *hitlessIntegrationImpl) GetConfig() *HitlessUpgradeConfig {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
internalConfig := h.integration.GetConfig()
|
||||
if internalConfig == nil {
|
||||
return DefaultHitlessUpgradeConfig()
|
||||
}
|
||||
|
||||
return &HitlessUpgradeConfig{
|
||||
Enabled: internalConfig.Enabled,
|
||||
TransitionTimeout: internalConfig.TransitionTimeout,
|
||||
CleanupInterval: internalConfig.CleanupInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the hitless integration
|
||||
func (h *hitlessIntegrationImpl) Close() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.integration.Close()
|
||||
}
|
||||
|
||||
// getInternalIntegration returns the internal integration for use by Redis clients
|
||||
func (h *hitlessIntegrationImpl) getInternalIntegration() *hitless.RedisClientIntegration {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.integration
|
||||
}
|
||||
|
||||
// ClientTimeoutProvider interface for extracting timeout configuration from client options
|
||||
type ClientTimeoutProvider interface {
|
||||
GetReadTimeout() time.Duration
|
||||
GetWriteTimeout() time.Duration
|
||||
}
|
||||
|
||||
// optionsTimeoutProvider implements ClientTimeoutProvider for Options struct
|
||||
type optionsTimeoutProvider struct {
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
}
|
||||
|
||||
func (p *optionsTimeoutProvider) GetReadTimeout() time.Duration {
|
||||
return p.readTimeout
|
||||
}
|
||||
|
||||
func (p *optionsTimeoutProvider) GetWriteTimeout() time.Duration {
|
||||
return p.writeTimeout
|
||||
}
|
||||
|
||||
// newOptionsTimeoutProvider creates a timeout provider from Options
|
||||
func newOptionsTimeoutProvider(readTimeout, writeTimeout time.Duration) ClientTimeoutProvider {
|
||||
return &optionsTimeoutProvider{
|
||||
readTimeout: readTimeout,
|
||||
writeTimeout: writeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// initializeHitlessIntegration initializes hitless integration for a client
|
||||
func initializeHitlessIntegration(client interface{}, config *HitlessUpgradeConfig, timeoutProvider ClientTimeoutProvider) (*hitlessIntegrationImpl, error) {
|
||||
if config == nil || !config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Extract timeout configuration from client options
|
||||
defaultReadTimeout := timeoutProvider.GetReadTimeout()
|
||||
defaultWriteTimeout := timeoutProvider.GetWriteTimeout()
|
||||
|
||||
// Create hitless integration - each client gets its own instance
|
||||
integration := newHitlessIntegrationWithTimeouts(config, defaultReadTimeout, defaultWriteTimeout)
|
||||
|
||||
// Push notification handlers are registered directly by the client
|
||||
// No separate registration needed in simplified implementation
|
||||
|
||||
internal.Logger.Printf(context.Background(), "hitless: initialized hitless upgrades for client")
|
||||
|
||||
return integration, nil
|
||||
}
|
||||
|
||||
// convertToInterfaceMap converts a typed map to interface{} map for public API
|
||||
func convertToInterfaceMap(input map[*pool.Conn]*hitless.ConnectionState) map[interface{}]interface{} {
|
||||
result := make(map[interface{}]interface{})
|
||||
for k, v := range input {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
197
hitless_config_defaults_test.go
Normal file
197
hitless_config_defaults_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHitlessUpgradeConfig_DefaultValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputConfig *HitlessUpgradeConfig
|
||||
expectedConfig *HitlessUpgradeConfig
|
||||
}{
|
||||
{
|
||||
name: "nil config should use all defaults",
|
||||
inputConfig: nil,
|
||||
expectedConfig: &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 60 * time.Second,
|
||||
CleanupInterval: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zero TransitionTimeout should use default",
|
||||
inputConfig: &HitlessUpgradeConfig{
|
||||
Enabled: false,
|
||||
TransitionTimeout: 0, // Zero value
|
||||
CleanupInterval: 45 * time.Second,
|
||||
},
|
||||
expectedConfig: &HitlessUpgradeConfig{
|
||||
Enabled: false,
|
||||
TransitionTimeout: 60 * time.Second, // Should use default
|
||||
CleanupInterval: 45 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zero CleanupInterval should use default",
|
||||
inputConfig: &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 90 * time.Second,
|
||||
CleanupInterval: 0, // Zero value
|
||||
},
|
||||
expectedConfig: &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 90 * time.Second,
|
||||
CleanupInterval: 30 * time.Second, // Should use default
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both timeouts zero should use defaults",
|
||||
inputConfig: &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 0, // Zero value
|
||||
CleanupInterval: 0, // Zero value
|
||||
},
|
||||
expectedConfig: &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 60 * time.Second, // Should use default
|
||||
CleanupInterval: 30 * time.Second, // Should use default
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all values set should be preserved",
|
||||
inputConfig: &HitlessUpgradeConfig{
|
||||
Enabled: false,
|
||||
TransitionTimeout: 120 * time.Second,
|
||||
CleanupInterval: 60 * time.Second,
|
||||
},
|
||||
expectedConfig: &HitlessUpgradeConfig{
|
||||
Enabled: false,
|
||||
TransitionTimeout: 120 * time.Second,
|
||||
CleanupInterval: 60 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test with a mock client that has hitless upgrades enabled
|
||||
opt := &Options{
|
||||
Addr: "127.0.0.1:6379",
|
||||
Protocol: 3,
|
||||
HitlessUpgrades: true,
|
||||
HitlessUpgradeConfig: tt.inputConfig,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Test the integration creation using the internal method directly
|
||||
// since initializeHitlessIntegration requires a push processor
|
||||
integration := newHitlessIntegrationWithTimeouts(tt.inputConfig, opt.ReadTimeout, opt.WriteTimeout)
|
||||
if integration == nil {
|
||||
t.Fatal("Integration should not be nil")
|
||||
}
|
||||
|
||||
// Get the config from the integration
|
||||
actualConfig := integration.GetConfig()
|
||||
if actualConfig == nil {
|
||||
t.Fatal("Config should not be nil")
|
||||
}
|
||||
|
||||
// Verify all fields match expected values
|
||||
if actualConfig.Enabled != tt.expectedConfig.Enabled {
|
||||
t.Errorf("Enabled: expected %v, got %v", tt.expectedConfig.Enabled, actualConfig.Enabled)
|
||||
}
|
||||
if actualConfig.TransitionTimeout != tt.expectedConfig.TransitionTimeout {
|
||||
t.Errorf("TransitionTimeout: expected %v, got %v", tt.expectedConfig.TransitionTimeout, actualConfig.TransitionTimeout)
|
||||
}
|
||||
if actualConfig.CleanupInterval != tt.expectedConfig.CleanupInterval {
|
||||
t.Errorf("CleanupInterval: expected %v, got %v", tt.expectedConfig.CleanupInterval, actualConfig.CleanupInterval)
|
||||
}
|
||||
|
||||
// Test UpdateConfig as well
|
||||
newConfig := &HitlessUpgradeConfig{
|
||||
Enabled: !tt.expectedConfig.Enabled,
|
||||
TransitionTimeout: 0, // Zero value should use default
|
||||
CleanupInterval: 0, // Zero value should use default
|
||||
}
|
||||
|
||||
err := integration.UpdateConfig(newConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update config: %v", err)
|
||||
}
|
||||
|
||||
// Verify updated config has defaults applied
|
||||
updatedConfig := integration.GetConfig()
|
||||
if updatedConfig.Enabled == tt.expectedConfig.Enabled {
|
||||
t.Error("Enabled should have been toggled")
|
||||
}
|
||||
if updatedConfig.TransitionTimeout != 60*time.Second {
|
||||
t.Errorf("TransitionTimeout should use default (60s), got %v", updatedConfig.TransitionTimeout)
|
||||
}
|
||||
if updatedConfig.CleanupInterval != 30*time.Second {
|
||||
t.Errorf("CleanupInterval should use default (30s), got %v", updatedConfig.CleanupInterval)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultHitlessUpgradeConfig(t *testing.T) {
|
||||
config := DefaultHitlessUpgradeConfig()
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("Default config should not be nil")
|
||||
}
|
||||
|
||||
if !config.Enabled {
|
||||
t.Error("Default config should have Enabled=true")
|
||||
}
|
||||
|
||||
if config.TransitionTimeout != 60*time.Second {
|
||||
t.Errorf("Default TransitionTimeout should be 60s, got %v", config.TransitionTimeout)
|
||||
}
|
||||
|
||||
if config.CleanupInterval != 30*time.Second {
|
||||
t.Errorf("Default CleanupInterval should be 30s, got %v", config.CleanupInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHitlessUpgradeConfig_ZeroValueHandling(t *testing.T) {
|
||||
// Test that zero values are properly handled in various scenarios
|
||||
|
||||
// Test 1: Partial config with some zero values
|
||||
partialConfig := &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
// TransitionTimeout and CleanupInterval are zero values
|
||||
}
|
||||
|
||||
integration := newHitlessIntegrationWithTimeouts(partialConfig, 3*time.Second, 3*time.Second)
|
||||
if integration == nil {
|
||||
t.Fatal("Integration should not be nil")
|
||||
}
|
||||
|
||||
config := integration.GetConfig()
|
||||
if config.TransitionTimeout == 0 {
|
||||
t.Error("Zero TransitionTimeout should have been replaced with default")
|
||||
}
|
||||
if config.CleanupInterval == 0 {
|
||||
t.Error("Zero CleanupInterval should have been replaced with default")
|
||||
}
|
||||
|
||||
// Test 2: Empty struct
|
||||
emptyConfig := &HitlessUpgradeConfig{}
|
||||
|
||||
integration2 := newHitlessIntegrationWithTimeouts(emptyConfig, 3*time.Second, 3*time.Second)
|
||||
if integration2 == nil {
|
||||
t.Fatal("Integration should not be nil")
|
||||
}
|
||||
|
||||
config2 := integration2.GetConfig()
|
||||
if config2.TransitionTimeout == 0 {
|
||||
t.Error("Zero TransitionTimeout in empty config should have been replaced with default")
|
||||
}
|
||||
if config2.CleanupInterval == 0 {
|
||||
t.Error("Zero CleanupInterval in empty config should have been replaced with default")
|
||||
}
|
||||
}
|
23
internal/hitless/README.md
Normal file
23
internal/hitless/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# Hitless Upgrade Package
|
||||
|
||||
This package implements hitless upgrade functionality for Redis cluster clients using the push notification architecture. It provides handlers for managing connection and pool state during Redis cluster upgrades.
|
||||
|
||||
## Quick Start
|
||||
|
||||
To enable hitless upgrades in your Redis client, simply set the configuration option:
|
||||
|
||||
```go
|
||||
import "github.com/redis/go-redis/v9"
|
||||
|
||||
// Enable hitless upgrades with a simple configuration option
|
||||
client := redis.NewClusterClient(&redis.ClusterOptions{
|
||||
Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"},
|
||||
Protocol: 3, // RESP3 required for push notifications
|
||||
HitlessUpgrades: true, // Enable hitless upgrades
|
||||
})
|
||||
defer client.Close()
|
||||
|
||||
// That's it! Use your client normally - hitless upgrades work automatically
|
||||
ctx := context.Background()
|
||||
client.Set(ctx, "key", "value", 0)
|
||||
```
|
200
internal/hitless/client_integration.go
Normal file
200
internal/hitless/client_integration.go
Normal file
@ -0,0 +1,200 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// ClientIntegrator provides integration between hitless upgrade handlers and Redis clients
|
||||
type ClientIntegrator struct {
|
||||
upgradeHandler *UpgradeHandler
|
||||
mu sync.RWMutex
|
||||
|
||||
// Simple atomic state for pool redirection
|
||||
isMoving int32 // atomic: 0 = not moving, 1 = moving
|
||||
newEndpoint string // only written during MOVING, read-only after
|
||||
}
|
||||
|
||||
// NewClientIntegrator creates a new client integrator with client timeout configuration
|
||||
func NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout time.Duration) *ClientIntegrator {
|
||||
return &ClientIntegrator{
|
||||
upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpgradeHandler returns the upgrade handler for direct access
|
||||
func (ci *ClientIntegrator) GetUpgradeHandler() *UpgradeHandler {
|
||||
return ci.upgradeHandler
|
||||
}
|
||||
|
||||
// HandlePushNotification is the main entry point for processing upgrade notifications
|
||||
func (ci *ClientIntegrator) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// Handle MOVING notifications for pool redirection
|
||||
if len(notification) > 0 {
|
||||
if notificationType, ok := notification[0].(string); ok && notificationType == "MOVING" {
|
||||
if len(notification) >= 3 {
|
||||
if newEndpoint, ok := notification[2].(string); ok {
|
||||
// Simple atomic state update - no locks needed
|
||||
ci.newEndpoint = newEndpoint
|
||||
atomic.StoreInt32(&ci.isMoving, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ci.upgradeHandler.HandlePushNotification(ctx, handlerCtx, notification)
|
||||
}
|
||||
|
||||
// Close shuts down the client integrator
|
||||
func (ci *ClientIntegrator) Close() error {
|
||||
ci.mu.Lock()
|
||||
defer ci.mu.Unlock()
|
||||
|
||||
// Reset atomic state
|
||||
atomic.StoreInt32(&ci.isMoving, 0)
|
||||
ci.newEndpoint = ""
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMoving returns true if the pool is currently moving to a new endpoint
|
||||
// Uses atomic read - no locks needed
|
||||
func (ci *ClientIntegrator) IsMoving() bool {
|
||||
return atomic.LoadInt32(&ci.isMoving) == 1
|
||||
}
|
||||
|
||||
// GetNewEndpoint returns the new endpoint if moving, empty string otherwise
|
||||
// Safe to read without locks since it's only written during MOVING
|
||||
func (ci *ClientIntegrator) GetNewEndpoint() string {
|
||||
if ci.IsMoving() {
|
||||
return ci.newEndpoint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// PushNotificationHandlerInterface defines the interface for push notification handlers
|
||||
// This implements the interface expected by the push notification system
|
||||
type PushNotificationHandlerInterface interface {
|
||||
HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error
|
||||
}
|
||||
|
||||
// Ensure ClientIntegrator implements the interface
|
||||
var _ PushNotificationHandlerInterface = (*ClientIntegrator)(nil)
|
||||
|
||||
// PoolRedirector provides pool redirection functionality for hitless upgrades
|
||||
type PoolRedirector struct {
|
||||
poolManager *PoolEndpointManager
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolRedirector creates a new pool redirector
|
||||
func NewPoolRedirector() *PoolRedirector {
|
||||
return &PoolRedirector{
|
||||
poolManager: NewPoolEndpointManager(),
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectPool redirects a connection pool to a new endpoint
|
||||
func (pr *PoolRedirector) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error {
|
||||
pr.mu.Lock()
|
||||
defer pr.mu.Unlock()
|
||||
|
||||
return pr.poolManager.RedirectPool(ctx, pooler, newEndpoint, timeout)
|
||||
}
|
||||
|
||||
// IsPoolRedirected checks if a pool is currently redirected
|
||||
func (pr *PoolRedirector) IsPoolRedirected(pooler pool.Pooler) bool {
|
||||
pr.mu.RLock()
|
||||
defer pr.mu.RUnlock()
|
||||
|
||||
return pr.poolManager.IsPoolRedirected(pooler)
|
||||
}
|
||||
|
||||
// GetRedirection returns redirection information for a pool
|
||||
func (pr *PoolRedirector) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) {
|
||||
pr.mu.RLock()
|
||||
defer pr.mu.RUnlock()
|
||||
|
||||
return pr.poolManager.GetRedirection(pooler)
|
||||
}
|
||||
|
||||
// Close shuts down the pool redirector
|
||||
func (pr *PoolRedirector) Close() error {
|
||||
pr.mu.Lock()
|
||||
defer pr.mu.Unlock()
|
||||
|
||||
// Clean up all redirections
|
||||
ctx := context.Background()
|
||||
pr.poolManager.CleanupExpiredRedirections(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectionStateTracker tracks connection states during upgrades
|
||||
type ConnectionStateTracker struct {
|
||||
upgradeHandler *UpgradeHandler
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewConnectionStateTracker creates a new connection state tracker with timeout configuration
|
||||
func NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout time.Duration) *ConnectionStateTracker {
|
||||
return &ConnectionStateTracker{
|
||||
upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||
func (cst *ConnectionStateTracker) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||
cst.mu.RLock()
|
||||
defer cst.mu.RUnlock()
|
||||
|
||||
return cst.upgradeHandler.IsConnectionTransitioning(conn)
|
||||
}
|
||||
|
||||
// GetConnectionState returns the current state of a connection
|
||||
func (cst *ConnectionStateTracker) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||
cst.mu.RLock()
|
||||
defer cst.mu.RUnlock()
|
||||
|
||||
return cst.upgradeHandler.GetConnectionState(conn)
|
||||
}
|
||||
|
||||
// CleanupConnection removes tracking for a connection
|
||||
func (cst *ConnectionStateTracker) CleanupConnection(conn *pool.Conn) {
|
||||
cst.mu.Lock()
|
||||
defer cst.mu.Unlock()
|
||||
|
||||
cst.upgradeHandler.CleanupConnection(conn)
|
||||
}
|
||||
|
||||
// Close shuts down the connection state tracker
|
||||
func (cst *ConnectionStateTracker) Close() error {
|
||||
cst.mu.Lock()
|
||||
defer cst.mu.Unlock()
|
||||
|
||||
// Clean up all expired states
|
||||
cst.upgradeHandler.CleanupExpiredStates()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HitlessUpgradeConfig provides configuration for hitless upgrades
|
||||
type HitlessUpgradeConfig struct {
|
||||
Enabled bool
|
||||
TransitionTimeout time.Duration
|
||||
CleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// DefaultHitlessUpgradeConfig returns default configuration for hitless upgrades
|
||||
func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig {
|
||||
return &HitlessUpgradeConfig{
|
||||
Enabled: true,
|
||||
TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections
|
||||
CleanupInterval: 30 * time.Second, // How often to clean up expired states
|
||||
}
|
||||
}
|
205
internal/hitless/pool_manager.go
Normal file
205
internal/hitless/pool_manager.go
Normal file
@ -0,0 +1,205 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// PoolEndpointManager manages endpoint transitions for connection pools during hitless upgrades.
|
||||
// It provides functionality to redirect new connections to new endpoints while maintaining
|
||||
// existing connections until they can be gracefully transitioned.
|
||||
type PoolEndpointManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Map of pools to their endpoint redirections
|
||||
redirections map[interface{}]*EndpointRedirection
|
||||
|
||||
// Original dialers for pools (to restore after transition)
|
||||
originalDialers map[interface{}]func(context.Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
// EndpointRedirection represents an active endpoint redirection
|
||||
type EndpointRedirection struct {
|
||||
OriginalEndpoint string
|
||||
NewEndpoint string
|
||||
StartTime time.Time
|
||||
Timeout time.Duration
|
||||
|
||||
// Statistics
|
||||
NewConnections int64
|
||||
FailedConnections int64
|
||||
}
|
||||
|
||||
// NewPoolEndpointManager creates a new pool endpoint manager
|
||||
func NewPoolEndpointManager() *PoolEndpointManager {
|
||||
return &PoolEndpointManager{
|
||||
redirections: make(map[interface{}]*EndpointRedirection),
|
||||
originalDialers: make(map[interface{}]func(context.Context) (net.Conn, error)),
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectPool redirects new connections from a pool to a new endpoint
|
||||
func (m *PoolEndpointManager) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if pool is already being redirected
|
||||
if _, exists := m.redirections[pooler]; exists {
|
||||
return fmt.Errorf("pool is already being redirected")
|
||||
}
|
||||
|
||||
// Get the current dialer from the pool
|
||||
connPool, ok := pooler.(*pool.ConnPool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported pool type: %T", pooler)
|
||||
}
|
||||
|
||||
// Store original dialer
|
||||
originalDialer := m.getPoolDialer(connPool)
|
||||
if originalDialer == nil {
|
||||
return fmt.Errorf("could not get original dialer from pool")
|
||||
}
|
||||
|
||||
m.originalDialers[pooler] = originalDialer
|
||||
|
||||
// Create new dialer that connects to the new endpoint
|
||||
newDialer := m.createRedirectDialer(ctx, newEndpoint, originalDialer)
|
||||
|
||||
// Replace the pool's dialer
|
||||
if err := m.setPoolDialer(connPool, newDialer); err != nil {
|
||||
delete(m.originalDialers, pooler)
|
||||
return fmt.Errorf("failed to set new dialer: %w", err)
|
||||
}
|
||||
|
||||
// Record the redirection
|
||||
m.redirections[pooler] = &EndpointRedirection{
|
||||
OriginalEndpoint: m.extractEndpointFromDialer(originalDialer),
|
||||
NewEndpoint: newEndpoint,
|
||||
StartTime: time.Now(),
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: redirected pool to new endpoint %s", newEndpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsPoolRedirected checks if a pool is currently being redirected
|
||||
func (m *PoolEndpointManager) IsPoolRedirected(pooler pool.Pooler) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
_, exists := m.redirections[pooler]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetRedirection returns redirection information for a pool
|
||||
func (m *PoolEndpointManager) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
redirection, exists := m.redirections[pooler]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
redirectionCopy := *redirection
|
||||
return &redirectionCopy, true
|
||||
}
|
||||
|
||||
// CleanupExpiredRedirections removes expired redirections
|
||||
func (m *PoolEndpointManager) CleanupExpiredRedirections(ctx context.Context) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for pooler, redirection := range m.redirections {
|
||||
if now.Sub(redirection.StartTime) > redirection.Timeout {
|
||||
// TODO: Here we should decide if we need to failback to the original dialer,
|
||||
// i.e. if the new endpoint did not produce any active connections.
|
||||
delete(m.redirections, pooler)
|
||||
delete(m.originalDialers, pooler)
|
||||
internal.Logger.Printf(ctx, "hitless: cleaned up expired redirection for pool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createRedirectDialer creates a dialer that connects to the new endpoint
|
||||
func (m *PoolEndpointManager) createRedirectDialer(ctx context.Context, newEndpoint string, originalDialer func(context.Context) (net.Conn, error)) func(context.Context) (net.Conn, error) {
|
||||
return func(dialCtx context.Context) (net.Conn, error) {
|
||||
// Try to connect to the new endpoint
|
||||
conn, err := net.DialTimeout("tcp", newEndpoint, 10*time.Second)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "hitless: failed to connect to new endpoint %s: %v", newEndpoint, err)
|
||||
|
||||
// Fallback to original dialer
|
||||
return originalDialer(dialCtx)
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: successfully connected to new endpoint %s", newEndpoint)
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// getPoolDialer extracts the dialer from a connection pool
|
||||
func (m *PoolEndpointManager) getPoolDialer(connPool *pool.ConnPool) func(context.Context) (net.Conn, error) {
|
||||
return connPool.GetDialer()
|
||||
}
|
||||
|
||||
// setPoolDialer sets a new dialer for a connection pool
|
||||
func (m *PoolEndpointManager) setPoolDialer(connPool *pool.ConnPool, dialer func(context.Context) (net.Conn, error)) error {
|
||||
return connPool.SetDialer(dialer)
|
||||
}
|
||||
|
||||
// extractEndpointFromDialer extracts the endpoint address from a dialer
|
||||
func (m *PoolEndpointManager) extractEndpointFromDialer(dialer func(context.Context) (net.Conn, error)) string {
|
||||
// Try to extract endpoint by making a test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dialer(ctx)
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if conn.RemoteAddr() != nil {
|
||||
return conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// GetActiveRedirections returns all active redirections
|
||||
func (m *PoolEndpointManager) GetActiveRedirections() map[interface{}]*EndpointRedirection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Create copies to avoid race conditions
|
||||
redirections := make(map[interface{}]*EndpointRedirection)
|
||||
for pooler, redirection := range m.redirections {
|
||||
redirectionCopy := *redirection
|
||||
redirections[pooler] = &redirectionCopy
|
||||
}
|
||||
|
||||
return redirections
|
||||
}
|
||||
|
||||
// UpdateRedirectionStats updates statistics for a redirection
|
||||
func (m *PoolEndpointManager) UpdateRedirectionStats(pooler pool.Pooler, newConnections, failedConnections int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if redirection, exists := m.redirections[pooler]; exists {
|
||||
redirection.NewConnections += newConnections
|
||||
redirection.FailedConnections += failedConnections
|
||||
}
|
||||
}
|
309
internal/hitless/redis_integration.go
Normal file
309
internal/hitless/redis_integration.go
Normal file
@ -0,0 +1,309 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// UpgradeStatus represents the current status of all upgrade operations
|
||||
type UpgradeStatus struct {
|
||||
ConnectionStates map[*pool.Conn]*ConnectionState
|
||||
IsMoving bool
|
||||
NewEndpoint string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// UpgradeStatistics provides statistics about upgrade operations
|
||||
type UpgradeStatistics struct {
|
||||
ActiveConnections int
|
||||
IsMoving bool
|
||||
MigratingConnections int
|
||||
FailingOverConnections int
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// RedisClientIntegration provides complete hitless upgrade integration for Redis clients
|
||||
type RedisClientIntegration struct {
|
||||
clientIntegrator *ClientIntegrator
|
||||
connectionStateTracker *ConnectionStateTracker
|
||||
config *HitlessUpgradeConfig
|
||||
|
||||
mu sync.RWMutex
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewRedisClientIntegration creates a new Redis client integration for hitless upgrades
|
||||
// This is used internally by the main hitless.go package
|
||||
func NewRedisClientIntegration(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *RedisClientIntegration {
|
||||
// Start with defaults
|
||||
defaults := DefaultHitlessUpgradeConfig()
|
||||
|
||||
// If config is nil, use all defaults
|
||||
if config == nil {
|
||||
config = defaults
|
||||
} else {
|
||||
// Ensure all fields are set with defaults if they are zero values
|
||||
if config.TransitionTimeout == 0 {
|
||||
config = &HitlessUpgradeConfig{
|
||||
Enabled: config.Enabled,
|
||||
TransitionTimeout: defaults.TransitionTimeout,
|
||||
CleanupInterval: config.CleanupInterval,
|
||||
}
|
||||
}
|
||||
if config.CleanupInterval == 0 {
|
||||
config = &HitlessUpgradeConfig{
|
||||
Enabled: config.Enabled,
|
||||
TransitionTimeout: config.TransitionTimeout,
|
||||
CleanupInterval: defaults.CleanupInterval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &RedisClientIntegration{
|
||||
clientIntegrator: NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout),
|
||||
connectionStateTracker: NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout),
|
||||
config: config,
|
||||
enabled: config.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// EnableHitlessUpgrades enables hitless upgrade functionality
|
||||
func (rci *RedisClientIntegration) EnableHitlessUpgrades() {
|
||||
rci.mu.Lock()
|
||||
defer rci.mu.Unlock()
|
||||
rci.enabled = true
|
||||
}
|
||||
|
||||
// DisableHitlessUpgrades disables hitless upgrade functionality
|
||||
func (rci *RedisClientIntegration) DisableHitlessUpgrades() {
|
||||
rci.mu.Lock()
|
||||
defer rci.mu.Unlock()
|
||||
rci.enabled = false
|
||||
}
|
||||
|
||||
// IsEnabled returns whether hitless upgrades are enabled
|
||||
func (rci *RedisClientIntegration) IsEnabled() bool {
|
||||
rci.mu.RLock()
|
||||
defer rci.mu.RUnlock()
|
||||
return rci.enabled
|
||||
}
|
||||
|
||||
// No client registration needed - each client has its own hitless integration instance
|
||||
|
||||
// HandlePushNotification processes push notifications for hitless upgrades
|
||||
func (rci *RedisClientIntegration) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if !rci.IsEnabled() {
|
||||
// If disabled, just log and return without processing
|
||||
internal.Logger.Printf(ctx, "hitless: received notification but hitless upgrades are disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
return rci.clientIntegrator.HandlePushNotification(ctx, handlerCtx, notification)
|
||||
}
|
||||
|
||||
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||
func (rci *RedisClientIntegration) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||
if !rci.IsEnabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
return rci.connectionStateTracker.IsConnectionTransitioning(conn)
|
||||
}
|
||||
|
||||
// GetConnectionState returns the current state of a connection
|
||||
func (rci *RedisClientIntegration) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||
if !rci.IsEnabled() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return rci.connectionStateTracker.GetConnectionState(conn)
|
||||
}
|
||||
|
||||
// GetUpgradeStatus returns comprehensive status of all ongoing upgrades
|
||||
func (rci *RedisClientIntegration) GetUpgradeStatus() *UpgradeStatus {
|
||||
connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions()
|
||||
|
||||
return &UpgradeStatus{
|
||||
ConnectionStates: connStates,
|
||||
IsMoving: rci.clientIntegrator.IsMoving(),
|
||||
NewEndpoint: rci.clientIntegrator.GetNewEndpoint(),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpgradeStatistics returns statistics about upgrade operations
|
||||
func (rci *RedisClientIntegration) GetUpgradeStatistics() *UpgradeStatistics {
|
||||
connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions()
|
||||
|
||||
stats := &UpgradeStatistics{
|
||||
ActiveConnections: len(connStates),
|
||||
IsMoving: rci.clientIntegrator.IsMoving(),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Count by type
|
||||
stats.MigratingConnections = 0
|
||||
stats.FailingOverConnections = 0
|
||||
for _, state := range connStates {
|
||||
switch state.TransitionType {
|
||||
case "MIGRATING":
|
||||
stats.MigratingConnections++
|
||||
case "FAILING_OVER":
|
||||
stats.FailingOverConnections++
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetConnectionTimeout returns the appropriate timeout for a connection
|
||||
// If the connection is transitioning (MIGRATING/FAILING_OVER), returns the longer TransitionTimeout
|
||||
// Otherwise returns the provided defaultTimeout
|
||||
func (rci *RedisClientIntegration) GetConnectionTimeout(conn *pool.Conn, defaultTimeout time.Duration) time.Duration {
|
||||
if !rci.IsEnabled() {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
// Check if connection is transitioning
|
||||
if rci.connectionStateTracker.IsConnectionTransitioning(conn) {
|
||||
// Use longer timeout for transitioning connections
|
||||
return rci.config.TransitionTimeout
|
||||
}
|
||||
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
// GetConnectionTimeouts returns both read and write timeouts for a connection
|
||||
func (rci *RedisClientIntegration) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) {
|
||||
if !rci.IsEnabled() {
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
}
|
||||
|
||||
// Use the upgrade handler to get appropriate timeouts
|
||||
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||
return upgradeHandler.GetConnectionTimeouts(conn, defaultReadTimeout, defaultWriteTimeout, rci.config.TransitionTimeout)
|
||||
}
|
||||
|
||||
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||
func (rci *RedisClientIntegration) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
|
||||
if !rci.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Use the upgrade handler to mark connection as blocking
|
||||
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||
upgradeHandler.MarkConnectionAsBlocking(conn, isBlocking)
|
||||
}
|
||||
|
||||
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||
func (rci *RedisClientIntegration) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
|
||||
if !rci.IsEnabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use the upgrade handler to check if connection is marked for closing
|
||||
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||
return upgradeHandler.IsConnectionMarkedForClosing(conn)
|
||||
}
|
||||
|
||||
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||
func (rci *RedisClientIntegration) ShouldRedirectBlockingConnection(conn *pool.Conn) (bool, string) {
|
||||
if !rci.IsEnabled() {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Check client integrator's atomic state for pool-level redirection
|
||||
if rci.clientIntegrator.IsMoving() {
|
||||
return true, rci.clientIntegrator.GetNewEndpoint()
|
||||
}
|
||||
|
||||
// Check specific connection state
|
||||
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||
return upgradeHandler.ShouldRedirectBlockingConnection(conn, rci.clientIntegrator)
|
||||
}
|
||||
|
||||
// CleanupConnection removes tracking for a connection (called when connection is closed)
|
||||
func (rci *RedisClientIntegration) CleanupConnection(conn *pool.Conn) {
|
||||
if rci.IsEnabled() {
|
||||
rci.connectionStateTracker.CleanupConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupPool removed - no pool state to clean up
|
||||
|
||||
// Close shuts down the Redis client integration
|
||||
func (rci *RedisClientIntegration) Close() error {
|
||||
rci.mu.Lock()
|
||||
defer rci.mu.Unlock()
|
||||
|
||||
var firstErr error
|
||||
|
||||
// Close all components
|
||||
if err := rci.clientIntegrator.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
|
||||
// poolRedirector removed in simplified implementation
|
||||
|
||||
if err := rci.connectionStateTracker.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
|
||||
rci.enabled = false
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// GetConfig returns the current configuration
|
||||
func (rci *RedisClientIntegration) GetConfig() *HitlessUpgradeConfig {
|
||||
rci.mu.RLock()
|
||||
defer rci.mu.RUnlock()
|
||||
|
||||
// Return a copy to prevent modification
|
||||
configCopy := *rci.config
|
||||
return &configCopy
|
||||
}
|
||||
|
||||
// UpdateConfig updates the configuration
|
||||
func (rci *RedisClientIntegration) UpdateConfig(config *HitlessUpgradeConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("config cannot be nil")
|
||||
}
|
||||
|
||||
rci.mu.Lock()
|
||||
defer rci.mu.Unlock()
|
||||
|
||||
// Start with defaults for any zero values
|
||||
defaults := DefaultHitlessUpgradeConfig()
|
||||
|
||||
// Ensure all fields are set with defaults if they are zero values
|
||||
enabled := config.Enabled
|
||||
transitionTimeout := config.TransitionTimeout
|
||||
cleanupInterval := config.CleanupInterval
|
||||
|
||||
// Apply defaults for zero values
|
||||
if transitionTimeout == 0 {
|
||||
transitionTimeout = defaults.TransitionTimeout
|
||||
}
|
||||
if cleanupInterval == 0 {
|
||||
cleanupInterval = defaults.CleanupInterval
|
||||
}
|
||||
|
||||
// Create properly configured config
|
||||
finalConfig := &HitlessUpgradeConfig{
|
||||
Enabled: enabled,
|
||||
TransitionTimeout: transitionTimeout,
|
||||
CleanupInterval: cleanupInterval,
|
||||
}
|
||||
|
||||
rci.config = finalConfig
|
||||
rci.enabled = finalConfig.Enabled
|
||||
|
||||
return nil
|
||||
}
|
500
internal/hitless/upgrade_handler.go
Normal file
500
internal/hitless/upgrade_handler.go
Normal file
@ -0,0 +1,500 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// UpgradeHandler handles hitless upgrade push notifications for Redis cluster upgrades.
|
||||
// It implements different strategies based on notification type:
|
||||
// - MOVING: Changes pool state to use new endpoint for future connections
|
||||
// - mark existing connections for closing, handle pubsub to change underlying connection, change pool dialer with the new endpoint
|
||||
//
|
||||
// - MIGRATING/FAILING_OVER: Marks specific connection as in transition
|
||||
// - relaxing timeouts
|
||||
//
|
||||
// - MIGRATED/FAILED_OVER: Clears transition state for specific connection
|
||||
// - return to original timeouts
|
||||
type UpgradeHandler struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Connection-specific state for MIGRATING/FAILING_OVER notifications
|
||||
connectionStates map[*pool.Conn]*ConnectionState
|
||||
|
||||
// Pool-level state removed - using atomic state in ClusterUpgradeManager instead
|
||||
|
||||
// Client configuration for getting default timeouts
|
||||
defaultReadTimeout time.Duration
|
||||
defaultWriteTimeout time.Duration
|
||||
}
|
||||
|
||||
// ConnectionState tracks the state of a specific connection during upgrades
|
||||
type ConnectionState struct {
|
||||
IsTransitioning bool
|
||||
TransitionType string
|
||||
StartTime time.Time
|
||||
ShardID string
|
||||
TimeoutSeconds int
|
||||
|
||||
// Timeout management
|
||||
OriginalReadTimeout time.Duration // Original read timeout
|
||||
OriginalWriteTimeout time.Duration // Original write timeout
|
||||
|
||||
// MOVING state specific
|
||||
MarkedForClosing bool // Connection should be closed after current commands
|
||||
IsBlocking bool // Connection has blocking commands
|
||||
NewEndpoint string // New endpoint for blocking commands
|
||||
LastCommandTime time.Time // When the last command was sent
|
||||
}
|
||||
|
||||
// PoolState removed - using atomic state in ClusterUpgradeManager instead
|
||||
|
||||
// NewUpgradeHandler creates a new hitless upgrade handler with client timeout configuration
|
||||
func NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout time.Duration) *UpgradeHandler {
|
||||
return &UpgradeHandler{
|
||||
connectionStates: make(map[*pool.Conn]*ConnectionState),
|
||||
defaultReadTimeout: defaultReadTimeout,
|
||||
defaultWriteTimeout: defaultWriteTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectionTimeouts returns the appropriate read and write timeouts for a connection
|
||||
// If the connection is transitioning, returns increased timeouts
|
||||
func (h *UpgradeHandler) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout, transitionTimeout time.Duration) (time.Duration, time.Duration) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists || !state.IsTransitioning {
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
}
|
||||
|
||||
// For transitioning connections (MIGRATING/FAILING_OVER), use longer timeouts
|
||||
switch state.TransitionType {
|
||||
case "MIGRATING", "FAILING_OVER":
|
||||
return transitionTimeout, transitionTimeout
|
||||
case "MOVING":
|
||||
// For MOVING connections, use default timeouts but mark for special handling
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
default:
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// MarkConnectionForClosing marks a connection to be closed after current commands complete
|
||||
func (h *UpgradeHandler) MarkConnectionForClosing(conn *pool.Conn, newEndpoint string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists {
|
||||
state = &ConnectionState{
|
||||
IsTransitioning: true,
|
||||
TransitionType: "MOVING",
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
h.connectionStates[conn] = state
|
||||
}
|
||||
|
||||
state.MarkedForClosing = true
|
||||
state.NewEndpoint = newEndpoint
|
||||
state.LastCommandTime = time.Now()
|
||||
}
|
||||
|
||||
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||
func (h *UpgradeHandler) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
return exists && state.MarkedForClosing
|
||||
}
|
||||
|
||||
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||
func (h *UpgradeHandler) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists {
|
||||
state = &ConnectionState{
|
||||
IsTransitioning: false,
|
||||
}
|
||||
h.connectionStates[conn] = state
|
||||
}
|
||||
|
||||
state.IsBlocking = isBlocking
|
||||
}
|
||||
|
||||
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||
// Uses client integrator's atomic state for pool-level checks - minimal locking
|
||||
func (h *UpgradeHandler) ShouldRedirectBlockingConnection(conn *pool.Conn, clientIntegrator interface{}) (bool, string) {
|
||||
if conn != nil {
|
||||
// Check specific connection - need lock only for connection state
|
||||
h.mu.RLock()
|
||||
state, exists := h.connectionStates[conn]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if exists && state.IsBlocking && state.TransitionType == "MOVING" && state.NewEndpoint != "" {
|
||||
return true, state.NewEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
// Check client integrator's atomic state - no locks needed
|
||||
if ci, ok := clientIntegrator.(*ClientIntegrator); ok && ci != nil && ci.IsMoving() {
|
||||
return true, ci.GetNewEndpoint()
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// ShouldRedirectNewBlockingConnection removed - functionality merged into ShouldRedirectBlockingConnection
|
||||
|
||||
// HandlePushNotification processes hitless upgrade push notifications
|
||||
func (h *UpgradeHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
return fmt.Errorf("hitless: empty notification received")
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("hitless: notification type is not a string: %T", notification[0])
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: processing %s notification with %d elements", notificationType, len(notification))
|
||||
|
||||
switch notificationType {
|
||||
case "MOVING":
|
||||
return h.handleMovingNotification(ctx, handlerCtx, notification)
|
||||
case "MIGRATING":
|
||||
return h.handleMigratingNotification(ctx, handlerCtx, notification)
|
||||
case "MIGRATED":
|
||||
return h.handleMigratedNotification(ctx, handlerCtx, notification)
|
||||
case "FAILING_OVER":
|
||||
return h.handleFailingOverNotification(ctx, handlerCtx, notification)
|
||||
case "FAILED_OVER":
|
||||
return h.handleFailedOverNotification(ctx, handlerCtx, notification)
|
||||
default:
|
||||
internal.Logger.Printf(ctx, "hitless: unknown notification type: %s", notificationType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleMovingNotification processes MOVING notifications that affect the entire pool
|
||||
// Format: ["MOVING", time_seconds, "new_endpoint"]
|
||||
func (h *UpgradeHandler) handleMovingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return fmt.Errorf("hitless: MOVING notification requires at least 3 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
// Parse timeout
|
||||
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse timeout for MOVING notification: %w", err)
|
||||
}
|
||||
|
||||
// Parse new endpoint
|
||||
newEndpoint, ok := notification[2].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("hitless: new endpoint is not a string: %T", notification[2])
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: MOVING notification - endpoint will move to %s in %d seconds", newEndpoint, timeoutSeconds)
|
||||
h.mu.Lock()
|
||||
// Mark all existing connections for closing after current commands complete
|
||||
for _, state := range h.connectionStates {
|
||||
state.MarkedForClosing = true
|
||||
state.NewEndpoint = newEndpoint
|
||||
state.TransitionType = "MOVING"
|
||||
state.LastCommandTime = time.Now()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: marked existing connections for closing, new blocking commands will use %s", newEndpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Removed complex helper methods - simplified to direct inline logic
|
||||
|
||||
// handleMigratingNotification processes MIGRATING notifications for specific connections
|
||||
// Format: ["MIGRATING", time_seconds, shard_id]
|
||||
func (h *UpgradeHandler) handleMigratingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return fmt.Errorf("hitless: MIGRATING notification requires at least 3 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse timeout for MIGRATING notification: %w", err)
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATING notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: MIGRATING notification - shard %s will migrate in %d seconds on connection %p", shardID, timeoutSeconds, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Store original timeouts if not already stored
|
||||
var originalReadTimeout, originalWriteTimeout time.Duration
|
||||
if existingState, exists := h.connectionStates[conn]; exists {
|
||||
originalReadTimeout = existingState.OriginalReadTimeout
|
||||
originalWriteTimeout = existingState.OriginalWriteTimeout
|
||||
} else {
|
||||
// Get default timeouts from client configuration
|
||||
originalReadTimeout = h.defaultReadTimeout
|
||||
originalWriteTimeout = h.defaultWriteTimeout
|
||||
}
|
||||
|
||||
h.connectionStates[conn] = &ConnectionState{
|
||||
IsTransitioning: true,
|
||||
TransitionType: "MIGRATING",
|
||||
StartTime: time.Now(),
|
||||
ShardID: shardID,
|
||||
TimeoutSeconds: timeoutSeconds,
|
||||
OriginalReadTimeout: originalReadTimeout,
|
||||
OriginalWriteTimeout: originalWriteTimeout,
|
||||
LastCommandTime: time.Now(),
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: connection %p marked as MIGRATING with increased timeouts", conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigratedNotification processes MIGRATED notifications for specific connections
|
||||
// Format: ["MIGRATED", shard_id]
|
||||
func (h *UpgradeHandler) handleMigratedNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 2 {
|
||||
return fmt.Errorf("hitless: MIGRATED notification requires at least 2 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATED notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: MIGRATED notification - shard %s migration completed on connection %p", shardID, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Clear the transitioning state for this connection and restore original timeouts
|
||||
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "MIGRATING" && state.ShardID == shardID {
|
||||
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
|
||||
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
|
||||
|
||||
// In a real implementation, this would restore the connection's original timeouts
|
||||
// For now, we'll just log and delete the state
|
||||
delete(h.connectionStates, conn)
|
||||
internal.Logger.Printf(ctx, "hitless: cleared MIGRATING state and restored timeouts for connection %p", conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOverNotification processes FAILING_OVER notifications for specific connections
|
||||
// Format: ["FAILING_OVER", time_seconds, shard_id]
|
||||
func (h *UpgradeHandler) handleFailingOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return fmt.Errorf("hitless: FAILING_OVER notification requires at least 3 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse timeout for FAILING_OVER notification: %w", err)
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for FAILING_OVER notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: FAILING_OVER notification - shard %s will failover in %d seconds on connection %p", shardID, timeoutSeconds, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Store original timeouts if not already stored
|
||||
var originalReadTimeout, originalWriteTimeout time.Duration
|
||||
if existingState, exists := h.connectionStates[conn]; exists {
|
||||
originalReadTimeout = existingState.OriginalReadTimeout
|
||||
originalWriteTimeout = existingState.OriginalWriteTimeout
|
||||
} else {
|
||||
// Get default timeouts from client configuration
|
||||
originalReadTimeout = h.defaultReadTimeout
|
||||
originalWriteTimeout = h.defaultWriteTimeout
|
||||
}
|
||||
|
||||
h.connectionStates[conn] = &ConnectionState{
|
||||
IsTransitioning: true,
|
||||
TransitionType: "FAILING_OVER",
|
||||
StartTime: time.Now(),
|
||||
ShardID: shardID,
|
||||
TimeoutSeconds: timeoutSeconds,
|
||||
OriginalReadTimeout: originalReadTimeout,
|
||||
OriginalWriteTimeout: originalWriteTimeout,
|
||||
LastCommandTime: time.Now(),
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: connection %p marked as FAILING_OVER with increased timeouts", conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOverNotification processes FAILED_OVER notifications for specific connections
|
||||
// Format: ["FAILED_OVER", shard_id]
|
||||
func (h *UpgradeHandler) handleFailedOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 2 {
|
||||
return fmt.Errorf("hitless: FAILED_OVER notification requires at least 2 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for FAILED_OVER notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: FAILED_OVER notification - shard %s failover completed on connection %p", shardID, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Clear the transitioning state for this connection and restore original timeouts
|
||||
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "FAILING_OVER" && state.ShardID == shardID {
|
||||
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
|
||||
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
|
||||
|
||||
// In a real implementation, this would restore the connection's original timeouts
|
||||
// For now, we'll just log and delete the state
|
||||
delete(h.connectionStates, conn)
|
||||
internal.Logger.Printf(ctx, "hitless: cleared FAILING_OVER state and restored timeouts for connection %p", conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseTimeoutSeconds parses timeout value from notification
|
||||
func (h *UpgradeHandler) parseTimeoutSeconds(value interface{}) (int, error) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return int(v), nil
|
||||
case int:
|
||||
return v, nil
|
||||
case string:
|
||||
return strconv.Atoi(v)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported timeout type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
// parseShardID parses shard ID from notification
|
||||
func (h *UpgradeHandler) parseShardID(value interface{}) (string, error) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10), nil
|
||||
case int:
|
||||
return strconv.Itoa(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported shard ID type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectionState returns the current state of a connection
|
||||
func (h *UpgradeHandler) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
stateCopy := *state
|
||||
return &stateCopy, true
|
||||
}
|
||||
|
||||
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||
func (h *UpgradeHandler) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
return exists && state.IsTransitioning
|
||||
}
|
||||
|
||||
// IsPoolMoving and GetNewEndpoint removed - using atomic state in ClusterUpgradeManager instead
|
||||
|
||||
// CleanupExpiredStates removes expired connection and pool states
|
||||
func (h *UpgradeHandler) CleanupExpiredStates() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Cleanup expired connection states
|
||||
for conn, state := range h.connectionStates {
|
||||
timeout := time.Duration(state.TimeoutSeconds) * time.Second
|
||||
if now.Sub(state.StartTime) > timeout {
|
||||
delete(h.connectionStates, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Pool state cleanup removed - using atomic state in ClusterUpgradeManager instead
|
||||
}
|
||||
|
||||
// CleanupConnection removes state for a specific connection (called when connection is closed)
|
||||
func (h *UpgradeHandler) CleanupConnection(conn *pool.Conn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
delete(h.connectionStates, conn)
|
||||
}
|
||||
|
||||
// GetActiveTransitions returns information about all active connection transitions
|
||||
func (h *UpgradeHandler) GetActiveTransitions() map[*pool.Conn]*ConnectionState {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Create copies to avoid race conditions
|
||||
connStates := make(map[*pool.Conn]*ConnectionState)
|
||||
for conn, state := range h.connectionStates {
|
||||
stateCopy := *state
|
||||
connStates[conn] = &stateCopy
|
||||
}
|
||||
|
||||
return connStates
|
||||
}
|
@ -565,3 +565,24 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
||||
cn.SetUsedAt(now)
|
||||
return true
|
||||
}
|
||||
|
||||
// GetDialer returns the current dialer function for the pool
|
||||
func (p *ConnPool) GetDialer() func(context.Context) (net.Conn, error) {
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
return p.cfg.Dialer
|
||||
}
|
||||
|
||||
// SetDialer sets a new dialer function for the pool
|
||||
// This is used for hitless upgrades to redirect new connections to new endpoints
|
||||
func (p *ConnPool) SetDialer(dialer func(context.Context) (net.Conn, error)) error {
|
||||
if p.closed() {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
|
||||
p.cfg.Dialer = dialer
|
||||
return nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package proto
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@ -215,9 +216,9 @@ func TestPeekPushNotificationName(t *testing.T) {
|
||||
// This is acceptable behavior for malformed input
|
||||
name, err := reader.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v", tc.name, err)
|
||||
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data)
|
||||
} else {
|
||||
t.Logf("PeekPushNotificationName returned '%s' for corrupted data %s", name, tc.name)
|
||||
t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -293,15 +294,27 @@ func TestPeekPushNotificationName(t *testing.T) {
|
||||
func createValidPushNotification(notificationName, data string) *bytes.Buffer {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
simpleOrString := rand.Intn(2) == 0
|
||||
|
||||
if data == "" {
|
||||
|
||||
// Single element notification
|
||||
buf.WriteString(">1\r\n")
|
||||
if simpleOrString {
|
||||
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
}
|
||||
} else {
|
||||
// Two element notification
|
||||
buf.WriteString(">2\r\n")
|
||||
if simpleOrString {
|
||||
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
|
||||
buf.WriteString(fmt.Sprintf("+%s\r\n", data))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(data), data))
|
||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
|
@ -116,26 +116,55 @@ func (r *Reader) PeekPushNotificationName() (string, error) {
|
||||
if buf[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
// remove push notification type and length
|
||||
buf = buf[2:]
|
||||
|
||||
if len(buf) < 3 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
// remove push notification type
|
||||
buf = buf[1:]
|
||||
// remove first line - e.g. >2\r\n
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
// next line should be $<length><string>\r\n or +<length><string>\r\n
|
||||
// should have the type of the push notification name and it's length
|
||||
if buf[0] != RespString {
|
||||
if buf[0] != RespString && buf[0] != RespStatus {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
typeOfName := buf[0]
|
||||
// remove the type of the push notification name
|
||||
buf = buf[1:]
|
||||
if typeOfName == RespString {
|
||||
// remove the length of the string
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
// skip the length of the string
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
// keep only the notification name
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
@ -143,6 +172,7 @@ func (r *Reader) PeekPushNotificationName() (string, error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return util.BytesToString(buf), nil
|
||||
}
|
||||
|
||||
|
12
options.go
12
options.go
@ -224,6 +224,18 @@ type Options struct {
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// HitlessUpgrades enables hitless upgrade functionality for cluster upgrades.
|
||||
// Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// When enabled, the client will automatically handle cluster upgrade notifications
|
||||
// and manage connection/pool state transitions seamlessly.
|
||||
//
|
||||
// default: false
|
||||
HitlessUpgrades bool
|
||||
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// If nil, default configuration will be used when HitlessUpgrades is true.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
func (opt *Options) init() {
|
||||
|
@ -110,6 +110,18 @@ type ClusterOptions struct {
|
||||
|
||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||
UnstableResp3 bool
|
||||
|
||||
// HitlessUpgrades enables hitless upgrade functionality for cluster upgrades.
|
||||
// Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// When enabled, the client will automatically handle cluster upgrade notifications
|
||||
// and manage connection/pool state transitions seamlessly.
|
||||
//
|
||||
// default: false
|
||||
HitlessUpgrades bool
|
||||
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// If nil, default configuration will be used when HitlessUpgrades is true.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) init() {
|
||||
@ -329,6 +341,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// situations in the options below will prevent that from happening.
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
HitlessUpgrades: opt.HitlessUpgrades,
|
||||
HitlessUpgradeConfig: opt.HitlessUpgradeConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@ -943,6 +957,9 @@ type ClusterClient struct {
|
||||
cmdsInfoCache *cmdsInfoCache
|
||||
cmdable
|
||||
hooksMixin
|
||||
|
||||
// hitlessIntegration provides hitless upgrade functionality
|
||||
hitlessIntegration HitlessIntegration
|
||||
}
|
||||
|
||||
// NewClusterClient returns a Redis Cluster client as described in
|
||||
@ -969,6 +986,22 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
|
||||
txPipeline: c.processTxPipeline,
|
||||
})
|
||||
|
||||
// Initialize hitless upgrades if enabled
|
||||
if opt.HitlessUpgrades {
|
||||
if opt.Protocol != 3 {
|
||||
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
|
||||
} else {
|
||||
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
|
||||
integration, err := initializeHitlessIntegration(c, opt.HitlessUpgradeConfig, timeoutProvider)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||
} else {
|
||||
c.hitlessIntegration = integration
|
||||
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for cluster client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
@ -977,6 +1010,14 @@ func (c *ClusterClient) Options() *ClusterOptions {
|
||||
return c.opt
|
||||
}
|
||||
|
||||
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
|
||||
// Returns nil if hitless upgrades are not enabled.
|
||||
func (c *ClusterClient) GetHitlessIntegration() HitlessIntegration {
|
||||
return c.hitlessIntegration
|
||||
}
|
||||
|
||||
// getPushNotificationProcessor removed - not needed in simplified implementation
|
||||
|
||||
// ReloadState reloads cluster state. If available it calls ClusterSlots func
|
||||
// to get cluster slots information.
|
||||
func (c *ClusterClient) ReloadState(ctx context.Context) {
|
||||
@ -988,6 +1029,13 @@ func (c *ClusterClient) ReloadState(ctx context.Context) {
|
||||
// It is rare to Close a ClusterClient, as the ClusterClient is meant
|
||||
// to be long-lived and shared between many goroutines.
|
||||
func (c *ClusterClient) Close() error {
|
||||
// Close hitless integration first
|
||||
if c.hitlessIntegration != nil {
|
||||
if err := c.hitlessIntegration.Close(); err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c.nodes.Close()
|
||||
}
|
||||
|
||||
|
@ -13,9 +13,6 @@ type NotificationHandlerContext struct {
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *redis.baseClient
|
||||
// - *redis.Client
|
||||
// - *redis.ClusterClient
|
||||
// - *redis.Conn
|
||||
Client interface{}
|
||||
|
||||
// ConnPool is the connection pool from which the connection was obtained.
|
||||
@ -25,7 +22,7 @@ type NotificationHandlerContext struct {
|
||||
// - *pool.ConnPool
|
||||
// - *pool.SingleConnPool
|
||||
// - *pool.StickyConnPool
|
||||
ConnPool interface{}
|
||||
ConnPool pool.Pooler
|
||||
|
||||
// PubSub is the PubSub instance that received the notification.
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
|
191
redis.go
191
redis.go
@ -2,6 +2,7 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -211,6 +212,9 @@ type baseClient struct {
|
||||
|
||||
// Push notification processing
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// hitlessIntegration provides hitless upgrade functionality
|
||||
hitlessIntegration HitlessIntegration
|
||||
}
|
||||
|
||||
func (c *baseClient) clone() *baseClient {
|
||||
@ -466,8 +470,17 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
|
||||
}
|
||||
|
||||
// Check if connection is marked for closing due to hitless upgrades
|
||||
if c.hitlessIntegration != nil && c.hitlessIntegration.IsConnectionMarkedForClosing(cn) {
|
||||
// Connection is marked for closing (e.g., during MOVING state)
|
||||
// Remove it instead of putting it back in the pool
|
||||
internal.Logger.Printf(ctx, "hitless: closing connection marked for closure during upgrade")
|
||||
c.connPool.Remove(ctx, cn, nil)
|
||||
} else {
|
||||
c.connPool.Put(ctx, cn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) withConn(
|
||||
@ -528,7 +541,33 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
||||
}
|
||||
|
||||
retryTimeout := uint32(0)
|
||||
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
// Check if this is a blocking command that needs redirection
|
||||
isBlockingCommand := cmd.readTimeout() != nil
|
||||
var redirectedConn *pool.Conn
|
||||
var shouldRedirect bool
|
||||
var newEndpoint string
|
||||
|
||||
if c.hitlessIntegration != nil && isBlockingCommand {
|
||||
// For blocking commands, check if we need to redirect to a new endpoint
|
||||
// This happens during MOVING state when the endpoint is changing
|
||||
shouldRedirect, newEndpoint = c.hitlessIntegration.ShouldRedirectBlockingConnection(nil)
|
||||
if shouldRedirect {
|
||||
// Create a new connection to the new endpoint
|
||||
var err error
|
||||
redirectedConn, err = c.createConnectionToEndpoint(ctx, newEndpoint)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "hitless: failed to create redirected connection to %s: %v", newEndpoint, err)
|
||||
// Fall back to normal connection if redirection fails
|
||||
shouldRedirect = false
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, "hitless: redirecting blocking command %s to new endpoint %s", cmd.Name(), newEndpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use redirected connection if available, otherwise use normal connection
|
||||
connFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
// Process any pending push notifications before executing the command
|
||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
@ -536,7 +575,19 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
// Mark connection as blocking if this is a blocking command
|
||||
if c.hitlessIntegration != nil && isBlockingCommand {
|
||||
c.hitlessIntegration.MarkConnectionAsBlocking(cn, true)
|
||||
internal.Logger.Printf(ctx, "hitless: marked connection as blocking for command %s", cmd.Name())
|
||||
}
|
||||
|
||||
// Get appropriate write timeout for this connection
|
||||
writeTimeout := c.opt.WriteTimeout
|
||||
if c.hitlessIntegration != nil {
|
||||
_, writeTimeout = c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmd(wr, cmd)
|
||||
}); err != nil {
|
||||
atomic.StoreUint32(&retryTimeout, 1)
|
||||
@ -547,7 +598,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
||||
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
|
||||
readReplyFunc = cmd.readRawReply
|
||||
}
|
||||
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error {
|
||||
if err := cn.WithReader(c.context(ctx), c.cmdTimeoutForConnection(cmd, cn), func(rd *proto.Reader) error {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
@ -564,8 +615,31 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
||||
return err
|
||||
}
|
||||
|
||||
// Unmark connection as blocking after command completes
|
||||
if c.hitlessIntegration != nil && isBlockingCommand {
|
||||
c.hitlessIntegration.MarkConnectionAsBlocking(cn, false)
|
||||
internal.Logger.Printf(ctx, "hitless: unmarked connection as blocking after command %s completed", cmd.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
}
|
||||
|
||||
// Execute the command with either redirected or normal connection
|
||||
var err error
|
||||
if shouldRedirect && redirectedConn != nil {
|
||||
// Use the redirected connection for blocking command
|
||||
err = connFunc(ctx, redirectedConn)
|
||||
// Close the redirected connection after use
|
||||
defer func() {
|
||||
redirectedConn.Close()
|
||||
internal.Logger.Printf(ctx, "hitless: closed redirected connection to %s after blocking command completed", newEndpoint)
|
||||
}()
|
||||
} else {
|
||||
// Use normal connection pool
|
||||
err = c.withConn(ctx, connFunc)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
|
||||
return retry, err
|
||||
}
|
||||
@ -588,6 +662,70 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
||||
return c.opt.ReadTimeout
|
||||
}
|
||||
|
||||
// cmdTimeoutForConnection returns the appropriate read timeout for a specific connection
|
||||
// taking into account hitless upgrade state
|
||||
func (c *baseClient) cmdTimeoutForConnection(cmd Cmder, cn *pool.Conn) time.Duration {
|
||||
baseTimeout := c.cmdTimeout(cmd)
|
||||
|
||||
// If hitless upgrades are enabled, get dynamic timeout based on connection state
|
||||
if c.hitlessIntegration != nil {
|
||||
// For blocking commands, use the command's timeout but check if connection needs increased timeout
|
||||
if cmd.readTimeout() != nil {
|
||||
// For blocking commands, use the base timeout but apply hitless upgrade adjustments
|
||||
adjustedTimeout := c.hitlessIntegration.GetConnectionTimeout(cn, baseTimeout)
|
||||
return adjustedTimeout
|
||||
} else {
|
||||
// For regular commands, get both read and write timeouts (use read timeout for command)
|
||||
readTimeout, _ := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||
return readTimeout
|
||||
}
|
||||
}
|
||||
|
||||
return baseTimeout
|
||||
}
|
||||
|
||||
// createConnectionToEndpoint creates a new connection to a specific endpoint
|
||||
// This is used for redirecting blocking commands during MOVING state
|
||||
func (c *baseClient) createConnectionToEndpoint(ctx context.Context, endpoint string) (*pool.Conn, error) {
|
||||
// Parse the endpoint to get host and port
|
||||
addr := endpoint
|
||||
if addr == "" {
|
||||
return nil, fmt.Errorf("empty endpoint provided")
|
||||
}
|
||||
|
||||
// Create a temporary dialer for the new endpoint
|
||||
dialer := func(ctx context.Context) (net.Conn, error) {
|
||||
netDialer := &net.Dialer{
|
||||
Timeout: c.opt.DialTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
if c.opt.TLSConfig == nil {
|
||||
return netDialer.DialContext(ctx, c.opt.Network, addr)
|
||||
}
|
||||
|
||||
return tls.DialWithDialer(netDialer, c.opt.Network, addr, c.opt.TLSConfig)
|
||||
}
|
||||
|
||||
// Create a new connection using the dialer
|
||||
netConn, err := dialer(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial new endpoint %s: %w", endpoint, err)
|
||||
}
|
||||
|
||||
// Wrap in pool.Conn
|
||||
cn := pool.NewConn(netConn)
|
||||
|
||||
// Initialize the connection (auth, select db, etc.)
|
||||
if err := c.initConn(ctx, cn); err != nil {
|
||||
cn.Close()
|
||||
return nil, fmt.Errorf("failed to initialize connection to %s: %w", endpoint, err)
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: created new connection to endpoint %s for blocking command redirection", endpoint)
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
// context returns the context for the current connection.
|
||||
// If the context timeout is enabled, it returns the original context.
|
||||
// Otherwise, it returns a new background context.
|
||||
@ -604,8 +742,19 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
// long-lived and shared between many goroutines.
|
||||
func (c *baseClient) Close() error {
|
||||
var firstErr error
|
||||
|
||||
// Close hitless integration first
|
||||
if c.hitlessIntegration != nil {
|
||||
if err := c.hitlessIntegration.Close(); err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.onClose != nil {
|
||||
if err := c.onClose(); err != nil {
|
||||
if err := c.onClose(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
@ -678,14 +827,15 @@ func (c *baseClient) pipelineProcessCmds(
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
}); err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
return true, err
|
||||
}
|
||||
|
||||
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
|
||||
if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
|
||||
// read all replies
|
||||
return c.pipelineReadCmds(ctx, cn, rd, cmds)
|
||||
}); err != nil {
|
||||
@ -725,14 +875,15 @@ func (c *baseClient) txPipelineProcessCmds(
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
}); err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
return true, err
|
||||
}
|
||||
|
||||
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
|
||||
if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
|
||||
statusCmd := cmds[0].(*StatusCmd)
|
||||
// Trim multi and exec.
|
||||
trimmedCmds := cmds[1 : len(cmds)-1]
|
||||
@ -837,6 +988,22 @@ func NewClient(opt *Options) *Client {
|
||||
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
|
||||
// Initialize hitless upgrades if enabled
|
||||
if opt.HitlessUpgrades {
|
||||
if opt.Protocol != 3 {
|
||||
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
|
||||
} else {
|
||||
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
|
||||
integration, err := initializeHitlessIntegration(&c, opt.HitlessUpgradeConfig, timeoutProvider)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||
} else {
|
||||
c.hitlessIntegration = integration
|
||||
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
@ -857,6 +1024,12 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
||||
return &clone
|
||||
}
|
||||
|
||||
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
|
||||
// Returns nil if hitless upgrades are not enabled.
|
||||
func (c *Client) GetHitlessIntegration() HitlessIntegration {
|
||||
return c.hitlessIntegration
|
||||
}
|
||||
|
||||
func (c *Client) Conn() *Conn {
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
|
||||
}
|
||||
|
Reference in New Issue
Block a user