1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-20 09:52:25 +03:00

feat(e2e-testing): maintnotifications e2e and refactor (#3526)

* e2e wip

* cleanup

* remove unused fault injector mock

* errChan in test

* remove log messages tests

* cleanup log messages

* s/hitless/maintnotifications/

* fix moving when none

* better logs

* test with second client after action has started

* Fixes

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Test fix

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* feat(e2e-test): Extended e2e tests

* imroved e2e test resiliency

---------

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Co-authored-by: Elena Kolevska <elena@kolevska.com>
Co-authored-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>
Co-authored-by: Hristo Temelski <hristo.temelski@redis.com>
This commit is contained in:
Nedyalko Dyakov
2025-09-26 19:17:09 +03:00
committed by GitHub
parent e6e52bc735
commit 75ddeb3d5a
52 changed files with 5848 additions and 570 deletions

5
.gitignore vendored
View File

@@ -9,6 +9,7 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
*.test
# Hitless upgrade documentation (temporary)
hitless/docs/
# maintenanceNotifications upgrade documentation (temporary)
maintenanceNotifications/docs/

View File

@@ -7,7 +7,7 @@ import (
"testing"
"time"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
@@ -42,7 +42,7 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
}
// Create processor with event-driven handoff support
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
processor := maintnotifications.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create a test pool with hooks
@@ -141,7 +141,7 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
processor := maintnotifications.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
@@ -213,7 +213,7 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
}
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
processor := maintnotifications.NewPoolHook(failingDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
@@ -276,7 +276,7 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
processor := maintnotifications.NewPoolHook(slowDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook

View File

@@ -520,7 +520,7 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}

View File

@@ -9,8 +9,8 @@ import (
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
var ctx = context.Background()
@@ -19,24 +19,28 @@ var cntSuccess atomic.Int64
var startTime = time.Now()
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
// It was used to find regressions in pool management in hitless mode.
// It was used to find regressions in pool management in maintnotifications mode.
// Please don't use it as a reference for how to use pubsub.
func main() {
startTime = time.Now()
wg := &sync.WaitGroup{}
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
EndpointType: maintnotifications.EndpointTypeExternalIP,
HandoffTimeout: 10 * time.Second,
RelaxedTimeout: 10 * time.Second,
PostHandoffRelaxedDuration: 10 * time.Second,
},
})
_ = rdb.FlushDB(ctx).Err()
hitlessManager := rdb.GetHitlessManager()
if hitlessManager == nil {
panic("hitless manager is nil")
maintnotificationsManager := rdb.GetMaintNotificationsManager()
if maintnotificationsManager == nil {
panic("maintnotifications manager is nil")
}
loggingHook := hitless.NewLoggingHook(logging.LogLevelDebug)
hitlessManager.AddNotificationHook(loggingHook)
loggingHook := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
maintnotificationsManager.AddNotificationHook(loggingHook)
go func() {
for {

View File

@@ -1,105 +0,0 @@
package hitless
import (
"errors"
"fmt"
"time"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
ErrInvalidLogLevel = errors.New("hitless: log level must be LogLevelError (0), LogLevelWarn (1), LogLevelInfo (2), or LogLevelDebug (3)")
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
)
// Integration errors
var (
ErrInvalidClient = errors.New("hitless: invalid client type")
)
// Handoff errors
var (
ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration")
)
// Notification errors
var (
ErrInvalidNotification = errors.New("hitless: invalid notification format")
)
// connection handoff errors
var (
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoff = errors.New("hitless: connection marked for handoff")
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff")
)
// general errors
var (
ErrShutdown = errors.New("hitless: shutdown")
)
// circuit breaker errors
var (
ErrCircuitBreakerOpen = errors.New("hitless: circuit breaker is open, failing fast")
)
// CircuitBreakerError provides detailed context for circuit breaker failures
type CircuitBreakerError struct {
Endpoint string
State string
Failures int64
LastFailure time.Time
NextAttempt time.Time
Message string
}
func (e *CircuitBreakerError) Error() string {
if e.NextAttempt.IsZero() {
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v): %s",
e.State, e.Endpoint, e.Failures, e.LastFailure, e.Message)
}
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v, next attempt: %v): %s",
e.State, e.Endpoint, e.Failures, e.LastFailure, e.NextAttempt, e.Message)
}
// HandoffError provides detailed context for connection handoff failures
type HandoffError struct {
ConnectionID uint64
SourceEndpoint string
TargetEndpoint string
Attempt int
MaxAttempts int
Duration time.Duration
FinalError error
Message string
}
func (e *HandoffError) Error() string {
return fmt.Sprintf("hitless: handoff failed for conn[%d] %s→%s (attempt %d/%d, duration: %v): %s",
e.ConnectionID, e.SourceEndpoint, e.TargetEndpoint,
e.Attempt, e.MaxAttempts, e.Duration, e.Message)
}
func (e *HandoffError) Unwrap() error {
return e.FinalError
}
// circuit breaker configuration errors
var (
ErrInvalidCircuitBreakerFailureThreshold = errors.New("hitless: circuit breaker failure threshold must be >= 1")
ErrInvalidCircuitBreakerResetTimeout = errors.New("hitless: circuit breaker reset timeout must be >= 0")
ErrInvalidCircuitBreakerMaxRequests = errors.New("hitless: circuit breaker max requests must be >= 1")
)

View File

@@ -1,5 +1,5 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the hitless upgrade package to avoid circular dependencies.
// and the maintnotifications upgrade package to avoid circular dependencies.
package interfaces
import (
@@ -16,7 +16,7 @@ type NotificationProcessor interface {
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for hitless upgrades.
// ClientInterface defines the interface that clients must implement for maintnotifications upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface

View File

@@ -31,3 +31,49 @@ func NewDefaultLogger() Logging {
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = NewDefaultLogger()
var LogLevel LogLevelT = LogLevelError
// LogLevelT represents the logging level
type LogLevelT int
// Log level constants for the entire go-redis library
const (
LogLevelError LogLevelT = iota // 0 - errors only
LogLevelWarn // 1 - warnings and errors
LogLevelInfo // 2 - info, warnings, and errors
LogLevelDebug // 3 - debug, info, warnings, and errors
)
// String returns the string representation of the log level
func (l LogLevelT) String() string {
switch l {
case LogLevelError:
return "ERROR"
case LogLevelWarn:
return "WARN"
case LogLevelInfo:
return "INFO"
case LogLevelDebug:
return "DEBUG"
default:
return "UNKNOWN"
}
}
// IsValid returns true if the log level is valid
func (l LogLevelT) IsValid() bool {
return l >= LogLevelError && l <= LogLevelDebug
}
func (l LogLevelT) WarnOrAbove() bool {
return l >= LogLevelWarn
}
func (l LogLevelT) InfoOrAbove() bool {
return l >= LogLevelInfo
}
func (l LogLevelT) DebugOrAbove() bool {
return l >= LogLevelDebug
}

View File

@@ -0,0 +1,625 @@
package logs
import (
"encoding/json"
"fmt"
"regexp"
"github.com/redis/go-redis/v9/internal"
)
// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug
func appendJSONIfDebug(message string, data map[string]interface{}) string {
if internal.LogLevel.DebugOrAbove() {
jsonData, _ := json.Marshal(data)
return fmt.Sprintf("%s %s", message, string(jsonData))
}
return message
}
const (
// ========================================
// CIRCUIT_BREAKER.GO - Circuit breaker management
// ========================================
CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open"
CircuitBreakerOpenedMessage = "circuit breaker opened"
CircuitBreakerReopenedMessage = "circuit breaker reopened"
CircuitBreakerClosedMessage = "circuit breaker closed"
CircuitBreakerCleanupMessage = "circuit breaker cleanup"
CircuitBreakerOpenMessage = "circuit breaker is open, failing fast"
// ========================================
// CONFIG.GO - Configuration and debug
// ========================================
DebugLoggingEnabledMessage = "debug logging enabled"
ConfigDebugMessage = "config debug"
// ========================================
// ERRORS.GO - Error message constants
// ========================================
InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0"
InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0"
InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0"
InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0"
InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0"
InvalidEndpointTypeErrorMessage = "invalid endpoint type"
InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')"
InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10"
InvalidClientErrorMessage = "invalid client type"
InvalidNotificationErrorMessage = "invalid notification format"
MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached"
HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration"
InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1"
InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0"
InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1"
ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff"
ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff"
ShutdownErrorMessage = "shutdown"
CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast"
// ========================================
// EXAMPLE_HOOKS.GO - Example metrics hooks
// ========================================
MetricsHookProcessingNotificationMessage = "metrics hook processing"
MetricsHookRecordedErrorMessage = "metrics hook recorded error"
// ========================================
// HANDOFF_WORKER.GO - Connection handoff processing
// ========================================
HandoffStartedMessage = "handoff started"
HandoffFailedMessage = "handoff failed"
ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries"
ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff"
HandoffRetryAttemptMessage = "Performing handoff"
CannotQueueHandoffForRetryMessage = "can't queue handoff for retry"
HandoffQueueFullMessage = "handoff queue is full"
FailedToDialNewEndpointMessage = "failed to dial new endpoint"
ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff"
HandoffSuccessMessage = "handoff succeeded"
RemovingConnectionFromPoolMessage = "removing connection from pool"
NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it"
WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown"
WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request"
WorkerPanicRecoveredMessage = "worker panic recovered"
WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout"
ReachedMaxHandoffRetriesMessage = "reached max handoff retries"
// ========================================
// MANAGER.GO - Moving operation tracking and handler registration
// ========================================
DuplicateMovingOperationMessage = "duplicate MOVING operation ignored"
TrackingMovingOperationMessage = "tracking MOVING operation"
UntrackingMovingOperationMessage = "untracking MOVING operation"
OperationNotTrackedMessage = "operation not tracked"
FailedToRegisterHandlerMessage = "failed to register handler"
// ========================================
// HOOKS.GO - Notification processing hooks
// ========================================
ProcessingNotificationMessage = "processing notification started"
ProcessingNotificationFailedMessage = "proccessing notification failed"
ProcessingNotificationSucceededMessage = "processing notification succeeded"
// ========================================
// POOL_HOOK.GO - Pool connection management
// ========================================
FailedToQueueHandoffMessage = "failed to queue handoff"
MarkedForHandoffMessage = "connection marked for handoff"
// ========================================
// PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing
// ========================================
InvalidNotificationFormatMessage = "invalid notification format"
InvalidNotificationTypeFormatMessage = "invalid notification type format"
InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification"
InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification"
InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification"
NoConnectionInHandlerContextMessage = "no connection in handler context"
InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context"
SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint"
RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification"
UnrelaxedTimeoutMessage = "clearing relaxed timeout"
ManagerNotInitializedMessage = "manager not initialized"
FailedToMarkForHandoffMessage = "failed to mark connection for handoff"
// ========================================
// used in pool/conn
// ========================================
UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline"
)
func HandoffStarted(connID uint64, newEndpoint string) string {
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
})
}
func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string {
message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
"attempt": attempt,
"maxAttempts": maxAttempts,
"error": err.Error(),
})
}
func HandoffSucceeded(connID uint64, newEndpoint string) string {
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
})
}
// Timeout-related log functions
func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string {
message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
"timeout": fmt.Sprintf("%v", timeout),
})
}
func UnrelaxedTimeout(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
func UnrelaxedTimeoutAfterDeadline(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
// Handoff queue and marking functions
func HandoffQueueFull(queueLen, queueCap int) string {
message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap)
return appendJSONIfDebug(message, map[string]interface{}{
"queueLen": queueLen,
"queueCap": queueCap,
})
}
func FailedToQueueHandoff(connID uint64, err error) string {
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"error": err.Error(),
})
}
func FailedToMarkForHandoff(connID uint64, err error) string {
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"error": err.Error(),
})
}
func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string {
message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"error": err.Error(),
})
}
func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string {
message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"maxRetries": maxRetries,
})
}
// Notification processing functions
func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string {
message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
"notificationType": notificationType,
"notification": fmt.Sprintf("%v", notification),
})
}
func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string {
message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
"error": err.Error(),
"notification": fmt.Sprintf("%v", notification),
})
}
func ProcessingNotificationSucceeded(connID uint64, notificationType string) string {
message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
})
}
// Moving operation tracking functions
func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"seqID": seqID,
})
}
func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"seqID": seqID,
})
}
func UntrackingMovingOperation(connID uint64, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
})
}
func OperationNotTracked(connID uint64, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
})
}
// Connection pool functions
func RemovingConnectionFromPool(connID uint64, reason error) string {
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"reason": reason.Error(),
})
}
func NoPoolProvidedCannotRemove(connID uint64, reason error) string {
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"reason": reason.Error(),
})
}
// Circuit breaker functions
func CircuitBreakerOpen(connID uint64, endpoint string) string {
message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
})
}
// Additional handoff functions for specific cases
func ConnectionNotMarkedForHandoff(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
func ConnectionNotMarkedForHandoffError(connID uint64) string {
return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage)
}
func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string {
message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"retries": retries,
"newEndpoint": newEndpoint,
"oldEndpoint": oldEndpoint,
})
}
func CannotQueueHandoffForRetry(err error) string {
message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"error": err.Error(),
})
}
// Validation and error functions
func InvalidNotificationFormat(notification interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"notification": fmt.Sprintf("%v", notification),
})
}
func InvalidNotificationTypeFormat(notificationType interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": fmt.Sprintf("%v", notificationType),
})
}
// InvalidNotification creates a log message for invalid notifications of any type
func InvalidNotification(notificationType string, notification interface{}) string {
message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"notification": fmt.Sprintf("%v", notification),
})
}
func InvalidSeqIDInMovingNotification(seqID interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"seqID": fmt.Sprintf("%v", seqID),
})
}
func InvalidTimeSInMovingNotification(timeS interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS)
return appendJSONIfDebug(message, map[string]interface{}{
"timeS": fmt.Sprintf("%v", timeS),
})
}
func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"newEndpoint": fmt.Sprintf("%v", newEndpoint),
})
}
func NoConnectionInHandlerContext(notificationType string) string {
message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
})
}
func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string {
message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connType": fmt.Sprintf("%T", conn),
})
}
func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string {
message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seconds": seconds,
})
}
func ManagerNotInitialized() string {
return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{})
}
func FailedToRegisterHandler(notificationType string, err error) string {
message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"error": err.Error(),
})
}
func ShutdownError() string {
return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{})
}
// Configuration validation error functions
func InvalidRelaxedTimeoutError() string {
return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidHandoffTimeoutError() string {
return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidHandoffWorkersError() string {
return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{})
}
func InvalidHandoffQueueSizeError() string {
return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{})
}
func InvalidPostHandoffRelaxedDurationError() string {
return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{})
}
func InvalidEndpointTypeError() string {
return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{})
}
func InvalidMaintNotificationsError() string {
return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{})
}
func InvalidHandoffRetriesError() string {
return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{})
}
func InvalidClientError() string {
return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{})
}
func InvalidNotificationError() string {
return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{})
}
func MaxHandoffRetriesReachedError() string {
return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{})
}
func HandoffQueueFullError() string {
return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerFailureThresholdError() string {
return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerResetTimeoutError() string {
return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerMaxRequestsError() string {
return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{})
}
// Configuration and debug functions
func DebugLoggingEnabled() string {
return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{})
}
func ConfigDebug(config interface{}) string {
message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config)
return appendJSONIfDebug(message, map[string]interface{}{
"config": fmt.Sprintf("%+v", config),
})
}
// Handoff worker functions
func WorkerExitingDueToShutdown() string {
return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{})
}
func WorkerExitingDueToShutdownWhileProcessing() string {
return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{})
}
func WorkerPanicRecovered(panicValue interface{}) string {
message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue)
return appendJSONIfDebug(message, map[string]interface{}{
"panic": fmt.Sprintf("%v", panicValue),
})
}
func WorkerExitingDueToInactivityTimeout(timeout interface{}) string {
message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout)
return appendJSONIfDebug(message, map[string]interface{}{
"timeout": fmt.Sprintf("%v", timeout),
})
}
func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string {
message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"timeout": fmt.Sprintf("%v", timeout),
"until": until,
})
}
// Example hooks functions
func MetricsHookProcessingNotification(notificationType string, connID uint64) string {
message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connID": connID,
})
}
func MetricsHookRecordedError(notificationType string, connID uint64, err error) string {
message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connID": connID,
"error": err.Error(),
})
}
// Pool hook functions
func MarkedForHandoff(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
// Circuit breaker additional functions
func CircuitBreakerTransitioningToHalfOpen(endpoint string) string {
message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
})
}
func CircuitBreakerOpened(endpoint string, failures int64) string {
message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
"failures": failures,
})
}
func CircuitBreakerReopened(endpoint string) string {
message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
})
}
func CircuitBreakerClosed(endpoint string, successes int64) string {
message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
"successes": successes,
})
}
func CircuitBreakerCleanup(removed int, total int) string {
message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total)
return appendJSONIfDebug(message, map[string]interface{}{
"removed": removed,
"total": total,
})
}
// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages
// Returns a map containing the parsed key-value pairs from the structured data section
// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}"
// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"}
func ExtractDataFromLogMessage(logMessage string) map[string]interface{} {
result := make(map[string]interface{})
// Find the JSON data section at the end of the message
re := regexp.MustCompile(`(\{.*\})$`)
matches := re.FindStringSubmatch(logMessage)
if len(matches) < 2 {
return result
}
jsonStr := matches[1]
if jsonStr == "" {
return result
}
// Parse the JSON directly
var jsonResult map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil {
return jsonResult
}
// If JSON parsing fails, return empty map
return result
}

View File

@@ -10,6 +10,8 @@ import (
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/proto"
)
@@ -59,7 +61,7 @@ type Conn struct {
createdAt time.Time
expiresAt time.Time
// Hitless upgrade support: relaxed timeouts during migrations/failovers
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
@@ -73,7 +75,7 @@ type Conn struct {
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking across handoffs
// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
@@ -236,7 +238,7 @@ func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
@@ -258,7 +260,8 @@ func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Dur
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
deadlineNs := cn.relaxedDeadlineNs.Load()
if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) {
// Use atomic load to get current value for CAS to avoid stale value race
current := cn.relaxedCounter.Load()
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
@@ -325,8 +328,9 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
cn.clearRelaxedTimeout()
}
return normalTimeout
@@ -357,8 +361,9 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
cn.clearRelaxedTimeout()
}
return normalTimeout
@@ -472,6 +477,7 @@ func (cn *Conn) MarkQueuedForHandoff() error {
}
// If CAS failed, add exponential backoff to reduce contention
// the delay will be 1, 2, 4... up to 512 microseconds
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
@@ -529,6 +535,11 @@ func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return cn.incrementHandoffRetries(n)
}
// GetHandoffRetries returns the current handoff retry count (lock-free).
func (cn *Conn) HandoffRetries() int {
return int(cn.handoffRetriesAtomic.Load())
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
@@ -603,7 +614,7 @@ func (cn *Conn) WithWriter(
} else {
// If getNetConn() returns nil, we still need to respect the timeout
// Return an error to prevent indefinite blocking
return fmt.Errorf("redis: connection not available for write operation")
return fmt.Errorf("redis: conn[%d] not available for write operation", cn.GetID())
}
}

View File

@@ -26,13 +26,13 @@ var (
// popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during hitless upgrades or network issues).
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
// Value of 50 provides sufficient resilience without excessive overhead.
// This is capped by the idle connection count, so we won't loop excessively.
popAttempts = 50
// getAttempts is the maximum number of attempts to get a connection that passes
// hook validation (e.g., hitless upgrade hooks). This protects against race conditions
// hook validation (e.g., maintenanceNotifications upgrade hooks). This protects against race conditions
// where hooks might temporarily reject connections during cluster transitions.
// Value of 3 balances resilience with performance - most hook rejections resolve quickly.
getAttempts = 3
@@ -257,7 +257,7 @@ func (p *ConnPool) addIdleConn() error {
// NewConn creates a new connection and returns it to the user.
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
//
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support maintnotifications upgrades.
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
@@ -812,7 +812,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
// For RESP3 connections with push notifications, we allow some buffered data
// The client will process these notifications before using the connection
internal.Logger.Printf(context.Background(), "push: connection has buffered data, likely push notifications - will be processed by client")
internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID())
return true // Connection is healthy, client will handle notifications
}
return false // Unexpected data, not push notifications, connection is unhealthy

View File

@@ -1,3 +1,3 @@
package internal
const RedisNull = "null"
const RedisNull = "<nil>"

View File

@@ -10,50 +10,15 @@ import (
"github.com/redis/go-redis/v9/internal"
)
// LogLevel represents the logging level
type LogLevel int
type LogLevelT = internal.LogLevelT
// Log level constants for the entire go-redis library
const (
LogLevelError LogLevel = iota // 0 - errors only
LogLevelWarn // 1 - warnings and errors
LogLevelInfo // 2 - info, warnings, and errors
LogLevelDebug // 3 - debug, info, warnings, and errors
LogLevelError = internal.LogLevelError
LogLevelWarn = internal.LogLevelWarn
LogLevelInfo = internal.LogLevelInfo
LogLevelDebug = internal.LogLevelDebug
)
// String returns the string representation of the log level
func (l LogLevel) String() string {
switch l {
case LogLevelError:
return "ERROR"
case LogLevelWarn:
return "WARN"
case LogLevelInfo:
return "INFO"
case LogLevelDebug:
return "DEBUG"
default:
return "UNKNOWN"
}
}
// IsValid returns true if the log level is valid
func (l LogLevel) IsValid() bool {
return l >= LogLevelError && l <= LogLevelDebug
}
func (l LogLevel) WarnOrAbove() bool {
return l >= LogLevelWarn
}
func (l LogLevel) InfoOrAbove() bool {
return l >= LogLevelInfo
}
func (l LogLevel) DebugOrAbove() bool {
return l >= LogLevelDebug
}
// VoidLogger is a logger that does nothing.
// Used to disable logging and thus speed up the library.
type VoidLogger struct{}
@@ -79,6 +44,11 @@ func Enable() {
internal.Logger = internal.NewDefaultLogger()
}
// SetLogLevel sets the log level for the library.
func SetLogLevel(logLevel LogLevelT) {
internal.LogLevel = logLevel
}
// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings.
// This can be used to filter out messages containing sensitive information.
func NewBlacklistLogger(substr []string) internal.Logging {

View File

@@ -4,14 +4,14 @@ import "testing"
func TestLogLevel_String(t *testing.T) {
tests := []struct {
level LogLevel
level LogLevelT
expected string
}{
{LogLevelError, "ERROR"},
{LogLevelWarn, "WARN"},
{LogLevelInfo, "INFO"},
{LogLevelDebug, "DEBUG"},
{LogLevel(99), "UNKNOWN"},
{LogLevelT(99), "UNKNOWN"},
}
for _, test := range tests {
@@ -23,16 +23,16 @@ func TestLogLevel_String(t *testing.T) {
func TestLogLevel_IsValid(t *testing.T) {
tests := []struct {
level LogLevel
level LogLevelT
expected bool
}{
{LogLevelError, true},
{LogLevelWarn, true},
{LogLevelInfo, true},
{LogLevelDebug, true},
{LogLevel(-1), false},
{LogLevel(4), false},
{LogLevel(99), false},
{LogLevelT(-1), false},
{LogLevelT(4), false},
{LogLevelT(99), false},
}
for _, test := range tests {

View File

@@ -1,6 +1,9 @@
# Hitless Upgrades
# Maintenance Notifications
Seamless Redis connection handoffs during cluster changes without dropping connections.
Seamless Redis connection handoffs during cluster maintenance operations without dropping connections.
## ⚠️ **Important Note**
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
## Quick Start
@@ -8,31 +11,30 @@ Seamless Redis connection handoffs during cluster changes without dropping conne
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
HitlessUpgrades: &hitless.Config{
Mode: hitless.MaintNotificationsEnabled,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
},
})
```
## Modes
- **`MaintNotificationsDisabled`** - Hitless upgrades disabled
- **`MaintNotificationsEnabled`** - Forcefully enabled (fails if server doesn't support)
- **`MaintNotificationsAuto`** - Auto-detect server support (default)
- **`ModeDisabled`** - Maintenance notifications disabled
- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support)
- **`ModeAuto`** - Auto-detect server support (default)
## Configuration
```go
&hitless.Config{
Mode: hitless.MaintNotificationsAuto,
EndpointType: hitless.EndpointTypeAuto,
&maintnotifications.Config{
Mode: maintnotifications.ModeAuto,
EndpointType: maintnotifications.EndpointTypeAuto,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxHandoffRetries: 3,
MaxWorkers: 0, // Auto-calculated
HandoffQueueSize: 0, // Auto-calculated
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
LogLevel: logging.LogLevelError,
}
```
@@ -56,7 +58,7 @@ client := redis.NewClient(&redis.Options{
## How It Works
1. Redis sends push notifications about cluster changes
1. Redis sends push notifications about cluster maintenance operations
2. Client creates new connections to updated endpoints
3. Active operations transfer to new connections
4. Old connections close gracefully
@@ -71,7 +73,7 @@ client := redis.NewClient(&redis.Options{
## Hooks (Optional)
Monitor and customize hitless operations:
Monitor and customize maintenance notification operations:
```go
type NotificationHook interface {
@@ -87,7 +89,7 @@ manager.AddNotificationHook(&MyHook{})
```go
// Create metrics hook
metricsHook := hitless.NewMetricsHook()
metricsHook := maintnotifications.NewMetricsHook()
manager.AddNotificationHook(metricsHook)
// Access collected metrics

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -7,6 +7,7 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// CircuitBreakerState represents the state of a circuit breaker
@@ -101,9 +102,8 @@ func (cb *CircuitBreaker) Execute(fn func() error) error {
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.requests.Store(0)
cb.successes.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker for %s transitioning to half-open", cb.endpoint)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint))
}
// Fall through to half-open logic
} else {
@@ -144,20 +144,16 @@ func (cb *CircuitBreaker) recordFailure() {
case CircuitBreakerClosed:
if failures >= int64(cb.failureThreshold) {
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker opened for endpoint %s after %d failures",
cb.endpoint, failures)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures))
}
}
}
case CircuitBreakerHalfOpen:
// Any failure in half-open state immediately opens the circuit
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker reopened for endpoint %s due to failure in half-open state",
cb.endpoint)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint))
}
}
}
@@ -180,10 +176,8 @@ func (cb *CircuitBreaker) recordSuccess() {
if successes >= int64(cb.maxRequests) {
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.failures.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker closed for endpoint %s after %d successful requests",
cb.endpoint, successes)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes))
}
}
}
@@ -331,9 +325,8 @@ func (cbm *CircuitBreakerManager) cleanup() {
}
// Log cleanup results
if len(toDelete) > 0 && cbm.config != nil && cbm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker cleanup removed %d/%d entries", len(toDelete), count)
if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count))
}
cbm.lastCleanup.Store(now.Unix())

View File

@@ -1,16 +1,13 @@
package hitless
package maintnotifications
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9/logging"
)
func TestCircuitBreaker(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError, // Reduce noise in tests
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
@@ -96,7 +93,6 @@ func TestCircuitBreaker(t *testing.T) {
t.Run("HalfOpenTransition", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing
CircuitBreakerMaxRequests: 3,
@@ -134,7 +130,6 @@ func TestCircuitBreaker(t *testing.T) {
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
@@ -168,7 +163,6 @@ func TestCircuitBreaker(t *testing.T) {
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
@@ -233,7 +227,6 @@ func TestCircuitBreaker(t *testing.T) {
func TestCircuitBreakerManager(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
@@ -312,7 +305,6 @@ func TestCircuitBreakerManager(t *testing.T) {
t.Run("ConfigurableParameters", func(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 10,
CircuitBreakerResetTimeout: 30 * time.Second,
CircuitBreakerMaxRequests: 5,

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -8,24 +8,24 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/logging"
)
// MaintNotificationsMode represents the maintenance notifications mode
type MaintNotificationsMode string
// Mode represents the maintenance notifications mode
type Mode string
// Constants for maintenance push notifications modes
const (
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error
ModeAuto Mode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m MaintNotificationsMode) IsValid() bool {
func (m Mode) IsValid() bool {
switch m {
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
case ModeDisabled, ModeEnabled, ModeAuto:
return true
default:
return false
@@ -33,7 +33,7 @@ func (m MaintNotificationsMode) IsValid() bool {
}
// String returns the string representation of the mode
func (m MaintNotificationsMode) String() string {
func (m Mode) String() string {
return string(m)
}
@@ -66,12 +66,12 @@ func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for hitless upgrades.
// Config provides configuration options for maintenance notifications
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
// Default: MaintNotificationsAuto
Mode MaintNotificationsMode
// Valid values: ModeDisabled, ModeEnabled, ModeAuto
// Default: ModeAuto
Mode Mode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
@@ -111,11 +111,6 @@ type Config struct {
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// LogLevel controls the verbosity of hitless upgrade logging.
// LogLevelError (0) = errors only, LogLevelWarn (1) = warnings, LogLevelInfo (2) = info, LogLevelDebug (3) = debug
// Default: logging.LogLevelError(0)
LogLevel logging.LogLevel
// Circuit breaker configuration for endpoint failure handling
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
// Default: 5
@@ -136,20 +131,19 @@ type Config struct {
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != MaintNotificationsDisabled
return c != nil && c.Mode != ModeDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
Mode: ModeAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
LogLevel: logging.LogLevelError,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: 5,
@@ -181,9 +175,6 @@ func (c *Config) Validate() error {
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
if !c.LogLevel.IsValid() {
return ErrInvalidLogLevel
}
// Circuit breaker validation
if c.CircuitBreakerFailureThreshold < 1 {
@@ -299,10 +290,6 @@ func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
// We'll use the provided value as-is, since 0 is valid
result.LogLevel = c.LogLevel
// Apply defaults for configuration fields
result.MaxHandoffRetries = defaults.MaxHandoffRetries
if c.MaxHandoffRetries > 0 {
@@ -325,9 +312,9 @@ func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
}
if result.LogLevel.DebugOrAbove() {
internal.Logger.Printf(context.Background(), "hitless: debug logging enabled")
internal.Logger.Printf(context.Background(), "hitless: config: %+v", result)
if internal.LogLevel.DebugOrAbove() {
internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled())
internal.Logger.Printf(context.Background(), logs.ConfigDebug(result))
}
return result
}
@@ -346,7 +333,6 @@ func (c *Config) Clone() *Config {
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
LogLevel: c.LogLevel,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -7,7 +7,6 @@ import (
"time"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/logging"
)
func TestConfig(t *testing.T) {
@@ -73,7 +72,6 @@ func TestConfig(t *testing.T) {
MaxWorkers: -1, // This should be invalid
HandoffQueueSize: 100,
PostHandoffRelaxedDuration: 10 * time.Second,
LogLevel: 1,
MaxHandoffRetries: 3, // Add required field
}
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
@@ -213,7 +211,6 @@ func TestApplyDefaults(t *testing.T) {
MaxWorkers: 0, // Zero value should get auto-calculated defaults
HandoffQueueSize: 0, // Zero value should get default
RelaxedTimeout: 0, // Zero value should get default
LogLevel: 0, // Zero is valid for LogLevel (errors only)
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
@@ -238,10 +235,7 @@ func TestApplyDefaults(t *testing.T) {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
// LogLevel 0 should be preserved (it's a valid value)
if result.LogLevel != 0 {
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
}
})
}
@@ -306,7 +300,6 @@ func TestIntegrationWithApplyDefaults(t *testing.T) {
// Create a partial config with only some fields set
partialConfig := &Config{
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
LogLevel: logging.LogLevelInfo, // Custom value
// Other fields left as zero values - should get defaults
}
@@ -332,9 +325,7 @@ func TestIntegrationWithApplyDefaults(t *testing.T) {
t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers)
}
if expectedConfig.LogLevel != 2 {
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
}
// Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := expectedConfig.MaxWorkers * 20

30
maintnotifications/e2e/.gitignore vendored Normal file
View File

@@ -0,0 +1,30 @@
# E2E test artifacts
*.log
*.out
test-results/
coverage/
profiles/
# Test data
test-data/
temp/
*.tmp
# CI artifacts
artifacts/
reports/
# Redis data files (if running local Redis for testing)
dump.rdb
appendonly.aof
redis.conf.local
# Performance test results
*.prof
*.trace
benchmarks/
# Docker compose files for local testing
docker-compose.override.yml
.env.local
infra/

View File

@@ -0,0 +1,141 @@
# E2E Test Scenarios for Push Notifications
This directory contains comprehensive end-to-end test scenarios for Redis push notifications and maintenance notifications functionality. Each scenario tests different aspects of the system under various conditions.
## ⚠️ **Important Note**
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support maintenance notifications functionality.
## Introduction
To run those tests you would need a fault injector service, please review the client and feel free to implement your
fault injector of choice. Those tests are tailored for Redis Enterprise, but can be adapted to other Redis distributions where
a fault injector is available.
Once you have fault injector service up and running, you can execute the tests by running the `run-e2e-tests.sh` script.
there are three environment variables that need to be set before running the tests:
- `REDIS_ENDPOINTS_CONFIG_PATH`: Path to Redis endpoints configuration
- `FAULT_INJECTION_API_URL`: URL of the fault injector server
- `E2E_SCENARIO_TESTS`: Set to `true` to enable scenario tests
## Test Scenarios Overview
### 1. Basic Push Notifications (`scenario_push_notifications_test.go`)
**Original template scenario**
- **Purpose**: Basic functionality test for Redis Enterprise push notifications
- **Features Tested**: FAILING_OVER, FAILED_OVER, MIGRATING, MIGRATED, MOVING notifications
- **Configuration**: Standard enterprise cluster setup
- **Duration**: ~10 minutes
- **Key Validations**:
- All notification types received
- Timeout behavior (relaxed/unrelaxed)
- Handoff success rates
- Connection pool management
### 2. Endpoint Types Scenario (`scenario_endpoint_types_test.go`)
**Different endpoint resolution strategies**
- **Purpose**: Test push notifications with different endpoint types
- **Features Tested**: ExternalIP, InternalIP, InternalFQDN, ExternalFQDN endpoint types
- **Configuration**: Standard setup with varying endpoint types
- **Duration**: ~5 minutes
- **Key Validations**:
- Functionality with each endpoint type
- Proper endpoint resolution
- Notification delivery consistency
- Handoff behavior per endpoint type
### 3. Timeout Configurations Scenario (`scenario_timeout_configs_test.go`)
**Various timeout strategies**
- **Purpose**: Test different timeout configurations and their impact
- **Features Tested**: Conservative, Aggressive, HighLatency timeouts
- **Configuration**:
- Conservative: 60s handoff, 20s relaxed, 5s post-handoff
- Aggressive: 5s handoff, 3s relaxed, 1s post-handoff
- HighLatency: 90s handoff, 30s relaxed, 10m post-handoff
- **Duration**: ~10 minutes (3 sub-tests)
- **Key Validations**:
- Timeout behavior matches configuration
- Recovery times appropriate for each strategy
- Error rates correlate with timeout aggressiveness
### 4. TLS Configurations Scenario (`scenario_tls_configs_test.go`)
**Security and encryption testing framework**
- **Purpose**: Test push notifications with different TLS configurations
- **Features Tested**: NoTLS, TLSInsecure, TLSSecure, TLSMinimal, TLSStrict
- **Configuration**: Framework for testing various TLS settings (TLS config handled at connection level)
- **Duration**: ~10 minutes (multiple sub-tests)
- **Key Validations**:
- Functionality with each TLS configuration
- Performance impact of encryption
- Certificate handling (where applicable)
- Security compliance
- **Note**: TLS configuration is handled at the Redis connection config level, not client options level
### 5. Stress Test Scenario (`scenario_stress_test.go`)
**Extreme load and concurrent operations**
- **Purpose**: Test system limits and behavior under extreme stress
- **Features Tested**: Maximum concurrent operations, multiple clients
- **Configuration**:
- 4 clients with 150 pool size each
- 200 max connections per client
- 50 workers, 1000 queue size
- Concurrent failover/migration actions
- **Duration**: ~15 minutes
- **Key Validations**:
- System stability under extreme load
- Error rates within stress limits (<20%)
- Resource utilization and limits
- Concurrent fault injection handling
## Running the Scenarios
### Prerequisites
- Set environment variable: `E2E_SCENARIO_TESTS=true`
- Redis Enterprise cluster available
- Fault injection service available
- Appropriate network access and permissions
- **Note**: Tests use standalone Redis clients only (cluster clients not supported)
### Individual Scenario Execution
```bash
# Run a specific scenario
E2E_SCENARIO_TESTS=true go test -v ./maintnotifications/e2e -run TestEndpointTypesPushNotifications
# Run with timeout
E2E_SCENARIO_TESTS=true go test -v -timeout 30m ./maintnotifications/e2e -run TestStressPushNotifications
```
### All Scenarios Execution
```bash
./scripts/run-e2e-tests.sh
```
## Expected Outcomes
### Success Criteria
- All notifications received and processed correctly
- Error rates within acceptable limits for each scenario
- No notification processing errors
- Proper timeout behavior
- Successful handoffs
- Connection pool management within limits
### Performance Benchmarks
- **Basic**: >1000 operations, <1% errors
- **Stress**: >10000 operations, <20% errors
- **Others**: Functionality over performance
## Troubleshooting
### Common Issues
1. **Enterprise cluster not available**: Most scenarios require Redis Enterprise
2. **Fault injector unavailable**: Some scenarios need fault injection service
3. **Network timeouts**: Increase test timeouts for slow networks
4. **TLS certificate issues**: Some TLS scenarios may fail without proper certs
5. **Resource limits**: Stress scenarios may hit system limits
### Debug Options
- Enable detailed logging in scenarios
- Use `dump = true` to see full log analysis
- Check pool statistics for connection issues
- Monitor client resources during stress tests

View File

@@ -0,0 +1,127 @@
package e2e
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
)
type CommandRunnerStats struct {
Operations int64
Errors int64
TimeoutErrors int64
ErrorsList []error
}
// CommandRunner provides utilities for running commands during tests
type CommandRunner struct {
client redis.UniversalClient
stopCh chan struct{}
operationCount atomic.Int64
errorCount atomic.Int64
timeoutErrors atomic.Int64
errors []error
errorsMutex sync.Mutex
}
// NewCommandRunner creates a new command runner
func NewCommandRunner(client redis.UniversalClient) (*CommandRunner, func()) {
stopCh := make(chan struct{})
return &CommandRunner{
client: client,
stopCh: stopCh,
errors: make([]error, 0),
}, func() {
stopCh <- struct{}{}
}
}
func (cr *CommandRunner) Stop() {
select {
case cr.stopCh <- struct{}{}:
return
case <-time.After(500 * time.Millisecond):
return
}
}
func (cr *CommandRunner) Close() {
close(cr.stopCh)
}
// FireCommandsUntilStop runs commands continuously until stop signal
func (cr *CommandRunner) FireCommandsUntilStop(ctx context.Context) {
fmt.Printf("[CR] Starting command runner...\n")
defer fmt.Printf("[CR] Command runner stopped\n")
// High frequency for timeout testing
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
counter := 0
for {
select {
case <-cr.stopCh:
return
case <-ctx.Done():
return
case <-ticker.C:
poolSize := cr.client.PoolStats().IdleConns
if poolSize == 0 {
poolSize = 1
}
wg := sync.WaitGroup{}
for i := 0; i < int(poolSize); i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
key := fmt.Sprintf("timeout-test-key-%d-%d", counter, i)
value := fmt.Sprintf("timeout-test-value-%d-%d", counter, i)
// Use a short timeout context for individual operations
opCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err := cr.client.Set(opCtx, key, value, time.Minute).Err()
cancel()
cr.operationCount.Add(1)
if err != nil {
fmt.Printf("Error: %v\n", err)
cr.errorCount.Add(1)
// Check if it's a timeout error
if isTimeoutError(err) {
cr.timeoutErrors.Add(1)
}
cr.errorsMutex.Lock()
cr.errors = append(cr.errors, err)
cr.errorsMutex.Unlock()
}
}(i)
}
wg.Wait()
counter++
}
}
}
// GetStats returns operation statistics
func (cr *CommandRunner) GetStats() CommandRunnerStats {
cr.errorsMutex.Lock()
defer cr.errorsMutex.Unlock()
errorList := make([]error, len(cr.errors))
copy(errorList, cr.errors)
stats := CommandRunnerStats{
Operations: cr.operationCount.Load(),
Errors: cr.errorCount.Load(),
TimeoutErrors: cr.timeoutErrors.Load(),
ErrorsList: errorList,
}
return stats
}

View File

@@ -0,0 +1,463 @@
package e2e
import (
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/maintnotifications"
)
// DatabaseEndpoint represents a single database endpoint configuration
type DatabaseEndpoint struct {
Addr []string `json:"addr"`
AddrType string `json:"addr_type"`
DNSName string `json:"dns_name"`
OSSClusterAPIPreferredEndpointType string `json:"oss_cluster_api_preferred_endpoint_type"`
OSSClusterAPIPreferredIPType string `json:"oss_cluster_api_preferred_ip_type"`
Port int `json:"port"`
ProxyPolicy string `json:"proxy_policy"`
UID string `json:"uid"`
}
// DatabaseConfig represents the configuration for a single database
type DatabaseConfig struct {
BdbID int `json:"bdb_id,omitempty"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
TLS bool `json:"tls"`
CertificatesLocation string `json:"certificatesLocation,omitempty"`
RawEndpoints []DatabaseEndpoint `json:"raw_endpoints,omitempty"`
Endpoints []string `json:"endpoints"`
}
// DatabasesConfig represents the complete configuration file structure
type DatabasesConfig map[string]DatabaseConfig
// EnvConfig represents environment configuration for test scenarios
type EnvConfig struct {
RedisEndpointsConfigPath string
FaultInjectorURL string
}
// RedisConnectionConfig represents Redis connection parameters
type RedisConnectionConfig struct {
Host string
Port int
Username string
Password string
TLS bool
BdbID int
CertificatesLocation string
Endpoints []string
}
// GetEnvConfig reads environment variables required for the test scenario
func GetEnvConfig() (*EnvConfig, error) {
redisConfigPath := os.Getenv("REDIS_ENDPOINTS_CONFIG_PATH")
if redisConfigPath == "" {
return nil, fmt.Errorf("REDIS_ENDPOINTS_CONFIG_PATH environment variable must be set")
}
faultInjectorURL := os.Getenv("FAULT_INJECTION_API_URL")
if faultInjectorURL == "" {
// Default to localhost if not set
faultInjectorURL = "http://localhost:8080"
}
return &EnvConfig{
RedisEndpointsConfigPath: redisConfigPath,
FaultInjectorURL: faultInjectorURL,
}, nil
}
// GetDatabaseConfigFromEnv reads database configuration from a file
func GetDatabaseConfigFromEnv(filePath string) (DatabasesConfig, error) {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read database config from %s: %w", filePath, err)
}
var config DatabasesConfig
if err := json.Unmarshal(fileContent, &config); err != nil {
return nil, fmt.Errorf("failed to parse database config from %s: %w", filePath, err)
}
return config, nil
}
// GetDatabaseConfig gets Redis connection parameters for a specific database
func GetDatabaseConfig(databasesConfig DatabasesConfig, databaseName string) (*RedisConnectionConfig, error) {
var dbConfig DatabaseConfig
var exists bool
if databaseName == "" {
// Get the first database if no name is provided
for _, config := range databasesConfig {
dbConfig = config
exists = true
break
}
} else {
dbConfig, exists = databasesConfig[databaseName]
}
if !exists {
return nil, fmt.Errorf("database %s not found in configuration", databaseName)
}
// Parse connection details from endpoints or raw_endpoints
var host string
var port int
if len(dbConfig.RawEndpoints) > 0 {
// Use raw_endpoints if available (for more complex configurations)
endpoint := dbConfig.RawEndpoints[0] // Use the first endpoint
host = endpoint.DNSName
port = endpoint.Port
} else if len(dbConfig.Endpoints) > 0 {
// Parse from endpoints URLs
endpointURL, err := url.Parse(dbConfig.Endpoints[0])
if err != nil {
return nil, fmt.Errorf("failed to parse endpoint URL %s: %w", dbConfig.Endpoints[0], err)
}
host = endpointURL.Hostname()
portStr := endpointURL.Port()
if portStr == "" {
// Default ports based on scheme
switch endpointURL.Scheme {
case "redis":
port = 6379
case "rediss":
port = 6380
default:
port = 6379
}
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("invalid port in endpoint URL %s: %w", dbConfig.Endpoints[0], err)
}
}
// Override TLS setting based on scheme if not explicitly set
if endpointURL.Scheme == "rediss" {
dbConfig.TLS = true
}
} else {
return nil, fmt.Errorf("no endpoints found in database configuration")
}
return &RedisConnectionConfig{
Host: host,
Port: port,
Username: dbConfig.Username,
Password: dbConfig.Password,
TLS: dbConfig.TLS,
BdbID: dbConfig.BdbID,
CertificatesLocation: dbConfig.CertificatesLocation,
Endpoints: dbConfig.Endpoints,
}, nil
}
// ClientFactory manages Redis client creation and lifecycle
type ClientFactory struct {
config *RedisConnectionConfig
clients map[string]redis.UniversalClient
mutex sync.RWMutex
}
// NewClientFactory creates a new client factory with the specified configuration
func NewClientFactory(config *RedisConnectionConfig) *ClientFactory {
return &ClientFactory{
config: config,
clients: make(map[string]redis.UniversalClient),
}
}
// CreateClientOptions represents options for creating Redis clients
type CreateClientOptions struct {
Protocol int
MaintNotificationsConfig *maintnotifications.Config
MaxRetries int
PoolSize int
MinIdleConns int
MaxActiveConns int
ClientName string
DB int
ReadTimeout time.Duration
WriteTimeout time.Duration
}
// DefaultCreateClientOptions returns default options for creating Redis clients
func DefaultCreateClientOptions() *CreateClientOptions {
return &CreateClientOptions{
Protocol: 3, // RESP3 by default for push notifications
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 30 * time.Second,
RelaxedTimeout: 10 * time.Second,
MaxWorkers: 20,
},
MaxRetries: 3,
PoolSize: 10,
MinIdleConns: 10,
MaxActiveConns: 10,
}
}
func (cf *ClientFactory) PrintPoolStats(t *testing.T) {
cf.mutex.RLock()
defer cf.mutex.RUnlock()
for key, client := range cf.clients {
stats := client.PoolStats()
t.Logf("Pool stats for client %s: %+v", key, stats)
}
}
// Create creates a new Redis client with the specified options and connects it
func (cf *ClientFactory) Create(key string, options *CreateClientOptions) (redis.UniversalClient, error) {
if options == nil {
options = DefaultCreateClientOptions()
}
cf.mutex.Lock()
defer cf.mutex.Unlock()
// Check if client already exists
if client, exists := cf.clients[key]; exists {
return client, nil
}
var client redis.UniversalClient
// Determine if this is a cluster configuration
if len(cf.config.Endpoints) > 1 || cf.isClusterEndpoint() {
// Create cluster client
clusterOptions := &redis.ClusterOptions{
Addrs: cf.getAddresses(),
Username: cf.config.Username,
Password: cf.config.Password,
Protocol: options.Protocol,
MaintNotificationsConfig: options.MaintNotificationsConfig,
MaxRetries: options.MaxRetries,
PoolSize: options.PoolSize,
MinIdleConns: options.MinIdleConns,
MaxActiveConns: options.MaxActiveConns,
ClientName: options.ClientName,
}
if options.ReadTimeout > 0 {
clusterOptions.ReadTimeout = options.ReadTimeout
}
if options.WriteTimeout > 0 {
clusterOptions.WriteTimeout = options.WriteTimeout
}
if cf.config.TLS {
clusterOptions.TLSConfig = &tls.Config{
InsecureSkipVerify: true, // For testing purposes
}
}
client = redis.NewClusterClient(clusterOptions)
} else {
// Create single client
clientOptions := &redis.Options{
Addr: fmt.Sprintf("%s:%d", cf.config.Host, cf.config.Port),
Username: cf.config.Username,
Password: cf.config.Password,
DB: options.DB,
Protocol: options.Protocol,
MaintNotificationsConfig: options.MaintNotificationsConfig,
MaxRetries: options.MaxRetries,
PoolSize: options.PoolSize,
MinIdleConns: options.MinIdleConns,
MaxActiveConns: options.MaxActiveConns,
ClientName: options.ClientName,
}
if options.ReadTimeout > 0 {
clientOptions.ReadTimeout = options.ReadTimeout
}
if options.WriteTimeout > 0 {
clientOptions.WriteTimeout = options.WriteTimeout
}
if cf.config.TLS {
clientOptions.TLSConfig = &tls.Config{
InsecureSkipVerify: true, // For testing purposes
}
}
client = redis.NewClient(clientOptions)
}
// Store the client
cf.clients[key] = client
return client, nil
}
// Get retrieves an existing client by key or the first one if no key is provided
func (cf *ClientFactory) Get(key string) redis.UniversalClient {
cf.mutex.RLock()
defer cf.mutex.RUnlock()
if key != "" {
return cf.clients[key]
}
// Return the first client if no key is provided
for _, client := range cf.clients {
return client
}
return nil
}
// GetAll returns all created clients
func (cf *ClientFactory) GetAll() map[string]redis.UniversalClient {
cf.mutex.RLock()
defer cf.mutex.RUnlock()
result := make(map[string]redis.UniversalClient)
for key, client := range cf.clients {
result[key] = client
}
return result
}
// DestroyAll closes and removes all created clients
func (cf *ClientFactory) DestroyAll() error {
cf.mutex.Lock()
defer cf.mutex.Unlock()
var lastErr error
for key, client := range cf.clients {
if err := client.Close(); err != nil {
lastErr = err
}
delete(cf.clients, key)
}
return lastErr
}
// Destroy closes and removes a specific client
func (cf *ClientFactory) Destroy(key string) error {
cf.mutex.Lock()
defer cf.mutex.Unlock()
client, exists := cf.clients[key]
if !exists {
return fmt.Errorf("client %s not found", key)
}
err := client.Close()
delete(cf.clients, key)
return err
}
// GetConfig returns the connection configuration
func (cf *ClientFactory) GetConfig() *RedisConnectionConfig {
return cf.config
}
// Helper methods
// isClusterEndpoint determines if the configuration represents a cluster
func (cf *ClientFactory) isClusterEndpoint() bool {
// Check if any endpoint contains cluster-related keywords
for _, endpoint := range cf.config.Endpoints {
if strings.Contains(strings.ToLower(endpoint), "cluster") {
return true
}
}
// Check if we have multiple raw endpoints
if len(cf.config.Endpoints) > 1 {
return true
}
return false
}
// getAddresses returns a list of addresses for cluster configuration
func (cf *ClientFactory) getAddresses() []string {
if len(cf.config.Endpoints) > 0 {
addresses := make([]string, 0, len(cf.config.Endpoints))
for _, endpoint := range cf.config.Endpoints {
if parsedURL, err := url.Parse(endpoint); err == nil {
addr := parsedURL.Host
if addr != "" {
addresses = append(addresses, addr)
}
}
}
if len(addresses) > 0 {
return addresses
}
}
// Fallback to single address
return []string{fmt.Sprintf("%s:%d", cf.config.Host, cf.config.Port)}
}
// Utility functions for common test scenarios
// CreateTestClientFactory creates a client factory from environment configuration
func CreateTestClientFactory(databaseName string) (*ClientFactory, error) {
envConfig, err := GetEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to get environment config: %w", err)
}
databasesConfig, err := GetDatabaseConfigFromEnv(envConfig.RedisEndpointsConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to get database config: %w", err)
}
dbConfig, err := GetDatabaseConfig(databasesConfig, databaseName)
if err != nil {
return nil, fmt.Errorf("failed to get database config for %s: %w", databaseName, err)
}
return NewClientFactory(dbConfig), nil
}
// CreateTestFaultInjector creates a fault injector client from environment configuration
func CreateTestFaultInjector() (*FaultInjectorClient, error) {
envConfig, err := GetEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to get environment config: %w", err)
}
return NewFaultInjectorClient(envConfig.FaultInjectorURL), nil
}
// GetAvailableDatabases returns a list of available database names from the configuration
func GetAvailableDatabases(configPath string) ([]string, error) {
databasesConfig, err := GetDatabaseConfigFromEnv(configPath)
if err != nil {
return nil, err
}
databases := make([]string, 0, len(databasesConfig))
for name := range databasesConfig {
databases = append(databases, name)
}
return databases, nil
}

View File

@@ -0,0 +1,21 @@
// Package e2e provides end-to-end testing scenarios for the maintenance notifications system.
//
// This package contains comprehensive test scenarios that validate the maintenance notifications
// functionality in realistic environments. The tests are designed to work with Redis Enterprise
// clusters and require specific environment configuration.
//
// Environment Variables:
// - E2E_SCENARIO_TESTS: Set to "true" to enable scenario tests
// - REDIS_ENDPOINTS_CONFIG_PATH: Path to endpoints configuration file
// - FAULT_INJECTION_API_URL: URL for fault injection API (optional)
//
// Test Scenarios:
// - Basic Push Notifications: Core functionality testing
// - Endpoint Types: Different endpoint resolution strategies
// - Timeout Configurations: Various timeout strategies
// - TLS Configurations: Different TLS setups
// - Stress Testing: Extreme load and concurrent operations
//
// Note: Maintenance notifications are currently supported only in standalone Redis clients.
// Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
package e2e

View File

@@ -0,0 +1,110 @@
{
"standalone0": {
"password": "foobared",
"tls": false,
"endpoints": [
"redis://localhost:6379"
]
},
"standalone0-tls": {
"username": "default",
"password": "foobared",
"tls": true,
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
"endpoints": [
"rediss://localhost:6390"
]
},
"standalone0-acl": {
"username": "acljedis",
"password": "fizzbuzz",
"tls": false,
"endpoints": [
"redis://localhost:6379"
]
},
"standalone0-acl-tls": {
"username": "acljedis",
"password": "fizzbuzz",
"tls": true,
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
"endpoints": [
"rediss://localhost:6390"
]
},
"cluster0": {
"username": "default",
"password": "foobared",
"tls": false,
"endpoints": [
"redis://localhost:7001",
"redis://localhost:7002",
"redis://localhost:7003",
"redis://localhost:7004",
"redis://localhost:7005",
"redis://localhost:7006"
]
},
"cluster0-tls": {
"username": "default",
"password": "foobared",
"tls": true,
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
"endpoints": [
"rediss://localhost:7011",
"rediss://localhost:7012",
"rediss://localhost:7013",
"rediss://localhost:7014",
"rediss://localhost:7015",
"rediss://localhost:7016"
]
},
"sentinel0": {
"username": "default",
"password": "foobared",
"tls": false,
"endpoints": [
"redis://localhost:26379",
"redis://localhost:26380",
"redis://localhost:26381"
]
},
"modules-docker": {
"tls": false,
"endpoints": [
"redis://localhost:6479"
]
},
"enterprise-cluster": {
"bdb_id": 1,
"username": "default",
"password": "enterprise-password",
"tls": true,
"raw_endpoints": [
{
"addr": ["10.0.0.1"],
"addr_type": "ipv4",
"dns_name": "redis-enterprise-cluster.example.com",
"oss_cluster_api_preferred_endpoint_type": "internal",
"oss_cluster_api_preferred_ip_type": "ipv4",
"port": 12000,
"proxy_policy": "single",
"uid": "endpoint-1"
},
{
"addr": ["10.0.0.2"],
"addr_type": "ipv4",
"dns_name": "redis-enterprise-cluster-2.example.com",
"oss_cluster_api_preferred_endpoint_type": "internal",
"oss_cluster_api_preferred_ip_type": "ipv4",
"port": 12000,
"proxy_policy": "single",
"uid": "endpoint-2"
}
],
"endpoints": [
"rediss://redis-enterprise-cluster.example.com:12000",
"rediss://redis-enterprise-cluster-2.example.com:12000"
]
}
}

View File

@@ -0,0 +1,565 @@
package e2e
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
// ActionType represents the type of fault injection action
type ActionType string
const (
// Redis cluster actions
ActionClusterFailover ActionType = "cluster_failover"
ActionClusterReshard ActionType = "cluster_reshard"
ActionClusterAddNode ActionType = "cluster_add_node"
ActionClusterRemoveNode ActionType = "cluster_remove_node"
ActionClusterMigrate ActionType = "cluster_migrate"
// Node-level actions
ActionNodeRestart ActionType = "node_restart"
ActionNodeStop ActionType = "node_stop"
ActionNodeStart ActionType = "node_start"
ActionNodeKill ActionType = "node_kill"
// Network simulation actions
ActionNetworkPartition ActionType = "network_partition"
ActionNetworkLatency ActionType = "network_latency"
ActionNetworkPacketLoss ActionType = "network_packet_loss"
ActionNetworkBandwidth ActionType = "network_bandwidth"
ActionNetworkRestore ActionType = "network_restore"
// Redis configuration actions
ActionConfigChange ActionType = "config_change"
ActionMaintenanceMode ActionType = "maintenance_mode"
ActionSlotMigration ActionType = "slot_migration"
// Sequence and complex actions
ActionSequence ActionType = "sequence_of_actions"
ActionExecuteCommand ActionType = "execute_command"
)
// ActionStatus represents the status of an action
type ActionStatus string
const (
StatusPending ActionStatus = "pending"
StatusRunning ActionStatus = "running"
StatusFinished ActionStatus = "finished"
StatusFailed ActionStatus = "failed"
StatusSuccess ActionStatus = "success"
StatusCancelled ActionStatus = "cancelled"
)
// ActionRequest represents a request to trigger an action
type ActionRequest struct {
Type ActionType `json:"type"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
// ActionResponse represents the response from triggering an action
type ActionResponse struct {
ActionID string `json:"action_id"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
// ActionStatusResponse represents the status of an action
type ActionStatusResponse struct {
ActionID string `json:"action_id"`
Status ActionStatus `json:"status"`
Error interface{} `json:"error,omitempty"`
Output map[string]interface{} `json:"output,omitempty"`
Progress float64 `json:"progress,omitempty"`
StartTime time.Time `json:"start_time,omitempty"`
EndTime time.Time `json:"end_time,omitempty"`
}
// SequenceAction represents an action in a sequence
type SequenceAction struct {
Type ActionType `json:"type"`
Parameters map[string]interface{} `json:"params,omitempty"`
Delay time.Duration `json:"delay,omitempty"`
}
// FaultInjectorClient provides programmatic control over test infrastructure
type FaultInjectorClient struct {
baseURL string
httpClient *http.Client
}
// NewFaultInjectorClient creates a new fault injector client
func NewFaultInjectorClient(baseURL string) *FaultInjectorClient {
return &FaultInjectorClient{
baseURL: strings.TrimSuffix(baseURL, "/"),
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// GetBaseURL returns the base URL of the fault injector server
func (c *FaultInjectorClient) GetBaseURL() string {
return c.baseURL
}
// ListActions lists all available actions
func (c *FaultInjectorClient) ListActions(ctx context.Context) ([]ActionType, error) {
var actions []ActionType
err := c.request(ctx, "GET", "/actions", nil, &actions)
return actions, err
}
// TriggerAction triggers a specific action
func (c *FaultInjectorClient) TriggerAction(ctx context.Context, action ActionRequest) (*ActionResponse, error) {
var response ActionResponse
err := c.request(ctx, "POST", "/action", action, &response)
return &response, err
}
func (c *FaultInjectorClient) TriggerSequence(ctx context.Context, bdbID int, actions []SequenceAction) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionSequence,
Parameters: map[string]interface{}{
"bdb_id": bdbID,
"actions": actions,
},
})
}
// GetActionStatus gets the status of a specific action
func (c *FaultInjectorClient) GetActionStatus(ctx context.Context, actionID string) (*ActionStatusResponse, error) {
var status ActionStatusResponse
err := c.request(ctx, "GET", fmt.Sprintf("/action/%s", actionID), nil, &status)
return &status, err
}
// WaitForAction waits for an action to complete
func (c *FaultInjectorClient) WaitForAction(ctx context.Context, actionID string, options ...WaitOption) (*ActionStatusResponse, error) {
config := &waitConfig{
pollInterval: 1 * time.Second,
maxWaitTime: 60 * time.Second,
}
for _, opt := range options {
opt(config)
}
deadline := time.Now().Add(config.maxWaitTime)
ticker := time.NewTicker(config.pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Until(deadline)):
return nil, fmt.Errorf("timeout waiting for action %s after %v", actionID, config.maxWaitTime)
case <-ticker.C:
status, err := c.GetActionStatus(ctx, actionID)
if err != nil {
return nil, fmt.Errorf("failed to get action status: %w", err)
}
switch status.Status {
case StatusFinished, StatusSuccess, StatusFailed, StatusCancelled:
return status, nil
}
}
}
}
// Cluster Management Actions
// TriggerClusterFailover triggers a cluster failover
func (c *FaultInjectorClient) TriggerClusterFailover(ctx context.Context, nodeID string, force bool) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionClusterFailover,
Parameters: map[string]interface{}{
"node_id": nodeID,
"force": force,
},
})
}
// TriggerClusterReshard triggers cluster resharding
func (c *FaultInjectorClient) TriggerClusterReshard(ctx context.Context, slots []int, sourceNode, targetNode string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionClusterReshard,
Parameters: map[string]interface{}{
"slots": slots,
"source_node": sourceNode,
"target_node": targetNode,
},
})
}
// TriggerSlotMigration triggers migration of specific slots
func (c *FaultInjectorClient) TriggerSlotMigration(ctx context.Context, startSlot, endSlot int, sourceNode, targetNode string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionSlotMigration,
Parameters: map[string]interface{}{
"start_slot": startSlot,
"end_slot": endSlot,
"source_node": sourceNode,
"target_node": targetNode,
},
})
}
// Node Management Actions
// RestartNode restarts a specific Redis node
func (c *FaultInjectorClient) RestartNode(ctx context.Context, nodeID string, graceful bool) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeRestart,
Parameters: map[string]interface{}{
"node_id": nodeID,
"graceful": graceful,
},
})
}
// StopNode stops a specific Redis node
func (c *FaultInjectorClient) StopNode(ctx context.Context, nodeID string, graceful bool) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeStop,
Parameters: map[string]interface{}{
"node_id": nodeID,
"graceful": graceful,
},
})
}
// StartNode starts a specific Redis node
func (c *FaultInjectorClient) StartNode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeStart,
Parameters: map[string]interface{}{
"node_id": nodeID,
},
})
}
// KillNode forcefully kills a Redis node
func (c *FaultInjectorClient) KillNode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeKill,
Parameters: map[string]interface{}{
"node_id": nodeID,
},
})
}
// Network Simulation Actions
// SimulateNetworkPartition simulates a network partition
func (c *FaultInjectorClient) SimulateNetworkPartition(ctx context.Context, nodes []string, duration time.Duration) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkPartition,
Parameters: map[string]interface{}{
"nodes": nodes,
"duration": duration.String(),
},
})
}
// SimulateNetworkLatency adds network latency
func (c *FaultInjectorClient) SimulateNetworkLatency(ctx context.Context, nodes []string, latency time.Duration, jitter time.Duration) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkLatency,
Parameters: map[string]interface{}{
"nodes": nodes,
"latency": latency.String(),
"jitter": jitter.String(),
},
})
}
// SimulatePacketLoss simulates packet loss
func (c *FaultInjectorClient) SimulatePacketLoss(ctx context.Context, nodes []string, lossPercent float64) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkPacketLoss,
Parameters: map[string]interface{}{
"nodes": nodes,
"loss_percent": lossPercent,
},
})
}
// LimitBandwidth limits network bandwidth
func (c *FaultInjectorClient) LimitBandwidth(ctx context.Context, nodes []string, bandwidth string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkBandwidth,
Parameters: map[string]interface{}{
"nodes": nodes,
"bandwidth": bandwidth,
},
})
}
// RestoreNetwork restores normal network conditions
func (c *FaultInjectorClient) RestoreNetwork(ctx context.Context, nodes []string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkRestore,
Parameters: map[string]interface{}{
"nodes": nodes,
},
})
}
// Configuration Actions
// ChangeConfig changes Redis configuration
func (c *FaultInjectorClient) ChangeConfig(ctx context.Context, nodeID string, config map[string]string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionConfigChange,
Parameters: map[string]interface{}{
"node_id": nodeID,
"config": config,
},
})
}
// EnableMaintenanceMode enables maintenance mode
func (c *FaultInjectorClient) EnableMaintenanceMode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionMaintenanceMode,
Parameters: map[string]interface{}{
"node_id": nodeID,
"enabled": true,
},
})
}
// DisableMaintenanceMode disables maintenance mode
func (c *FaultInjectorClient) DisableMaintenanceMode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionMaintenanceMode,
Parameters: map[string]interface{}{
"node_id": nodeID,
"enabled": false,
},
})
}
// Complex Actions
// ExecuteSequence executes a sequence of actions
func (c *FaultInjectorClient) ExecuteSequence(ctx context.Context, actions []SequenceAction) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionSequence,
Parameters: map[string]interface{}{
"actions": actions,
},
})
}
// ExecuteCommand executes a custom command
func (c *FaultInjectorClient) ExecuteCommand(ctx context.Context, nodeID, command string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionExecuteCommand,
Parameters: map[string]interface{}{
"node_id": nodeID,
"command": command,
},
})
}
// Convenience Methods
// SimulateClusterUpgrade simulates a complete cluster upgrade scenario
func (c *FaultInjectorClient) SimulateClusterUpgrade(ctx context.Context, nodes []string) (*ActionResponse, error) {
actions := make([]SequenceAction, 0, len(nodes)*2)
// Rolling restart of all nodes
for i, nodeID := range nodes {
actions = append(actions, SequenceAction{
Type: ActionNodeRestart,
Parameters: map[string]interface{}{
"node_id": nodeID,
"graceful": true,
},
Delay: time.Duration(i*10) * time.Second, // Stagger restarts
})
}
return c.ExecuteSequence(ctx, actions)
}
// SimulateNetworkIssues simulates various network issues
func (c *FaultInjectorClient) SimulateNetworkIssues(ctx context.Context, nodes []string) (*ActionResponse, error) {
actions := []SequenceAction{
{
Type: ActionNetworkLatency,
Parameters: map[string]interface{}{
"nodes": nodes,
"latency": "100ms",
"jitter": "20ms",
},
},
{
Type: ActionNetworkPacketLoss,
Parameters: map[string]interface{}{
"nodes": nodes,
"loss_percent": 2.0,
},
Delay: 30 * time.Second,
},
{
Type: ActionNetworkRestore,
Parameters: map[string]interface{}{
"nodes": nodes,
},
Delay: 60 * time.Second,
},
}
return c.ExecuteSequence(ctx, actions)
}
// Helper types and functions
type waitConfig struct {
pollInterval time.Duration
maxWaitTime time.Duration
}
type WaitOption func(*waitConfig)
// WithPollInterval sets the polling interval for waiting
func WithPollInterval(interval time.Duration) WaitOption {
return func(c *waitConfig) {
c.pollInterval = interval
}
}
// WithMaxWaitTime sets the maximum wait time
func WithMaxWaitTime(maxWait time.Duration) WaitOption {
return func(c *waitConfig) {
c.maxWaitTime = maxWait
}
}
// Internal HTTP request method
func (c *FaultInjectorClient) request(ctx context.Context, method, path string, body interface{}, result interface{}) error {
url := c.baseURL + path
var reqBody io.Reader
if body != nil {
jsonData, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
reqBody = bytes.NewReader(jsonData)
}
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
if result != nil {
if err := json.Unmarshal(respBody, result); err != nil {
// happens when the API changes and the response structure changes
// sometimes the output of the action status is map, sometimes it is json.
// since we don't have a proper response structure we are going to handle it here
if result, ok := result.(*ActionStatusResponse); ok {
mapResult := map[string]interface{}{}
err = json.Unmarshal(respBody, &mapResult)
if err != nil {
fmt.Println("Failed to unmarshal response:", string(respBody))
panic(err)
}
result.Error = mapResult["error"]
result.Output = map[string]interface{}{"result": mapResult["output"]}
if status, ok := mapResult["status"].(string); ok {
result.Status = ActionStatus(status)
}
if result.Status == StatusSuccess || result.Status == StatusFailed || result.Status == StatusCancelled {
result.EndTime = time.Now()
}
if progress, ok := mapResult["progress"].(float64); ok {
result.Progress = progress
}
if actionID, ok := mapResult["action_id"].(string); ok {
result.ActionID = actionID
}
return nil
}
fmt.Println("Failed to unmarshal response:", string(respBody))
panic(err)
}
}
return nil
}
// Utility functions for common scenarios
// GetClusterNodes returns a list of cluster node IDs
func GetClusterNodes() []string {
// TODO Implement
// This would typically be configured via environment or discovery
return []string{"node-1", "node-2", "node-3", "node-4", "node-5", "node-6"}
}
// GetMasterNodes returns a list of master node IDs
func GetMasterNodes() []string {
// TODO Implement
return []string{"node-1", "node-2", "node-3"}
}
// GetSlaveNodes returns a list of slave node IDs
func GetSlaveNodes() []string {
// TODO Implement
return []string{"node-4", "node-5", "node-6"}
}
// ParseNodeID extracts node ID from various formats
func ParseNodeID(nodeAddr string) string {
// Extract node ID from address like "redis-node-1:7001" -> "node-1"
parts := strings.Split(nodeAddr, ":")
if len(parts) > 0 {
addr := parts[0]
if strings.Contains(addr, "redis-") {
return strings.TrimPrefix(addr, "redis-")
}
return addr
}
return nodeAddr
}
// FormatSlotRange formats a slot range for Redis commands
func FormatSlotRange(start, end int) string {
if start == end {
return strconv.Itoa(start)
}
return fmt.Sprintf("%d-%d", start, end)
}

View File

@@ -0,0 +1,434 @@
package e2e
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// logs is a slice of strings that provides additional functionality
// for filtering and analysis
type logs []string
func (l logs) Contains(searchString string) bool {
for _, log := range l {
if log == searchString {
return true
}
}
return false
}
func (l logs) GetCount() int {
return len(l)
}
func (l logs) GetCountThatContain(searchString string) int {
count := 0
for _, log := range l {
if strings.Contains(log, searchString) {
count++
}
}
return count
}
func (l logs) GetLogsFiltered(filter func(string) bool) []string {
filteredLogs := make([]string, 0, len(l))
for _, log := range l {
if filter(log) {
filteredLogs = append(filteredLogs, log)
}
}
return filteredLogs
}
func (l logs) GetTimedOutLogs() logs {
return l.GetLogsFiltered(isTimeout)
}
func (l logs) GetLogsPerConn(connID uint64) logs {
return l.GetLogsFiltered(func(log string) bool {
return strings.Contains(log, fmt.Sprintf("conn[%d]", connID))
})
}
func (l logs) GetAnalysis() *LogAnalisis {
return NewLogAnalysis(l)
}
// TestLogCollector is a simple logger that captures logs for analysis
// It is thread safe and can be used to capture logs from multiple clients
// It uses type logs to provide additional functionality like filtering
// and analysis
type TestLogCollector struct {
l logs
doPrint bool
matchFuncs []*MatchFunc
matchFuncsMutex sync.Mutex
mu sync.Mutex
}
func (tlc *TestLogCollector) DontPrint() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.doPrint = false
}
func (tlc *TestLogCollector) DoPrint() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.l = make([]string, 0)
tlc.doPrint = true
}
// MatchFunc is a slice of functions that check the logs for a specific condition
// use in WaitForLogMatchFunc
type MatchFunc struct {
completed atomic.Bool
F func(lstring string) bool
matches []string
found chan struct{} // channel to notify when match is found, will be closed
done func()
}
func (tlc *TestLogCollector) Printf(_ context.Context, format string, v ...interface{}) {
tlc.mu.Lock()
defer tlc.mu.Unlock()
lstr := fmt.Sprintf(format, v...)
if len(tlc.matchFuncs) > 0 {
go func(lstr string) {
for _, matchFunc := range tlc.matchFuncs {
if matchFunc.F(lstr) {
matchFunc.matches = append(matchFunc.matches, lstr)
matchFunc.done()
return
}
}
}(lstr)
}
if tlc.doPrint {
fmt.Println(lstr)
}
tlc.l = append(tlc.l, fmt.Sprintf(format, v...))
}
func (tlc *TestLogCollector) WaitForLogContaining(searchString string, timeout time.Duration) bool {
timeoutCh := time.After(timeout)
ticker := time.NewTicker(100 * time.Millisecond)
for {
select {
case <-timeoutCh:
return false
case <-ticker.C:
if tlc.Contains(searchString) {
return true
}
}
}
}
func (tlc *TestLogCollector) MatchOrWaitForLogMatchFunc(mf func(string) bool, timeout time.Duration) (string, bool) {
if logs := tlc.GetLogsFiltered(mf); len(logs) > 0 {
return logs[0], true
}
return tlc.WaitForLogMatchFunc(mf, timeout)
}
func (tlc *TestLogCollector) WaitForLogMatchFunc(mf func(string) bool, timeout time.Duration) (string, bool) {
matchFunc := &MatchFunc{
completed: atomic.Bool{},
F: mf,
found: make(chan struct{}),
matches: make([]string, 0),
}
matchFunc.done = func() {
if !matchFunc.completed.CompareAndSwap(false, true) {
return
}
close(matchFunc.found)
tlc.matchFuncsMutex.Lock()
defer tlc.matchFuncsMutex.Unlock()
for i, mf := range tlc.matchFuncs {
if mf == matchFunc {
tlc.matchFuncs = append(tlc.matchFuncs[:i], tlc.matchFuncs[i+1:]...)
return
}
}
}
tlc.matchFuncsMutex.Lock()
tlc.matchFuncs = append(tlc.matchFuncs, matchFunc)
tlc.matchFuncsMutex.Unlock()
select {
case <-matchFunc.found:
return matchFunc.matches[0], true
case <-time.After(timeout):
return "", false
}
}
func (tlc *TestLogCollector) GetLogs() logs {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l
}
func (tlc *TestLogCollector) DumpLogs() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
fmt.Println("Dumping logs:")
fmt.Println("===================================================")
for _, log := range tlc.l {
fmt.Println(log)
}
}
func (tlc *TestLogCollector) ClearLogs() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.l = make([]string, 0)
}
func (tlc *TestLogCollector) Contains(searchString string) bool {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.Contains(searchString)
}
func (tlc *TestLogCollector) MatchContainsAll(searchStrings []string) []string {
// match a log that contains all
return tlc.GetLogsFiltered(func(log string) bool {
for _, searchString := range searchStrings {
if !strings.Contains(log, searchString) {
return false
}
}
return true
})
}
func (tlc *TestLogCollector) GetLogCount() int {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetCount()
}
func (tlc *TestLogCollector) GetLogCountThatContain(searchString string) int {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetCountThatContain(searchString)
}
func (tlc *TestLogCollector) GetLogsFiltered(filter func(string) bool) logs {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetLogsFiltered(filter)
}
func (tlc *TestLogCollector) GetTimedOutLogs() []string {
return tlc.GetLogsFiltered(isTimeout)
}
func (tlc *TestLogCollector) GetLogsPerConn(connID uint64) logs {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetLogsPerConn(connID)
}
func (tlc *TestLogCollector) GetAnalysisForConn(connID uint64) *LogAnalisis {
return NewLogAnalysis(tlc.GetLogsPerConn(connID))
}
func NewTestLogCollector() *TestLogCollector {
return &TestLogCollector{
l: make([]string, 0),
}
}
func (tlc *TestLogCollector) GetAnalysis() *LogAnalisis {
return NewLogAnalysis(tlc.GetLogs())
}
func (tlc *TestLogCollector) Clear() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.matchFuncs = make([]*MatchFunc, 0)
tlc.l = make([]string, 0)
}
// LogAnalisis provides analysis of logs captured by TestLogCollector
type LogAnalisis struct {
logs []string
TimeoutErrorsCount int64
RelaxedTimeoutCount int64
RelaxedPostHandoffCount int64
UnrelaxedTimeoutCount int64
UnrelaxedAfterMoving int64
ConnectionCount int64
connLogs map[uint64][]string
connIds map[uint64]bool
TotalNotifications int64
MovingCount int64
MigratingCount int64
MigratedCount int64
FailingOverCount int64
FailedOverCount int64
UnexpectedCount int64
TotalHandoffCount int64
FailedHandoffCount int64
SucceededHandoffCount int64
TotalHandoffRetries int64
TotalHandoffToCurrentEndpoint int64
}
func NewLogAnalysis(logs []string) *LogAnalisis {
la := &LogAnalisis{
logs: logs,
connLogs: make(map[uint64][]string),
connIds: make(map[uint64]bool),
}
la.Analyze()
return la
}
func (la *LogAnalisis) Analyze() {
hasMoving := false
for _, log := range la.logs {
if isTimeout(log) {
la.TimeoutErrorsCount++
}
if strings.Contains(log, "MOVING") {
hasMoving = true
}
if strings.Contains(log, logs2.RelaxedTimeoutDueToNotificationMessage) {
la.RelaxedTimeoutCount++
}
if strings.Contains(log, logs2.ApplyingRelaxedTimeoutDueToPostHandoffMessage) {
la.RelaxedTimeoutCount++
la.RelaxedPostHandoffCount++
}
if strings.Contains(log, logs2.UnrelaxedTimeoutMessage) {
la.UnrelaxedTimeoutCount++
}
if strings.Contains(log, logs2.UnrelaxedTimeoutAfterDeadlineMessage) {
if hasMoving {
la.UnrelaxedAfterMoving++
} else {
fmt.Printf("Unrelaxed after deadline but no MOVING: %s\n", log)
}
}
if strings.Contains(log, logs2.ProcessingNotificationMessage) {
la.TotalNotifications++
switch {
case notificationType(log, "MOVING"):
la.MovingCount++
case notificationType(log, "MIGRATING"):
la.MigratingCount++
case notificationType(log, "MIGRATED"):
la.MigratedCount++
case notificationType(log, "FAILING_OVER"):
la.FailingOverCount++
case notificationType(log, "FAILED_OVER"):
la.FailedOverCount++
default:
fmt.Printf("[ERROR] Unexpected notification: %s\n", log)
la.UnexpectedCount++
}
}
if strings.Contains(log, "conn[") {
connID := extractConnID(log)
if _, ok := la.connIds[connID]; !ok {
la.connIds[connID] = true
la.ConnectionCount++
}
la.connLogs[connID] = append(la.connLogs[connID], log)
}
if strings.Contains(log, logs2.SchedulingHandoffToCurrentEndpointMessage) {
la.TotalHandoffToCurrentEndpoint++
}
if strings.Contains(log, logs2.HandoffSuccessMessage) {
la.SucceededHandoffCount++
}
if strings.Contains(log, logs2.HandoffFailedMessage) {
la.FailedHandoffCount++
}
if strings.Contains(log, logs2.HandoffStartedMessage) {
la.TotalHandoffCount++
}
if strings.Contains(log, logs2.HandoffRetryAttemptMessage) {
la.TotalHandoffRetries++
}
}
}
func (la *LogAnalisis) Print(t *testing.T) {
t.Logf("Log Analysis results for %d logs and %d connections:", len(la.logs), len(la.connIds))
t.Logf("Connection Count: %d", la.ConnectionCount)
t.Logf("-------------")
t.Logf("-Timeout Analysis-")
t.Logf("-------------")
t.Logf("Timeout Errors: %d", la.TimeoutErrorsCount)
t.Logf("Relaxed Timeout Count: %d", la.RelaxedTimeoutCount)
t.Logf(" - Relaxed Timeout After Post-Handoff: %d", la.RelaxedPostHandoffCount)
t.Logf("Unrelaxed Timeout Count: %d", la.UnrelaxedTimeoutCount)
t.Logf(" - Unrelaxed Timeout After Moving: %d", la.UnrelaxedAfterMoving)
t.Logf("-------------")
t.Logf("-Handoff Analysis-")
t.Logf("-------------")
t.Logf("Total Handoffs: %d", la.TotalHandoffCount)
t.Logf(" - Succeeded: %d", la.SucceededHandoffCount)
t.Logf(" - Failed: %d", la.FailedHandoffCount)
t.Logf(" - Retries: %d", la.TotalHandoffRetries)
t.Logf(" - Handoffs to current endpoint: %d", la.TotalHandoffToCurrentEndpoint)
t.Logf("-------------")
t.Logf("-Notification Analysis-")
t.Logf("-------------")
t.Logf("Total Notifications: %d", la.TotalNotifications)
t.Logf(" - MOVING: %d", la.MovingCount)
t.Logf(" - MIGRATING: %d", la.MigratingCount)
t.Logf(" - MIGRATED: %d", la.MigratedCount)
t.Logf(" - FAILING_OVER: %d", la.FailingOverCount)
t.Logf(" - FAILED_OVER: %d", la.FailedOverCount)
t.Logf(" - Unexpected: %d", la.UnexpectedCount)
t.Logf("-------------")
t.Logf("Log Analysis completed successfully")
}
func extractConnID(log string) uint64 {
logParts := strings.Split(log, "conn[")
if len(logParts) < 2 {
return 0
}
connIDStr := strings.Split(logParts[1], "]")[0]
connID, err := strconv.ParseUint(connIDStr, 10, 64)
if err != nil {
return 0
}
return connID
}
func notificationType(log string, nt string) bool {
return strings.Contains(log, nt)
}
func connID(log string, connID uint64) bool {
return strings.Contains(log, fmt.Sprintf("conn[%d]", connID))
}
func seqID(log string, seqID int64) bool {
return strings.Contains(log, fmt.Sprintf("seqID[%d]", seqID))
}

View File

@@ -0,0 +1,39 @@
package e2e
import (
"log"
"os"
"testing"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
)
// Global log collector
var logCollector *TestLogCollector
// Global fault injector client
var faultInjector *FaultInjectorClient
func TestMain(m *testing.M) {
var err error
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
log.Println("Skipping scenario tests, E2E_SCENARIO_TESTS is not set")
return
}
faultInjector, err = CreateTestFaultInjector()
if err != nil {
panic("Failed to create fault injector: " + err.Error())
}
// use log collector to capture logs from redis clients
logCollector = NewTestLogCollector()
redis.SetLogger(logCollector)
redis.SetLogLevel(logging.LogLevelDebug)
logCollector.Clear()
defer logCollector.Clear()
log.Println("Running scenario tests...")
status := m.Run()
os.Exit(status)
}

View File

@@ -0,0 +1,404 @@
package e2e
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
// DiagnosticsEvent represents a notification event
// it may be a push notification or an error when processing
// push notifications
type DiagnosticsEvent struct {
// is this pre or post hook
Type string `json:"type"`
ConnID uint64 `json:"connID"`
SeqID int64 `json:"seqID"`
Error error `json:"error"`
Pre bool `json:"pre"`
Timestamp time.Time `json:"timestamp"`
Details map[string]interface{} `json:"details"`
}
// TrackingNotificationsHook is a notification hook that tracks notifications
type TrackingNotificationsHook struct {
// unique connection count
connectionCount atomic.Int64
// timeouts
relaxedTimeoutCount atomic.Int64
unrelaxedTimeoutCount atomic.Int64
notificationProcessingErrors atomic.Int64
// notification types
totalNotifications atomic.Int64
migratingCount atomic.Int64
migratedCount atomic.Int64
failingOverCount atomic.Int64
failedOverCount atomic.Int64
movingCount atomic.Int64
unexpectedNotificationCount atomic.Int64
diagnosticsLog []DiagnosticsEvent
connIds map[uint64]bool
connLogs map[uint64][]DiagnosticsEvent
mutex sync.RWMutex
}
// NewTrackingNotificationsHook creates a new notification hook with counters
func NewTrackingNotificationsHook() *TrackingNotificationsHook {
return &TrackingNotificationsHook{
diagnosticsLog: make([]DiagnosticsEvent, 0),
connIds: make(map[uint64]bool),
connLogs: make(map[uint64][]DiagnosticsEvent),
}
}
// it is not reusable, but just to keep it consistent
// with the log collector
func (tnh *TrackingNotificationsHook) Clear() {
tnh.mutex.Lock()
defer tnh.mutex.Unlock()
tnh.diagnosticsLog = make([]DiagnosticsEvent, 0)
tnh.connIds = make(map[uint64]bool)
tnh.connLogs = make(map[uint64][]DiagnosticsEvent)
tnh.relaxedTimeoutCount.Store(0)
tnh.unrelaxedTimeoutCount.Store(0)
tnh.notificationProcessingErrors.Store(0)
tnh.totalNotifications.Store(0)
tnh.migratingCount.Store(0)
tnh.migratedCount.Store(0)
tnh.failingOverCount.Store(0)
}
// PreHook captures timeout-related events before processing
func (tnh *TrackingNotificationsHook) PreHook(_ context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
tnh.increaseNotificationCount(notificationType)
tnh.storeDiagnosticsEvent(notificationType, notification, notificationCtx)
tnh.increaseRelaxedTimeoutCount(notificationType)
return notification, true
}
func (tnh *TrackingNotificationsHook) getConnID(notificationCtx push.NotificationHandlerContext) uint64 {
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
return conn.GetID()
}
return 0
}
func (tnh *TrackingNotificationsHook) getSeqID(notification []interface{}) int64 {
seqID, ok := notification[1].(int64)
if !ok {
return 0
}
return seqID
}
// PostHook captures the result after processing push notification
func (tnh *TrackingNotificationsHook) PostHook(_ context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, err error) {
if err != nil {
event := DiagnosticsEvent{
Type: notificationType + "_ERROR",
ConnID: tnh.getConnID(notificationCtx),
SeqID: tnh.getSeqID(notification),
Error: err,
Timestamp: time.Now(),
Details: map[string]interface{}{
"notification": notification,
"context": "post-hook",
},
}
tnh.notificationProcessingErrors.Add(1)
tnh.mutex.Lock()
tnh.diagnosticsLog = append(tnh.diagnosticsLog, event)
tnh.mutex.Unlock()
}
}
func (tnh *TrackingNotificationsHook) storeDiagnosticsEvent(notificationType string, notification []interface{}, notificationCtx push.NotificationHandlerContext) {
connID := tnh.getConnID(notificationCtx)
event := DiagnosticsEvent{
Type: notificationType,
ConnID: connID,
SeqID: tnh.getSeqID(notification),
Pre: true,
Timestamp: time.Now(),
Details: map[string]interface{}{
"notification": notification,
"context": "pre-hook",
},
}
tnh.mutex.Lock()
if v, ok := tnh.connIds[connID]; !ok || !v {
tnh.connIds[connID] = true
tnh.connectionCount.Add(1)
}
tnh.connLogs[connID] = append(tnh.connLogs[connID], event)
tnh.diagnosticsLog = append(tnh.diagnosticsLog, event)
tnh.mutex.Unlock()
}
// GetRelaxedTimeoutCount returns the count of relaxed timeout events
func (tnh *TrackingNotificationsHook) GetRelaxedTimeoutCount() int64 {
return tnh.relaxedTimeoutCount.Load()
}
// GetUnrelaxedTimeoutCount returns the count of unrelaxed timeout events
func (tnh *TrackingNotificationsHook) GetUnrelaxedTimeoutCount() int64 {
return tnh.unrelaxedTimeoutCount.Load()
}
// GetNotificationProcessingErrors returns the count of timeout errors
func (tnh *TrackingNotificationsHook) GetNotificationProcessingErrors() int64 {
return tnh.notificationProcessingErrors.Load()
}
// GetTotalNotifications returns the total number of notifications processed
func (tnh *TrackingNotificationsHook) GetTotalNotifications() int64 {
return tnh.totalNotifications.Load()
}
// GetConnectionCount returns the current connection count
func (tnh *TrackingNotificationsHook) GetConnectionCount() int64 {
return tnh.connectionCount.Load()
}
// GetMovingCount returns the count of MOVING notifications
func (tnh *TrackingNotificationsHook) GetMovingCount() int64 {
return tnh.movingCount.Load()
}
// GetDiagnosticsLog returns a copy of the diagnostics log
func (tnh *TrackingNotificationsHook) GetDiagnosticsLog() []DiagnosticsEvent {
tnh.mutex.RLock()
defer tnh.mutex.RUnlock()
logCopy := make([]DiagnosticsEvent, len(tnh.diagnosticsLog))
copy(logCopy, tnh.diagnosticsLog)
return logCopy
}
func (tnh *TrackingNotificationsHook) increaseNotificationCount(notificationType string) {
tnh.totalNotifications.Add(1)
switch notificationType {
case "MOVING":
tnh.movingCount.Add(1)
case "MIGRATING":
tnh.migratingCount.Add(1)
case "MIGRATED":
tnh.migratedCount.Add(1)
case "FAILING_OVER":
tnh.failingOverCount.Add(1)
case "FAILED_OVER":
tnh.failedOverCount.Add(1)
default:
tnh.unexpectedNotificationCount.Add(1)
}
}
func (tnh *TrackingNotificationsHook) increaseRelaxedTimeoutCount(notificationType string) {
switch notificationType {
case "MIGRATING", "FAILING_OVER":
tnh.relaxedTimeoutCount.Add(1)
case "MIGRATED", "FAILED_OVER":
tnh.unrelaxedTimeoutCount.Add(1)
}
}
// setupNotificationHook sets up tracking for both regular and cluster clients with notification hooks
func setupNotificationHook(client redis.UniversalClient, hook maintnotifications.NotificationHook) {
if clusterClient, ok := client.(*redis.ClusterClient); ok {
setupClusterClientNotificationHook(clusterClient, hook)
} else if regularClient, ok := client.(*redis.Client); ok {
setupRegularClientNotificationHook(regularClient, hook)
}
}
// setupNotificationHooks sets up tracking for both regular and cluster clients with notification hooks
func setupNotificationHooks(client redis.UniversalClient, hooks ...maintnotifications.NotificationHook) {
for _, hook := range hooks {
setupNotificationHook(client, hook)
}
}
// setupRegularClientNotificationHook sets up notification hook for regular clients
func setupRegularClientNotificationHook(client *redis.Client, hook maintnotifications.NotificationHook) {
maintnotificationsManager := client.GetMaintNotificationsManager()
if maintnotificationsManager != nil {
maintnotificationsManager.AddNotificationHook(hook)
} else {
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking\n")
}
}
// setupClusterClientNotificationHook sets up notification hook for cluster clients
func setupClusterClientNotificationHook(client *redis.ClusterClient, hook maintnotifications.NotificationHook) {
ctx := context.Background()
// Register hook on existing nodes
err := client.ForEachShard(ctx, func(ctx context.Context, nodeClient *redis.Client) error {
maintnotificationsManager := nodeClient.GetMaintNotificationsManager()
if maintnotificationsManager != nil {
maintnotificationsManager.AddNotificationHook(hook)
} else {
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking on node: %s\n", nodeClient.Options().Addr)
}
return nil
})
if err != nil {
fmt.Printf("[TNH] Warning: Failed to register timeout tracking hooks on existing cluster nodes: %v\n", err)
}
// Register hook on new nodes
client.OnNewNode(func(nodeClient *redis.Client) {
maintnotificationsManager := nodeClient.GetMaintNotificationsManager()
if maintnotificationsManager != nil {
maintnotificationsManager.AddNotificationHook(hook)
} else {
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking on new node: %s\n", nodeClient.Options().Addr)
}
})
}
// filterPushNotificationLogs filters the diagnostics log for push notification events
func filterPushNotificationLogs(diagnosticsLog []DiagnosticsEvent) []DiagnosticsEvent {
var pushNotificationLogs []DiagnosticsEvent
for _, log := range diagnosticsLog {
switch log.Type {
case "MOVING", "MIGRATING", "MIGRATED":
pushNotificationLogs = append(pushNotificationLogs, log)
}
}
return pushNotificationLogs
}
func (tnh *TrackingNotificationsHook) GetAnalysis() *DiagnosticsAnalysis {
return NewDiagnosticsAnalysis(tnh.GetDiagnosticsLog())
}
func (tnh *TrackingNotificationsHook) GetDiagnosticsLogForConn(connID uint64) []DiagnosticsEvent {
tnh.mutex.RLock()
defer tnh.mutex.RUnlock()
var connLogs []DiagnosticsEvent
for _, log := range tnh.diagnosticsLog {
if log.ConnID == connID {
connLogs = append(connLogs, log)
}
}
return connLogs
}
func (tnh *TrackingNotificationsHook) GetAnalysisForConn(connID uint64) *DiagnosticsAnalysis {
return NewDiagnosticsAnalysis(tnh.GetDiagnosticsLogForConn(connID))
}
type DiagnosticsAnalysis struct {
RelaxedTimeoutCount int64
UnrelaxedTimeoutCount int64
NotificationProcessingErrors int64
ConnectionCount int64
MovingCount int64
MigratingCount int64
MigratedCount int64
FailingOverCount int64
FailedOverCount int64
UnexpectedNotificationCount int64
TotalNotifications int64
diagnosticsLog []DiagnosticsEvent
connLogs map[uint64][]DiagnosticsEvent
connIds map[uint64]bool
}
func NewDiagnosticsAnalysis(diagnosticsLog []DiagnosticsEvent) *DiagnosticsAnalysis {
da := &DiagnosticsAnalysis{
diagnosticsLog: diagnosticsLog,
connLogs: make(map[uint64][]DiagnosticsEvent),
connIds: make(map[uint64]bool),
}
da.Analyze()
return da
}
func (da *DiagnosticsAnalysis) Analyze() {
for _, log := range da.diagnosticsLog {
da.TotalNotifications++
switch log.Type {
case "MOVING":
da.MovingCount++
case "MIGRATING":
da.MigratingCount++
case "MIGRATED":
da.MigratedCount++
case "FAILING_OVER":
da.FailingOverCount++
case "FAILED_OVER":
da.FailedOverCount++
default:
da.UnexpectedNotificationCount++
}
if log.Error != nil {
fmt.Printf("[ERROR] Notification processing error: %v\n", log.Error)
fmt.Printf("[ERROR] Notification: %v\n", log.Details["notification"])
fmt.Printf("[ERROR] Context: %v\n", log.Details["context"])
da.NotificationProcessingErrors++
}
if log.Type == "MIGRATING" || log.Type == "FAILING_OVER" {
da.RelaxedTimeoutCount++
} else if log.Type == "MIGRATED" || log.Type == "FAILED_OVER" {
da.UnrelaxedTimeoutCount++
}
if log.ConnID != 0 {
if v, ok := da.connIds[log.ConnID]; !ok || !v {
da.connIds[log.ConnID] = true
da.connLogs[log.ConnID] = make([]DiagnosticsEvent, 0)
da.ConnectionCount++
}
da.connLogs[log.ConnID] = append(da.connLogs[log.ConnID], log)
}
}
}
func (a *DiagnosticsAnalysis) Print(t *testing.T) {
t.Logf("Notification Analysis results for %d events and %d connections:", len(a.diagnosticsLog), len(a.connIds))
t.Logf("-------------")
t.Logf("-Timeout Analysis based on type of notification-")
t.Logf("Note: MIGRATED and FAILED_OVER notifications are not tracked by the hook, so they are not included in the relaxed/unrelaxed count")
t.Logf("Note: The hook only tracks timeouts that occur after the notification is processed, so timeouts that occur during processing are not included")
t.Logf("-------------")
t.Logf(" - Relaxed Timeout Count: %d", a.RelaxedTimeoutCount)
t.Logf(" - Unrelaxed Timeout Count: %d", a.UnrelaxedTimeoutCount)
t.Logf("-------------")
t.Logf("-Notification Analysis-")
t.Logf("-------------")
t.Logf(" - MOVING: %d", a.MovingCount)
t.Logf(" - MIGRATING: %d", a.MigratingCount)
t.Logf(" - MIGRATED: %d", a.MigratedCount)
t.Logf(" - FAILING_OVER: %d", a.FailingOverCount)
t.Logf(" - FAILED_OVER: %d", a.FailedOverCount)
t.Logf(" - Unexpected: %d", a.UnexpectedNotificationCount)
t.Logf("-------------")
t.Logf(" - Total Notifications: %d", a.TotalNotifications)
t.Logf(" - Notification Processing Errors: %d", a.NotificationProcessingErrors)
t.Logf(" - Connection Count: %d", a.ConnectionCount)
t.Logf("-------------")
t.Logf("Diagnostics Analysis completed successfully")
}

View File

@@ -0,0 +1,377 @@
package e2e
import (
"context"
"fmt"
"net"
"os"
"strings"
"testing"
"time"
"github.com/redis/go-redis/v9/internal"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestEndpointTypesPushNotifications tests push notifications with different endpoint types
func TestEndpointTypesPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
defer cancel()
var dump = true
var errorsDetected = false
var p = func(format string, args ...interface{}) {
format = "[%s][ENDPOINT-TYPES] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
// Test different endpoint types
endpointTypes := []struct {
name string
endpointType maintnotifications.EndpointType
description string
}{
{
name: "ExternalIP",
endpointType: maintnotifications.EndpointTypeExternalIP,
description: "External IP endpoint type for enterprise clusters",
},
{
name: "ExternalFQDN",
endpointType: maintnotifications.EndpointTypeExternalFQDN,
description: "External FQDN endpoint type for DNS-based routing",
},
{
name: "None",
endpointType: maintnotifications.EndpointTypeNone,
description: "No endpoint type - reconnect with current config",
},
}
defer func() {
logCollector.Clear()
}()
// Create client factory from configuration
factory, err := CreateTestClientFactory("standalone")
if err != nil {
t.Skipf("Enterprise cluster not available, skipping endpoint types test: %v", err)
}
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
factory.DestroyAll()
}()
// Test each endpoint type
for _, endpointTest := range endpointTypes {
t.Run(endpointTest.name, func(t *testing.T) {
// Clear logs between endpoint type tests
logCollector.Clear()
dump = true // reset dump flag
// redefine p and e for each test to get
// proper test name in logs and proper test failures
var p = func(format string, args ...interface{}) {
format = "[%s][ENDPOINT-TYPES] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
var e = func(format string, args ...interface{}) {
errorsDetected = true
format = "[%s][ENDPOINT-TYPES][ERROR] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Errorf(format, args...)
}
p("Testing endpoint type: %s - %s", endpointTest.name, endpointTest.description)
minIdleConns := 3
poolSize := 8
maxConnections := 12
// Create Redis client with specific endpoint type
client, err := factory.Create(fmt.Sprintf("endpoint-test-%s", endpointTest.name), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 30 * time.Second,
RelaxedTimeout: 8 * time.Second,
PostHandoffRelaxedDuration: 2 * time.Second,
MaxWorkers: 15,
EndpointType: endpointTest.endpointType, // Test specific endpoint type
},
ClientName: fmt.Sprintf("endpoint-test-%s", endpointTest.name),
})
if err != nil {
t.Fatalf("Failed to create client for %s: %v", endpointTest.name, err)
}
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
if dump {
p("Tracker analysis for %s:", endpointTest.name)
tracker.GetAnalysis().Print(t)
}
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
t.Fatalf("Failed to ping Redis with %s endpoint type: %v", endpointTest.name, err)
}
p("Client connected successfully with %s endpoint type", endpointTest.name)
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
stats := commandsRunner.GetStats()
p("%s endpoint stats: Operations: %d, Errors: %d, Timeout Errors: %d",
endpointTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
}
commandsRunner.Stop()
}()
// Test failover with this endpoint type
p("Testing failover with %s endpoint type...", endpointTest.name)
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger failover action for %s: %v", endpointTest.name, err)
}
// Start command traffic
go func() {
commandsRunner.FireCommandsUntilStop(ctx)
}()
// Wait for FAILING_OVER notification
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 2*time.Minute)
if !found {
t.Fatalf("FAILING_OVER notification was not received for %s endpoint type", endpointTest.name)
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received for %s. %v", endpointTest.name, failingOverData)
// Wait for FAILED_OVER notification
seqIDToObserve := int64(failingOverData["seqID"].(float64))
connIDToObserve := uint64(failingOverData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 2*time.Minute)
if !found {
t.Fatalf("FAILED_OVER notification was not received for %s endpoint type", endpointTest.name)
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received for %s. %v", endpointTest.name, failedOverData)
// Wait for failover to complete
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Failover action failed for %s: %v", endpointTest.name, err)
}
p("[FI] Failover action completed for %s: %s", endpointTest.name, status.Status)
// Test migration with this endpoint type
p("Testing migration with %s endpoint type...", endpointTest.name)
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"cluster_index": "0",
},
})
if err != nil {
t.Fatalf("Failed to trigger migrate action for %s: %v", endpointTest.name, err)
}
// Wait for MIGRATING notification
match, found = logCollector.WaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 30*time.Second)
if !found {
t.Fatalf("MIGRATING notification was not received for %s endpoint type", endpointTest.name)
}
migrateData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATING notification received for %s: %v", endpointTest.name, migrateData)
// Wait for migration to complete
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Migrate action failed for %s: %v", endpointTest.name, err)
}
p("[FI] Migrate action completed for %s: %s", endpointTest.name, status.Status)
// Wait for MIGRATED notification
seqIDToObserve = int64(migrateData["seqID"].(float64))
connIDToObserve = uint64(migrateData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "MIGRATED") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 2*time.Minute)
if !found {
t.Fatalf("MIGRATED notification was not received for %s endpoint type", endpointTest.name)
}
migratedData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATED notification received for %s. %v", endpointTest.name, migratedData)
// Complete migration with bind action
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "bind",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger bind action for %s: %v", endpointTest.name, err)
}
// Wait for MOVING notification
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
}, 2*time.Minute)
if !found {
t.Fatalf("MOVING notification was not received for %s endpoint type", endpointTest.name)
}
movingData := logs2.ExtractDataFromLogMessage(match)
p("MOVING notification received for %s. %v", endpointTest.name, movingData)
notification, ok := movingData["notification"].(string)
if !ok {
e("invalid notification message")
}
notification = notification[:len(notification)-1]
notificationParts := strings.Split(notification, " ")
address := notificationParts[len(notificationParts)-1]
switch endpointTest.endpointType {
case maintnotifications.EndpointTypeExternalFQDN:
address = strings.Split(address, ":")[0]
addressParts := strings.SplitN(address, ".", 2)
if len(addressParts) != 2 {
e("invalid address %s", address)
} else {
address = addressParts[1]
}
var expectedAddress string
hostParts := strings.SplitN(endpointConfig.Host, ".", 2)
if len(hostParts) != 2 {
e("invalid host %s", endpointConfig.Host)
} else {
expectedAddress = hostParts[1]
}
if address != expectedAddress {
e("invalid fqdn, expected: %s, got: %s", expectedAddress, address)
}
case maintnotifications.EndpointTypeExternalIP:
address = strings.Split(address, ":")[0]
ip := net.ParseIP(address)
if ip == nil {
e("invalid message format, expected valid IP, got: %s", address)
}
case maintnotifications.EndpointTypeNone:
if address != internal.RedisNull {
e("invalid endpoint type, expected: %s, got: %s", internal.RedisNull, address)
}
}
// Wait for bind to complete
bindStatus, err := faultInjector.WaitForAction(ctx, bindResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(2*time.Second))
if err != nil {
t.Fatalf("Bind action failed for %s: %v", endpointTest.name, err)
}
p("Bind action completed for %s: %s", endpointTest.name, bindStatus.Status)
// Continue traffic for analysis
time.Sleep(30 * time.Second)
commandsRunner.Stop()
// Analyze results for this endpoint type
trackerAnalysis := tracker.GetAnalysis()
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors with %s endpoint type: %d", endpointTest.name, trackerAnalysis.NotificationProcessingErrors)
}
if trackerAnalysis.UnexpectedNotificationCount > 0 {
e("Unexpected notifications with %s endpoint type: %d", endpointTest.name, trackerAnalysis.UnexpectedNotificationCount)
}
// Validate we received all expected notification types
if trackerAnalysis.FailingOverCount == 0 {
e("Expected FAILING_OVER notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.FailedOverCount == 0 {
e("Expected FAILED_OVER notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.MigratingCount == 0 {
e("Expected MIGRATING notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.MigratedCount == 0 {
e("Expected MIGRATED notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.MovingCount == 0 {
e("Expected MOVING notifications with %s endpoint type, got none", endpointTest.name)
}
if errorsDetected {
logCollector.DumpLogs()
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
t.Fatalf("[FAIL] Errors detected with %s endpoint type", endpointTest.name)
}
dump = false
p("Endpoint type %s test completed successfully", endpointTest.name)
logCollector.GetAnalysis().Print(t)
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
})
}
p("All endpoint types tested successfully")
}

View File

@@ -0,0 +1,473 @@
package e2e
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestPushNotifications tests Redis Enterprise push notifications (MOVING, MIGRATING, MIGRATED)
func TestPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
var dump = true
var seqIDToObserve int64
var connIDToObserve uint64
var match string
var found bool
var status *ActionStatusResponse
var p = func(format string, args ...interface{}) {
format = "[%s] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
var errorsDetected = false
var e = func(format string, args ...interface{}) {
errorsDetected = true
format = "[%s][ERROR] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Errorf(format, args...)
}
logCollector.ClearLogs()
defer func() {
if dump {
p("Dumping logs...")
logCollector.DumpLogs()
p("Log Analysis:")
logCollector.GetAnalysis().Print(t)
}
logCollector.Clear()
}()
// Create client factory from configuration
factory, err := CreateTestClientFactory("standalone")
if err != nil {
t.Skipf("Enterprise cluster not available, skipping push notification tests: %v", err)
}
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
minIdleConns := 5
poolSize := 10
maxConnections := 15
// Create Redis client with push notifications enabled
client, err := factory.Create("push-notification-client", &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 40 * time.Second, // 30 seconds
RelaxedTimeout: 10 * time.Second, // 10 seconds relaxed timeout
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
},
ClientName: "push-notification-test-client",
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
factory.DestroyAll()
}()
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
if dump {
tracker.GetAnalysis().Print(t)
}
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
t.Fatalf("Failed to ping Redis: %v", err)
}
p("Client connected successfully, starting push notification test")
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
p("Command runner stats:")
p("Operations: %d, Errors: %d, Timeout Errors: %d",
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
}
p("Stopping command runner...")
commandsRunner.Stop()
}()
p("Starting FAILING_OVER / FAILED_OVER notifications test...")
// Test: Trigger failover action to generate FAILING_OVER, FAILED_OVER notifications
p("Triggering failover action to generate push notifications...")
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger failover action: %v", err)
}
go func() {
p("Waiting for FAILING_OVER notification")
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 2*time.Minute)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
t.Fatal("FAILING_OVER notification was not received within 2 minutes")
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received. %v", failingOverData)
seqIDToObserve = int64(failingOverData["seqID"].(float64))
connIDToObserve = uint64(failingOverData["connID"].(float64))
go func() {
p("Waiting for FAILED_OVER notification on conn %d with seqID %d...", connIDToObserve, seqIDToObserve+1)
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 2*time.Minute)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
t.Fatal("FAILED_OVER notification was not received within 2 minutes")
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received. %v", failedOverData)
status, err = faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Failover action failed: %v", err)
}
fmt.Printf("[FI] Failover action completed: %s\n", status.Status)
p("FAILING_OVER / FAILED_OVER notifications test completed successfully")
// Test: Trigger migrate action to generate MOVING, MIGRATING, MIGRATED notifications
p("Triggering migrate action to generate push notifications...")
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"cluster_index": "0",
},
})
if err != nil {
t.Fatalf("Failed to trigger migrate action: %v", err)
}
go func() {
match, found = logCollector.WaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 20*time.Second)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
t.Fatal("MIGRATING notification for migrate action was not received within 20 seconds")
}
migrateData := logs2.ExtractDataFromLogMessage(match)
seqIDToObserve = int64(migrateData["seqID"].(float64))
connIDToObserve = uint64(migrateData["connID"].(float64))
p("MIGRATING notification received: seqID: %d, connID: %d", seqIDToObserve, connIDToObserve)
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Migrate action failed: %v", err)
}
fmt.Printf("[FI] Migrate action completed: %s\n", status.Status)
go func() {
p("Waiting for MIGRATED notification on conn %d with seqID %d...", connIDToObserve, seqIDToObserve+1)
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "MIGRATED") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 2*time.Minute)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
t.Fatal("MIGRATED notification was not received within 2 minutes")
}
migratedData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATED notification received. %v", migratedData)
p("MIGRATING / MIGRATED notifications test completed successfully")
// Trigger bind action to complete the migration process
p("Triggering bind action to complete migration...")
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "bind",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger bind action: %v", err)
}
// start a second client but don't execute any commands on it
p("Starting a second client to observe notification during moving...")
client2, err := factory.Create("push-notification-client-2", &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 40 * time.Second, // 30 seconds
RelaxedTimeout: 30 * time.Minute, // 30 minutes relaxed timeout for second client
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
},
ClientName: "push-notification-test-client-2",
})
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
// setup tracking for second client
tracker2 := NewTrackingNotificationsHook()
logger2 := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client2, tracker2, logger2)
commandsRunner2, _ := NewCommandRunner(client2)
t.Log("Second client created")
// Use a channel to communicate errors from the goroutine
errChan := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("goroutine panic: %v", r)
}
}()
p("Waiting for MOVING notification on second client")
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
}, 2*time.Minute)
commandsRunner.Stop()
// once moving is received, start a second client commands runner
p("Starting commands on second client")
go commandsRunner2.FireCommandsUntilStop(ctx)
defer func() {
// stop the second runner
commandsRunner2.Stop()
// destroy the second client
factory.Destroy("push-notification-client-2")
}()
// wait for moving on second client
// we know the maxconn is 15, assuming 16/17 was used to init the second client, so connID 18 should be from the second client
// also validate big enough relaxed timeout
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING") && connID(s, 18)
}, 2*time.Minute)
if !found {
errChan <- fmt.Errorf("MOVING notification was not received within 2 minutes ON A SECOND CLIENT")
return
} else {
p("MOVING notification received on second client %v", logs2.ExtractDataFromLogMessage(match))
}
// wait for relaxation of 30m
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ApplyingRelaxedTimeoutDueToPostHandoffMessage) && strings.Contains(s, "30m")
}, 2*time.Minute)
if !found {
errChan <- fmt.Errorf("relaxed timeout was not applied within 2 minutes ON A SECOND CLIENT")
return
} else {
p("Relaxed timeout applied on second client")
}
// Signal success
errChan <- nil
}()
commandsRunner.FireCommandsUntilStop(ctx)
movingData := logs2.ExtractDataFromLogMessage(match)
p("MOVING notification received. %v", movingData)
seqIDToObserve = int64(movingData["seqID"].(float64))
connIDToObserve = uint64(movingData["connID"].(float64))
// Wait for the goroutine to complete and check for errors
if err := <-errChan; err != nil {
t.Fatalf("Second client goroutine error: %v", err)
}
// Wait for bind action to complete
bindStatus, err := faultInjector.WaitForAction(ctx, bindResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(2*time.Second))
if err != nil {
t.Fatalf("Bind action failed: %v", err)
}
p("Bind action completed: %s", bindStatus.Status)
p("MOVING notification test completed successfully")
p("Executing commands and collecting logs for analysis... This will take 30 seconds...")
go commandsRunner.FireCommandsUntilStop(ctx)
time.Sleep(30 * time.Second)
commandsRunner.Stop()
allLogsAnalysis := logCollector.GetAnalysis()
trackerAnalysis := tracker.GetAnalysis()
if allLogsAnalysis.TimeoutErrorsCount > 0 {
e("Unexpected timeout errors: %d", allLogsAnalysis.TimeoutErrorsCount)
}
if trackerAnalysis.UnexpectedNotificationCount > 0 {
e("Unexpected notifications: %d", trackerAnalysis.UnexpectedNotificationCount)
}
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors: %d", trackerAnalysis.NotificationProcessingErrors)
}
if allLogsAnalysis.RelaxedTimeoutCount == 0 {
e("Expected relaxed timeouts, got none")
}
if allLogsAnalysis.UnrelaxedTimeoutCount == 0 {
e("Expected unrelaxed timeouts, got none")
}
if allLogsAnalysis.UnrelaxedAfterMoving == 0 {
e("Expected unrelaxed timeouts after moving, got none")
}
if allLogsAnalysis.RelaxedPostHandoffCount == 0 {
e("Expected relaxed timeouts after post-handoff, got none")
}
// validate number of connections we do not exceed max connections
// we started a second client, so we expect 2x the connections
if allLogsAnalysis.ConnectionCount > int64(maxConnections)*2 {
e("Expected no more than %d connections, got %d", maxConnections, allLogsAnalysis.ConnectionCount)
}
if allLogsAnalysis.ConnectionCount < int64(minIdleConns) {
e("Expected at least %d connections, got %d", minIdleConns, allLogsAnalysis.ConnectionCount)
}
// validate logs are present for all connections
for connID := range trackerAnalysis.connIds {
if len(allLogsAnalysis.connLogs[connID]) == 0 {
e("No logs found for connection %d", connID)
}
}
// validate number of notifications in tracker matches number of notifications in logs
// allow for more moving in the logs since we started a second client
if trackerAnalysis.TotalNotifications > allLogsAnalysis.TotalNotifications {
e("Expected %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications)
}
// and per type
// allow for more moving in the logs since we started a second client
if trackerAnalysis.MovingCount > allLogsAnalysis.MovingCount {
e("Expected %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount)
}
if trackerAnalysis.MigratingCount != allLogsAnalysis.MigratingCount {
e("Expected %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount)
}
if trackerAnalysis.MigratedCount != allLogsAnalysis.MigratedCount {
e("Expected %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount)
}
if trackerAnalysis.FailingOverCount != allLogsAnalysis.FailingOverCount {
e("Expected %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount)
}
if trackerAnalysis.FailedOverCount != allLogsAnalysis.FailedOverCount {
e("Expected %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount)
}
if trackerAnalysis.UnexpectedNotificationCount != allLogsAnalysis.UnexpectedCount {
e("Expected %d unexpected notifications, got %d", trackerAnalysis.UnexpectedNotificationCount, allLogsAnalysis.UnexpectedCount)
}
// unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it
if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving {
e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount)
}
if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount {
e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount)
}
// validate all handoffs succeeded
if allLogsAnalysis.FailedHandoffCount > 0 {
e("Expected no failed handoffs, got %d", allLogsAnalysis.FailedHandoffCount)
}
if allLogsAnalysis.SucceededHandoffCount == 0 {
e("Expected at least one successful handoff, got none")
}
if allLogsAnalysis.TotalHandoffCount != allLogsAnalysis.SucceededHandoffCount {
e("Expected total handoffs to match successful handoffs, got %d != %d", allLogsAnalysis.TotalHandoffCount, allLogsAnalysis.SucceededHandoffCount)
}
// no additional retries
if allLogsAnalysis.TotalHandoffRetries != allLogsAnalysis.TotalHandoffCount {
e("Expected no additional handoff retries, got %d", allLogsAnalysis.TotalHandoffRetries-allLogsAnalysis.TotalHandoffCount)
}
if errorsDetected {
logCollector.DumpLogs()
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
t.Fatalf("[FAIL] Errors detected in push notification test")
}
p("Analysis complete, no errors found")
// print analysis here, don't dump logs later
dump = false
allLogsAnalysis.Print(t)
trackerAnalysis.Print(t)
p("Command runner stats:")
p("Operations: %d, Errors: %d, Timeout Errors: %d",
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
p("Push notification test completed successfully")
}

View File

@@ -0,0 +1,303 @@
package e2e
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestStressPushNotifications tests push notifications under extreme stress conditions
func TestStressPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
var dump = true
var p = func(format string, args ...interface{}) {
format = "[%s][STRESS] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
var e = func(format string, args ...interface{}) {
format = "[%s][STRESS][ERROR] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Errorf(format, args...)
}
logCollector.ClearLogs()
defer func() {
if dump {
p("Dumping logs...")
logCollector.DumpLogs()
p("Log Analysis:")
logCollector.GetAnalysis().Print(t)
}
logCollector.Clear()
}()
// Create client factory from configuration
factory, err := CreateTestClientFactory("standalone")
if err != nil {
t.Skipf("Enterprise cluster not available, skipping stress test: %v", err)
}
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
// Extreme stress configuration
minIdleConns := 50
poolSize := 150
maxConnections := 200
numClients := 4
var clients []redis.UniversalClient
var trackers []*TrackingNotificationsHook
var commandRunners []*CommandRunner
// Create multiple clients for extreme stress
for i := 0; i < numClients; i++ {
client, err := factory.Create(fmt.Sprintf("stress-client-%d", i), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 60 * time.Second, // Longer timeout for stress
RelaxedTimeout: 20 * time.Second, // Longer relaxed timeout
PostHandoffRelaxedDuration: 5 * time.Second, // Longer post-handoff duration
MaxWorkers: 50, // Maximum workers for stress
HandoffQueueSize: 1000, // Large queue for stress
EndpointType: maintnotifications.EndpointTypeExternalIP,
},
ClientName: fmt.Sprintf("stress-test-client-%d", i),
})
if err != nil {
t.Fatalf("Failed to create stress client %d: %v", i, err)
}
clients = append(clients, client)
// Setup tracking for each client
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelWarn)) // Minimal logging for stress
setupNotificationHooks(client, tracker, logger)
trackers = append(trackers, tracker)
// Create command runner for each client
commandRunner, _ := NewCommandRunner(client)
commandRunners = append(commandRunners, commandRunner)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
for i, tracker := range trackers {
p("Stress client %d analysis:", i)
tracker.GetAnalysis().Print(t)
}
}
for _, runner := range commandRunners {
runner.Stop()
}
factory.DestroyAll()
}()
// Verify initial connectivity for all clients
for i, client := range clients {
err = client.Ping(ctx).Err()
if err != nil {
t.Fatalf("Failed to ping Redis with stress client %d: %v", i, err)
}
}
p("All %d stress clients connected successfully", numClients)
// Start extreme traffic load on all clients
var trafficWg sync.WaitGroup
for i, runner := range commandRunners {
trafficWg.Add(1)
go func(clientID int, r *CommandRunner) {
defer trafficWg.Done()
p("Starting extreme traffic load on stress client %d", clientID)
r.FireCommandsUntilStop(ctx)
}(i, runner)
}
// Wait for traffic to stabilize
time.Sleep(10 * time.Second)
// Trigger multiple concurrent fault injection actions
var actionWg sync.WaitGroup
var actionResults []string
var actionMutex sync.Mutex
actions := []struct {
name string
action string
delay time.Duration
}{
{"failover-1", "failover", 0},
{"migrate-1", "migrate", 5 * time.Second},
{"failover-2", "failover", 10 * time.Second},
}
p("Starting %d concurrent fault injection actions under extreme stress...", len(actions))
for _, action := range actions {
actionWg.Add(1)
go func(actionName, actionType string, delay time.Duration) {
defer actionWg.Done()
if delay > 0 {
time.Sleep(delay)
}
p("Triggering %s action under extreme stress...", actionName)
var resp *ActionResponse
var err error
switch actionType {
case "failover":
resp, err = faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
case "migrate":
resp, err = faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"cluster_index": "0",
},
})
}
if err != nil {
e("Failed to trigger %s action: %v", actionName, err)
return
}
// Wait for action to complete
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
WithMaxWaitTime(300*time.Second), // Very long wait for stress
WithPollInterval(2*time.Second),
)
if err != nil {
e("[FI] %s action failed: %v", actionName, err)
return
}
actionMutex.Lock()
actionResults = append(actionResults, fmt.Sprintf("%s: %s", actionName, status.Status))
actionMutex.Unlock()
p("[FI] %s action completed: %s", actionName, status.Status)
}(action.name, action.action, action.delay)
}
// Wait for all actions to complete
actionWg.Wait()
// Continue stress for a bit longer
p("All fault injection actions completed, continuing stress for 2 more minutes...")
time.Sleep(2 * time.Minute)
// Stop all command runners
for _, runner := range commandRunners {
runner.Stop()
}
trafficWg.Wait()
// Analyze stress test results
allLogsAnalysis := logCollector.GetAnalysis()
totalOperations := int64(0)
totalErrors := int64(0)
totalTimeoutErrors := int64(0)
for i, runner := range commandRunners {
stats := runner.GetStats()
p("Stress client %d stats: Operations: %d, Errors: %d, Timeout Errors: %d",
i, stats.Operations, stats.Errors, stats.TimeoutErrors)
totalOperations += stats.Operations
totalErrors += stats.Errors
totalTimeoutErrors += stats.TimeoutErrors
}
p("STRESS TEST RESULTS:")
p("Total operations across all clients: %d", totalOperations)
p("Total errors: %d (%.2f%%)", totalErrors, float64(totalErrors)/float64(totalOperations)*100)
p("Total timeout errors: %d (%.2f%%)", totalTimeoutErrors, float64(totalTimeoutErrors)/float64(totalOperations)*100)
p("Total connections used: %d", allLogsAnalysis.ConnectionCount)
// Print action results
actionMutex.Lock()
p("Fault injection action results:")
for _, result := range actionResults {
p(" %s", result)
}
actionMutex.Unlock()
// Validate stress test results
if totalOperations < 1000 {
e("Expected at least 1000 operations under stress, got %d", totalOperations)
}
// Allow higher error rates under extreme stress (up to 20%)
errorRate := float64(totalErrors) / float64(totalOperations) * 100
if errorRate > 20.0 {
e("Error rate too high under stress: %.2f%% (max allowed: 20%%)", errorRate)
}
// Validate connection limits weren't exceeded
expectedMaxConnections := int64(numClients * maxConnections)
if allLogsAnalysis.ConnectionCount > expectedMaxConnections {
e("Connection count exceeded limit: %d > %d", allLogsAnalysis.ConnectionCount, expectedMaxConnections)
}
// Validate notifications were processed
totalTrackerNotifications := int64(0)
totalProcessingErrors := int64(0)
for _, tracker := range trackers {
analysis := tracker.GetAnalysis()
totalTrackerNotifications += analysis.TotalNotifications
totalProcessingErrors += analysis.NotificationProcessingErrors
}
if totalProcessingErrors > totalTrackerNotifications/10 { // Allow up to 10% processing errors under stress
e("Too many notification processing errors under stress: %d/%d", totalProcessingErrors, totalTrackerNotifications)
}
p("Stress test completed successfully!")
p("Processed %d operations across %d clients with %d connections",
totalOperations, numClients, allLogsAnalysis.ConnectionCount)
p("Error rate: %.2f%%, Notification processing errors: %d/%d",
errorRate, totalProcessingErrors, totalTrackerNotifications)
// Print final analysis
dump = false
allLogsAnalysis.Print(t)
for i, tracker := range trackers {
p("=== Stress Client %d Analysis ===", i)
tracker.GetAnalysis().Print(t)
}
}

View File

@@ -0,0 +1,245 @@
package e2e
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
)
// TestScenarioTemplate is a template for writing scenario tests
// Copy this file and rename it to scenario_your_test_name.go
func TestScenarioTemplate(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Step 1: Create client factory from configuration
factory, err := CreateTestClientFactory("enterprise-cluster") // or "standalone0"
if err != nil {
t.Fatalf("Failed to create client factory: %v", err)
}
defer factory.DestroyAll()
// Step 2: Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
// Step 3: Create Redis client with hitless upgrades
client, err := factory.Create("scenario-client", &CreateClientOptions{
Protocol: 3,
HitlessUpgradeConfig: &hitless.Config{
Mode: hitless.MaintNotificationsEnabled,
HandoffTimeout: 30000, // 30 seconds
RelaxedTimeout: 10000, // 10 seconds
MaxWorkers: 20,
},
ClientName: "scenario-test-client",
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Step 4: Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
t.Fatalf("Failed to ping Redis: %v", err)
}
t.Log("Initial setup completed successfully")
// Step 5: Start background operations (optional)
stopCh := make(chan struct{})
defer close(stopCh)
go func() {
counter := 0
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
key := fmt.Sprintf("test-key-%d", counter)
value := fmt.Sprintf("test-value-%d", counter)
err := client.Set(ctx, key, value, time.Minute).Err()
if err != nil {
t.Logf("Background operation failed: %v", err)
}
counter++
}
}
}()
// Step 6: Wait for baseline operations
time.Sleep(5 * time.Second)
// Step 7: Trigger fault injection scenario
t.Log("Triggering fault injection scenario...")
// Example: Cluster failover
// resp, err := faultInjector.TriggerClusterFailover(ctx, "node-1", false)
// if err != nil {
// t.Fatalf("Failed to trigger failover: %v", err)
// }
// Example: Network latency
// nodes := []string{"localhost:7001", "localhost:7002"}
// resp, err := faultInjector.SimulateNetworkLatency(ctx, nodes, 100*time.Millisecond, 20*time.Millisecond)
// if err != nil {
// t.Fatalf("Failed to simulate latency: %v", err)
// }
// Example: Complex sequence
// sequence := []SequenceAction{
// {
// Type: ActionNetworkLatency,
// Parameters: map[string]interface{}{
// "nodes": []string{"localhost:7001"},
// "latency": "50ms",
// },
// },
// {
// Type: ActionClusterFailover,
// Parameters: map[string]interface{}{
// "node_id": "node-1",
// "force": false,
// },
// Delay: 10 * time.Second,
// },
// }
// resp, err := faultInjector.ExecuteSequence(ctx, sequence)
// if err != nil {
// t.Fatalf("Failed to execute sequence: %v", err)
// }
// Step 8: Wait for fault injection to complete
// status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
// WithMaxWaitTime(120*time.Second),
// WithPollInterval(2*time.Second))
// if err != nil {
// t.Fatalf("Fault injection failed: %v", err)
// }
// t.Logf("Fault injection completed: %s", status.Status)
// Step 9: Verify client remains operational during and after fault injection
time.Sleep(10 * time.Second)
err = client.Ping(ctx).Err()
if err != nil {
t.Errorf("Client not responsive after fault injection: %v", err)
}
// Step 10: Perform additional validation
testKey := "validation-key"
testValue := "validation-value"
err = client.Set(ctx, testKey, testValue, time.Minute).Err()
if err != nil {
t.Errorf("Failed to set validation key: %v", err)
}
retrievedValue, err := client.Get(ctx, testKey).Result()
if err != nil {
t.Errorf("Failed to get validation key: %v", err)
} else if retrievedValue != testValue {
t.Errorf("Validation failed: expected %s, got %s", testValue, retrievedValue)
}
t.Log("Scenario test completed successfully")
}
// Helper functions for common scenario patterns
func performContinuousOperations(ctx context.Context, client redis.UniversalClient, workerID int, stopCh <-chan struct{}, errorCh chan<- error) {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
counter := 0
for {
select {
case <-stopCh:
return
case <-ticker.C:
key := fmt.Sprintf("worker_%d_key_%d", workerID, counter)
value := fmt.Sprintf("value_%d", counter)
// Perform operation with timeout
opCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := client.Set(opCtx, key, value, time.Minute).Err()
cancel()
if err != nil {
select {
case errorCh <- err:
default:
}
}
counter++
}
}
}
func validateClusterHealth(ctx context.Context, client redis.UniversalClient) error {
// Basic connectivity test
if err := client.Ping(ctx).Err(); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
// Test basic operations
testKey := "health-check-key"
testValue := "health-check-value"
if err := client.Set(ctx, testKey, testValue, time.Minute).Err(); err != nil {
return fmt.Errorf("set operation failed: %w", err)
}
retrievedValue, err := client.Get(ctx, testKey).Result()
if err != nil {
return fmt.Errorf("get operation failed: %w", err)
}
if retrievedValue != testValue {
return fmt.Errorf("value mismatch: expected %s, got %s", testValue, retrievedValue)
}
// Clean up
client.Del(ctx, testKey)
return nil
}
func waitForStableOperations(ctx context.Context, client redis.UniversalClient, duration time.Duration) error {
deadline := time.Now().Add(duration)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if err := validateClusterHealth(ctx, client); err != nil {
return fmt.Errorf("cluster health check failed: %w", err)
}
}
}
return nil
}

View File

@@ -0,0 +1,365 @@
package e2e
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestTimeoutConfigurationsPushNotifications tests push notifications with different timeout configurations
func TestTimeoutConfigurationsPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
defer cancel()
var dump = true
var p = func(format string, args ...interface{}) {
format = "[%s][TIMEOUT-CONFIGS] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
// Test different timeout configurations
timeoutConfigs := []struct {
name string
handoffTimeout time.Duration
relaxedTimeout time.Duration
postHandoffRelaxedDuration time.Duration
description string
expectedBehavior string
}{
{
name: "Conservative",
handoffTimeout: 60 * time.Second,
relaxedTimeout: 20 * time.Second,
postHandoffRelaxedDuration: 5 * time.Second,
description: "Conservative timeouts for stable environments",
expectedBehavior: "Longer timeouts, fewer timeout errors",
},
{
name: "Aggressive",
handoffTimeout: 5 * time.Second,
relaxedTimeout: 3 * time.Second,
postHandoffRelaxedDuration: 1 * time.Second,
description: "Aggressive timeouts for fast failover",
expectedBehavior: "Shorter timeouts, faster recovery",
},
{
name: "HighLatency",
handoffTimeout: 90 * time.Second,
relaxedTimeout: 30 * time.Second,
postHandoffRelaxedDuration: 10 * time.Minute,
description: "High latency environment timeouts",
expectedBehavior: "Very long timeouts for high latency networks",
},
}
logCollector.ClearLogs()
defer func() {
if dump {
p("Dumping logs...")
logCollector.DumpLogs()
p("Log Analysis:")
logCollector.GetAnalysis().Print(t)
}
logCollector.Clear()
}()
// Create client factory from configuration
factory, err := CreateTestClientFactory("standalone")
if err != nil {
t.Skipf("Enterprise cluster not available, skipping timeout configs test: %v", err)
}
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
factory.DestroyAll()
}()
// Test each timeout configuration
for _, timeoutTest := range timeoutConfigs {
t.Run(timeoutTest.name, func(t *testing.T) {
// redefine p and e for each test to get
// proper test name in logs and proper test failures
var p = func(format string, args ...interface{}) {
format = "[%s][ENDPOINT-TYPES] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
var e = func(format string, args ...interface{}) {
format = "[%s][ENDPOINT-TYPES][ERROR] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Errorf(format, args...)
}
p("Testing timeout configuration: %s - %s", timeoutTest.name, timeoutTest.description)
p("Expected behavior: %s", timeoutTest.expectedBehavior)
p("Handoff timeout: %v, Relaxed timeout: %v, Post-handoff duration: %v",
timeoutTest.handoffTimeout, timeoutTest.relaxedTimeout, timeoutTest.postHandoffRelaxedDuration)
minIdleConns := 4
poolSize := 10
maxConnections := 15
// Create Redis client with specific timeout configuration
client, err := factory.Create(fmt.Sprintf("timeout-test-%s", timeoutTest.name), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: timeoutTest.handoffTimeout,
RelaxedTimeout: timeoutTest.relaxedTimeout,
PostHandoffRelaxedDuration: timeoutTest.postHandoffRelaxedDuration,
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP,
},
ClientName: fmt.Sprintf("timeout-test-%s", timeoutTest.name),
})
if err != nil {
t.Fatalf("Failed to create client for %s: %v", timeoutTest.name, err)
}
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
if dump {
p("Tracker analysis for %s:", timeoutTest.name)
tracker.GetAnalysis().Print(t)
}
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
t.Fatalf("Failed to ping Redis with %s timeout config: %v", timeoutTest.name, err)
}
p("Client connected successfully with %s timeout configuration", timeoutTest.name)
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
stats := commandsRunner.GetStats()
p("%s timeout config stats: Operations: %d, Errors: %d, Timeout Errors: %d",
timeoutTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
}
commandsRunner.Stop()
}()
// Start command traffic
go func() {
commandsRunner.FireCommandsUntilStop(ctx)
}()
// Record start time for timeout analysis
testStartTime := time.Now()
// Test failover with this timeout configuration
p("Testing failover with %s timeout configuration...", timeoutTest.name)
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger failover action for %s: %v", timeoutTest.name, err)
}
// Wait for FAILING_OVER notification
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 3*time.Minute)
if !found {
t.Fatalf("FAILING_OVER notification was not received for %s timeout config", timeoutTest.name)
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received for %s. %v", timeoutTest.name, failingOverData)
// Wait for FAILED_OVER notification
seqIDToObserve := int64(failingOverData["seqID"].(float64))
connIDToObserve := uint64(failingOverData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 3*time.Minute)
if !found {
t.Fatalf("FAILED_OVER notification was not received for %s timeout config", timeoutTest.name)
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received for %s. %v", timeoutTest.name, failedOverData)
// Wait for failover to complete
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(180*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Failover action failed for %s: %v", timeoutTest.name, err)
}
p("[FI] Failover action completed for %s: %s", timeoutTest.name, status.Status)
// Continue traffic to observe timeout behavior
p("Continuing traffic for %v to observe timeout behavior...", timeoutTest.relaxedTimeout*2)
time.Sleep(timeoutTest.relaxedTimeout * 2)
// Test migration to trigger more timeout scenarios
p("Testing migration with %s timeout configuration...", timeoutTest.name)
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"cluster_index": "0",
},
})
if err != nil {
t.Fatalf("Failed to trigger migrate action for %s: %v", timeoutTest.name, err)
}
// Wait for MIGRATING notification
match, found = logCollector.WaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 30*time.Second)
if !found {
t.Fatalf("MIGRATING notification was not received for %s timeout config", timeoutTest.name)
}
migrateData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATING notification received for %s: %v", timeoutTest.name, migrateData)
// Wait for migration to complete
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Migrate action failed for %s: %v", timeoutTest.name, err)
}
p("[FI] Migrate action completed for %s: %s", timeoutTest.name, status.Status)
// do a bind action
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "bind",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger bind action for %s: %v", timeoutTest.name, err)
}
status, err = faultInjector.WaitForAction(ctx, bindResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Bind action failed for %s: %v", timeoutTest.name, err)
}
p("[FI] Bind action completed for %s: %s", timeoutTest.name, status.Status)
// waiting for moving notification
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
}, 2*time.Minute)
if !found {
t.Fatalf("MOVING notification was not received for %s timeout config", timeoutTest.name)
}
movingData := logs2.ExtractDataFromLogMessage(match)
p("MOVING notification received for %s. %v", timeoutTest.name, movingData)
// Continue traffic for post-handoff timeout observation
p("Continuing traffic for %v to observe post-handoff timeout behavior...", 1*time.Minute)
time.Sleep(1 * time.Minute)
commandsRunner.Stop()
testDuration := time.Since(testStartTime)
// Analyze timeout behavior
trackerAnalysis := tracker.GetAnalysis()
logAnalysis := logCollector.GetAnalysis()
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors with %s timeout config: %d", timeoutTest.name, trackerAnalysis.NotificationProcessingErrors)
}
// Validate timeout-specific behavior
switch timeoutTest.name {
case "Conservative":
if trackerAnalysis.UnrelaxedTimeoutCount > trackerAnalysis.RelaxedTimeoutCount {
e("Conservative config should have more relaxed than unrelaxed timeouts, got relaxed=%d, unrelaxed=%d",
trackerAnalysis.RelaxedTimeoutCount, trackerAnalysis.UnrelaxedTimeoutCount)
}
case "Aggressive":
// Aggressive timeouts should complete faster
if testDuration > 5*time.Minute {
e("Aggressive config took too long: %v", testDuration)
}
if logAnalysis.TotalHandoffRetries > logAnalysis.TotalHandoffCount {
e("Expect handoff retries since aggressive timeouts are shorter, got %d retries for %d handoffs",
logAnalysis.TotalHandoffRetries, logAnalysis.TotalHandoffCount)
}
case "HighLatency":
// High latency config should have very few unrelaxed after moving
if logAnalysis.UnrelaxedAfterMoving > 2 {
e("High latency config should have minimal unrelaxed timeouts after moving, got %d", logAnalysis.UnrelaxedAfterMoving)
}
}
// Validate we received expected notifications
if trackerAnalysis.FailingOverCount == 0 {
e("Expected FAILING_OVER notifications with %s timeout config, got none", timeoutTest.name)
}
if trackerAnalysis.FailedOverCount == 0 {
e("Expected FAILED_OVER notifications with %s timeout config, got none", timeoutTest.name)
}
if trackerAnalysis.MigratingCount == 0 {
e("Expected MIGRATING notifications with %s timeout config, got none", timeoutTest.name)
}
// Validate timeout counts are reasonable
if trackerAnalysis.RelaxedTimeoutCount == 0 {
e("Expected relaxed timeouts with %s config, got none", timeoutTest.name)
}
if logAnalysis.SucceededHandoffCount == 0 {
e("Expected successful handoffs with %s config, got none", timeoutTest.name)
}
p("Timeout configuration %s test completed successfully in %v", timeoutTest.name, testDuration)
p("Command runner stats:")
p("Operations: %d, Errors: %d, Timeout Errors: %d",
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
p("Relaxed timeouts: %d, Unrelaxed timeouts: %d", trackerAnalysis.RelaxedTimeoutCount, trackerAnalysis.UnrelaxedTimeoutCount)
})
// Clear logs between timeout configuration tests
logCollector.ClearLogs()
}
p("All timeout configurations tested successfully")
}

View File

@@ -0,0 +1,315 @@
package e2e
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TODO ADD TLS CONFIGS
// TestTLSConfigurationsPushNotifications tests push notifications with different TLS configurations
func TestTLSConfigurationsPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
defer cancel()
var dump = true
var p = func(format string, args ...interface{}) {
format = "[%s][TLS-CONFIGS] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
// Test different TLS configurations
// Note: TLS configuration is typically handled at the Redis connection config level
// This scenario demonstrates the testing pattern for different TLS setups
tlsConfigs := []struct {
name string
description string
skipReason string
}{
{
name: "NoTLS",
description: "No TLS encryption (plain text)",
},
{
name: "TLSInsecure",
description: "TLS with insecure skip verify (testing only)",
},
{
name: "TLSSecure",
description: "Secure TLS with certificate verification",
skipReason: "Requires valid certificates in test environment",
},
{
name: "TLSMinimal",
description: "TLS with minimal version requirements",
},
{
name: "TLSStrict",
description: "Strict TLS with TLS 1.3 and specific cipher suites",
},
}
logCollector.ClearLogs()
defer func() {
if dump {
p("Dumping logs...")
logCollector.DumpLogs()
p("Log Analysis:")
logCollector.GetAnalysis().Print(t)
}
logCollector.Clear()
}()
// Create client factory from configuration
factory, err := CreateTestClientFactory("standalone")
if err != nil {
t.Skipf("Enterprise cluster not available, skipping TLS configs test: %v", err)
}
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
factory.DestroyAll()
}()
// Test each TLS configuration
for _, tlsTest := range tlsConfigs {
t.Run(tlsTest.name, func(t *testing.T) {
// redefine p and e for each test to get
// proper test name in logs and proper test failures
var p = func(format string, args ...interface{}) {
format = "[%s][ENDPOINT-TYPES] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Logf(format, args...)
}
var e = func(format string, args ...interface{}) {
format = "[%s][ENDPOINT-TYPES][ERROR] " + format
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{ts}, args...)
t.Errorf(format, args...)
}
if tlsTest.skipReason != "" {
t.Skipf("Skipping %s: %s", tlsTest.name, tlsTest.skipReason)
}
p("Testing TLS configuration: %s - %s", tlsTest.name, tlsTest.description)
minIdleConns := 3
poolSize := 8
maxConnections := 12
// Create Redis client with specific TLS configuration
// Note: TLS configuration is handled at the factory/connection level
client, err := factory.Create(fmt.Sprintf("tls-test-%s", tlsTest.name), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 30 * time.Second,
RelaxedTimeout: 10 * time.Second,
PostHandoffRelaxedDuration: 2 * time.Second,
MaxWorkers: 15,
EndpointType: maintnotifications.EndpointTypeExternalIP,
},
ClientName: fmt.Sprintf("tls-test-%s", tlsTest.name),
})
if err != nil {
// Some TLS configurations might fail in test environments
if tlsTest.name == "TLSSecure" || tlsTest.name == "TLSStrict" {
t.Skipf("TLS configuration %s failed (expected in test environment): %v", tlsTest.name, err)
}
t.Fatalf("Failed to create client for %s: %v", tlsTest.name, err)
}
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
if dump {
p("Tracker analysis for %s:", tlsTest.name)
tracker.GetAnalysis().Print(t)
}
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
if tlsTest.name == "TLSSecure" || tlsTest.name == "TLSStrict" {
t.Skipf("TLS configuration %s ping failed (expected in test environment): %v", tlsTest.name, err)
}
t.Fatalf("Failed to ping Redis with %s TLS config: %v", tlsTest.name, err)
}
p("Client connected successfully with %s TLS configuration", tlsTest.name)
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
stats := commandsRunner.GetStats()
p("%s TLS config stats: Operations: %d, Errors: %d, Timeout Errors: %d",
tlsTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
}
commandsRunner.Stop()
}()
// Start command traffic
go func() {
commandsRunner.FireCommandsUntilStop(ctx)
}()
// Test failover with this TLS configuration
p("Testing failover with %s TLS configuration...", tlsTest.name)
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"cluster_index": "0",
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
t.Fatalf("Failed to trigger failover action for %s: %v", tlsTest.name, err)
}
// Wait for FAILING_OVER notification
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 2*time.Minute)
if !found {
t.Fatalf("FAILING_OVER notification was not received for %s TLS config", tlsTest.name)
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received for %s. %v", tlsTest.name, failingOverData)
// Wait for FAILED_OVER notification
seqIDToObserve := int64(failingOverData["seqID"].(float64))
connIDToObserve := uint64(failingOverData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 2*time.Minute)
if !found {
t.Fatalf("FAILED_OVER notification was not received for %s TLS config", tlsTest.name)
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received for %s. %v", tlsTest.name, failedOverData)
// Wait for failover to complete
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Failover action failed for %s: %v", tlsTest.name, err)
}
p("[FI] Failover action completed for %s: %s", tlsTest.name, status.Status)
// Test migration with this TLS configuration
p("Testing migration with %s TLS configuration...", tlsTest.name)
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"cluster_index": "0",
},
})
if err != nil {
t.Fatalf("Failed to trigger migrate action for %s: %v", tlsTest.name, err)
}
// Wait for MIGRATING notification
match, found = logCollector.WaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 30*time.Second)
if !found {
t.Fatalf("MIGRATING notification was not received for %s TLS config", tlsTest.name)
}
migrateData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATING notification received for %s: %v", tlsTest.name, migrateData)
// Wait for migration to complete
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(120*time.Second),
WithPollInterval(1*time.Second),
)
if err != nil {
t.Fatalf("[FI] Migrate action failed for %s: %v", tlsTest.name, err)
}
p("[FI] Migrate action completed for %s: %s", tlsTest.name, status.Status)
// Continue traffic for a bit to observe TLS behavior
time.Sleep(5 * time.Second)
commandsRunner.Stop()
// Analyze results for this TLS configuration
trackerAnalysis := tracker.GetAnalysis()
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors with %s TLS config: %d", tlsTest.name, trackerAnalysis.NotificationProcessingErrors)
}
if trackerAnalysis.UnexpectedNotificationCount > 0 {
e("Unexpected notifications with %s TLS config: %d", tlsTest.name, trackerAnalysis.UnexpectedNotificationCount)
}
// Validate we received expected notifications
if trackerAnalysis.FailingOverCount == 0 {
e("Expected FAILING_OVER notifications with %s TLS config, got none", tlsTest.name)
}
if trackerAnalysis.FailedOverCount == 0 {
e("Expected FAILED_OVER notifications with %s TLS config, got none", tlsTest.name)
}
if trackerAnalysis.MigratingCount == 0 {
e("Expected MIGRATING notifications with %s TLS config, got none", tlsTest.name)
}
// TLS-specific validations
stats := commandsRunner.GetStats()
switch tlsTest.name {
case "NoTLS":
// Plain text should work fine
p("Plain text connection processed %d operations", stats.Operations)
case "TLSInsecure", "TLSMinimal":
// Insecure TLS should work in test environments
p("Insecure TLS connection processed %d operations", stats.Operations)
if stats.Operations == 0 {
e("Expected operations with %s TLS config, got none", tlsTest.name)
}
case "TLSStrict":
// Strict TLS might have different performance characteristics
p("Strict TLS connection processed %d operations", stats.Operations)
}
p("TLS configuration %s test completed successfully", tlsTest.name)
})
// Clear logs between TLS configuration tests
logCollector.ClearLogs()
}
p("All TLS configurations tested successfully")
}

View File

@@ -0,0 +1,214 @@
#!/bin/bash
# Maintenance Notifications E2E Tests Runner
# This script sets up the environment and runs the maintnotifications upgrade E2E tests
set -euo pipefail
# Script directory and repository root
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
E2E_DIR="${REPO_ROOT}/maintnotifications/e2e"
# Configuration
FAULT_INJECTOR_URL="http://127.0.0.1:20324"
CONFIG_PATH="${REPO_ROOT}/maintnotifications/e2e/infra/cae-client-testing/endpoints.json"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Logging functions
log_info() {
echo -e "${BLUE}[INFO]${NC} $1"
}
log_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
log_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Help function
show_help() {
cat << EOF
Maintenance Notifications E2E Tests Runner
Usage: $0 [OPTIONS]
OPTIONS:
-h, --help Show this help message
-v, --verbose Enable verbose test output
-t, --timeout DURATION Test timeout (default: 30m)
-r, --run PATTERN Run only tests matching pattern
--dry-run Show what would be executed without running
--list List available tests
--config PATH Override config path (default: infra/cae-client-testing/endpoints.json)
--fault-injector URL Override fault injector URL (default: http://127.0.0.1:20324)
EXAMPLES:
$0 # Run all E2E tests
$0 -v # Run with verbose output
$0 -r TestPushNotifications # Run only push notification tests
$0 -t 45m # Run with 45 minute timeout
$0 --dry-run # Show what would be executed
$0 --list # List available tests
ENVIRONMENT:
The script automatically sets up the required environment variables:
- REDIS_ENDPOINTS_CONFIG_PATH: Path to Redis endpoints configuration
- FAULT_INJECTION_API_URL: URL of the fault injector server
- E2E_SCENARIO_TESTS: Enables scenario tests
EOF
}
# Parse command line arguments
VERBOSE=""
TIMEOUT="30m"
RUN_PATTERN=""
DRY_RUN=false
LIST_TESTS=false
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
-v|--verbose)
VERBOSE="-v"
shift
;;
-t|--timeout)
TIMEOUT="$2"
shift 2
;;
-r|--run)
RUN_PATTERN="$2"
shift 2
;;
--dry-run)
DRY_RUN=true
shift
;;
--list)
LIST_TESTS=true
shift
;;
--config)
CONFIG_PATH="$2"
shift 2
;;
--fault-injector)
FAULT_INJECTOR_URL="$2"
shift 2
;;
*)
log_error "Unknown option: $1"
show_help
exit 1
;;
esac
done
# Validate configuration file exists
if [[ ! -f "$CONFIG_PATH" ]]; then
log_error "Configuration file not found: $CONFIG_PATH"
log_info "Please ensure the endpoints.json file exists at the specified path"
exit 1
fi
# Set up environment variables
export REDIS_ENDPOINTS_CONFIG_PATH="$CONFIG_PATH"
export FAULT_INJECTION_API_URL="$FAULT_INJECTOR_URL"
export E2E_SCENARIO_TESTS="true"
# Build test command
TEST_CMD="go test -tags=e2e -v"
if [[ -n "$TIMEOUT" ]]; then
TEST_CMD="$TEST_CMD -timeout=$TIMEOUT"
fi
if [[ -n "$VERBOSE" ]]; then
TEST_CMD="$TEST_CMD $VERBOSE"
fi
if [[ -n "$RUN_PATTERN" ]]; then
TEST_CMD="$TEST_CMD -run $RUN_PATTERN"
fi
TEST_CMD="$TEST_CMD ./maintnotifications/e2e/ "
# List tests if requested
if [[ "$LIST_TESTS" == true ]]; then
log_info "Available E2E tests:"
cd "$REPO_ROOT"
go test -tags=e2e ./maintnotifications/e2e/ -list=. | grep -E "^Test" | sort
exit 0
fi
# Show configuration
log_info "Maintenance notifications E2E Tests Configuration:"
echo " Repository Root: $REPO_ROOT"
echo " E2E Directory: $E2E_DIR"
echo " Config Path: $CONFIG_PATH"
echo " Fault Injector URL: $FAULT_INJECTOR_URL"
echo " Test Timeout: $TIMEOUT"
if [[ -n "$RUN_PATTERN" ]]; then
echo " Test Pattern: $RUN_PATTERN"
fi
echo ""
# Validate fault injector connectivity
log_info "Checking fault injector connectivity..."
if command -v curl >/dev/null 2>&1; then
if curl -s --connect-timeout 5 "$FAULT_INJECTOR_URL/health" >/dev/null 2>&1; then
log_success "Fault injector is accessible at $FAULT_INJECTOR_URL"
else
log_warning "Cannot connect to fault injector at $FAULT_INJECTOR_URL"
log_warning "Tests may fail if fault injection is required"
fi
else
log_warning "curl not available, skipping fault injector connectivity check"
fi
# Show what would be executed in dry-run mode
if [[ "$DRY_RUN" == true ]]; then
log_info "Dry run mode - would execute:"
echo " cd $REPO_ROOT"
echo " export REDIS_ENDPOINTS_CONFIG_PATH=\"$CONFIG_PATH\""
echo " export FAULT_INJECTION_API_URL=\"$FAULT_INJECTOR_URL\""
echo " export E2E_SCENARIO_TESTS=\"true\""
echo " $TEST_CMD"
exit 0
fi
# Change to repository root
cd "$REPO_ROOT"
# Run the tests
log_info "Starting E2E tests..."
log_info "Command: $TEST_CMD"
echo ""
if eval "$TEST_CMD"; then
echo ""
log_success "All E2E tests completed successfully!"
exit 0
else
echo ""
log_error "E2E tests failed!"
log_info "Check the test output above for details"
exit 1
fi

View File

@@ -0,0 +1,44 @@
package e2e
func isTimeout(errMsg string) bool {
return contains(errMsg, "i/o timeout") ||
contains(errMsg, "deadline exceeded") ||
contains(errMsg, "context deadline exceeded")
}
// isTimeoutError checks if an error is a timeout error
func isTimeoutError(err error) bool {
if err == nil {
return false
}
// Check for various timeout error types
errStr := err.Error()
return isTimeout(errStr)
}
// contains checks if a string contains a substring (case-insensitive)
func contains(s, substr string) bool {
return len(s) >= len(substr) &&
(s == substr ||
(len(s) > len(substr) &&
(s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
containsSubstring(s, substr))))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,63 @@
package maintnotifications
import (
"errors"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError())
ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError())
ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError())
ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError())
ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError())
ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError())
ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError())
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
)
// Integration errors
var (
ErrInvalidClient = errors.New(logs.InvalidClientError())
)
// Handoff errors
var (
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
)
// Notification errors
var (
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
)
// connection handoff errors
var (
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoff = errors.New("" + logs.ConnectionMarkedForHandoffErrorMessage)
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
ErrConnectionInvalidHandoffState = errors.New("" + logs.ConnectionInvalidHandoffStateErrorMessage)
)
// general errors
var (
ErrShutdown = errors.New(logs.ShutdownError())
)
// circuit breaker errors
var (
ErrCircuitBreakerOpen = errors.New("" + logs.CircuitBreakerOpenErrorMessage)
)
// circuit breaker configuration errors
var (
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
)

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -6,6 +6,7 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
@@ -14,7 +15,7 @@ import (
type contextKey string
const (
startTimeKey contextKey = "notif_hitless_start_time"
startTimeKey contextKey = "maint_notif_start_time"
)
// MetricsHook collects metrics about notification processing.
@@ -42,7 +43,7 @@ func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.Notific
// Log connection information if available
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on conn[%d]", notificationType, conn.GetID())
internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID()))
}
// Store start time in context for duration calculation
@@ -66,7 +67,7 @@ func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.Notifi
// Log error details with connection information
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on conn[%d]: %v", notificationType, conn.GetID(), result)
internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result))
}
}
}

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -9,6 +9,7 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
)
@@ -29,7 +30,7 @@ type handoffWorkerManager struct {
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the hitless upgrade
// Configuration for the maintenance notifications
config *Config
// Pool hook reference for handoff processing
@@ -120,8 +121,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() {
defer func() {
// Handle panics to ensure proper cleanup
if r := recover(); r != nil {
internal.Logger.Printf(context.Background(),
"hitless: worker panic recovered: %v", r)
internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
}
// Decrement active worker count when exiting
@@ -145,18 +145,23 @@ func (hwm *handoffWorkerManager) onDemandWorker() {
select {
case <-hwm.shutdown:
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
}
return
case <-timer.C:
// Worker has been idle for too long, exit to save resources
if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
}
return
case request := <-hwm.handoffQueue:
// Check for shutdown before processing
select {
case <-hwm.shutdown:
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
}
// Clean up the request before exiting
hwm.pending.Delete(request.ConnID)
return
@@ -172,7 +177,9 @@ func (hwm *handoffWorkerManager) onDemandWorker() {
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer hwm.pending.Delete(request.Conn.GetID())
internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID())
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
}
// Create a context with handoff timeout from config
handoffTimeout := 15 * time.Second // Default timeout
@@ -212,10 +219,20 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
afterTime = minRetryBackoff
}
internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err)
if internal.LogLevel.InfoOrAbove() {
// Get current retry count for better logging
currentRetries := request.Conn.HandoffRetries()
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
}
time.AfterFunc(afterTime, func() {
if err := hwm.queueHandoff(request.Conn); err != nil {
internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
}
hwm.closeConnFromRequest(context.Background(), request, err)
}
})
@@ -227,8 +244,8 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if hwm.poolHook.hitlessManager != nil {
hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID)
if hwm.poolHook.operationsManager != nil {
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
}
}
}
@@ -238,8 +255,13 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Get handoff info atomically to prevent race conditions
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
if !shouldHandoff {
return errors.New("connection is not marked for handoff")
// on retries the connection will not be marked for handoff, but it will have retries > 0
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
if !shouldHandoff && conn.HandoffRetries() == 0 {
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
}
return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
}
// Create handoff request with atomically retrieved data
@@ -279,10 +301,8 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Queue is full - log and attempt scaling
queueLen := len(hwm.handoffQueue)
queueCap := cap(hwm.handoffQueue)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration",
queueLen, queueCap)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
}
}
}
@@ -336,7 +356,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
// Check if circuit breaker is open before attempting handoff
if circuitBreaker.IsOpen() {
internal.Logger.Printf(ctx, "hitless: conn[%d] handoff to %s failed fast due to circuit breaker", connID, newEndpoint)
internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
}
@@ -361,17 +381,15 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", connID, retries, newEndpoint, conn.RemoteAddr().String())
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
if retries > maxRetries {
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
maxRetries, connID, newEndpoint)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
@@ -383,8 +401,8 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", connID, newEndpoint, err)
// hitless: will retry
internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
// will retry
// Maybe a network error - retry after a delay
return true, err
}
@@ -402,17 +420,15 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v",
connID, relaxedTimeout, deadline.Format("15:04:05.000"))
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// hitless: won't retry
// won't retry
// Initialization failed - remove the connection
return false, err
}
@@ -423,7 +439,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", connID, newEndpoint)
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
return false, nil
}
@@ -452,17 +468,13 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque
conn := request.Conn
if pooler != nil {
pooler.Remove(ctx, conn, err)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed conn[%d] from pool due: %v",
conn.GetID(), err)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
}
} else {
conn.Close()
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for conn[%d], cannot remove due to: %v",
conn.GetID(), err)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
}
}
}

View File

@@ -1,28 +1,41 @@
package hitless
package maintnotifications
import (
"context"
"slices"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/push"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel logging.LogLevel
LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel.InfoOrAbove() { // Info level
if lh.LogLevel >= 2 { // Info level
// Log the notification type and content
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
internal.Logger.Printf(ctx, "hitless: conn[%d] processing %s notification: %v", connID, notificationType, notification)
seqID := int64(0)
if slices.Contains(maintenanceNotificationTypes, notificationType) {
// seqID is the second element in the notification array
if len(notification) > 1 {
if parsedSeqID, ok := notification[1].(int64); !ok {
seqID = 0
} else {
seqID = parsedSeqID
}
}
}
internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification))
}
return notification, true // Continue processing with unmodified notification
}
@@ -33,15 +46,15 @@ func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.Notifi
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
if result != nil && lh.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processing failed: %v - %v", connID, notificationType, result, notification)
} else if lh.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processed successfully", connID, notificationType)
if result != nil && lh.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification))
} else if lh.LogLevel >= 3 { // Debug level
internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType))
}
}
// NewLoggingHook creates a new logging hook with the specified log level.
// Log levels: LogLevelError=errors, LogLevelWarn=warnings, LogLevelInfo=info, LogLevelDebug=debug
func NewLoggingHook(logLevel logging.LogLevel) *LoggingHook {
// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug
func NewLoggingHook(logLevel int) *LoggingHook {
return &LoggingHook{LogLevel: logLevel}
}

View File

@@ -1,7 +1,8 @@
package hitless
package maintnotifications
import (
"context"
"errors"
"fmt"
"net"
"sync"
@@ -10,11 +11,12 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// Push notification type constants for hitless upgrades
// Push notification type constants for maintenance
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
@@ -23,8 +25,8 @@ const (
NotificationFailedOver = "FAILED_OVER"
)
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
var hitlessNotificationTypes = []string{
// maintenanceNotificationTypes contains all notification types that maintenance handles
var maintenanceNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
@@ -53,8 +55,8 @@ func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
type HitlessManager struct {
// Manager provides a simplified upgrade functionality with hooks and atomic state.
type Manager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
@@ -81,13 +83,13 @@ type MovingOperation struct {
Deadline time.Time
}
// NewHitlessManager creates a new simplified hitless manager.
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
// NewManager creates a new simplified manager.
func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &HitlessManager{
hm := &Manager{
client: client,
pool: pool,
options: client.GetOptions(),
@@ -104,25 +106,25 @@ func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, conf
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *HitlessManager) setupPushNotifications() error {
func (hm *Manager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm}
handler := &NotificationHandler{manager: hm, operationsManager: hm}
// Register handlers for all hitless upgrade notifications with the client's processor
for _, notificationType := range hitlessNotificationTypes {
// Register handlers for all upgrade notifications with the client's processor
for _, notificationType := range maintenanceNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
return errors.New(logs.FailedToRegisterHandler(notificationType, err))
}
}
@@ -130,7 +132,7 @@ func (hm *HitlessManager) setupPushNotifications() error {
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
@@ -148,13 +150,13 @@ func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, ne
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Duplicate MOVING operation ignored: %s", connID, seqID, key.String())
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
}
return nil
}
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Tracking MOVING operation: %s", connID, seqID, key.String())
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
}
// Increment active operation count atomically
@@ -164,7 +166,7 @@ func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, ne
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
@@ -173,14 +175,14 @@ func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64)
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Untracking MOVING operation: %s", connID, seqID, key.String())
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
}
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
} else {
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Operation not found for untracking: %s", connID, seqID, key.String())
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
}
}
}
@@ -188,7 +190,7 @@ func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64)
// GetActiveMovingOperations returns active operations with composite keys.
// WARNING: This method creates a new map and copies all operations on every call.
// Use sparingly, especially in hot paths or high-frequency logging.
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
@@ -211,18 +213,18 @@ func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*Mo
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) IsHandoffInProgress() bool {
func (hm *Manager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) GetActiveOperationCount() int64 {
func (hm *Manager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the hitless manager.
func (hm *HitlessManager) Close() error {
// Close closes the manager.
func (hm *Manager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
@@ -259,7 +261,7 @@ func (hm *HitlessManager) Close() error {
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *HitlessManager) GetState() State {
func (hm *Manager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
@@ -267,7 +269,7 @@ func (hm *HitlessManager) GetState() State {
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
@@ -285,7 +287,7 @@ func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx p
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
@@ -295,7 +297,7 @@ func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
@@ -311,7 +313,7 @@ func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string
return hm.poolHooksRef
}
func (hm *HitlessManager) AddNotificationHook(notificationHook NotificationHook) {
func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
hm.hooksMu.Lock()
defer hm.hooksMu.Unlock()
hm.hooks = append(hm.hooks, notificationHook)

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -74,14 +74,14 @@ func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
}
}
func TestHitlessManagerRefactoring(t *testing.T) {
func TestManagerRefactoring(t *testing.T) {
t.Run("AtomicStateTracking", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
manager, err := NewManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
t.Fatalf("Failed to create maintnotifications manager: %v", err)
}
defer manager.Close()
@@ -140,9 +140,9 @@ func TestHitlessManagerRefactoring(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
manager, err := NewManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
t.Fatalf("Failed to create maintnotifications manager: %v", err)
}
defer manager.Close()
@@ -182,9 +182,9 @@ func TestHitlessManagerRefactoring(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
manager, err := NewManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
t.Fatalf("Failed to create maintnotifications manager: %v", err)
}
defer manager.Close()
@@ -219,23 +219,23 @@ func TestHitlessManagerRefactoring(t *testing.T) {
NotificationFailedOver,
}
if len(hitlessNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
if len(maintenanceNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(maintenanceNotificationTypes))
}
// Test that all expected types are present
typeMap := make(map[string]bool)
for _, t := range hitlessNotificationTypes {
for _, t := range maintenanceNotificationTypes {
typeMap[t] = true
}
for _, expected := range expectedTypes {
if !typeMap[expected] {
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
t.Errorf("Expected notification type %s not found in maintenanceNotificationTypes", expected)
}
}
// Test that hitlessNotificationTypes contains all expected constants
// Test that maintenanceNotificationTypes contains all expected constants
expectedConstants := []string{
NotificationMoving,
NotificationMigrating,
@@ -246,14 +246,14 @@ func TestHitlessManagerRefactoring(t *testing.T) {
for _, expected := range expectedConstants {
found := false
for _, actual := range hitlessNotificationTypes {
for _, actual := range maintenanceNotificationTypes {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
t.Errorf("Expected constant %s not found in maintenanceNotificationTypes", expected)
}
}
})

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -7,11 +7,12 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
)
// HitlessManagerInterface defines the interface for completing handoff operations
type HitlessManagerInterface interface {
// OperationsManagerInterface defines the interface for completing handoff operations
type OperationsManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
@@ -26,7 +27,7 @@ type HandoffRequest struct {
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with hitless upgrade support.
// with maintenance notifications support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
@@ -38,23 +39,23 @@ type PoolHook struct {
// Worker manager for background handoff processing
workerManager *handoffWorkerManager
// Configuration for the hitless upgrade
// Configuration for the maintenance notifications
config *Config
// Hitless manager for operation completion tracking
hitlessManager HitlessManagerInterface
// Operations manager interface for operation completion tracking
operationsManager OperationsManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook {
// Apply defaults if config is nil or has zero values
if config == nil {
config = config.ApplyDefaultsWithPoolSize(poolSize)
@@ -65,8 +66,7 @@ func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (n
baseDialer: baseDialer,
network: network,
config: config,
// Hitless manager for operation completion tracking
hitlessManager: hitlessManager,
operationsManager: operationsManager,
}
// Create worker manager
@@ -150,7 +150,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool
if err := ph.workerManager.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err))
// Don't pool, remove connection, no error to caller
return false, true, nil
}
@@ -170,6 +170,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool
// Other error - remove the connection
return false, true, nil
}
internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID()))
return true, false, nil
}

View File

@@ -1,4 +1,4 @@
package hitless
package maintnotifications
import (
"context"
@@ -113,12 +113,11 @@ func TestConnectionHook(t *testing.T) {
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
Mode: ModeAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
@@ -263,13 +262,12 @@ func TestConnectionHook(t *testing.T) {
}
config := &Config{
Mode: MaintNotificationsAuto,
Mode: ModeAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 2, // Reduced retries for faster test
HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
@@ -366,12 +364,11 @@ func TestConnectionHook(t *testing.T) {
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
Mode: ModeAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
@@ -443,7 +440,6 @@ func TestConnectionHook(t *testing.T) {
MaxWorkers: 3,
HandoffQueueSize: 2,
MaxHandoffRetries: 3, // Small queue to trigger optimizations
LogLevel: 3, // Debug level to see optimization logs
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -500,7 +496,6 @@ func TestConnectionHook(t *testing.T) {
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
HandoffQueueSize: 1,
MaxHandoffRetries: 3, // Very small queue to force scaling
LogLevel: 2, // Info level to see scaling logs
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
@@ -528,7 +523,6 @@ func TestConnectionHook(t *testing.T) {
MaxHandoffRetries: 3, // Allow retries for successful handoff
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
RelaxedTimeout: 5 * time.Second,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
@@ -607,7 +601,6 @@ func TestConnectionHook(t *testing.T) {
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
@@ -694,7 +687,6 @@ func TestConnectionHook(t *testing.T) {
MaxWorkers: 3,
HandoffQueueSize: 50,
MaxHandoffRetries: 3, // Explicit static queue size
LogLevel: 2,
}
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
@@ -755,7 +747,6 @@ func TestConnectionHook(t *testing.T) {
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
@@ -906,7 +897,6 @@ func TestConnectionHook(t *testing.T) {
HandoffQueueSize: 10,
HandoffTimeout: customTimeout, // Custom timeout
MaxHandoffRetries: 1, // Single retry to speed up test
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)

View File

@@ -1,30 +1,33 @@
package hitless
package maintnotifications
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *HitlessManager
manager *Manager
operationsManager OperationsManagerInterface
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
internal.Logger.Printf(ctx, "hitless: invalid notification format: %v", notification)
internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification))
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid notification type format: %v", notification[0])
internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0]))
return ErrInvalidNotification
}
@@ -61,19 +64,19 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
internal.Logger.Printf(ctx, "hitless: invalid MOVING notification: %v", notification)
internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification))
return ErrInvalidNotification
}
seqID, ok := notification[1].(int64)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid seqID in MOVING notification: %v", notification[1])
internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1]))
return ErrInvalidNotification
}
// Extract timeS
timeS, ok := notification[2].(int64)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid timeS in MOVING notification: %v", notification[2])
internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2]))
return ErrInvalidNotification
}
@@ -82,15 +85,21 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid newEndpoint in MOVING notification: %v", notification[3])
stringified := fmt.Sprintf("%v", notification[3])
// this could be <nil> which is valid
if notification[3] == nil || stringified == internal.RedisNull {
newEndpoint = ""
} else {
internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3]))
return ErrInvalidNotification
}
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MOVING notification")
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING"))
return ErrInvalidNotification
}
@@ -99,7 +108,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus
if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx)
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx))
return ErrInvalidNotification
}
@@ -115,9 +124,8 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
if snh.manager.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] scheduling handoff to current endpoint in %v seconds",
poolConn.GetID(), timeS/2)
if internal.LogLevel.DebugOrAbove() {
internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2))
}
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
@@ -131,7 +139,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine - use background context since original may be cancelled
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err))
}
})
return nil
@@ -142,18 +150,18 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err))
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in hitless manager for monitoring/debugging
if snh.manager != nil {
// Optionally track in m
if snh.operationsManager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
_ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return fmt.Errorf("hitless: manager not initialized")
return errors.New(logs.ManagerNotInitialized())
}
return nil
}
@@ -163,26 +171,24 @@ func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid MIGRATING notification: %v", notification)
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification")
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification")
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for MIGRATING notification",
conn.GetID(),
snh.manager.config.RelaxedTimeout)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout))
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
@@ -193,25 +199,25 @@ func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx p
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid MIGRATED notification: %v", notification)
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification")
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification")
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for MIGRATED notification", connID)
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
}
conn.ClearRelaxedTimeout()
return nil
@@ -222,25 +228,25 @@ func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCt
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid FAILING_OVER notification: %v", notification)
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification")
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification")
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for FAILING_OVER notification", connID, snh.manager.config.RelaxedTimeout)
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout))
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
@@ -251,25 +257,25 @@ func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid FAILED_OVER notification: %v", notification)
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification")
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification")
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for FAILED_OVER notification", connID)
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
}
conn.ClearRelaxedTimeout()
return nil

View File

@@ -1,6 +1,6 @@
package hitless
package maintnotifications
// State represents the current state of a hitless upgrade operation.
// State represents the current state of a maintenance operation
type State int
const (

View File

@@ -14,10 +14,10 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
@@ -258,18 +258,14 @@ type Options struct {
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// MaintNotificationsConfig provides custom configuration for maintnotifications.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
HitlessUpgradeConfig *HitlessUpgradeConfig
// If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it.
MaintNotificationsConfig *maintnotifications.Config
}
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
// This is an alias to hitless.Config for convenience.
type HitlessUpgradeConfig = hitless.Config
func (opt *Options) init() {
if opt.Addr == "" {
opt.Addr = "localhost:6379"
@@ -351,24 +347,24 @@ func (opt *Options) init() {
opt.MaxRetryBackoff = 512 * time.Millisecond
}
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
// auto-detect endpoint type if not specified
endpointType := opt.HitlessUpgradeConfig.EndpointType
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
endpointType := opt.MaintNotificationsConfig.EndpointType
if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.HitlessUpgradeConfig.EndpointType = endpointType
opt.MaintNotificationsConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
clone.HitlessUpgradeConfig = &configClone
// Deep clone MaintNotificationsConfig to avoid sharing between clients
if opt.MaintNotificationsConfig != nil {
configClone := *opt.MaintNotificationsConfig
clone.MaintNotificationsConfig = &configClone
}
return &clone

View File

@@ -20,6 +20,7 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
@@ -39,7 +40,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
// If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -136,13 +137,13 @@ type ClusterOptions struct {
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
HitlessUpgradeConfig *HitlessUpgradeConfig
// If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications.
MaintNotificationsConfig *maintnotifications.Config
}
func (opt *ClusterOptions) init() {
@@ -333,11 +334,11 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
var hitlessConfig *HitlessUpgradeConfig
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
hitlessConfig = &configClone
// Clone MaintNotificationsConfig to avoid sharing between cluster node clients
var maintNotificationsConfig *maintnotifications.Config
if opt.MaintNotificationsConfig != nil {
configClone := *opt.MaintNotificationsConfig
maintNotificationsConfig = &configClone
}
return &Options{
@@ -383,7 +384,7 @@ func (opt *ClusterOptions) clientOptions() *Options {
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
HitlessUpgradeConfig: hitlessConfig,
MaintNotificationsConfig: maintNotificationsConfig,
PushNotificationProcessor: opt.PushNotificationProcessor,
}
}
@@ -1872,7 +1873,7 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// hitless won't work here for now
// maintenance notifications won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{

View File

@@ -43,7 +43,7 @@ type PubSub struct {
// Push notification processor for handling generic push notifications
pushProcessor push.NotificationProcessor
// Cleanup callback for hitless upgrade tracking
// Cleanup callback for maintenanceNotifications upgrade tracking
onClose func()
}
@@ -77,10 +77,10 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
}
if c.opt.Addr == "" {
// TODO(hitless):
// TODO(maintenanceNotifications):
// this is probably cluster client
// c.newConn will ignore the addr argument
// will be changed when we have hitless upgrades for cluster clients
// will be changed when we have maintenanceNotifications upgrades for cluster clients
c.opt.Addr = internal.RedisNull
}

123
redis.go
View File

@@ -10,11 +10,11 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
@@ -30,6 +30,11 @@ func SetLogger(logger internal.Logging) {
internal.Logger = logger
}
// SetLogLevel sets the log level for the library.
func SetLogLevel(logLevel internal.LogLevelT) {
internal.LogLevel = logLevel
}
//------------------------------------------------------------------------------
type Hook interface {
@@ -216,22 +221,22 @@ type baseClient struct {
// Push notification processing
pushProcessor push.NotificationProcessor
// Hitless upgrade manager
hitlessManager *hitless.HitlessManager
hitlessManagerLock sync.RWMutex
// Maintenance notifications manager
maintNotificationsManager *maintnotifications.Manager
maintNotificationsManagerLock sync.RWMutex
}
func (c *baseClient) clone() *baseClient {
c.hitlessManagerLock.RLock()
hitlessManager := c.hitlessManager
c.hitlessManagerLock.RUnlock()
c.maintNotificationsManagerLock.RLock()
maintNotificationsManager := c.maintNotificationsManager
c.maintNotificationsManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
hitlessManager: hitlessManager,
maintNotificationsManager: maintNotificationsManager,
}
return clone
}
@@ -430,39 +435,39 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("failed to initialize connection options: %w", err)
}
// Enable maintenance notifications if hitless upgrades are configured
// Enable maintnotifications if maintnotifications are configured
c.optLock.RLock()
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
maintNotifEnabled := c.opt.MaintNotificationsConfig != nil && c.opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled
protocol := c.opt.Protocol
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
endpointType := c.opt.MaintNotificationsConfig.EndpointType
c.optLock.RUnlock()
var hitlessHandshakeErr error
if hitlessEnabled && protocol == 3 {
hitlessHandshakeErr = conn.ClientMaintNotifications(
var maintNotifHandshakeErr error
if maintNotifEnabled && protocol == 3 {
maintNotifHandshakeErr = conn.ClientMaintNotifications(
ctx,
true,
endpointType.String(),
).Err()
if hitlessHandshakeErr != nil {
if !isRedisError(hitlessHandshakeErr) {
if maintNotifHandshakeErr != nil {
if !isRedisError(maintNotifHandshakeErr) {
// if not redis error, fail the connection
return hitlessHandshakeErr
return maintNotifHandshakeErr
}
c.optLock.Lock()
// handshake failed - check and modify config atomically
switch c.opt.HitlessUpgradeConfig.Mode {
case hitless.MaintNotificationsEnabled:
switch c.opt.MaintNotificationsConfig.Mode {
case maintnotifications.ModeEnabled:
// enabled mode, fail the connection
c.optLock.Unlock()
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
default: // will handle auto and any other
internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake error: %v", hitlessHandshakeErr)
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
c.optLock.Unlock()
// auto mode, disable hitless upgrades and continue
if err := c.disableHitlessUpgrades(); err != nil {
// auto mode, disable maintnotifications and continue
if err := c.disableMaintNotificationsUpgrades(); err != nil {
// Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err)
}
}
} else {
@@ -470,7 +475,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// to make sure that the handshake will be executed on other connections as well if it was successfully
// executed on this connection, we will force the handshake to be executed on all connections
c.optLock.Lock()
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeEnabled
c.optLock.Unlock()
}
}
@@ -657,39 +662,39 @@ func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) erro
}
}
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
// enableMaintNotificationsUpgrades initializes the maintnotifications upgrade manager and pool hook.
// This function is called during client initialization.
// will register push notification handlers for all hitless upgrade events.
// will register push notification handlers for all maintenance upgrade events.
// will start background workers for handoff processing in the pool hook.
func (c *baseClient) enableHitlessUpgrades() error {
func (c *baseClient) enableMaintNotificationsUpgrades() error {
// Create client adapter
clientAdapterInstance := newClientAdapter(c)
// Create hitless manager directly
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
// Create maintnotifications manager directly
manager, err := maintnotifications.NewManager(clientAdapterInstance, c.connPool, c.opt.MaintNotificationsConfig)
if err != nil {
return err
}
// Set the manager reference and initialize pool hook
c.hitlessManagerLock.Lock()
c.hitlessManager = manager
c.hitlessManagerLock.Unlock()
c.maintNotificationsManagerLock.Lock()
c.maintNotificationsManager = manager
c.maintNotificationsManagerLock.Unlock()
// Initialize pool hook (safe to call without lock since manager is now set)
manager.InitPoolHook(c.dialHook)
return nil
}
func (c *baseClient) disableHitlessUpgrades() error {
c.hitlessManagerLock.Lock()
defer c.hitlessManagerLock.Unlock()
func (c *baseClient) disableMaintNotificationsUpgrades() error {
c.maintNotificationsManagerLock.Lock()
defer c.maintNotificationsManagerLock.Unlock()
// Close the hitless manager
if c.hitlessManager != nil {
// Close the maintnotifications manager
if c.maintNotificationsManager != nil {
// Closing the manager will also shutdown the pool hook
// and remove it from the pool
c.hitlessManager.Close()
c.hitlessManager = nil
c.maintNotificationsManager.Close()
c.maintNotificationsManager = nil
}
return nil
}
@@ -701,8 +706,8 @@ func (c *baseClient) disableHitlessUpgrades() error {
func (c *baseClient) Close() error {
var firstErr error
// Close hitless manager first
if err := c.disableHitlessUpgrades(); err != nil {
// Close maintnotifications manager first
if err := c.disableMaintNotificationsUpgrades(); err != nil {
firstErr = err
}
@@ -947,23 +952,23 @@ func NewClient(opt *Options) *Client {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
// Initialize hitless upgrades first if enabled and protocol is RESP3
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
err := c.enableHitlessUpgrades()
// Initialize maintnotifications first if enabled and protocol is RESP3
if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 {
err := c.enableMaintNotificationsUpgrades()
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err)
if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled {
/*
Design decision: panic here to fail fast if hitless upgrades cannot be enabled when explicitly requested.
Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested.
We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect
an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced
immediately, rather than allowing the client to continue in a partially initialized or inconsistent state.
Clients relying on hitless upgrades should be aware that initialization errors will cause a panic, and should
Clients relying on maintnotifications should be aware that initialization errors will cause a panic, and should
handle this accordingly (e.g., via recover or by validating configuration before calling NewClient).
This approach is only used when HitlessUpgradeConfig.Mode is MaintNotificationsEnabled, indicating that hitless
This approach is only used when MaintNotificationsConfig.Mode is MaintNotificationsEnabled, indicating that maintnotifications
upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic.
*/
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
panic(fmt.Errorf("failed to enable maintnotifications: %w", err))
}
}
}
@@ -1003,12 +1008,12 @@ func (c *Client) Options() *Options {
return c.opt
}
// GetHitlessManager returns the hitless manager instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
c.hitlessManagerLock.RLock()
defer c.hitlessManagerLock.RUnlock()
return c.hitlessManager
// GetMaintNotificationsManager returns the maintnotifications manager instance for monitoring and control.
// Returns nil if maintnotifications are not enabled.
func (c *Client) GetMaintNotificationsManager() *maintnotifications.Manager {
c.maintNotificationsManagerLock.RLock()
defer c.maintNotificationsManagerLock.RUnlock()
return c.maintNotificationsManager
}
// initializePushProcessor initializes the push notification processor for any client type.
@@ -1260,7 +1265,7 @@ func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn
}
// Use WithReader to access the reader and process push notifications
// This is critical for hitless upgrades to work properly
// This is critical for maintnotifications to work properly
// NOTE: almost no timeouts are set for this read, so it should not block
// longer than necessary, 10us should be plenty of time to read if there are any push notifications
// on the socket.

View File

@@ -140,13 +140,14 @@ type FailoverOptions struct {
UnstableResp3 bool
// Hitless is not supported for FailoverClients at the moment
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// MaintNotificationsConfig is not supported for FailoverClients at the moment
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// upgrade notifications gracefully and manage connection/pool state transitions
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are disabled.
//HitlessUpgradeConfig *HitlessUpgradeConfig
// If nil, maintnotifications upgrades are disabled.
// (however if Mode is nil, it defaults to "auto" - enable if server supports it)
//MaintNotificationsConfig *maintnotifications.Config
}
func (opt *FailoverOptions) clientOptions() *Options {

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/maintnotifications"
)
// UniversalOptions information is required by UniversalClient to establish
@@ -123,8 +124,8 @@ type UniversalOptions struct {
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
IsClusterMode bool
// HitlessUpgradeConfig provides configuration for hitless upgrades.
HitlessUpgradeConfig *HitlessUpgradeConfig
// MaintNotificationsConfig provides configuration for maintnotifications upgrades.
MaintNotificationsConfig *maintnotifications.Config
}
// Cluster returns cluster options created from the universal options.
@@ -180,7 +181,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
MaintNotificationsConfig: o.MaintNotificationsConfig,
}
}
@@ -241,7 +242,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
// Note: HitlessUpgradeConfig not supported for FailoverOptions
// Note: MaintNotificationsConfig not supported for FailoverOptions
}
}
@@ -293,7 +294,7 @@ func (o *UniversalOptions) Simple() *Options {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
MaintNotificationsConfig: o.MaintNotificationsConfig,
}
}