mirror of
https://github.com/redis/go-redis.git
synced 2025-07-16 13:21:51 +03:00
wip.
This commit is contained in:
416
hitless.go
Normal file
416
hitless.go
Normal file
@ -0,0 +1,416 @@
|
|||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/hitless"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides configuration for hitless upgrades
|
||||||
|
type HitlessUpgradeConfig struct {
|
||||||
|
// Enabled controls whether hitless upgrades are active
|
||||||
|
Enabled bool
|
||||||
|
|
||||||
|
// TransitionTimeout is the increased timeout for connections during transitions
|
||||||
|
// (MIGRATING/FAILING_OVER). This should be longer than normal operation timeouts
|
||||||
|
// to account for the time needed to complete the transition.
|
||||||
|
// Default: 60 seconds
|
||||||
|
TransitionTimeout time.Duration
|
||||||
|
|
||||||
|
// CleanupInterval controls how often expired states are cleaned up
|
||||||
|
// Default: 30 seconds
|
||||||
|
CleanupInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHitlessUpgradeConfig returns the default configuration for hitless upgrades
|
||||||
|
func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig {
|
||||||
|
return &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections
|
||||||
|
CleanupInterval: 30 * time.Second, // How often to clean up expired states
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlessUpgradeStatistics provides statistics about ongoing upgrade operations
|
||||||
|
type HitlessUpgradeStatistics struct {
|
||||||
|
ActiveConnections int // Total connections in transition
|
||||||
|
IsMoving bool // Whether pool is currently moving
|
||||||
|
MigratingConnections int // Connections in MIGRATING state
|
||||||
|
FailingOverConnections int // Connections in FAILING_OVER state
|
||||||
|
Timestamp time.Time // When these statistics were collected
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlessUpgradeStatus provides detailed status of all ongoing upgrades
|
||||||
|
type HitlessUpgradeStatus struct {
|
||||||
|
ConnectionStates map[interface{}]interface{}
|
||||||
|
IsMoving bool
|
||||||
|
NewEndpoint string
|
||||||
|
Timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlessIntegration provides the interface for hitless upgrade functionality
|
||||||
|
type HitlessIntegration interface {
|
||||||
|
// IsEnabled returns whether hitless upgrades are currently enabled
|
||||||
|
IsEnabled() bool
|
||||||
|
|
||||||
|
// EnableHitlessUpgrades enables hitless upgrade functionality
|
||||||
|
EnableHitlessUpgrades()
|
||||||
|
|
||||||
|
// DisableHitlessUpgrades disables hitless upgrade functionality
|
||||||
|
DisableHitlessUpgrades()
|
||||||
|
|
||||||
|
// GetConnectionTimeout returns the appropriate timeout for a connection
|
||||||
|
// If the connection is transitioning, returns the longer TransitionTimeout
|
||||||
|
GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration
|
||||||
|
|
||||||
|
// GetConnectionTimeouts returns both read and write timeouts for a connection
|
||||||
|
// If the connection is transitioning, returns increased timeouts
|
||||||
|
GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration)
|
||||||
|
|
||||||
|
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||||
|
MarkConnectionAsBlocking(conn interface{}, isBlocking bool)
|
||||||
|
|
||||||
|
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||||
|
IsConnectionMarkedForClosing(conn interface{}) bool
|
||||||
|
|
||||||
|
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||||
|
ShouldRedirectBlockingConnection(conn interface{}) (bool, string)
|
||||||
|
|
||||||
|
// GetUpgradeStatistics returns current upgrade statistics
|
||||||
|
GetUpgradeStatistics() *HitlessUpgradeStatistics
|
||||||
|
|
||||||
|
// GetUpgradeStatus returns detailed upgrade status
|
||||||
|
GetUpgradeStatus() *HitlessUpgradeStatus
|
||||||
|
|
||||||
|
// UpdateConfig updates the hitless upgrade configuration
|
||||||
|
UpdateConfig(config *HitlessUpgradeConfig) error
|
||||||
|
|
||||||
|
// GetConfig returns the current configuration
|
||||||
|
GetConfig() *HitlessUpgradeConfig
|
||||||
|
|
||||||
|
// Close shuts down the hitless integration
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// hitlessIntegrationImpl implements the HitlessIntegration interface
|
||||||
|
type hitlessIntegrationImpl struct {
|
||||||
|
integration *hitless.RedisClientIntegration
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHitlessIntegration creates a new hitless integration instance
|
||||||
|
func newHitlessIntegration(config *HitlessUpgradeConfig) *hitlessIntegrationImpl {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultHitlessUpgradeConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to internal config format
|
||||||
|
internalConfig := &hitless.HitlessUpgradeConfig{
|
||||||
|
Enabled: config.Enabled,
|
||||||
|
TransitionTimeout: config.TransitionTimeout,
|
||||||
|
CleanupInterval: config.CleanupInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
integration := hitless.NewRedisClientIntegration(internalConfig, 3*time.Second, 3*time.Second)
|
||||||
|
|
||||||
|
return &hitlessIntegrationImpl{
|
||||||
|
integration: integration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHitlessIntegrationWithTimeouts creates a new hitless integration instance with timeout configuration
|
||||||
|
func newHitlessIntegrationWithTimeouts(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *hitlessIntegrationImpl {
|
||||||
|
// Start with defaults
|
||||||
|
defaults := DefaultHitlessUpgradeConfig()
|
||||||
|
|
||||||
|
// If config is nil, use all defaults
|
||||||
|
if config == nil {
|
||||||
|
config = defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure all fields are set with defaults if they are zero values
|
||||||
|
enabled := config.Enabled
|
||||||
|
transitionTimeout := config.TransitionTimeout
|
||||||
|
cleanupInterval := config.CleanupInterval
|
||||||
|
|
||||||
|
// Apply defaults for zero values
|
||||||
|
if transitionTimeout == 0 {
|
||||||
|
transitionTimeout = defaults.TransitionTimeout
|
||||||
|
}
|
||||||
|
if cleanupInterval == 0 {
|
||||||
|
cleanupInterval = defaults.CleanupInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to internal config format with all fields properly set
|
||||||
|
internalConfig := &hitless.HitlessUpgradeConfig{
|
||||||
|
Enabled: enabled,
|
||||||
|
TransitionTimeout: transitionTimeout,
|
||||||
|
CleanupInterval: cleanupInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
integration := hitless.NewRedisClientIntegration(internalConfig, defaultReadTimeout, defaultWriteTimeout)
|
||||||
|
|
||||||
|
return &hitlessIntegrationImpl{
|
||||||
|
integration: integration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled returns whether hitless upgrades are currently enabled
|
||||||
|
func (h *hitlessIntegrationImpl) IsEnabled() bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.integration.IsEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableHitlessUpgrades enables hitless upgrade functionality
|
||||||
|
func (h *hitlessIntegrationImpl) EnableHitlessUpgrades() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.integration.EnableHitlessUpgrades()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableHitlessUpgrades disables hitless upgrade functionality
|
||||||
|
func (h *hitlessIntegrationImpl) DisableHitlessUpgrades() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.integration.DisableHitlessUpgrades()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionTimeout returns the appropriate timeout for a connection
|
||||||
|
func (h *hitlessIntegrationImpl) GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
// Convert interface{} to *pool.Conn
|
||||||
|
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||||
|
return h.integration.GetConnectionTimeout(poolConn, defaultTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not a pool connection, return default timeout
|
||||||
|
return defaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionTimeouts returns both read and write timeouts for a connection
|
||||||
|
func (h *hitlessIntegrationImpl) GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
// Convert interface{} to *pool.Conn
|
||||||
|
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||||
|
return h.integration.GetConnectionTimeouts(poolConn, defaultReadTimeout, defaultWriteTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not a pool connection, return default timeouts
|
||||||
|
return defaultReadTimeout, defaultWriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||||
|
func (h *hitlessIntegrationImpl) MarkConnectionAsBlocking(conn interface{}, isBlocking bool) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Convert interface{} to *pool.Conn
|
||||||
|
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||||
|
h.integration.MarkConnectionAsBlocking(poolConn, isBlocking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||||
|
func (h *hitlessIntegrationImpl) IsConnectionMarkedForClosing(conn interface{}) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
// Convert interface{} to *pool.Conn
|
||||||
|
if poolConn, ok := conn.(*pool.Conn); ok {
|
||||||
|
return h.integration.IsConnectionMarkedForClosing(poolConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||||
|
func (h *hitlessIntegrationImpl) ShouldRedirectBlockingConnection(conn interface{}) (bool, string) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
// Convert interface{} to *pool.Conn (can be nil for checking pool state)
|
||||||
|
var poolConn *pool.Conn
|
||||||
|
if conn != nil {
|
||||||
|
if pc, ok := conn.(*pool.Conn); ok {
|
||||||
|
poolConn = pc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.integration.ShouldRedirectBlockingConnection(poolConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpgradeStatistics returns current upgrade statistics
|
||||||
|
func (h *hitlessIntegrationImpl) GetUpgradeStatistics() *HitlessUpgradeStatistics {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
stats := h.integration.GetUpgradeStatistics()
|
||||||
|
if stats == nil {
|
||||||
|
return &HitlessUpgradeStatistics{Timestamp: time.Now()}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HitlessUpgradeStatistics{
|
||||||
|
ActiveConnections: stats.ActiveConnections,
|
||||||
|
IsMoving: stats.IsMoving,
|
||||||
|
MigratingConnections: stats.MigratingConnections,
|
||||||
|
FailingOverConnections: stats.FailingOverConnections,
|
||||||
|
Timestamp: stats.Timestamp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpgradeStatus returns detailed upgrade status
|
||||||
|
func (h *hitlessIntegrationImpl) GetUpgradeStatus() *HitlessUpgradeStatus {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
status := h.integration.GetUpgradeStatus()
|
||||||
|
if status == nil {
|
||||||
|
return &HitlessUpgradeStatus{
|
||||||
|
ConnectionStates: make(map[interface{}]interface{}),
|
||||||
|
IsMoving: false,
|
||||||
|
NewEndpoint: "",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HitlessUpgradeStatus{
|
||||||
|
ConnectionStates: convertToInterfaceMap(status.ConnectionStates),
|
||||||
|
IsMoving: status.IsMoving,
|
||||||
|
NewEndpoint: status.NewEndpoint,
|
||||||
|
Timestamp: status.Timestamp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConfig updates the hitless upgrade configuration
|
||||||
|
func (h *hitlessIntegrationImpl) UpdateConfig(config *HitlessUpgradeConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return fmt.Errorf("config cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Start with defaults for any zero values
|
||||||
|
defaults := DefaultHitlessUpgradeConfig()
|
||||||
|
|
||||||
|
// Ensure all fields are set with defaults if they are zero values
|
||||||
|
enabled := config.Enabled
|
||||||
|
transitionTimeout := config.TransitionTimeout
|
||||||
|
cleanupInterval := config.CleanupInterval
|
||||||
|
|
||||||
|
// Apply defaults for zero values
|
||||||
|
if transitionTimeout == 0 {
|
||||||
|
transitionTimeout = defaults.TransitionTimeout
|
||||||
|
}
|
||||||
|
if cleanupInterval == 0 {
|
||||||
|
cleanupInterval = defaults.CleanupInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to internal config format with all fields properly set
|
||||||
|
internalConfig := &hitless.HitlessUpgradeConfig{
|
||||||
|
Enabled: enabled,
|
||||||
|
TransitionTimeout: transitionTimeout,
|
||||||
|
CleanupInterval: cleanupInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.integration.UpdateConfig(internalConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns the current configuration
|
||||||
|
func (h *hitlessIntegrationImpl) GetConfig() *HitlessUpgradeConfig {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
internalConfig := h.integration.GetConfig()
|
||||||
|
if internalConfig == nil {
|
||||||
|
return DefaultHitlessUpgradeConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HitlessUpgradeConfig{
|
||||||
|
Enabled: internalConfig.Enabled,
|
||||||
|
TransitionTimeout: internalConfig.TransitionTimeout,
|
||||||
|
CleanupInterval: internalConfig.CleanupInterval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the hitless integration
|
||||||
|
func (h *hitlessIntegrationImpl) Close() error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
return h.integration.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInternalIntegration returns the internal integration for use by Redis clients
|
||||||
|
func (h *hitlessIntegrationImpl) getInternalIntegration() *hitless.RedisClientIntegration {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.integration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientTimeoutProvider interface for extracting timeout configuration from client options
|
||||||
|
type ClientTimeoutProvider interface {
|
||||||
|
GetReadTimeout() time.Duration
|
||||||
|
GetWriteTimeout() time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionsTimeoutProvider implements ClientTimeoutProvider for Options struct
|
||||||
|
type optionsTimeoutProvider struct {
|
||||||
|
readTimeout time.Duration
|
||||||
|
writeTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *optionsTimeoutProvider) GetReadTimeout() time.Duration {
|
||||||
|
return p.readTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *optionsTimeoutProvider) GetWriteTimeout() time.Duration {
|
||||||
|
return p.writeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// newOptionsTimeoutProvider creates a timeout provider from Options
|
||||||
|
func newOptionsTimeoutProvider(readTimeout, writeTimeout time.Duration) ClientTimeoutProvider {
|
||||||
|
return &optionsTimeoutProvider{
|
||||||
|
readTimeout: readTimeout,
|
||||||
|
writeTimeout: writeTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeHitlessIntegration initializes hitless integration for a client
|
||||||
|
func initializeHitlessIntegration(client interface{}, config *HitlessUpgradeConfig, timeoutProvider ClientTimeoutProvider) (*hitlessIntegrationImpl, error) {
|
||||||
|
if config == nil || !config.Enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract timeout configuration from client options
|
||||||
|
defaultReadTimeout := timeoutProvider.GetReadTimeout()
|
||||||
|
defaultWriteTimeout := timeoutProvider.GetWriteTimeout()
|
||||||
|
|
||||||
|
// Create hitless integration - each client gets its own instance
|
||||||
|
integration := newHitlessIntegrationWithTimeouts(config, defaultReadTimeout, defaultWriteTimeout)
|
||||||
|
|
||||||
|
// Push notification handlers are registered directly by the client
|
||||||
|
// No separate registration needed in simplified implementation
|
||||||
|
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: initialized hitless upgrades for client")
|
||||||
|
|
||||||
|
return integration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToInterfaceMap converts a typed map to interface{} map for public API
|
||||||
|
func convertToInterfaceMap(input map[*pool.Conn]*hitless.ConnectionState) map[interface{}]interface{} {
|
||||||
|
result := make(map[interface{}]interface{})
|
||||||
|
for k, v := range input {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
197
hitless_config_defaults_test.go
Normal file
197
hitless_config_defaults_test.go
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHitlessUpgradeConfig_DefaultValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputConfig *HitlessUpgradeConfig
|
||||||
|
expectedConfig *HitlessUpgradeConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil config should use all defaults",
|
||||||
|
inputConfig: nil,
|
||||||
|
expectedConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 60 * time.Second,
|
||||||
|
CleanupInterval: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero TransitionTimeout should use default",
|
||||||
|
inputConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TransitionTimeout: 0, // Zero value
|
||||||
|
CleanupInterval: 45 * time.Second,
|
||||||
|
},
|
||||||
|
expectedConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TransitionTimeout: 60 * time.Second, // Should use default
|
||||||
|
CleanupInterval: 45 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero CleanupInterval should use default",
|
||||||
|
inputConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 90 * time.Second,
|
||||||
|
CleanupInterval: 0, // Zero value
|
||||||
|
},
|
||||||
|
expectedConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 90 * time.Second,
|
||||||
|
CleanupInterval: 30 * time.Second, // Should use default
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both timeouts zero should use defaults",
|
||||||
|
inputConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 0, // Zero value
|
||||||
|
CleanupInterval: 0, // Zero value
|
||||||
|
},
|
||||||
|
expectedConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 60 * time.Second, // Should use default
|
||||||
|
CleanupInterval: 30 * time.Second, // Should use default
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all values set should be preserved",
|
||||||
|
inputConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TransitionTimeout: 120 * time.Second,
|
||||||
|
CleanupInterval: 60 * time.Second,
|
||||||
|
},
|
||||||
|
expectedConfig: &HitlessUpgradeConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TransitionTimeout: 120 * time.Second,
|
||||||
|
CleanupInterval: 60 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test with a mock client that has hitless upgrades enabled
|
||||||
|
opt := &Options{
|
||||||
|
Addr: "127.0.0.1:6379",
|
||||||
|
Protocol: 3,
|
||||||
|
HitlessUpgrades: true,
|
||||||
|
HitlessUpgradeConfig: tt.inputConfig,
|
||||||
|
ReadTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the integration creation using the internal method directly
|
||||||
|
// since initializeHitlessIntegration requires a push processor
|
||||||
|
integration := newHitlessIntegrationWithTimeouts(tt.inputConfig, opt.ReadTimeout, opt.WriteTimeout)
|
||||||
|
if integration == nil {
|
||||||
|
t.Fatal("Integration should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the config from the integration
|
||||||
|
actualConfig := integration.GetConfig()
|
||||||
|
if actualConfig == nil {
|
||||||
|
t.Fatal("Config should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all fields match expected values
|
||||||
|
if actualConfig.Enabled != tt.expectedConfig.Enabled {
|
||||||
|
t.Errorf("Enabled: expected %v, got %v", tt.expectedConfig.Enabled, actualConfig.Enabled)
|
||||||
|
}
|
||||||
|
if actualConfig.TransitionTimeout != tt.expectedConfig.TransitionTimeout {
|
||||||
|
t.Errorf("TransitionTimeout: expected %v, got %v", tt.expectedConfig.TransitionTimeout, actualConfig.TransitionTimeout)
|
||||||
|
}
|
||||||
|
if actualConfig.CleanupInterval != tt.expectedConfig.CleanupInterval {
|
||||||
|
t.Errorf("CleanupInterval: expected %v, got %v", tt.expectedConfig.CleanupInterval, actualConfig.CleanupInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test UpdateConfig as well
|
||||||
|
newConfig := &HitlessUpgradeConfig{
|
||||||
|
Enabled: !tt.expectedConfig.Enabled,
|
||||||
|
TransitionTimeout: 0, // Zero value should use default
|
||||||
|
CleanupInterval: 0, // Zero value should use default
|
||||||
|
}
|
||||||
|
|
||||||
|
err := integration.UpdateConfig(newConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify updated config has defaults applied
|
||||||
|
updatedConfig := integration.GetConfig()
|
||||||
|
if updatedConfig.Enabled == tt.expectedConfig.Enabled {
|
||||||
|
t.Error("Enabled should have been toggled")
|
||||||
|
}
|
||||||
|
if updatedConfig.TransitionTimeout != 60*time.Second {
|
||||||
|
t.Errorf("TransitionTimeout should use default (60s), got %v", updatedConfig.TransitionTimeout)
|
||||||
|
}
|
||||||
|
if updatedConfig.CleanupInterval != 30*time.Second {
|
||||||
|
t.Errorf("CleanupInterval should use default (30s), got %v", updatedConfig.CleanupInterval)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultHitlessUpgradeConfig(t *testing.T) {
|
||||||
|
config := DefaultHitlessUpgradeConfig()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
t.Fatal("Default config should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !config.Enabled {
|
||||||
|
t.Error("Default config should have Enabled=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TransitionTimeout != 60*time.Second {
|
||||||
|
t.Errorf("Default TransitionTimeout should be 60s, got %v", config.TransitionTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.CleanupInterval != 30*time.Second {
|
||||||
|
t.Errorf("Default CleanupInterval should be 30s, got %v", config.CleanupInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHitlessUpgradeConfig_ZeroValueHandling(t *testing.T) {
|
||||||
|
// Test that zero values are properly handled in various scenarios
|
||||||
|
|
||||||
|
// Test 1: Partial config with some zero values
|
||||||
|
partialConfig := &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
// TransitionTimeout and CleanupInterval are zero values
|
||||||
|
}
|
||||||
|
|
||||||
|
integration := newHitlessIntegrationWithTimeouts(partialConfig, 3*time.Second, 3*time.Second)
|
||||||
|
if integration == nil {
|
||||||
|
t.Fatal("Integration should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := integration.GetConfig()
|
||||||
|
if config.TransitionTimeout == 0 {
|
||||||
|
t.Error("Zero TransitionTimeout should have been replaced with default")
|
||||||
|
}
|
||||||
|
if config.CleanupInterval == 0 {
|
||||||
|
t.Error("Zero CleanupInterval should have been replaced with default")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Empty struct
|
||||||
|
emptyConfig := &HitlessUpgradeConfig{}
|
||||||
|
|
||||||
|
integration2 := newHitlessIntegrationWithTimeouts(emptyConfig, 3*time.Second, 3*time.Second)
|
||||||
|
if integration2 == nil {
|
||||||
|
t.Fatal("Integration should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
config2 := integration2.GetConfig()
|
||||||
|
if config2.TransitionTimeout == 0 {
|
||||||
|
t.Error("Zero TransitionTimeout in empty config should have been replaced with default")
|
||||||
|
}
|
||||||
|
if config2.CleanupInterval == 0 {
|
||||||
|
t.Error("Zero CleanupInterval in empty config should have been replaced with default")
|
||||||
|
}
|
||||||
|
}
|
23
internal/hitless/README.md
Normal file
23
internal/hitless/README.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Hitless Upgrade Package
|
||||||
|
|
||||||
|
This package implements hitless upgrade functionality for Redis cluster clients using the push notification architecture. It provides handlers for managing connection and pool state during Redis cluster upgrades.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
To enable hitless upgrades in your Redis client, simply set the configuration option:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
// Enable hitless upgrades with a simple configuration option
|
||||||
|
client := redis.NewClusterClient(&redis.ClusterOptions{
|
||||||
|
Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"},
|
||||||
|
Protocol: 3, // RESP3 required for push notifications
|
||||||
|
HitlessUpgrades: true, // Enable hitless upgrades
|
||||||
|
})
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
// That's it! Use your client normally - hitless upgrades work automatically
|
||||||
|
ctx := context.Background()
|
||||||
|
client.Set(ctx, "key", "value", 0)
|
||||||
|
```
|
200
internal/hitless/client_integration.go
Normal file
200
internal/hitless/client_integration.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientIntegrator provides integration between hitless upgrade handlers and Redis clients
|
||||||
|
type ClientIntegrator struct {
|
||||||
|
upgradeHandler *UpgradeHandler
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Simple atomic state for pool redirection
|
||||||
|
isMoving int32 // atomic: 0 = not moving, 1 = moving
|
||||||
|
newEndpoint string // only written during MOVING, read-only after
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientIntegrator creates a new client integrator with client timeout configuration
|
||||||
|
func NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout time.Duration) *ClientIntegrator {
|
||||||
|
return &ClientIntegrator{
|
||||||
|
upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpgradeHandler returns the upgrade handler for direct access
|
||||||
|
func (ci *ClientIntegrator) GetUpgradeHandler() *UpgradeHandler {
|
||||||
|
return ci.upgradeHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlePushNotification is the main entry point for processing upgrade notifications
|
||||||
|
func (ci *ClientIntegrator) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
// Handle MOVING notifications for pool redirection
|
||||||
|
if len(notification) > 0 {
|
||||||
|
if notificationType, ok := notification[0].(string); ok && notificationType == "MOVING" {
|
||||||
|
if len(notification) >= 3 {
|
||||||
|
if newEndpoint, ok := notification[2].(string); ok {
|
||||||
|
// Simple atomic state update - no locks needed
|
||||||
|
ci.newEndpoint = newEndpoint
|
||||||
|
atomic.StoreInt32(&ci.isMoving, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ci.upgradeHandler.HandlePushNotification(ctx, handlerCtx, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the client integrator
|
||||||
|
func (ci *ClientIntegrator) Close() error {
|
||||||
|
ci.mu.Lock()
|
||||||
|
defer ci.mu.Unlock()
|
||||||
|
|
||||||
|
// Reset atomic state
|
||||||
|
atomic.StoreInt32(&ci.isMoving, 0)
|
||||||
|
ci.newEndpoint = ""
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMoving returns true if the pool is currently moving to a new endpoint
|
||||||
|
// Uses atomic read - no locks needed
|
||||||
|
func (ci *ClientIntegrator) IsMoving() bool {
|
||||||
|
return atomic.LoadInt32(&ci.isMoving) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNewEndpoint returns the new endpoint if moving, empty string otherwise
|
||||||
|
// Safe to read without locks since it's only written during MOVING
|
||||||
|
func (ci *ClientIntegrator) GetNewEndpoint() string {
|
||||||
|
if ci.IsMoving() {
|
||||||
|
return ci.newEndpoint
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushNotificationHandlerInterface defines the interface for push notification handlers
|
||||||
|
// This implements the interface expected by the push notification system
|
||||||
|
type PushNotificationHandlerInterface interface {
|
||||||
|
HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ClientIntegrator implements the interface
|
||||||
|
var _ PushNotificationHandlerInterface = (*ClientIntegrator)(nil)
|
||||||
|
|
||||||
|
// PoolRedirector provides pool redirection functionality for hitless upgrades
|
||||||
|
type PoolRedirector struct {
|
||||||
|
poolManager *PoolEndpointManager
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolRedirector creates a new pool redirector
|
||||||
|
func NewPoolRedirector() *PoolRedirector {
|
||||||
|
return &PoolRedirector{
|
||||||
|
poolManager: NewPoolEndpointManager(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectPool redirects a connection pool to a new endpoint
|
||||||
|
func (pr *PoolRedirector) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error {
|
||||||
|
pr.mu.Lock()
|
||||||
|
defer pr.mu.Unlock()
|
||||||
|
|
||||||
|
return pr.poolManager.RedirectPool(ctx, pooler, newEndpoint, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPoolRedirected checks if a pool is currently redirected
|
||||||
|
func (pr *PoolRedirector) IsPoolRedirected(pooler pool.Pooler) bool {
|
||||||
|
pr.mu.RLock()
|
||||||
|
defer pr.mu.RUnlock()
|
||||||
|
|
||||||
|
return pr.poolManager.IsPoolRedirected(pooler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRedirection returns redirection information for a pool
|
||||||
|
func (pr *PoolRedirector) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) {
|
||||||
|
pr.mu.RLock()
|
||||||
|
defer pr.mu.RUnlock()
|
||||||
|
|
||||||
|
return pr.poolManager.GetRedirection(pooler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the pool redirector
|
||||||
|
func (pr *PoolRedirector) Close() error {
|
||||||
|
pr.mu.Lock()
|
||||||
|
defer pr.mu.Unlock()
|
||||||
|
|
||||||
|
// Clean up all redirections
|
||||||
|
ctx := context.Background()
|
||||||
|
pr.poolManager.CleanupExpiredRedirections(ctx)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionStateTracker tracks connection states during upgrades
|
||||||
|
type ConnectionStateTracker struct {
|
||||||
|
upgradeHandler *UpgradeHandler
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnectionStateTracker creates a new connection state tracker with timeout configuration
|
||||||
|
func NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout time.Duration) *ConnectionStateTracker {
|
||||||
|
return &ConnectionStateTracker{
|
||||||
|
upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||||
|
func (cst *ConnectionStateTracker) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||||
|
cst.mu.RLock()
|
||||||
|
defer cst.mu.RUnlock()
|
||||||
|
|
||||||
|
return cst.upgradeHandler.IsConnectionTransitioning(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionState returns the current state of a connection
|
||||||
|
func (cst *ConnectionStateTracker) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||||
|
cst.mu.RLock()
|
||||||
|
defer cst.mu.RUnlock()
|
||||||
|
|
||||||
|
return cst.upgradeHandler.GetConnectionState(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupConnection removes tracking for a connection
|
||||||
|
func (cst *ConnectionStateTracker) CleanupConnection(conn *pool.Conn) {
|
||||||
|
cst.mu.Lock()
|
||||||
|
defer cst.mu.Unlock()
|
||||||
|
|
||||||
|
cst.upgradeHandler.CleanupConnection(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the connection state tracker
|
||||||
|
func (cst *ConnectionStateTracker) Close() error {
|
||||||
|
cst.mu.Lock()
|
||||||
|
defer cst.mu.Unlock()
|
||||||
|
|
||||||
|
// Clean up all expired states
|
||||||
|
cst.upgradeHandler.CleanupExpiredStates()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides configuration for hitless upgrades
|
||||||
|
type HitlessUpgradeConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
TransitionTimeout time.Duration
|
||||||
|
CleanupInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHitlessUpgradeConfig returns default configuration for hitless upgrades
|
||||||
|
func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig {
|
||||||
|
return &HitlessUpgradeConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections
|
||||||
|
CleanupInterval: 30 * time.Second, // How often to clean up expired states
|
||||||
|
}
|
||||||
|
}
|
205
internal/hitless/pool_manager.go
Normal file
205
internal/hitless/pool_manager.go
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PoolEndpointManager manages endpoint transitions for connection pools during hitless upgrades.
|
||||||
|
// It provides functionality to redirect new connections to new endpoints while maintaining
|
||||||
|
// existing connections until they can be gracefully transitioned.
|
||||||
|
type PoolEndpointManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Map of pools to their endpoint redirections
|
||||||
|
redirections map[interface{}]*EndpointRedirection
|
||||||
|
|
||||||
|
// Original dialers for pools (to restore after transition)
|
||||||
|
originalDialers map[interface{}]func(context.Context) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndpointRedirection represents an active endpoint redirection
|
||||||
|
type EndpointRedirection struct {
|
||||||
|
OriginalEndpoint string
|
||||||
|
NewEndpoint string
|
||||||
|
StartTime time.Time
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// Statistics
|
||||||
|
NewConnections int64
|
||||||
|
FailedConnections int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolEndpointManager creates a new pool endpoint manager
|
||||||
|
func NewPoolEndpointManager() *PoolEndpointManager {
|
||||||
|
return &PoolEndpointManager{
|
||||||
|
redirections: make(map[interface{}]*EndpointRedirection),
|
||||||
|
originalDialers: make(map[interface{}]func(context.Context) (net.Conn, error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectPool redirects new connections from a pool to a new endpoint
|
||||||
|
func (m *PoolEndpointManager) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if pool is already being redirected
|
||||||
|
if _, exists := m.redirections[pooler]; exists {
|
||||||
|
return fmt.Errorf("pool is already being redirected")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the current dialer from the pool
|
||||||
|
connPool, ok := pooler.(*pool.ConnPool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unsupported pool type: %T", pooler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store original dialer
|
||||||
|
originalDialer := m.getPoolDialer(connPool)
|
||||||
|
if originalDialer == nil {
|
||||||
|
return fmt.Errorf("could not get original dialer from pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.originalDialers[pooler] = originalDialer
|
||||||
|
|
||||||
|
// Create new dialer that connects to the new endpoint
|
||||||
|
newDialer := m.createRedirectDialer(ctx, newEndpoint, originalDialer)
|
||||||
|
|
||||||
|
// Replace the pool's dialer
|
||||||
|
if err := m.setPoolDialer(connPool, newDialer); err != nil {
|
||||||
|
delete(m.originalDialers, pooler)
|
||||||
|
return fmt.Errorf("failed to set new dialer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record the redirection
|
||||||
|
m.redirections[pooler] = &EndpointRedirection{
|
||||||
|
OriginalEndpoint: m.extractEndpointFromDialer(originalDialer),
|
||||||
|
NewEndpoint: newEndpoint,
|
||||||
|
StartTime: time.Now(),
|
||||||
|
Timeout: timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: redirected pool to new endpoint %s", newEndpoint)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPoolRedirected checks if a pool is currently being redirected
|
||||||
|
func (m *PoolEndpointManager) IsPoolRedirected(pooler pool.Pooler) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
_, exists := m.redirections[pooler]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRedirection returns redirection information for a pool
|
||||||
|
func (m *PoolEndpointManager) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
redirection, exists := m.redirections[pooler]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy to avoid race conditions
|
||||||
|
redirectionCopy := *redirection
|
||||||
|
return &redirectionCopy, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpiredRedirections removes expired redirections
|
||||||
|
func (m *PoolEndpointManager) CleanupExpiredRedirections(ctx context.Context) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for pooler, redirection := range m.redirections {
|
||||||
|
if now.Sub(redirection.StartTime) > redirection.Timeout {
|
||||||
|
// TODO: Here we should decide if we need to failback to the original dialer,
|
||||||
|
// i.e. if the new endpoint did not produce any active connections.
|
||||||
|
delete(m.redirections, pooler)
|
||||||
|
delete(m.originalDialers, pooler)
|
||||||
|
internal.Logger.Printf(ctx, "hitless: cleaned up expired redirection for pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createRedirectDialer creates a dialer that connects to the new endpoint
|
||||||
|
func (m *PoolEndpointManager) createRedirectDialer(ctx context.Context, newEndpoint string, originalDialer func(context.Context) (net.Conn, error)) func(context.Context) (net.Conn, error) {
|
||||||
|
return func(dialCtx context.Context) (net.Conn, error) {
|
||||||
|
// Try to connect to the new endpoint
|
||||||
|
conn, err := net.DialTimeout("tcp", newEndpoint, 10*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(ctx, "hitless: failed to connect to new endpoint %s: %v", newEndpoint, err)
|
||||||
|
|
||||||
|
// Fallback to original dialer
|
||||||
|
return originalDialer(dialCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: successfully connected to new endpoint %s", newEndpoint)
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPoolDialer extracts the dialer from a connection pool
|
||||||
|
func (m *PoolEndpointManager) getPoolDialer(connPool *pool.ConnPool) func(context.Context) (net.Conn, error) {
|
||||||
|
return connPool.GetDialer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setPoolDialer sets a new dialer for a connection pool
|
||||||
|
func (m *PoolEndpointManager) setPoolDialer(connPool *pool.ConnPool, dialer func(context.Context) (net.Conn, error)) error {
|
||||||
|
return connPool.SetDialer(dialer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractEndpointFromDialer extracts the endpoint address from a dialer
|
||||||
|
func (m *PoolEndpointManager) extractEndpointFromDialer(dialer func(context.Context) (net.Conn, error)) string {
|
||||||
|
// Try to extract endpoint by making a test connection
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := dialer(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if conn.RemoteAddr() != nil {
|
||||||
|
return conn.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveRedirections returns all active redirections
|
||||||
|
func (m *PoolEndpointManager) GetActiveRedirections() map[interface{}]*EndpointRedirection {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create copies to avoid race conditions
|
||||||
|
redirections := make(map[interface{}]*EndpointRedirection)
|
||||||
|
for pooler, redirection := range m.redirections {
|
||||||
|
redirectionCopy := *redirection
|
||||||
|
redirections[pooler] = &redirectionCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
return redirections
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRedirectionStats updates statistics for a redirection
|
||||||
|
func (m *PoolEndpointManager) UpdateRedirectionStats(pooler pool.Pooler, newConnections, failedConnections int64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if redirection, exists := m.redirections[pooler]; exists {
|
||||||
|
redirection.NewConnections += newConnections
|
||||||
|
redirection.FailedConnections += failedConnections
|
||||||
|
}
|
||||||
|
}
|
309
internal/hitless/redis_integration.go
Normal file
309
internal/hitless/redis_integration.go
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpgradeStatus represents the current status of all upgrade operations
|
||||||
|
type UpgradeStatus struct {
|
||||||
|
ConnectionStates map[*pool.Conn]*ConnectionState
|
||||||
|
IsMoving bool
|
||||||
|
NewEndpoint string
|
||||||
|
Timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpgradeStatistics provides statistics about upgrade operations
|
||||||
|
type UpgradeStatistics struct {
|
||||||
|
ActiveConnections int
|
||||||
|
IsMoving bool
|
||||||
|
MigratingConnections int
|
||||||
|
FailingOverConnections int
|
||||||
|
Timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisClientIntegration provides complete hitless upgrade integration for Redis clients
|
||||||
|
type RedisClientIntegration struct {
|
||||||
|
clientIntegrator *ClientIntegrator
|
||||||
|
connectionStateTracker *ConnectionStateTracker
|
||||||
|
config *HitlessUpgradeConfig
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedisClientIntegration creates a new Redis client integration for hitless upgrades
|
||||||
|
// This is used internally by the main hitless.go package
|
||||||
|
func NewRedisClientIntegration(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *RedisClientIntegration {
|
||||||
|
// Start with defaults
|
||||||
|
defaults := DefaultHitlessUpgradeConfig()
|
||||||
|
|
||||||
|
// If config is nil, use all defaults
|
||||||
|
if config == nil {
|
||||||
|
config = defaults
|
||||||
|
} else {
|
||||||
|
// Ensure all fields are set with defaults if they are zero values
|
||||||
|
if config.TransitionTimeout == 0 {
|
||||||
|
config = &HitlessUpgradeConfig{
|
||||||
|
Enabled: config.Enabled,
|
||||||
|
TransitionTimeout: defaults.TransitionTimeout,
|
||||||
|
CleanupInterval: config.CleanupInterval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.CleanupInterval == 0 {
|
||||||
|
config = &HitlessUpgradeConfig{
|
||||||
|
Enabled: config.Enabled,
|
||||||
|
TransitionTimeout: config.TransitionTimeout,
|
||||||
|
CleanupInterval: defaults.CleanupInterval,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RedisClientIntegration{
|
||||||
|
clientIntegrator: NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout),
|
||||||
|
connectionStateTracker: NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout),
|
||||||
|
config: config,
|
||||||
|
enabled: config.Enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableHitlessUpgrades enables hitless upgrade functionality
|
||||||
|
func (rci *RedisClientIntegration) EnableHitlessUpgrades() {
|
||||||
|
rci.mu.Lock()
|
||||||
|
defer rci.mu.Unlock()
|
||||||
|
rci.enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableHitlessUpgrades disables hitless upgrade functionality
|
||||||
|
func (rci *RedisClientIntegration) DisableHitlessUpgrades() {
|
||||||
|
rci.mu.Lock()
|
||||||
|
defer rci.mu.Unlock()
|
||||||
|
rci.enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled returns whether hitless upgrades are enabled
|
||||||
|
func (rci *RedisClientIntegration) IsEnabled() bool {
|
||||||
|
rci.mu.RLock()
|
||||||
|
defer rci.mu.RUnlock()
|
||||||
|
return rci.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// No client registration needed - each client has its own hitless integration instance
|
||||||
|
|
||||||
|
// HandlePushNotification processes push notifications for hitless upgrades
|
||||||
|
func (rci *RedisClientIntegration) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
// If disabled, just log and return without processing
|
||||||
|
internal.Logger.Printf(ctx, "hitless: received notification but hitless upgrades are disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return rci.clientIntegrator.HandlePushNotification(ctx, handlerCtx, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||||
|
func (rci *RedisClientIntegration) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return rci.connectionStateTracker.IsConnectionTransitioning(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionState returns the current state of a connection
|
||||||
|
func (rci *RedisClientIntegration) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return rci.connectionStateTracker.GetConnectionState(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpgradeStatus returns comprehensive status of all ongoing upgrades
|
||||||
|
func (rci *RedisClientIntegration) GetUpgradeStatus() *UpgradeStatus {
|
||||||
|
connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions()
|
||||||
|
|
||||||
|
return &UpgradeStatus{
|
||||||
|
ConnectionStates: connStates,
|
||||||
|
IsMoving: rci.clientIntegrator.IsMoving(),
|
||||||
|
NewEndpoint: rci.clientIntegrator.GetNewEndpoint(),
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpgradeStatistics returns statistics about upgrade operations
|
||||||
|
func (rci *RedisClientIntegration) GetUpgradeStatistics() *UpgradeStatistics {
|
||||||
|
connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions()
|
||||||
|
|
||||||
|
stats := &UpgradeStatistics{
|
||||||
|
ActiveConnections: len(connStates),
|
||||||
|
IsMoving: rci.clientIntegrator.IsMoving(),
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count by type
|
||||||
|
stats.MigratingConnections = 0
|
||||||
|
stats.FailingOverConnections = 0
|
||||||
|
for _, state := range connStates {
|
||||||
|
switch state.TransitionType {
|
||||||
|
case "MIGRATING":
|
||||||
|
stats.MigratingConnections++
|
||||||
|
case "FAILING_OVER":
|
||||||
|
stats.FailingOverConnections++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionTimeout returns the appropriate timeout for a connection
|
||||||
|
// If the connection is transitioning (MIGRATING/FAILING_OVER), returns the longer TransitionTimeout
|
||||||
|
// Otherwise returns the provided defaultTimeout
|
||||||
|
func (rci *RedisClientIntegration) GetConnectionTimeout(conn *pool.Conn, defaultTimeout time.Duration) time.Duration {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return defaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if connection is transitioning
|
||||||
|
if rci.connectionStateTracker.IsConnectionTransitioning(conn) {
|
||||||
|
// Use longer timeout for transitioning connections
|
||||||
|
return rci.config.TransitionTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionTimeouts returns both read and write timeouts for a connection
|
||||||
|
func (rci *RedisClientIntegration) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return defaultReadTimeout, defaultWriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the upgrade handler to get appropriate timeouts
|
||||||
|
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||||
|
return upgradeHandler.GetConnectionTimeouts(conn, defaultReadTimeout, defaultWriteTimeout, rci.config.TransitionTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||||
|
func (rci *RedisClientIntegration) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the upgrade handler to mark connection as blocking
|
||||||
|
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||||
|
upgradeHandler.MarkConnectionAsBlocking(conn, isBlocking)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||||
|
func (rci *RedisClientIntegration) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the upgrade handler to check if connection is marked for closing
|
||||||
|
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||||
|
return upgradeHandler.IsConnectionMarkedForClosing(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||||
|
func (rci *RedisClientIntegration) ShouldRedirectBlockingConnection(conn *pool.Conn) (bool, string) {
|
||||||
|
if !rci.IsEnabled() {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check client integrator's atomic state for pool-level redirection
|
||||||
|
if rci.clientIntegrator.IsMoving() {
|
||||||
|
return true, rci.clientIntegrator.GetNewEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check specific connection state
|
||||||
|
upgradeHandler := rci.clientIntegrator.GetUpgradeHandler()
|
||||||
|
return upgradeHandler.ShouldRedirectBlockingConnection(conn, rci.clientIntegrator)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupConnection removes tracking for a connection (called when connection is closed)
|
||||||
|
func (rci *RedisClientIntegration) CleanupConnection(conn *pool.Conn) {
|
||||||
|
if rci.IsEnabled() {
|
||||||
|
rci.connectionStateTracker.CleanupConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupPool removed - no pool state to clean up
|
||||||
|
|
||||||
|
// Close shuts down the Redis client integration
|
||||||
|
func (rci *RedisClientIntegration) Close() error {
|
||||||
|
rci.mu.Lock()
|
||||||
|
defer rci.mu.Unlock()
|
||||||
|
|
||||||
|
var firstErr error
|
||||||
|
|
||||||
|
// Close all components
|
||||||
|
if err := rci.clientIntegrator.Close(); err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
// poolRedirector removed in simplified implementation
|
||||||
|
|
||||||
|
if err := rci.connectionStateTracker.Close(); err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
rci.enabled = false
|
||||||
|
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns the current configuration
|
||||||
|
func (rci *RedisClientIntegration) GetConfig() *HitlessUpgradeConfig {
|
||||||
|
rci.mu.RLock()
|
||||||
|
defer rci.mu.RUnlock()
|
||||||
|
|
||||||
|
// Return a copy to prevent modification
|
||||||
|
configCopy := *rci.config
|
||||||
|
return &configCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConfig updates the configuration
|
||||||
|
func (rci *RedisClientIntegration) UpdateConfig(config *HitlessUpgradeConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return fmt.Errorf("config cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
rci.mu.Lock()
|
||||||
|
defer rci.mu.Unlock()
|
||||||
|
|
||||||
|
// Start with defaults for any zero values
|
||||||
|
defaults := DefaultHitlessUpgradeConfig()
|
||||||
|
|
||||||
|
// Ensure all fields are set with defaults if they are zero values
|
||||||
|
enabled := config.Enabled
|
||||||
|
transitionTimeout := config.TransitionTimeout
|
||||||
|
cleanupInterval := config.CleanupInterval
|
||||||
|
|
||||||
|
// Apply defaults for zero values
|
||||||
|
if transitionTimeout == 0 {
|
||||||
|
transitionTimeout = defaults.TransitionTimeout
|
||||||
|
}
|
||||||
|
if cleanupInterval == 0 {
|
||||||
|
cleanupInterval = defaults.CleanupInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create properly configured config
|
||||||
|
finalConfig := &HitlessUpgradeConfig{
|
||||||
|
Enabled: enabled,
|
||||||
|
TransitionTimeout: transitionTimeout,
|
||||||
|
CleanupInterval: cleanupInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
rci.config = finalConfig
|
||||||
|
rci.enabled = finalConfig.Enabled
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
500
internal/hitless/upgrade_handler.go
Normal file
500
internal/hitless/upgrade_handler.go
Normal file
@ -0,0 +1,500 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpgradeHandler handles hitless upgrade push notifications for Redis cluster upgrades.
|
||||||
|
// It implements different strategies based on notification type:
|
||||||
|
// - MOVING: Changes pool state to use new endpoint for future connections
|
||||||
|
// - mark existing connections for closing, handle pubsub to change underlying connection, change pool dialer with the new endpoint
|
||||||
|
//
|
||||||
|
// - MIGRATING/FAILING_OVER: Marks specific connection as in transition
|
||||||
|
// - relaxing timeouts
|
||||||
|
//
|
||||||
|
// - MIGRATED/FAILED_OVER: Clears transition state for specific connection
|
||||||
|
// - return to original timeouts
|
||||||
|
type UpgradeHandler struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Connection-specific state for MIGRATING/FAILING_OVER notifications
|
||||||
|
connectionStates map[*pool.Conn]*ConnectionState
|
||||||
|
|
||||||
|
// Pool-level state removed - using atomic state in ClusterUpgradeManager instead
|
||||||
|
|
||||||
|
// Client configuration for getting default timeouts
|
||||||
|
defaultReadTimeout time.Duration
|
||||||
|
defaultWriteTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState tracks the state of a specific connection during upgrades
|
||||||
|
type ConnectionState struct {
|
||||||
|
IsTransitioning bool
|
||||||
|
TransitionType string
|
||||||
|
StartTime time.Time
|
||||||
|
ShardID string
|
||||||
|
TimeoutSeconds int
|
||||||
|
|
||||||
|
// Timeout management
|
||||||
|
OriginalReadTimeout time.Duration // Original read timeout
|
||||||
|
OriginalWriteTimeout time.Duration // Original write timeout
|
||||||
|
|
||||||
|
// MOVING state specific
|
||||||
|
MarkedForClosing bool // Connection should be closed after current commands
|
||||||
|
IsBlocking bool // Connection has blocking commands
|
||||||
|
NewEndpoint string // New endpoint for blocking commands
|
||||||
|
LastCommandTime time.Time // When the last command was sent
|
||||||
|
}
|
||||||
|
|
||||||
|
// PoolState removed - using atomic state in ClusterUpgradeManager instead
|
||||||
|
|
||||||
|
// NewUpgradeHandler creates a new hitless upgrade handler with client timeout configuration
|
||||||
|
func NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout time.Duration) *UpgradeHandler {
|
||||||
|
return &UpgradeHandler{
|
||||||
|
connectionStates: make(map[*pool.Conn]*ConnectionState),
|
||||||
|
defaultReadTimeout: defaultReadTimeout,
|
||||||
|
defaultWriteTimeout: defaultWriteTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionTimeouts returns the appropriate read and write timeouts for a connection
|
||||||
|
// If the connection is transitioning, returns increased timeouts
|
||||||
|
func (h *UpgradeHandler) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout, transitionTimeout time.Duration) (time.Duration, time.Duration) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
if !exists || !state.IsTransitioning {
|
||||||
|
return defaultReadTimeout, defaultWriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// For transitioning connections (MIGRATING/FAILING_OVER), use longer timeouts
|
||||||
|
switch state.TransitionType {
|
||||||
|
case "MIGRATING", "FAILING_OVER":
|
||||||
|
return transitionTimeout, transitionTimeout
|
||||||
|
case "MOVING":
|
||||||
|
// For MOVING connections, use default timeouts but mark for special handling
|
||||||
|
return defaultReadTimeout, defaultWriteTimeout
|
||||||
|
default:
|
||||||
|
return defaultReadTimeout, defaultWriteTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkConnectionForClosing marks a connection to be closed after current commands complete
|
||||||
|
func (h *UpgradeHandler) MarkConnectionForClosing(conn *pool.Conn, newEndpoint string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
if !exists {
|
||||||
|
state = &ConnectionState{
|
||||||
|
IsTransitioning: true,
|
||||||
|
TransitionType: "MOVING",
|
||||||
|
StartTime: time.Now(),
|
||||||
|
}
|
||||||
|
h.connectionStates[conn] = state
|
||||||
|
}
|
||||||
|
|
||||||
|
state.MarkedForClosing = true
|
||||||
|
state.NewEndpoint = newEndpoint
|
||||||
|
state.LastCommandTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||||
|
func (h *UpgradeHandler) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
return exists && state.MarkedForClosing
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||||
|
func (h *UpgradeHandler) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
if !exists {
|
||||||
|
state = &ConnectionState{
|
||||||
|
IsTransitioning: false,
|
||||||
|
}
|
||||||
|
h.connectionStates[conn] = state
|
||||||
|
}
|
||||||
|
|
||||||
|
state.IsBlocking = isBlocking
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||||
|
// Uses client integrator's atomic state for pool-level checks - minimal locking
|
||||||
|
func (h *UpgradeHandler) ShouldRedirectBlockingConnection(conn *pool.Conn, clientIntegrator interface{}) (bool, string) {
|
||||||
|
if conn != nil {
|
||||||
|
// Check specific connection - need lock only for connection state
|
||||||
|
h.mu.RLock()
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
if exists && state.IsBlocking && state.TransitionType == "MOVING" && state.NewEndpoint != "" {
|
||||||
|
return true, state.NewEndpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check client integrator's atomic state - no locks needed
|
||||||
|
if ci, ok := clientIntegrator.(*ClientIntegrator); ok && ci != nil && ci.IsMoving() {
|
||||||
|
return true, ci.GetNewEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldRedirectNewBlockingConnection removed - functionality merged into ShouldRedirectBlockingConnection
|
||||||
|
|
||||||
|
// HandlePushNotification processes hitless upgrade push notifications
|
||||||
|
func (h *UpgradeHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) == 0 {
|
||||||
|
return fmt.Errorf("hitless: empty notification received")
|
||||||
|
}
|
||||||
|
|
||||||
|
notificationType, ok := notification[0].(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("hitless: notification type is not a string: %T", notification[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: processing %s notification with %d elements", notificationType, len(notification))
|
||||||
|
|
||||||
|
switch notificationType {
|
||||||
|
case "MOVING":
|
||||||
|
return h.handleMovingNotification(ctx, handlerCtx, notification)
|
||||||
|
case "MIGRATING":
|
||||||
|
return h.handleMigratingNotification(ctx, handlerCtx, notification)
|
||||||
|
case "MIGRATED":
|
||||||
|
return h.handleMigratedNotification(ctx, handlerCtx, notification)
|
||||||
|
case "FAILING_OVER":
|
||||||
|
return h.handleFailingOverNotification(ctx, handlerCtx, notification)
|
||||||
|
case "FAILED_OVER":
|
||||||
|
return h.handleFailedOverNotification(ctx, handlerCtx, notification)
|
||||||
|
default:
|
||||||
|
internal.Logger.Printf(ctx, "hitless: unknown notification type: %s", notificationType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMovingNotification processes MOVING notifications that affect the entire pool
|
||||||
|
// Format: ["MOVING", time_seconds, "new_endpoint"]
|
||||||
|
func (h *UpgradeHandler) handleMovingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) < 3 {
|
||||||
|
return fmt.Errorf("hitless: MOVING notification requires at least 3 elements, got %d", len(notification))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse timeout
|
||||||
|
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse timeout for MOVING notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse new endpoint
|
||||||
|
newEndpoint, ok := notification[2].(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("hitless: new endpoint is not a string: %T", notification[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: MOVING notification - endpoint will move to %s in %d seconds", newEndpoint, timeoutSeconds)
|
||||||
|
h.mu.Lock()
|
||||||
|
// Mark all existing connections for closing after current commands complete
|
||||||
|
for _, state := range h.connectionStates {
|
||||||
|
state.MarkedForClosing = true
|
||||||
|
state.NewEndpoint = newEndpoint
|
||||||
|
state.TransitionType = "MOVING"
|
||||||
|
state.LastCommandTime = time.Now()
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: marked existing connections for closing, new blocking commands will use %s", newEndpoint)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removed complex helper methods - simplified to direct inline logic
|
||||||
|
|
||||||
|
// handleMigratingNotification processes MIGRATING notifications for specific connections
|
||||||
|
// Format: ["MIGRATING", time_seconds, shard_id]
|
||||||
|
func (h *UpgradeHandler) handleMigratingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) < 3 {
|
||||||
|
return fmt.Errorf("hitless: MIGRATING notification requires at least 3 elements, got %d", len(notification))
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse timeout for MIGRATING notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shardID, err := h.parseShardID(notification[2])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATING notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := handlerCtx.Conn
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("hitless: no connection available in handler context")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: MIGRATING notification - shard %s will migrate in %d seconds on connection %p", shardID, timeoutSeconds, conn)
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Store original timeouts if not already stored
|
||||||
|
var originalReadTimeout, originalWriteTimeout time.Duration
|
||||||
|
if existingState, exists := h.connectionStates[conn]; exists {
|
||||||
|
originalReadTimeout = existingState.OriginalReadTimeout
|
||||||
|
originalWriteTimeout = existingState.OriginalWriteTimeout
|
||||||
|
} else {
|
||||||
|
// Get default timeouts from client configuration
|
||||||
|
originalReadTimeout = h.defaultReadTimeout
|
||||||
|
originalWriteTimeout = h.defaultWriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
h.connectionStates[conn] = &ConnectionState{
|
||||||
|
IsTransitioning: true,
|
||||||
|
TransitionType: "MIGRATING",
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ShardID: shardID,
|
||||||
|
TimeoutSeconds: timeoutSeconds,
|
||||||
|
OriginalReadTimeout: originalReadTimeout,
|
||||||
|
OriginalWriteTimeout: originalWriteTimeout,
|
||||||
|
LastCommandTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: connection %p marked as MIGRATING with increased timeouts", conn)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMigratedNotification processes MIGRATED notifications for specific connections
|
||||||
|
// Format: ["MIGRATED", shard_id]
|
||||||
|
func (h *UpgradeHandler) handleMigratedNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) < 2 {
|
||||||
|
return fmt.Errorf("hitless: MIGRATED notification requires at least 2 elements, got %d", len(notification))
|
||||||
|
}
|
||||||
|
|
||||||
|
shardID, err := h.parseShardID(notification[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATED notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := handlerCtx.Conn
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("hitless: no connection available in handler context")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: MIGRATED notification - shard %s migration completed on connection %p", shardID, conn)
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Clear the transitioning state for this connection and restore original timeouts
|
||||||
|
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "MIGRATING" && state.ShardID == shardID {
|
||||||
|
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
|
||||||
|
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
|
||||||
|
|
||||||
|
// In a real implementation, this would restore the connection's original timeouts
|
||||||
|
// For now, we'll just log and delete the state
|
||||||
|
delete(h.connectionStates, conn)
|
||||||
|
internal.Logger.Printf(ctx, "hitless: cleared MIGRATING state and restored timeouts for connection %p", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailingOverNotification processes FAILING_OVER notifications for specific connections
|
||||||
|
// Format: ["FAILING_OVER", time_seconds, shard_id]
|
||||||
|
func (h *UpgradeHandler) handleFailingOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) < 3 {
|
||||||
|
return fmt.Errorf("hitless: FAILING_OVER notification requires at least 3 elements, got %d", len(notification))
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse timeout for FAILING_OVER notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shardID, err := h.parseShardID(notification[2])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse shard ID for FAILING_OVER notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := handlerCtx.Conn
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("hitless: no connection available in handler context")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: FAILING_OVER notification - shard %s will failover in %d seconds on connection %p", shardID, timeoutSeconds, conn)
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Store original timeouts if not already stored
|
||||||
|
var originalReadTimeout, originalWriteTimeout time.Duration
|
||||||
|
if existingState, exists := h.connectionStates[conn]; exists {
|
||||||
|
originalReadTimeout = existingState.OriginalReadTimeout
|
||||||
|
originalWriteTimeout = existingState.OriginalWriteTimeout
|
||||||
|
} else {
|
||||||
|
// Get default timeouts from client configuration
|
||||||
|
originalReadTimeout = h.defaultReadTimeout
|
||||||
|
originalWriteTimeout = h.defaultWriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
h.connectionStates[conn] = &ConnectionState{
|
||||||
|
IsTransitioning: true,
|
||||||
|
TransitionType: "FAILING_OVER",
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ShardID: shardID,
|
||||||
|
TimeoutSeconds: timeoutSeconds,
|
||||||
|
OriginalReadTimeout: originalReadTimeout,
|
||||||
|
OriginalWriteTimeout: originalWriteTimeout,
|
||||||
|
LastCommandTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: connection %p marked as FAILING_OVER with increased timeouts", conn)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailedOverNotification processes FAILED_OVER notifications for specific connections
|
||||||
|
// Format: ["FAILED_OVER", shard_id]
|
||||||
|
func (h *UpgradeHandler) handleFailedOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) < 2 {
|
||||||
|
return fmt.Errorf("hitless: FAILED_OVER notification requires at least 2 elements, got %d", len(notification))
|
||||||
|
}
|
||||||
|
|
||||||
|
shardID, err := h.parseShardID(notification[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hitless: failed to parse shard ID for FAILED_OVER notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := handlerCtx.Conn
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("hitless: no connection available in handler context")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: FAILED_OVER notification - shard %s failover completed on connection %p", shardID, conn)
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Clear the transitioning state for this connection and restore original timeouts
|
||||||
|
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "FAILING_OVER" && state.ShardID == shardID {
|
||||||
|
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
|
||||||
|
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
|
||||||
|
|
||||||
|
// In a real implementation, this would restore the connection's original timeouts
|
||||||
|
// For now, we'll just log and delete the state
|
||||||
|
delete(h.connectionStates, conn)
|
||||||
|
internal.Logger.Printf(ctx, "hitless: cleared FAILING_OVER state and restored timeouts for connection %p", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTimeoutSeconds parses timeout value from notification
|
||||||
|
func (h *UpgradeHandler) parseTimeoutSeconds(value interface{}) (int, error) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return int(v), nil
|
||||||
|
case int:
|
||||||
|
return v, nil
|
||||||
|
case string:
|
||||||
|
return strconv.Atoi(v)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported timeout type: %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseShardID parses shard ID from notification
|
||||||
|
func (h *UpgradeHandler) parseShardID(value interface{}) (string, error) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return v, nil
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(v, 10), nil
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(v), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported shard ID type: %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnectionState returns the current state of a connection
|
||||||
|
func (h *UpgradeHandler) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy to avoid race conditions
|
||||||
|
stateCopy := *state
|
||||||
|
return &stateCopy, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||||
|
func (h *UpgradeHandler) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
state, exists := h.connectionStates[conn]
|
||||||
|
return exists && state.IsTransitioning
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPoolMoving and GetNewEndpoint removed - using atomic state in ClusterUpgradeManager instead
|
||||||
|
|
||||||
|
// CleanupExpiredStates removes expired connection and pool states
|
||||||
|
func (h *UpgradeHandler) CleanupExpiredStates() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Cleanup expired connection states
|
||||||
|
for conn, state := range h.connectionStates {
|
||||||
|
timeout := time.Duration(state.TimeoutSeconds) * time.Second
|
||||||
|
if now.Sub(state.StartTime) > timeout {
|
||||||
|
delete(h.connectionStates, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pool state cleanup removed - using atomic state in ClusterUpgradeManager instead
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupConnection removes state for a specific connection (called when connection is closed)
|
||||||
|
func (h *UpgradeHandler) CleanupConnection(conn *pool.Conn) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
delete(h.connectionStates, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveTransitions returns information about all active connection transitions
|
||||||
|
func (h *UpgradeHandler) GetActiveTransitions() map[*pool.Conn]*ConnectionState {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create copies to avoid race conditions
|
||||||
|
connStates := make(map[*pool.Conn]*ConnectionState)
|
||||||
|
for conn, state := range h.connectionStates {
|
||||||
|
stateCopy := *state
|
||||||
|
connStates[conn] = &stateCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
return connStates
|
||||||
|
}
|
@ -565,3 +565,24 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
|||||||
cn.SetUsedAt(now)
|
cn.SetUsedAt(now)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDialer returns the current dialer function for the pool
|
||||||
|
func (p *ConnPool) GetDialer() func(context.Context) (net.Conn, error) {
|
||||||
|
p.connsMu.Lock()
|
||||||
|
defer p.connsMu.Unlock()
|
||||||
|
return p.cfg.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDialer sets a new dialer function for the pool
|
||||||
|
// This is used for hitless upgrades to redirect new connections to new endpoints
|
||||||
|
func (p *ConnPool) SetDialer(dialer func(context.Context) (net.Conn, error)) error {
|
||||||
|
if p.closed() {
|
||||||
|
return ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
p.connsMu.Lock()
|
||||||
|
defer p.connsMu.Unlock()
|
||||||
|
|
||||||
|
p.cfg.Dialer = dialer
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package proto
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -215,9 +216,9 @@ func TestPeekPushNotificationName(t *testing.T) {
|
|||||||
// This is acceptable behavior for malformed input
|
// This is acceptable behavior for malformed input
|
||||||
name, err := reader.PeekPushNotificationName()
|
name, err := reader.PeekPushNotificationName()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v", tc.name, err)
|
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data)
|
||||||
} else {
|
} else {
|
||||||
t.Logf("PeekPushNotificationName returned '%s' for corrupted data %s", name, tc.name)
|
t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -293,15 +294,27 @@ func TestPeekPushNotificationName(t *testing.T) {
|
|||||||
func createValidPushNotification(notificationName, data string) *bytes.Buffer {
|
func createValidPushNotification(notificationName, data string) *bytes.Buffer {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
simpleOrString := rand.Intn(2) == 0
|
||||||
|
|
||||||
if data == "" {
|
if data == "" {
|
||||||
|
|
||||||
// Single element notification
|
// Single element notification
|
||||||
buf.WriteString(">1\r\n")
|
buf.WriteString(">1\r\n")
|
||||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
if simpleOrString {
|
||||||
|
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
|
||||||
|
} else {
|
||||||
|
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Two element notification
|
// Two element notification
|
||||||
buf.WriteString(">2\r\n")
|
buf.WriteString(">2\r\n")
|
||||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
if simpleOrString {
|
||||||
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(data), data))
|
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
|
||||||
|
buf.WriteString(fmt.Sprintf("+%s\r\n", data))
|
||||||
|
} else {
|
||||||
|
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||||
|
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf
|
return buf
|
||||||
|
@ -116,26 +116,55 @@ func (r *Reader) PeekPushNotificationName() (string, error) {
|
|||||||
if buf[0] != RespPush {
|
if buf[0] != RespPush {
|
||||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||||
}
|
}
|
||||||
// remove push notification type and length
|
|
||||||
buf = buf[2:]
|
if len(buf) < 3 {
|
||||||
|
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove push notification type
|
||||||
|
buf = buf[1:]
|
||||||
|
// remove first line - e.g. >2\r\n
|
||||||
for i := 0; i < len(buf)-1; i++ {
|
for i := 0; i < len(buf)-1; i++ {
|
||||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||||
buf = buf[i+2:]
|
buf = buf[i+2:]
|
||||||
break
|
break
|
||||||
|
} else {
|
||||||
|
if buf[i] < '0' || buf[i] > '9' {
|
||||||
|
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(buf) < 2 {
|
||||||
|
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||||
|
}
|
||||||
|
// next line should be $<length><string>\r\n or +<length><string>\r\n
|
||||||
// should have the type of the push notification name and it's length
|
// should have the type of the push notification name and it's length
|
||||||
if buf[0] != RespString {
|
if buf[0] != RespString && buf[0] != RespStatus {
|
||||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||||
}
|
}
|
||||||
// skip the length of the string
|
typeOfName := buf[0]
|
||||||
for i := 0; i < len(buf)-1; i++ {
|
// remove the type of the push notification name
|
||||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
buf = buf[1:]
|
||||||
buf = buf[i+2:]
|
if typeOfName == RespString {
|
||||||
break
|
// remove the length of the string
|
||||||
|
if len(buf) < 2 {
|
||||||
|
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||||
|
}
|
||||||
|
for i := 0; i < len(buf)-1; i++ {
|
||||||
|
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||||
|
buf = buf[i+2:]
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if buf[i] < '0' || buf[i] > '9' {
|
||||||
|
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(buf) < 2 {
|
||||||
|
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||||
|
}
|
||||||
// keep only the notification name
|
// keep only the notification name
|
||||||
for i := 0; i < len(buf)-1; i++ {
|
for i := 0; i < len(buf)-1; i++ {
|
||||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||||
@ -143,6 +172,7 @@ func (r *Reader) PeekPushNotificationName() (string, error) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.BytesToString(buf), nil
|
return util.BytesToString(buf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
12
options.go
12
options.go
@ -224,6 +224,18 @@ type Options struct {
|
|||||||
// PushNotificationProcessor is the processor for handling push notifications.
|
// PushNotificationProcessor is the processor for handling push notifications.
|
||||||
// If nil, a default processor will be created for RESP3 connections.
|
// If nil, a default processor will be created for RESP3 connections.
|
||||||
PushNotificationProcessor push.NotificationProcessor
|
PushNotificationProcessor push.NotificationProcessor
|
||||||
|
|
||||||
|
// HitlessUpgrades enables hitless upgrade functionality for cluster upgrades.
|
||||||
|
// Requires Protocol: 3 (RESP3) for push notifications.
|
||||||
|
// When enabled, the client will automatically handle cluster upgrade notifications
|
||||||
|
// and manage connection/pool state transitions seamlessly.
|
||||||
|
//
|
||||||
|
// default: false
|
||||||
|
HitlessUpgrades bool
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||||
|
// If nil, default configuration will be used when HitlessUpgrades is true.
|
||||||
|
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *Options) init() {
|
func (opt *Options) init() {
|
||||||
|
@ -110,6 +110,18 @@ type ClusterOptions struct {
|
|||||||
|
|
||||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||||
UnstableResp3 bool
|
UnstableResp3 bool
|
||||||
|
|
||||||
|
// HitlessUpgrades enables hitless upgrade functionality for cluster upgrades.
|
||||||
|
// Requires Protocol: 3 (RESP3) for push notifications.
|
||||||
|
// When enabled, the client will automatically handle cluster upgrade notifications
|
||||||
|
// and manage connection/pool state transitions seamlessly.
|
||||||
|
//
|
||||||
|
// default: false
|
||||||
|
HitlessUpgrades bool
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||||
|
// If nil, default configuration will be used when HitlessUpgrades is true.
|
||||||
|
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *ClusterOptions) init() {
|
func (opt *ClusterOptions) init() {
|
||||||
@ -327,8 +339,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
|||||||
// much use for ClusterSlots config). This means we cannot execute the
|
// much use for ClusterSlots config). This means we cannot execute the
|
||||||
// READONLY command against that node -- setting readOnly to false in such
|
// READONLY command against that node -- setting readOnly to false in such
|
||||||
// situations in the options below will prevent that from happening.
|
// situations in the options below will prevent that from happening.
|
||||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||||
UnstableResp3: opt.UnstableResp3,
|
UnstableResp3: opt.UnstableResp3,
|
||||||
|
HitlessUpgrades: opt.HitlessUpgrades,
|
||||||
|
HitlessUpgradeConfig: opt.HitlessUpgradeConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -943,6 +957,9 @@ type ClusterClient struct {
|
|||||||
cmdsInfoCache *cmdsInfoCache
|
cmdsInfoCache *cmdsInfoCache
|
||||||
cmdable
|
cmdable
|
||||||
hooksMixin
|
hooksMixin
|
||||||
|
|
||||||
|
// hitlessIntegration provides hitless upgrade functionality
|
||||||
|
hitlessIntegration HitlessIntegration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClusterClient returns a Redis Cluster client as described in
|
// NewClusterClient returns a Redis Cluster client as described in
|
||||||
@ -969,6 +986,22 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
|
|||||||
txPipeline: c.processTxPipeline,
|
txPipeline: c.processTxPipeline,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize hitless upgrades if enabled
|
||||||
|
if opt.HitlessUpgrades {
|
||||||
|
if opt.Protocol != 3 {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
|
||||||
|
} else {
|
||||||
|
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
|
||||||
|
integration, err := initializeHitlessIntegration(c, opt.HitlessUpgradeConfig, timeoutProvider)
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||||
|
} else {
|
||||||
|
c.hitlessIntegration = integration
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for cluster client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -977,6 +1010,14 @@ func (c *ClusterClient) Options() *ClusterOptions {
|
|||||||
return c.opt
|
return c.opt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
|
||||||
|
// Returns nil if hitless upgrades are not enabled.
|
||||||
|
func (c *ClusterClient) GetHitlessIntegration() HitlessIntegration {
|
||||||
|
return c.hitlessIntegration
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPushNotificationProcessor removed - not needed in simplified implementation
|
||||||
|
|
||||||
// ReloadState reloads cluster state. If available it calls ClusterSlots func
|
// ReloadState reloads cluster state. If available it calls ClusterSlots func
|
||||||
// to get cluster slots information.
|
// to get cluster slots information.
|
||||||
func (c *ClusterClient) ReloadState(ctx context.Context) {
|
func (c *ClusterClient) ReloadState(ctx context.Context) {
|
||||||
@ -988,6 +1029,13 @@ func (c *ClusterClient) ReloadState(ctx context.Context) {
|
|||||||
// It is rare to Close a ClusterClient, as the ClusterClient is meant
|
// It is rare to Close a ClusterClient, as the ClusterClient is meant
|
||||||
// to be long-lived and shared between many goroutines.
|
// to be long-lived and shared between many goroutines.
|
||||||
func (c *ClusterClient) Close() error {
|
func (c *ClusterClient) Close() error {
|
||||||
|
// Close hitless integration first
|
||||||
|
if c.hitlessIntegration != nil {
|
||||||
|
if err := c.hitlessIntegration.Close(); err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return c.nodes.Close()
|
return c.nodes.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,9 +13,6 @@ type NotificationHandlerContext struct {
|
|||||||
// circular dependencies. The developer is responsible for type assertion.
|
// circular dependencies. The developer is responsible for type assertion.
|
||||||
// It can be one of the following types:
|
// It can be one of the following types:
|
||||||
// - *redis.baseClient
|
// - *redis.baseClient
|
||||||
// - *redis.Client
|
|
||||||
// - *redis.ClusterClient
|
|
||||||
// - *redis.Conn
|
|
||||||
Client interface{}
|
Client interface{}
|
||||||
|
|
||||||
// ConnPool is the connection pool from which the connection was obtained.
|
// ConnPool is the connection pool from which the connection was obtained.
|
||||||
@ -25,7 +22,7 @@ type NotificationHandlerContext struct {
|
|||||||
// - *pool.ConnPool
|
// - *pool.ConnPool
|
||||||
// - *pool.SingleConnPool
|
// - *pool.SingleConnPool
|
||||||
// - *pool.StickyConnPool
|
// - *pool.StickyConnPool
|
||||||
ConnPool interface{}
|
ConnPool pool.Pooler
|
||||||
|
|
||||||
// PubSub is the PubSub instance that received the notification.
|
// PubSub is the PubSub instance that received the notification.
|
||||||
// It is interface to both allow for future expansion and to avoid
|
// It is interface to both allow for future expansion and to avoid
|
||||||
|
193
redis.go
193
redis.go
@ -2,6 +2,7 @@ package redis
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -211,6 +212,9 @@ type baseClient struct {
|
|||||||
|
|
||||||
// Push notification processing
|
// Push notification processing
|
||||||
pushProcessor push.NotificationProcessor
|
pushProcessor push.NotificationProcessor
|
||||||
|
|
||||||
|
// hitlessIntegration provides hitless upgrade functionality
|
||||||
|
hitlessIntegration HitlessIntegration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) clone() *baseClient {
|
func (c *baseClient) clone() *baseClient {
|
||||||
@ -466,7 +470,16 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
|
|||||||
// Push notification processing errors shouldn't break normal Redis operations
|
// Push notification processing errors shouldn't break normal Redis operations
|
||||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
|
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
|
||||||
}
|
}
|
||||||
c.connPool.Put(ctx, cn)
|
|
||||||
|
// Check if connection is marked for closing due to hitless upgrades
|
||||||
|
if c.hitlessIntegration != nil && c.hitlessIntegration.IsConnectionMarkedForClosing(cn) {
|
||||||
|
// Connection is marked for closing (e.g., during MOVING state)
|
||||||
|
// Remove it instead of putting it back in the pool
|
||||||
|
internal.Logger.Printf(ctx, "hitless: closing connection marked for closure during upgrade")
|
||||||
|
c.connPool.Remove(ctx, cn, nil)
|
||||||
|
} else {
|
||||||
|
c.connPool.Put(ctx, cn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -528,7 +541,33 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
retryTimeout := uint32(0)
|
retryTimeout := uint32(0)
|
||||||
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
|
|
||||||
|
// Check if this is a blocking command that needs redirection
|
||||||
|
isBlockingCommand := cmd.readTimeout() != nil
|
||||||
|
var redirectedConn *pool.Conn
|
||||||
|
var shouldRedirect bool
|
||||||
|
var newEndpoint string
|
||||||
|
|
||||||
|
if c.hitlessIntegration != nil && isBlockingCommand {
|
||||||
|
// For blocking commands, check if we need to redirect to a new endpoint
|
||||||
|
// This happens during MOVING state when the endpoint is changing
|
||||||
|
shouldRedirect, newEndpoint = c.hitlessIntegration.ShouldRedirectBlockingConnection(nil)
|
||||||
|
if shouldRedirect {
|
||||||
|
// Create a new connection to the new endpoint
|
||||||
|
var err error
|
||||||
|
redirectedConn, err = c.createConnectionToEndpoint(ctx, newEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(ctx, "hitless: failed to create redirected connection to %s: %v", newEndpoint, err)
|
||||||
|
// Fall back to normal connection if redirection fails
|
||||||
|
shouldRedirect = false
|
||||||
|
} else {
|
||||||
|
internal.Logger.Printf(ctx, "hitless: redirecting blocking command %s to new endpoint %s", cmd.Name(), newEndpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use redirected connection if available, otherwise use normal connection
|
||||||
|
connFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||||
// Process any pending push notifications before executing the command
|
// Process any pending push notifications before executing the command
|
||||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||||
// Log the error but don't fail the command execution
|
// Log the error but don't fail the command execution
|
||||||
@ -536,7 +575,19 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
|||||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
|
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
// Mark connection as blocking if this is a blocking command
|
||||||
|
if c.hitlessIntegration != nil && isBlockingCommand {
|
||||||
|
c.hitlessIntegration.MarkConnectionAsBlocking(cn, true)
|
||||||
|
internal.Logger.Printf(ctx, "hitless: marked connection as blocking for command %s", cmd.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get appropriate write timeout for this connection
|
||||||
|
writeTimeout := c.opt.WriteTimeout
|
||||||
|
if c.hitlessIntegration != nil {
|
||||||
|
_, writeTimeout = c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
|
||||||
return writeCmd(wr, cmd)
|
return writeCmd(wr, cmd)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
atomic.StoreUint32(&retryTimeout, 1)
|
atomic.StoreUint32(&retryTimeout, 1)
|
||||||
@ -547,7 +598,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
|||||||
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
|
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
|
||||||
readReplyFunc = cmd.readRawReply
|
readReplyFunc = cmd.readRawReply
|
||||||
}
|
}
|
||||||
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error {
|
if err := cn.WithReader(c.context(ctx), c.cmdTimeoutForConnection(cmd, cn), func(rd *proto.Reader) error {
|
||||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||||
// Log the error but don't fail the command execution
|
// Log the error but don't fail the command execution
|
||||||
@ -564,8 +615,31 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmark connection as blocking after command completes
|
||||||
|
if c.hitlessIntegration != nil && isBlockingCommand {
|
||||||
|
c.hitlessIntegration.MarkConnectionAsBlocking(cn, false)
|
||||||
|
internal.Logger.Printf(ctx, "hitless: unmarked connection as blocking after command %s completed", cmd.Name())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}
|
||||||
|
|
||||||
|
// Execute the command with either redirected or normal connection
|
||||||
|
var err error
|
||||||
|
if shouldRedirect && redirectedConn != nil {
|
||||||
|
// Use the redirected connection for blocking command
|
||||||
|
err = connFunc(ctx, redirectedConn)
|
||||||
|
// Close the redirected connection after use
|
||||||
|
defer func() {
|
||||||
|
redirectedConn.Close()
|
||||||
|
internal.Logger.Printf(ctx, "hitless: closed redirected connection to %s after blocking command completed", newEndpoint)
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
// Use normal connection pool
|
||||||
|
err = c.withConn(ctx, connFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
|
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
|
||||||
return retry, err
|
return retry, err
|
||||||
}
|
}
|
||||||
@ -588,6 +662,70 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
|||||||
return c.opt.ReadTimeout
|
return c.opt.ReadTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cmdTimeoutForConnection returns the appropriate read timeout for a specific connection
|
||||||
|
// taking into account hitless upgrade state
|
||||||
|
func (c *baseClient) cmdTimeoutForConnection(cmd Cmder, cn *pool.Conn) time.Duration {
|
||||||
|
baseTimeout := c.cmdTimeout(cmd)
|
||||||
|
|
||||||
|
// If hitless upgrades are enabled, get dynamic timeout based on connection state
|
||||||
|
if c.hitlessIntegration != nil {
|
||||||
|
// For blocking commands, use the command's timeout but check if connection needs increased timeout
|
||||||
|
if cmd.readTimeout() != nil {
|
||||||
|
// For blocking commands, use the base timeout but apply hitless upgrade adjustments
|
||||||
|
adjustedTimeout := c.hitlessIntegration.GetConnectionTimeout(cn, baseTimeout)
|
||||||
|
return adjustedTimeout
|
||||||
|
} else {
|
||||||
|
// For regular commands, get both read and write timeouts (use read timeout for command)
|
||||||
|
readTimeout, _ := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||||
|
return readTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// createConnectionToEndpoint creates a new connection to a specific endpoint
|
||||||
|
// This is used for redirecting blocking commands during MOVING state
|
||||||
|
func (c *baseClient) createConnectionToEndpoint(ctx context.Context, endpoint string) (*pool.Conn, error) {
|
||||||
|
// Parse the endpoint to get host and port
|
||||||
|
addr := endpoint
|
||||||
|
if addr == "" {
|
||||||
|
return nil, fmt.Errorf("empty endpoint provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary dialer for the new endpoint
|
||||||
|
dialer := func(ctx context.Context) (net.Conn, error) {
|
||||||
|
netDialer := &net.Dialer{
|
||||||
|
Timeout: c.opt.DialTimeout,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.opt.TLSConfig == nil {
|
||||||
|
return netDialer.DialContext(ctx, c.opt.Network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tls.DialWithDialer(netDialer, c.opt.Network, addr, c.opt.TLSConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new connection using the dialer
|
||||||
|
netConn, err := dialer(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial new endpoint %s: %w", endpoint, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap in pool.Conn
|
||||||
|
cn := pool.NewConn(netConn)
|
||||||
|
|
||||||
|
// Initialize the connection (auth, select db, etc.)
|
||||||
|
if err := c.initConn(ctx, cn); err != nil {
|
||||||
|
cn.Close()
|
||||||
|
return nil, fmt.Errorf("failed to initialize connection to %s: %w", endpoint, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.Logger.Printf(ctx, "hitless: created new connection to endpoint %s for blocking command redirection", endpoint)
|
||||||
|
return cn, nil
|
||||||
|
}
|
||||||
|
|
||||||
// context returns the context for the current connection.
|
// context returns the context for the current connection.
|
||||||
// If the context timeout is enabled, it returns the original context.
|
// If the context timeout is enabled, it returns the original context.
|
||||||
// Otherwise, it returns a new background context.
|
// Otherwise, it returns a new background context.
|
||||||
@ -604,8 +742,19 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
|||||||
// long-lived and shared between many goroutines.
|
// long-lived and shared between many goroutines.
|
||||||
func (c *baseClient) Close() error {
|
func (c *baseClient) Close() error {
|
||||||
var firstErr error
|
var firstErr error
|
||||||
|
|
||||||
|
// Close hitless integration first
|
||||||
|
if c.hitlessIntegration != nil {
|
||||||
|
if err := c.hitlessIntegration.Close(); err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.onClose != nil {
|
if c.onClose != nil {
|
||||||
if err := c.onClose(); err != nil {
|
if err := c.onClose(); err != nil && firstErr == nil {
|
||||||
firstErr = err
|
firstErr = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -678,14 +827,15 @@ func (c *baseClient) pipelineProcessCmds(
|
|||||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err)
|
internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||||
|
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
|
||||||
return writeCmds(wr, cmds)
|
return writeCmds(wr, cmds)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
setCmdsErr(cmds, err)
|
setCmdsErr(cmds, err)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
|
if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
|
||||||
// read all replies
|
// read all replies
|
||||||
return c.pipelineReadCmds(ctx, cn, rd, cmds)
|
return c.pipelineReadCmds(ctx, cn, rd, cmds)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@ -725,14 +875,15 @@ func (c *baseClient) txPipelineProcessCmds(
|
|||||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
|
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
|
||||||
|
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
|
||||||
return writeCmds(wr, cmds)
|
return writeCmds(wr, cmds)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
setCmdsErr(cmds, err)
|
setCmdsErr(cmds, err)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
|
if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
|
||||||
statusCmd := cmds[0].(*StatusCmd)
|
statusCmd := cmds[0].(*StatusCmd)
|
||||||
// Trim multi and exec.
|
// Trim multi and exec.
|
||||||
trimmedCmds := cmds[1 : len(cmds)-1]
|
trimmedCmds := cmds[1 : len(cmds)-1]
|
||||||
@ -837,6 +988,22 @@ func NewClient(opt *Options) *Client {
|
|||||||
|
|
||||||
c.connPool = newConnPool(opt, c.dialHook)
|
c.connPool = newConnPool(opt, c.dialHook)
|
||||||
|
|
||||||
|
// Initialize hitless upgrades if enabled
|
||||||
|
if opt.HitlessUpgrades {
|
||||||
|
if opt.Protocol != 3 {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
|
||||||
|
} else {
|
||||||
|
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
|
||||||
|
integration, err := initializeHitlessIntegration(&c, opt.HitlessUpgradeConfig, timeoutProvider)
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||||
|
} else {
|
||||||
|
c.hitlessIntegration = integration
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -857,6 +1024,12 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
|||||||
return &clone
|
return &clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
|
||||||
|
// Returns nil if hitless upgrades are not enabled.
|
||||||
|
func (c *Client) GetHitlessIntegration() HitlessIntegration {
|
||||||
|
return c.hitlessIntegration
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) Conn() *Conn {
|
func (c *Client) Conn() *Conn {
|
||||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
|
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user