mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
wip.
This commit is contained in:
500
internal/hitless/upgrade_handler.go
Normal file
500
internal/hitless/upgrade_handler.go
Normal file
@ -0,0 +1,500 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// UpgradeHandler handles hitless upgrade push notifications for Redis cluster upgrades.
|
||||
// It implements different strategies based on notification type:
|
||||
// - MOVING: Changes pool state to use new endpoint for future connections
|
||||
// - mark existing connections for closing, handle pubsub to change underlying connection, change pool dialer with the new endpoint
|
||||
//
|
||||
// - MIGRATING/FAILING_OVER: Marks specific connection as in transition
|
||||
// - relaxing timeouts
|
||||
//
|
||||
// - MIGRATED/FAILED_OVER: Clears transition state for specific connection
|
||||
// - return to original timeouts
|
||||
type UpgradeHandler struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Connection-specific state for MIGRATING/FAILING_OVER notifications
|
||||
connectionStates map[*pool.Conn]*ConnectionState
|
||||
|
||||
// Pool-level state removed - using atomic state in ClusterUpgradeManager instead
|
||||
|
||||
// Client configuration for getting default timeouts
|
||||
defaultReadTimeout time.Duration
|
||||
defaultWriteTimeout time.Duration
|
||||
}
|
||||
|
||||
// ConnectionState tracks the state of a specific connection during upgrades
|
||||
type ConnectionState struct {
|
||||
IsTransitioning bool
|
||||
TransitionType string
|
||||
StartTime time.Time
|
||||
ShardID string
|
||||
TimeoutSeconds int
|
||||
|
||||
// Timeout management
|
||||
OriginalReadTimeout time.Duration // Original read timeout
|
||||
OriginalWriteTimeout time.Duration // Original write timeout
|
||||
|
||||
// MOVING state specific
|
||||
MarkedForClosing bool // Connection should be closed after current commands
|
||||
IsBlocking bool // Connection has blocking commands
|
||||
NewEndpoint string // New endpoint for blocking commands
|
||||
LastCommandTime time.Time // When the last command was sent
|
||||
}
|
||||
|
||||
// PoolState removed - using atomic state in ClusterUpgradeManager instead
|
||||
|
||||
// NewUpgradeHandler creates a new hitless upgrade handler with client timeout configuration
|
||||
func NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout time.Duration) *UpgradeHandler {
|
||||
return &UpgradeHandler{
|
||||
connectionStates: make(map[*pool.Conn]*ConnectionState),
|
||||
defaultReadTimeout: defaultReadTimeout,
|
||||
defaultWriteTimeout: defaultWriteTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectionTimeouts returns the appropriate read and write timeouts for a connection
|
||||
// If the connection is transitioning, returns increased timeouts
|
||||
func (h *UpgradeHandler) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout, transitionTimeout time.Duration) (time.Duration, time.Duration) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists || !state.IsTransitioning {
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
}
|
||||
|
||||
// For transitioning connections (MIGRATING/FAILING_OVER), use longer timeouts
|
||||
switch state.TransitionType {
|
||||
case "MIGRATING", "FAILING_OVER":
|
||||
return transitionTimeout, transitionTimeout
|
||||
case "MOVING":
|
||||
// For MOVING connections, use default timeouts but mark for special handling
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
default:
|
||||
return defaultReadTimeout, defaultWriteTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// MarkConnectionForClosing marks a connection to be closed after current commands complete
|
||||
func (h *UpgradeHandler) MarkConnectionForClosing(conn *pool.Conn, newEndpoint string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists {
|
||||
state = &ConnectionState{
|
||||
IsTransitioning: true,
|
||||
TransitionType: "MOVING",
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
h.connectionStates[conn] = state
|
||||
}
|
||||
|
||||
state.MarkedForClosing = true
|
||||
state.NewEndpoint = newEndpoint
|
||||
state.LastCommandTime = time.Now()
|
||||
}
|
||||
|
||||
// IsConnectionMarkedForClosing checks if a connection should be closed
|
||||
func (h *UpgradeHandler) IsConnectionMarkedForClosing(conn *pool.Conn) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
return exists && state.MarkedForClosing
|
||||
}
|
||||
|
||||
// MarkConnectionAsBlocking marks a connection as having blocking commands
|
||||
func (h *UpgradeHandler) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists {
|
||||
state = &ConnectionState{
|
||||
IsTransitioning: false,
|
||||
}
|
||||
h.connectionStates[conn] = state
|
||||
}
|
||||
|
||||
state.IsBlocking = isBlocking
|
||||
}
|
||||
|
||||
// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected
|
||||
// Uses client integrator's atomic state for pool-level checks - minimal locking
|
||||
func (h *UpgradeHandler) ShouldRedirectBlockingConnection(conn *pool.Conn, clientIntegrator interface{}) (bool, string) {
|
||||
if conn != nil {
|
||||
// Check specific connection - need lock only for connection state
|
||||
h.mu.RLock()
|
||||
state, exists := h.connectionStates[conn]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if exists && state.IsBlocking && state.TransitionType == "MOVING" && state.NewEndpoint != "" {
|
||||
return true, state.NewEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
// Check client integrator's atomic state - no locks needed
|
||||
if ci, ok := clientIntegrator.(*ClientIntegrator); ok && ci != nil && ci.IsMoving() {
|
||||
return true, ci.GetNewEndpoint()
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// ShouldRedirectNewBlockingConnection removed - functionality merged into ShouldRedirectBlockingConnection
|
||||
|
||||
// HandlePushNotification processes hitless upgrade push notifications
|
||||
func (h *UpgradeHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
return fmt.Errorf("hitless: empty notification received")
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("hitless: notification type is not a string: %T", notification[0])
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: processing %s notification with %d elements", notificationType, len(notification))
|
||||
|
||||
switch notificationType {
|
||||
case "MOVING":
|
||||
return h.handleMovingNotification(ctx, handlerCtx, notification)
|
||||
case "MIGRATING":
|
||||
return h.handleMigratingNotification(ctx, handlerCtx, notification)
|
||||
case "MIGRATED":
|
||||
return h.handleMigratedNotification(ctx, handlerCtx, notification)
|
||||
case "FAILING_OVER":
|
||||
return h.handleFailingOverNotification(ctx, handlerCtx, notification)
|
||||
case "FAILED_OVER":
|
||||
return h.handleFailedOverNotification(ctx, handlerCtx, notification)
|
||||
default:
|
||||
internal.Logger.Printf(ctx, "hitless: unknown notification type: %s", notificationType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleMovingNotification processes MOVING notifications that affect the entire pool
|
||||
// Format: ["MOVING", time_seconds, "new_endpoint"]
|
||||
func (h *UpgradeHandler) handleMovingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return fmt.Errorf("hitless: MOVING notification requires at least 3 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
// Parse timeout
|
||||
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse timeout for MOVING notification: %w", err)
|
||||
}
|
||||
|
||||
// Parse new endpoint
|
||||
newEndpoint, ok := notification[2].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("hitless: new endpoint is not a string: %T", notification[2])
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: MOVING notification - endpoint will move to %s in %d seconds", newEndpoint, timeoutSeconds)
|
||||
h.mu.Lock()
|
||||
// Mark all existing connections for closing after current commands complete
|
||||
for _, state := range h.connectionStates {
|
||||
state.MarkedForClosing = true
|
||||
state.NewEndpoint = newEndpoint
|
||||
state.TransitionType = "MOVING"
|
||||
state.LastCommandTime = time.Now()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: marked existing connections for closing, new blocking commands will use %s", newEndpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Removed complex helper methods - simplified to direct inline logic
|
||||
|
||||
// handleMigratingNotification processes MIGRATING notifications for specific connections
|
||||
// Format: ["MIGRATING", time_seconds, shard_id]
|
||||
func (h *UpgradeHandler) handleMigratingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return fmt.Errorf("hitless: MIGRATING notification requires at least 3 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse timeout for MIGRATING notification: %w", err)
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATING notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: MIGRATING notification - shard %s will migrate in %d seconds on connection %p", shardID, timeoutSeconds, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Store original timeouts if not already stored
|
||||
var originalReadTimeout, originalWriteTimeout time.Duration
|
||||
if existingState, exists := h.connectionStates[conn]; exists {
|
||||
originalReadTimeout = existingState.OriginalReadTimeout
|
||||
originalWriteTimeout = existingState.OriginalWriteTimeout
|
||||
} else {
|
||||
// Get default timeouts from client configuration
|
||||
originalReadTimeout = h.defaultReadTimeout
|
||||
originalWriteTimeout = h.defaultWriteTimeout
|
||||
}
|
||||
|
||||
h.connectionStates[conn] = &ConnectionState{
|
||||
IsTransitioning: true,
|
||||
TransitionType: "MIGRATING",
|
||||
StartTime: time.Now(),
|
||||
ShardID: shardID,
|
||||
TimeoutSeconds: timeoutSeconds,
|
||||
OriginalReadTimeout: originalReadTimeout,
|
||||
OriginalWriteTimeout: originalWriteTimeout,
|
||||
LastCommandTime: time.Now(),
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: connection %p marked as MIGRATING with increased timeouts", conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigratedNotification processes MIGRATED notifications for specific connections
|
||||
// Format: ["MIGRATED", shard_id]
|
||||
func (h *UpgradeHandler) handleMigratedNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 2 {
|
||||
return fmt.Errorf("hitless: MIGRATED notification requires at least 2 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for MIGRATED notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: MIGRATED notification - shard %s migration completed on connection %p", shardID, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Clear the transitioning state for this connection and restore original timeouts
|
||||
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "MIGRATING" && state.ShardID == shardID {
|
||||
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
|
||||
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
|
||||
|
||||
// In a real implementation, this would restore the connection's original timeouts
|
||||
// For now, we'll just log and delete the state
|
||||
delete(h.connectionStates, conn)
|
||||
internal.Logger.Printf(ctx, "hitless: cleared MIGRATING state and restored timeouts for connection %p", conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOverNotification processes FAILING_OVER notifications for specific connections
|
||||
// Format: ["FAILING_OVER", time_seconds, shard_id]
|
||||
func (h *UpgradeHandler) handleFailingOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return fmt.Errorf("hitless: FAILING_OVER notification requires at least 3 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
timeoutSeconds, err := h.parseTimeoutSeconds(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse timeout for FAILING_OVER notification: %w", err)
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for FAILING_OVER notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: FAILING_OVER notification - shard %s will failover in %d seconds on connection %p", shardID, timeoutSeconds, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Store original timeouts if not already stored
|
||||
var originalReadTimeout, originalWriteTimeout time.Duration
|
||||
if existingState, exists := h.connectionStates[conn]; exists {
|
||||
originalReadTimeout = existingState.OriginalReadTimeout
|
||||
originalWriteTimeout = existingState.OriginalWriteTimeout
|
||||
} else {
|
||||
// Get default timeouts from client configuration
|
||||
originalReadTimeout = h.defaultReadTimeout
|
||||
originalWriteTimeout = h.defaultWriteTimeout
|
||||
}
|
||||
|
||||
h.connectionStates[conn] = &ConnectionState{
|
||||
IsTransitioning: true,
|
||||
TransitionType: "FAILING_OVER",
|
||||
StartTime: time.Now(),
|
||||
ShardID: shardID,
|
||||
TimeoutSeconds: timeoutSeconds,
|
||||
OriginalReadTimeout: originalReadTimeout,
|
||||
OriginalWriteTimeout: originalWriteTimeout,
|
||||
LastCommandTime: time.Now(),
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: connection %p marked as FAILING_OVER with increased timeouts", conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOverNotification processes FAILED_OVER notifications for specific connections
|
||||
// Format: ["FAILED_OVER", shard_id]
|
||||
func (h *UpgradeHandler) handleFailedOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 2 {
|
||||
return fmt.Errorf("hitless: FAILED_OVER notification requires at least 2 elements, got %d", len(notification))
|
||||
}
|
||||
|
||||
shardID, err := h.parseShardID(notification[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("hitless: failed to parse shard ID for FAILED_OVER notification: %w", err)
|
||||
}
|
||||
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return fmt.Errorf("hitless: no connection available in handler context")
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "hitless: FAILED_OVER notification - shard %s failover completed on connection %p", shardID, conn)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Clear the transitioning state for this connection and restore original timeouts
|
||||
if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "FAILING_OVER" && state.ShardID == shardID {
|
||||
internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)",
|
||||
conn, state.OriginalReadTimeout, state.OriginalWriteTimeout)
|
||||
|
||||
// In a real implementation, this would restore the connection's original timeouts
|
||||
// For now, we'll just log and delete the state
|
||||
delete(h.connectionStates, conn)
|
||||
internal.Logger.Printf(ctx, "hitless: cleared FAILING_OVER state and restored timeouts for connection %p", conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseTimeoutSeconds parses timeout value from notification
|
||||
func (h *UpgradeHandler) parseTimeoutSeconds(value interface{}) (int, error) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return int(v), nil
|
||||
case int:
|
||||
return v, nil
|
||||
case string:
|
||||
return strconv.Atoi(v)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported timeout type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
// parseShardID parses shard ID from notification
|
||||
func (h *UpgradeHandler) parseShardID(value interface{}) (string, error) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10), nil
|
||||
case int:
|
||||
return strconv.Itoa(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported shard ID type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectionState returns the current state of a connection
|
||||
func (h *UpgradeHandler) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
stateCopy := *state
|
||||
return &stateCopy, true
|
||||
}
|
||||
|
||||
// IsConnectionTransitioning checks if a connection is currently transitioning
|
||||
func (h *UpgradeHandler) IsConnectionTransitioning(conn *pool.Conn) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
state, exists := h.connectionStates[conn]
|
||||
return exists && state.IsTransitioning
|
||||
}
|
||||
|
||||
// IsPoolMoving and GetNewEndpoint removed - using atomic state in ClusterUpgradeManager instead
|
||||
|
||||
// CleanupExpiredStates removes expired connection and pool states
|
||||
func (h *UpgradeHandler) CleanupExpiredStates() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Cleanup expired connection states
|
||||
for conn, state := range h.connectionStates {
|
||||
timeout := time.Duration(state.TimeoutSeconds) * time.Second
|
||||
if now.Sub(state.StartTime) > timeout {
|
||||
delete(h.connectionStates, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Pool state cleanup removed - using atomic state in ClusterUpgradeManager instead
|
||||
}
|
||||
|
||||
// CleanupConnection removes state for a specific connection (called when connection is closed)
|
||||
func (h *UpgradeHandler) CleanupConnection(conn *pool.Conn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
delete(h.connectionStates, conn)
|
||||
}
|
||||
|
||||
// GetActiveTransitions returns information about all active connection transitions
|
||||
func (h *UpgradeHandler) GetActiveTransitions() map[*pool.Conn]*ConnectionState {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// Create copies to avoid race conditions
|
||||
connStates := make(map[*pool.Conn]*ConnectionState)
|
||||
for conn, state := range h.connectionStates {
|
||||
stateCopy := *state
|
||||
connStates[conn] = &stateCopy
|
||||
}
|
||||
|
||||
return connStates
|
||||
}
|
Reference in New Issue
Block a user