From cb3af0800e5e66bba751d24acd80b432cf07b4cf Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:49:16 +0300 Subject: [PATCH] [CAE-1072] Hitless Upgrades (#3447) * feat(hitless): Introduce handlers for hitless upgrades This commit includes all the work on hitless upgrades with the addition of: - Pubsub Pool - Examples - Refactor of push - Refactor of pool (using atomics for most things) - Introducing of hooks in pool --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .gitignore | 3 + adapters.go | 111 +++ async_handoff_integration_test.go | 353 ++++++++ commands.go | 18 + example/pubsub/go.mod | 12 + example/pubsub/go.sum | 6 + example/pubsub/main.go | 171 ++++ example_instrumentation_test.go | 6 + hitless/README.md | 98 +++ hitless/circuit_breaker.go | 360 ++++++++ hitless/circuit_breaker_test.go | 356 ++++++++ hitless/config.go | 472 ++++++++++ hitless/config_test.go | 490 +++++++++++ hitless/errors.go | 105 +++ hitless/example_hooks.go | 100 +++ hitless/handoff_worker.go | 455 ++++++++++ hitless/hitless_manager.go | 318 +++++++ hitless/hitless_manager_test.go | 260 ++++++ hitless/hooks.go | 47 + hitless/pool_hook.go | 179 ++++ hitless/pool_hook_test.go | 964 +++++++++++++++++++++ hitless/push_notification_handler.go | 276 ++++++ hitless/state.go | 24 + internal/interfaces/interfaces.go | 54 ++ internal/log.go | 24 +- internal/pool/bench_test.go | 7 +- internal/pool/buffer_size_test.go | 8 +- internal/pool/conn.go | 468 +++++++++- internal/pool/conn_relaxed_timeout_test.go | 92 ++ internal/pool/export_test.go | 2 +- internal/pool/hooks.go | 114 +++ internal/pool/hooks_test.go | 213 +++++ internal/pool/pool.go | 475 +++++++--- internal/pool/pool_single.go | 8 +- internal/pool/pool_sticky.go | 4 + internal/pool/pool_test.go | 112 ++- internal/pool/pubsub.go | 78 ++ internal/redis.go | 3 + internal/util/convert.go | 11 + internal/util/math.go | 17 + logging/logging.go | 121 +++ logging/logging_test.go | 59 ++ main_test.go | 2 + options.go | 150 +++- osscluster.go | 48 +- pool_pubsub_bench_test.go | 375 ++++++++ pubsub.go | 54 +- pubsub_test.go | 3 + push/handler_context.go | 10 +- push/processor_unit_test.go | 315 +++++++ push_notifications.go | 18 - redis.go | 234 ++++- redis_test.go | 1 - sentinel.go | 68 +- tx.go | 2 +- universal.go | 14 +- 56 files changed, 8062 insertions(+), 286 deletions(-) create mode 100644 adapters.go create mode 100644 async_handoff_integration_test.go create mode 100644 example/pubsub/go.mod create mode 100644 example/pubsub/go.sum create mode 100644 example/pubsub/main.go create mode 100644 hitless/README.md create mode 100644 hitless/circuit_breaker.go create mode 100644 hitless/circuit_breaker_test.go create mode 100644 hitless/config.go create mode 100644 hitless/config_test.go create mode 100644 hitless/errors.go create mode 100644 hitless/example_hooks.go create mode 100644 hitless/handoff_worker.go create mode 100644 hitless/hitless_manager.go create mode 100644 hitless/hitless_manager_test.go create mode 100644 hitless/hooks.go create mode 100644 hitless/pool_hook.go create mode 100644 hitless/pool_hook_test.go create mode 100644 hitless/push_notification_handler.go create mode 100644 hitless/state.go create mode 100644 internal/interfaces/interfaces.go create mode 100644 internal/pool/conn_relaxed_timeout_test.go create mode 100644 internal/pool/hooks.go create mode 100644 internal/pool/hooks_test.go create mode 100644 internal/pool/pubsub.go create mode 100644 internal/redis.go create mode 100644 internal/util/math.go create mode 100644 logging/logging.go create mode 100644 logging/logging_test.go create mode 100644 pool_pubsub_bench_test.go create mode 100644 push/processor_unit_test.go diff --git a/.gitignore b/.gitignore index 0d99709e..5fe0716e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage.txt **/coverage.txt .vscode tmp/* + +# Hitless upgrade documentation (temporary) +hitless/docs/ diff --git a/adapters.go b/adapters.go new file mode 100644 index 00000000..4146153b --- /dev/null +++ b/adapters.go @@ -0,0 +1,111 @@ +package redis + +import ( + "context" + "errors" + "net" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/push" +) + +// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand. +var ErrInvalidCommand = errors.New("invalid command type") + +// ErrInvalidPool is returned when the pool type is not supported. +var ErrInvalidPool = errors.New("invalid pool type") + +// newClientAdapter creates a new client adapter for regular Redis clients. +func newClientAdapter(client *baseClient) interfaces.ClientInterface { + return &clientAdapter{client: client} +} + +// clientAdapter adapts a Redis client to implement interfaces.ClientInterface. +type clientAdapter struct { + client *baseClient +} + +// GetOptions returns the client options. +func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface { + return &optionsAdapter{options: ca.client.opt} +} + +// GetPushProcessor returns the client's push notification processor. +func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor { + return &pushProcessorAdapter{processor: ca.client.pushProcessor} +} + +// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface. +type optionsAdapter struct { + options *Options +} + +// GetReadTimeout returns the read timeout. +func (oa *optionsAdapter) GetReadTimeout() time.Duration { + return oa.options.ReadTimeout +} + +// GetWriteTimeout returns the write timeout. +func (oa *optionsAdapter) GetWriteTimeout() time.Duration { + return oa.options.WriteTimeout +} + +// GetNetwork returns the network type. +func (oa *optionsAdapter) GetNetwork() string { + return oa.options.Network +} + +// GetAddr returns the connection address. +func (oa *optionsAdapter) GetAddr() string { + return oa.options.Addr +} + +// IsTLSEnabled returns true if TLS is enabled. +func (oa *optionsAdapter) IsTLSEnabled() bool { + return oa.options.TLSConfig != nil +} + +// GetProtocol returns the protocol version. +func (oa *optionsAdapter) GetProtocol() int { + return oa.options.Protocol +} + +// GetPoolSize returns the connection pool size. +func (oa *optionsAdapter) GetPoolSize() int { + return oa.options.PoolSize +} + +// NewDialer returns a new dialer function for the connection. +func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) { + baseDialer := oa.options.NewDialer() + return func(ctx context.Context) (net.Conn, error) { + // Extract network and address from the options + network := oa.options.Network + addr := oa.options.Addr + return baseDialer(ctx, network, addr) + } +} + +// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor. +type pushProcessorAdapter struct { + processor push.NotificationProcessor +} + +// RegisterHandler registers a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error { + if pushHandler, ok := handler.(push.NotificationHandler); ok { + return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected) + } + return errors.New("handler must implement push.NotificationHandler") +} + +// UnregisterHandler removes a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error { + return ppa.processor.UnregisterHandler(pushNotificationName) +} + +// GetHandler returns the handler for a specific push notification name. +func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} { + return ppa.processor.GetHandler(pushNotificationName) +} diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go new file mode 100644 index 00000000..7e34bf9d --- /dev/null +++ b/async_handoff_integration_test.go @@ -0,0 +1,353 @@ +package redis + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9/hitless" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" +) + +// mockNetConn implements net.Conn for testing +type mockNetConn struct { + addr string +} + +func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockAddr struct { + addr string +} + +func (m *mockAddr) Network() string { return "tcp" } +func (m *mockAddr) String() string { return m.addr } + +// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow +func TestEventDrivenHandoffIntegration(t *testing.T) { + t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) { + // Create a base dialer for testing + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + // Create processor with event-driven handoff support + processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create a test pool with hooks + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + PoolSize: int32(5), + PoolTimeout: time.Second, + }) + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + defer testPool.Close() + + // Set the pool reference in the processor for connection removal on handoff failure + processor.SetPool(testPool) + + ctx := context.Background() + + // Get a connection and mark it for handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + // Set initialization function with a small delay to ensure handoff is pending + initConnCalled := false + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending + initConnCalled = true + return nil + } + conn.SetInitConnFunc(initConnFunc) + + // Mark connection for handoff + err = conn.MarkForHandoff("new-endpoint:6379", 12345) + if err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Return connection to pool - this should queue handoff + testPool.Put(ctx, conn) + + // Give the on-demand worker a moment to start processing + time.Sleep(10 * time.Millisecond) + + // Verify handoff was queued + if !processor.IsHandoffPending(conn) { + t.Error("Handoff should be queued in pending map") + } + + // Try to get the same connection - should be skipped due to pending handoff + conn2, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get second connection: %v", err) + } + + // Should get a different connection (the pending one should be skipped) + if conn == conn2 { + t.Error("Should have gotten a different connection while handoff is pending") + } + + // Return the second connection + testPool.Put(ctx, conn2) + + // Wait for handoff to complete + time.Sleep(200 * time.Millisecond) + + // Verify handoff completed (removed from pending map) + if processor.IsHandoffPending(conn) { + t.Error("Handoff should have completed and been removed from pending map") + } + + if !initConnCalled { + t.Error("InitConn should have been called during handoff") + } + + // Now the original connection should be available again + conn3, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get third connection: %v", err) + } + + // Could be the original connection (now handed off) or a new one + testPool.Put(ctx, conn3) + }) + + t.Run("ConcurrentHandoffs", func(t *testing.T) { + // Create a base dialer that simulates slow handoffs + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(50 * time.Millisecond) // Simulate network delay + return &mockNetConn{addr: addr}, nil + } + + processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create hooks manager and add processor as hook + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + + PoolSize: int32(10), + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + var wg sync.WaitGroup + + // Start multiple concurrent handoffs + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Get connection + conn, err := testPool.Get(ctx) + if err != nil { + t.Errorf("Failed to get conn[%d]: %v", id, err) + return + } + + // Set initialization function + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + return nil + } + conn.SetInitConnFunc(initConnFunc) + + // Mark for handoff + conn.MarkForHandoff("new-endpoint:6379", int64(id)) + + // Return to pool (starts async handoff) + testPool.Put(ctx, conn) + }(i) + } + + wg.Wait() + + // Wait for all handoffs to complete + time.Sleep(300 * time.Millisecond) + + // Verify pool is still functional + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err) + } + testPool.Put(ctx, conn) + }) + + t.Run("HandoffFailureRecovery", func(t *testing.T) { + // Create a failing base dialer + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}} + } + + processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create hooks manager and add processor as hook + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + + PoolSize: int32(3), + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + + // Get connection and mark for handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + conn.MarkForHandoff("unreachable-endpoint:6379", 12345) + + // Return to pool (starts async handoff that will fail) + testPool.Put(ctx, conn) + + // Wait for handoff to fail + time.Sleep(200 * time.Millisecond) + + // Connection should be removed from pending map after failed handoff + if processor.IsHandoffPending(conn) { + t.Error("Connection should be removed from pending map after failed handoff") + } + + // Pool should still be functional + conn2, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Pool should still be functional: %v", err) + } + + // In event-driven approach, the original connection remains in pool + // even after failed handoff (it's still a valid connection) + // We might get the same connection or a different one + testPool.Put(ctx, conn2) + }) + + t.Run("GracefulShutdown", func(t *testing.T) { + // Create a slow base dialer + slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(100 * time.Millisecond) + return &mockNetConn{addr: addr}, nil + } + + processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create hooks manager and add processor as hook + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + + PoolSize: int32(2), + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + + // Start a handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function with delay to ensure handoff is pending + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending + return nil + }) + + testPool.Put(ctx, conn) + + // Give the on-demand worker a moment to start and begin processing + // The handoff should be pending because the slowDialer takes 100ms + time.Sleep(10 * time.Millisecond) + + // Verify handoff was queued and is being processed + if !processor.IsHandoffPending(conn) { + t.Error("Handoff should be queued in pending map") + } + + // Give the handoff a moment to start processing + time.Sleep(50 * time.Millisecond) + + // Shutdown processor gracefully + // Use a longer timeout to account for slow dialer (100ms) plus processing overhead + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = processor.Shutdown(shutdownCtx) + if err != nil { + t.Errorf("Graceful shutdown should succeed: %v", err) + } + + // Handoff should have completed (removed from pending map) + if processor.IsHandoffPending(conn) { + t.Error("Handoff should have completed and been removed from pending map after shutdown") + } + }) +} + +func init() { + logging.Disable() +} diff --git a/commands.go b/commands.go index c0358001..3a1cfdef 100644 --- a/commands.go +++ b/commands.go @@ -193,6 +193,7 @@ type Cmdable interface { ClientID(ctx context.Context) *IntCmd ClientUnblock(ctx context.Context, id int64) *IntCmd ClientUnblockWithError(ctx context.Context, id int64) *IntCmd + ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd ConfigResetStat(ctx context.Context) *StatusCmd ConfigSet(ctx context.Context, parameter, value string) *StatusCmd @@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd { return cmd } +// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades. +// When enabled, the client will receive push notifications about Redis maintenance events. +func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd { + args := []interface{}{"client", "maint_notifications"} + if enabled { + if endpointType == "" { + endpointType = "none" + } + args = append(args, "on", "moving-endpoint-type", endpointType) + } else { + args = append(args, "off") + } + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + // ------------------------------------------------------------------------------------------------ func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd { diff --git a/example/pubsub/go.mod b/example/pubsub/go.mod new file mode 100644 index 00000000..731a9283 --- /dev/null +++ b/example/pubsub/go.mod @@ -0,0 +1,12 @@ +module github.com/redis/go-redis/example/pubsub + +go 1.18 + +replace github.com/redis/go-redis/v9 => ../.. + +require github.com/redis/go-redis/v9 v9.11.0 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/example/pubsub/go.sum b/example/pubsub/go.sum new file mode 100644 index 00000000..d64ea030 --- /dev/null +++ b/example/pubsub/go.sum @@ -0,0 +1,6 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= diff --git a/example/pubsub/main.go b/example/pubsub/main.go new file mode 100644 index 00000000..1017c0ca --- /dev/null +++ b/example/pubsub/main.go @@ -0,0 +1,171 @@ +package main + +import ( + "context" + "fmt" + "log" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/hitless" + "github.com/redis/go-redis/v9/logging" +) + +var ctx = context.Background() +var cntErrors atomic.Int64 +var cntSuccess atomic.Int64 +var startTime = time.Now() + +// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management. +// It was used to find regressions in pool management in hitless mode. +// Please don't use it as a reference for how to use pubsub. +func main() { + startTime = time.Now() + wg := &sync.WaitGroup{} + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Mode: hitless.MaintNotificationsEnabled, + }, + }) + _ = rdb.FlushDB(ctx).Err() + hitlessManager := rdb.GetHitlessManager() + if hitlessManager == nil { + panic("hitless manager is nil") + } + loggingHook := hitless.NewLoggingHook(logging.LogLevelDebug) + hitlessManager.AddNotificationHook(loggingHook) + + go func() { + for { + time.Sleep(2 * time.Second) + fmt.Printf("pool stats: %+v\n", rdb.PoolStats()) + } + }() + err := rdb.Ping(ctx).Err() + if err != nil { + panic(err) + } + if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil { + panic(err) + } + fmt.Println("published", rdb.Get(ctx, "published").Val()) + fmt.Println("received", rdb.Get(ctx, "received").Val()) + subCtx, cancelSubCtx := context.WithCancel(ctx) + pubCtx, cancelPublishers := context.WithCancel(ctx) + for i := 0; i < 10; i++ { + wg.Add(1) + go subscribe(subCtx, rdb, "test", i, wg) + } + time.Sleep(time.Second) + cancelSubCtx() + time.Sleep(time.Second) + subCtx, cancelSubCtx = context.WithCancel(ctx) + for i := 0; i < 10; i++ { + if err := rdb.Incr(ctx, "publishers").Err(); err != nil { + fmt.Println("incr error:", err) + cntErrors.Add(1) + } + wg.Add(1) + go floodThePool(pubCtx, rdb, wg) + } + + for i := 0; i < 500; i++ { + if err := rdb.Incr(ctx, "subscribers").Err(); err != nil { + fmt.Println("incr error:", err) + cntErrors.Add(1) + } + + wg.Add(1) + go subscribe(subCtx, rdb, "test2", i, wg) + } + time.Sleep(120 * time.Second) + fmt.Println("canceling publishers") + cancelPublishers() + time.Sleep(10 * time.Second) + fmt.Println("canceling subscribers") + cancelSubCtx() + wg.Wait() + published, err := rdb.Get(ctx, "published").Result() + received, err := rdb.Get(ctx, "received").Result() + publishers, err := rdb.Get(ctx, "publishers").Result() + subscribers, err := rdb.Get(ctx, "subscribers").Result() + fmt.Printf("publishers: %s\n", publishers) + fmt.Printf("published: %s\n", published) + fmt.Printf("subscribers: %s\n", subscribers) + fmt.Printf("received: %s\n", received) + publishedInt, err := rdb.Get(ctx, "published").Int() + subscribersInt, err := rdb.Get(ctx, "subscribers").Int() + fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt) + + time.Sleep(2 * time.Second) + fmt.Println("errors:", cntErrors.Load()) + fmt.Println("success:", cntSuccess.Load()) + fmt.Println("time:", time.Since(startTime)) +} + +func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + } + err := rdb.Publish(ctx, "test2", "hello").Err() + if err != nil { + if err.Error() != "context canceled" { + log.Println("publish error:", err) + cntErrors.Add(1) + } + } + + err = rdb.Incr(ctx, "published").Err() + if err != nil { + if err.Error() != "context canceled" { + log.Println("incr error:", err) + cntErrors.Add(1) + } + } + time.Sleep(10 * time.Nanosecond) + } +} + +func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) { + defer wg.Done() + rec := rdb.Subscribe(ctx, topic) + recChan := rec.Channel() + for { + select { + case <-ctx.Done(): + rec.Close() + return + default: + select { + case <-ctx.Done(): + rec.Close() + return + case msg := <-recChan: + err := rdb.Incr(ctx, "received").Err() + if err != nil { + if err.Error() != "context canceled" { + log.Printf("%s\n", err.Error()) + cntErrors.Add(1) + } + } + _ = msg // Use the message to avoid unused variable warning + } + } + } +} diff --git a/example_instrumentation_test.go b/example_instrumentation_test.go index 36234ff0..fa776fcf 100644 --- a/example_instrumentation_test.go +++ b/example_instrumentation_test.go @@ -57,6 +57,8 @@ func Example_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> + // starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> + // finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> // finished processing: <[ping]> } @@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> + // starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> + // finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> // pipeline finished processing: [[ping] [ping]] } @@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> + // starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> + // finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> // finished processing: <[watch foo]> // starting processing: <[ping]> // finished processing: <[ping]> diff --git a/hitless/README.md b/hitless/README.md new file mode 100644 index 00000000..0803c0d4 --- /dev/null +++ b/hitless/README.md @@ -0,0 +1,98 @@ +# Hitless Upgrades + +Seamless Redis connection handoffs during cluster changes without dropping connections. + +## Quick Start + +```go +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // RESP3 required + HitlessUpgrades: &hitless.Config{ + Mode: hitless.MaintNotificationsEnabled, + }, +}) +``` + +## Modes + +- **`MaintNotificationsDisabled`** - Hitless upgrades disabled +- **`MaintNotificationsEnabled`** - Forcefully enabled (fails if server doesn't support) +- **`MaintNotificationsAuto`** - Auto-detect server support (default) + +## Configuration + +```go +&hitless.Config{ + Mode: hitless.MaintNotificationsAuto, + EndpointType: hitless.EndpointTypeAuto, + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxHandoffRetries: 3, + MaxWorkers: 0, // Auto-calculated + HandoffQueueSize: 0, // Auto-calculated + PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout + LogLevel: logging.LogLevelError, +} +``` + +### Endpoint Types + +- **`EndpointTypeAuto`** - Auto-detect based on connection (default) +- **`EndpointTypeInternalIP`** - Internal IP address +- **`EndpointTypeInternalFQDN`** - Internal FQDN +- **`EndpointTypeExternalIP`** - External IP address +- **`EndpointTypeExternalFQDN`** - External FQDN +- **`EndpointTypeNone`** - No endpoint (reconnect with current config) + +### Auto-Scaling + +**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated +**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize` + +**Examples:** +- Pool 100: 33 workers, 660 queue (capped at 500) +- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue + +## How It Works + +1. Redis sends push notifications about cluster changes +2. Client creates new connections to updated endpoints +3. Active operations transfer to new connections +4. Old connections close gracefully + +## Supported Notifications + +- `MOVING` - Slot moving to new node +- `MIGRATING` - Slot in migration state +- `MIGRATED` - Migration completed +- `FAILING_OVER` - Node failing over +- `FAILED_OVER` - Failover completed + +## Hooks (Optional) + +Monitor and customize hitless operations: + +```go +type NotificationHook interface { + PreHook(ctx, notificationCtx, notificationType, notification) ([]interface{}, bool) + PostHook(ctx, notificationCtx, notificationType, notification, result) +} + +// Add custom hook +manager.AddNotificationHook(&MyHook{}) +``` + +### Metrics Hook Example + +```go +// Create metrics hook +metricsHook := hitless.NewMetricsHook() +manager.AddNotificationHook(metricsHook) + +// Access collected metrics +metrics := metricsHook.GetMetrics() +fmt.Printf("Notification counts: %v\n", metrics["notification_counts"]) +fmt.Printf("Processing times: %v\n", metrics["processing_times"]) +fmt.Printf("Error counts: %v\n", metrics["error_counts"]) +``` diff --git a/hitless/circuit_breaker.go b/hitless/circuit_breaker.go new file mode 100644 index 00000000..8f985123 --- /dev/null +++ b/hitless/circuit_breaker.go @@ -0,0 +1,360 @@ +package hitless + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" +) + +// CircuitBreakerState represents the state of a circuit breaker +type CircuitBreakerState int32 + +const ( + // CircuitBreakerClosed - normal operation, requests allowed + CircuitBreakerClosed CircuitBreakerState = iota + // CircuitBreakerOpen - failing fast, requests rejected + CircuitBreakerOpen + // CircuitBreakerHalfOpen - testing if service recovered + CircuitBreakerHalfOpen +) + +func (s CircuitBreakerState) String() string { + switch s { + case CircuitBreakerClosed: + return "closed" + case CircuitBreakerOpen: + return "open" + case CircuitBreakerHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling +type CircuitBreaker struct { + // Configuration + failureThreshold int // Number of failures before opening + resetTimeout time.Duration // How long to stay open before testing + maxRequests int // Max requests allowed in half-open state + + // State tracking (atomic for lock-free access) + state atomic.Int32 // CircuitBreakerState + failures atomic.Int64 // Current failure count + successes atomic.Int64 // Success count in half-open state + requests atomic.Int64 // Request count in half-open state + lastFailureTime atomic.Int64 // Unix timestamp of last failure + lastSuccessTime atomic.Int64 // Unix timestamp of last success + + // Endpoint identification + endpoint string + config *Config +} + +// newCircuitBreaker creates a new circuit breaker for an endpoint +func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker { + // Use configuration values with sensible defaults + failureThreshold := 5 + resetTimeout := 60 * time.Second + maxRequests := 3 + + if config != nil { + failureThreshold = config.CircuitBreakerFailureThreshold + resetTimeout = config.CircuitBreakerResetTimeout + maxRequests = config.CircuitBreakerMaxRequests + } + + return &CircuitBreaker{ + failureThreshold: failureThreshold, + resetTimeout: resetTimeout, + maxRequests: maxRequests, + endpoint: endpoint, + config: config, + state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0) + } +} + +// IsOpen returns true if the circuit breaker is open (rejecting requests) +func (cb *CircuitBreaker) IsOpen() bool { + state := CircuitBreakerState(cb.state.Load()) + return state == CircuitBreakerOpen +} + +// shouldAttemptReset checks if enough time has passed to attempt reset +func (cb *CircuitBreaker) shouldAttemptReset() bool { + lastFailure := time.Unix(cb.lastFailureTime.Load(), 0) + return time.Since(lastFailure) >= cb.resetTimeout +} + +// Execute runs the given function with circuit breaker protection +func (cb *CircuitBreaker) Execute(fn func() error) error { + // Single atomic state load for consistency + state := CircuitBreakerState(cb.state.Load()) + + switch state { + case CircuitBreakerOpen: + if cb.shouldAttemptReset() { + // Attempt transition to half-open + if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) { + cb.requests.Store(0) + cb.successes.Store(0) + if cb.config != nil && cb.config.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: circuit breaker for %s transitioning to half-open", cb.endpoint) + } + // Fall through to half-open logic + } else { + return ErrCircuitBreakerOpen + } + } else { + return ErrCircuitBreakerOpen + } + fallthrough + case CircuitBreakerHalfOpen: + requests := cb.requests.Add(1) + if requests > int64(cb.maxRequests) { + cb.requests.Add(-1) // Revert the increment + return ErrCircuitBreakerOpen + } + } + + // Execute the function with consistent state + err := fn() + + if err != nil { + cb.recordFailure() + return err + } + + cb.recordSuccess() + return nil +} + +// recordFailure records a failure and potentially opens the circuit +func (cb *CircuitBreaker) recordFailure() { + cb.lastFailureTime.Store(time.Now().Unix()) + failures := cb.failures.Add(1) + + state := CircuitBreakerState(cb.state.Load()) + + switch state { + case CircuitBreakerClosed: + if failures >= int64(cb.failureThreshold) { + if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) { + if cb.config != nil && cb.config.LogLevel.WarnOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: circuit breaker opened for endpoint %s after %d failures", + cb.endpoint, failures) + } + } + } + case CircuitBreakerHalfOpen: + // Any failure in half-open state immediately opens the circuit + if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) { + if cb.config != nil && cb.config.LogLevel.WarnOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: circuit breaker reopened for endpoint %s due to failure in half-open state", + cb.endpoint) + } + } + } +} + +// recordSuccess records a success and potentially closes the circuit +func (cb *CircuitBreaker) recordSuccess() { + cb.lastSuccessTime.Store(time.Now().Unix()) + + state := CircuitBreakerState(cb.state.Load()) + + switch state { + case CircuitBreakerClosed: + // Reset failure count on success in closed state + cb.failures.Store(0) + case CircuitBreakerHalfOpen: + successes := cb.successes.Add(1) + + // If we've had enough successful requests, close the circuit + if successes >= int64(cb.maxRequests) { + if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) { + cb.failures.Store(0) + if cb.config != nil && cb.config.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: circuit breaker closed for endpoint %s after %d successful requests", + cb.endpoint, successes) + } + } + } + } +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) GetState() CircuitBreakerState { + return CircuitBreakerState(cb.state.Load()) +} + +// GetStats returns current statistics for monitoring +func (cb *CircuitBreaker) GetStats() CircuitBreakerStats { + return CircuitBreakerStats{ + Endpoint: cb.endpoint, + State: cb.GetState(), + Failures: cb.failures.Load(), + Successes: cb.successes.Load(), + Requests: cb.requests.Load(), + LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0), + LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0), + } +} + +// CircuitBreakerStats provides statistics about a circuit breaker +type CircuitBreakerStats struct { + Endpoint string + State CircuitBreakerState + Failures int64 + Successes int64 + Requests int64 + LastFailureTime time.Time + LastSuccessTime time.Time +} + +// CircuitBreakerEntry wraps a circuit breaker with access tracking +type CircuitBreakerEntry struct { + breaker *CircuitBreaker + lastAccess atomic.Int64 // Unix timestamp + created time.Time +} + +// CircuitBreakerManager manages circuit breakers for multiple endpoints +type CircuitBreakerManager struct { + breakers sync.Map // map[string]*CircuitBreakerEntry + config *Config + cleanupStop chan struct{} + cleanupMu sync.Mutex + lastCleanup atomic.Int64 // Unix timestamp +} + +// newCircuitBreakerManager creates a new circuit breaker manager +func newCircuitBreakerManager(config *Config) *CircuitBreakerManager { + cbm := &CircuitBreakerManager{ + config: config, + cleanupStop: make(chan struct{}), + } + cbm.lastCleanup.Store(time.Now().Unix()) + + // Start background cleanup goroutine + go cbm.cleanupLoop() + + return cbm +} + +// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary +func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker { + now := time.Now().Unix() + + if entry, ok := cbm.breakers.Load(endpoint); ok { + cbEntry := entry.(*CircuitBreakerEntry) + cbEntry.lastAccess.Store(now) + return cbEntry.breaker + } + + // Create new circuit breaker with metadata + newBreaker := newCircuitBreaker(endpoint, cbm.config) + newEntry := &CircuitBreakerEntry{ + breaker: newBreaker, + created: time.Now(), + } + newEntry.lastAccess.Store(now) + + actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry) + return actual.(*CircuitBreakerEntry).breaker +} + +// GetAllStats returns statistics for all circuit breakers +func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats { + var stats []CircuitBreakerStats + cbm.breakers.Range(func(key, value interface{}) bool { + entry := value.(*CircuitBreakerEntry) + stats = append(stats, entry.breaker.GetStats()) + return true + }) + return stats +} + +// cleanupLoop runs background cleanup of unused circuit breakers +func (cbm *CircuitBreakerManager) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes + defer ticker.Stop() + + for { + select { + case <-ticker.C: + cbm.cleanup() + case <-cbm.cleanupStop: + return + } + } +} + +// cleanup removes circuit breakers that haven't been accessed recently +func (cbm *CircuitBreakerManager) cleanup() { + // Prevent concurrent cleanups + if !cbm.cleanupMu.TryLock() { + return + } + defer cbm.cleanupMu.Unlock() + + now := time.Now() + cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL + + var toDelete []string + count := 0 + + cbm.breakers.Range(func(key, value interface{}) bool { + endpoint := key.(string) + entry := value.(*CircuitBreakerEntry) + + count++ + + // Remove if not accessed recently + if entry.lastAccess.Load() < cutoff { + toDelete = append(toDelete, endpoint) + } + + return true + }) + + // Delete expired entries + for _, endpoint := range toDelete { + cbm.breakers.Delete(endpoint) + } + + // Log cleanup results + if len(toDelete) > 0 && cbm.config != nil && cbm.config.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: circuit breaker cleanup removed %d/%d entries", len(toDelete), count) + } + + cbm.lastCleanup.Store(now.Unix()) +} + +// Shutdown stops the cleanup goroutine +func (cbm *CircuitBreakerManager) Shutdown() { + close(cbm.cleanupStop) +} + +// Reset resets all circuit breakers (useful for testing) +func (cbm *CircuitBreakerManager) Reset() { + cbm.breakers.Range(func(key, value interface{}) bool { + entry := value.(*CircuitBreakerEntry) + breaker := entry.breaker + breaker.state.Store(int32(CircuitBreakerClosed)) + breaker.failures.Store(0) + breaker.successes.Store(0) + breaker.requests.Store(0) + breaker.lastFailureTime.Store(0) + breaker.lastSuccessTime.Store(0) + return true + }) +} diff --git a/hitless/circuit_breaker_test.go b/hitless/circuit_breaker_test.go new file mode 100644 index 00000000..16015ec8 --- /dev/null +++ b/hitless/circuit_breaker_test.go @@ -0,0 +1,356 @@ +package hitless + +import ( + "errors" + "testing" + "time" + + "github.com/redis/go-redis/v9/logging" +) + +func TestCircuitBreaker(t *testing.T) { + config := &Config{ + LogLevel: logging.LogLevelError, // Reduce noise in tests + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 60 * time.Second, + CircuitBreakerMaxRequests: 3, + } + + t.Run("InitialState", func(t *testing.T) { + cb := newCircuitBreaker("test-endpoint:6379", config) + + if cb.IsOpen() { + t.Error("Circuit breaker should start in closed state") + } + + if cb.GetState() != CircuitBreakerClosed { + t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState()) + } + }) + + t.Run("SuccessfulExecution", func(t *testing.T) { + cb := newCircuitBreaker("test-endpoint:6379", config) + + err := cb.Execute(func() error { + return nil // Success + }) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if cb.GetState() != CircuitBreakerClosed { + t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState()) + } + }) + + t.Run("FailureThreshold", func(t *testing.T) { + cb := newCircuitBreaker("test-endpoint:6379", config) + testError := errors.New("test error") + + // Fail 4 times (below threshold of 5) + for i := 0; i < 4; i++ { + err := cb.Execute(func() error { + return testError + }) + if err != testError { + t.Errorf("Expected test error, got %v", err) + } + if cb.GetState() != CircuitBreakerClosed { + t.Errorf("Circuit should still be closed after %d failures", i+1) + } + } + + // 5th failure should open the circuit + err := cb.Execute(func() error { + return testError + }) + if err != testError { + t.Errorf("Expected test error, got %v", err) + } + + if cb.GetState() != CircuitBreakerOpen { + t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState()) + } + }) + + t.Run("OpenCircuitFailsFast", func(t *testing.T) { + cb := newCircuitBreaker("test-endpoint:6379", config) + testError := errors.New("test error") + + // Force circuit to open + for i := 0; i < 5; i++ { + cb.Execute(func() error { return testError }) + } + + // Now it should fail fast + err := cb.Execute(func() error { + t.Error("Function should not be called when circuit is open") + return nil + }) + + if err != ErrCircuitBreakerOpen { + t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err) + } + }) + + t.Run("HalfOpenTransition", func(t *testing.T) { + testConfig := &Config{ + LogLevel: logging.LogLevelError, + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing + CircuitBreakerMaxRequests: 3, + } + cb := newCircuitBreaker("test-endpoint:6379", testConfig) + testError := errors.New("test error") + + // Force circuit to open + for i := 0; i < 5; i++ { + cb.Execute(func() error { return testError }) + } + + if cb.GetState() != CircuitBreakerOpen { + t.Error("Circuit should be open") + } + + // Wait for reset timeout + time.Sleep(150 * time.Millisecond) + + // Next call should transition to half-open + executed := false + err := cb.Execute(func() error { + executed = true + return nil // Success + }) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if !executed { + t.Error("Function should have been executed in half-open state") + } + }) + + t.Run("HalfOpenToClosedTransition", func(t *testing.T) { + testConfig := &Config{ + LogLevel: logging.LogLevelError, + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 50 * time.Millisecond, + CircuitBreakerMaxRequests: 3, + } + cb := newCircuitBreaker("test-endpoint:6379", testConfig) + testError := errors.New("test error") + + // Force circuit to open + for i := 0; i < 5; i++ { + cb.Execute(func() error { return testError }) + } + + // Wait for reset timeout + time.Sleep(100 * time.Millisecond) + + // Execute successful requests in half-open state + for i := 0; i < 3; i++ { + err := cb.Execute(func() error { + return nil // Success + }) + if err != nil { + t.Errorf("Expected no error on attempt %d, got %v", i+1, err) + } + } + + // Circuit should now be closed + if cb.GetState() != CircuitBreakerClosed { + t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState()) + } + }) + + t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) { + testConfig := &Config{ + LogLevel: logging.LogLevelError, + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 50 * time.Millisecond, + CircuitBreakerMaxRequests: 3, + } + cb := newCircuitBreaker("test-endpoint:6379", testConfig) + testError := errors.New("test error") + + // Force circuit to open + for i := 0; i < 5; i++ { + cb.Execute(func() error { return testError }) + } + + // Wait for reset timeout + time.Sleep(100 * time.Millisecond) + + // First request in half-open state fails + err := cb.Execute(func() error { + return testError + }) + + if err != testError { + t.Errorf("Expected test error, got %v", err) + } + + // Circuit should be open again + if cb.GetState() != CircuitBreakerOpen { + t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState()) + } + }) + + t.Run("Stats", func(t *testing.T) { + cb := newCircuitBreaker("test-endpoint:6379", config) + testError := errors.New("test error") + + // Execute some operations + cb.Execute(func() error { return testError }) // Failure + cb.Execute(func() error { return testError }) // Failure + + stats := cb.GetStats() + + if stats.Endpoint != "test-endpoint:6379" { + t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint) + } + + if stats.Failures != 2 { + t.Errorf("Expected 2 failures, got %d", stats.Failures) + } + + if stats.State != CircuitBreakerClosed { + t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State) + } + + // Test that success resets failure count + cb.Execute(func() error { return nil }) // Success + stats = cb.GetStats() + + if stats.Failures != 0 { + t.Errorf("Expected 0 failures after success, got %d", stats.Failures) + } + }) +} + +func TestCircuitBreakerManager(t *testing.T) { + config := &Config{ + LogLevel: logging.LogLevelError, + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 60 * time.Second, + CircuitBreakerMaxRequests: 3, + } + + t.Run("GetCircuitBreaker", func(t *testing.T) { + manager := newCircuitBreakerManager(config) + + cb1 := manager.GetCircuitBreaker("endpoint1:6379") + cb2 := manager.GetCircuitBreaker("endpoint2:6379") + cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1 + + if cb1 == cb2 { + t.Error("Different endpoints should have different circuit breakers") + } + + if cb1 != cb3 { + t.Error("Same endpoint should return the same circuit breaker") + } + }) + + t.Run("GetAllStats", func(t *testing.T) { + manager := newCircuitBreakerManager(config) + + // Create circuit breakers for different endpoints + cb1 := manager.GetCircuitBreaker("endpoint1:6379") + cb2 := manager.GetCircuitBreaker("endpoint2:6379") + + // Execute some operations + cb1.Execute(func() error { return nil }) + cb2.Execute(func() error { return errors.New("test error") }) + + stats := manager.GetAllStats() + + if len(stats) != 2 { + t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats)) + } + + // Check that we have stats for both endpoints + endpoints := make(map[string]bool) + for _, stat := range stats { + endpoints[stat.Endpoint] = true + } + + if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] { + t.Error("Missing stats for expected endpoints") + } + }) + + t.Run("Reset", func(t *testing.T) { + manager := newCircuitBreakerManager(config) + testError := errors.New("test error") + + cb := manager.GetCircuitBreaker("test-endpoint:6379") + + // Force circuit to open + for i := 0; i < 5; i++ { + cb.Execute(func() error { return testError }) + } + + if cb.GetState() != CircuitBreakerOpen { + t.Error("Circuit should be open") + } + + // Reset all circuit breakers + manager.Reset() + + if cb.GetState() != CircuitBreakerClosed { + t.Error("Circuit should be closed after reset") + } + + if cb.failures.Load() != 0 { + t.Error("Failure count should be reset to 0") + } + }) + + t.Run("ConfigurableParameters", func(t *testing.T) { + config := &Config{ + LogLevel: logging.LogLevelError, + CircuitBreakerFailureThreshold: 10, + CircuitBreakerResetTimeout: 30 * time.Second, + CircuitBreakerMaxRequests: 5, + } + + cb := newCircuitBreaker("test-endpoint:6379", config) + + // Test that configuration values are used + if cb.failureThreshold != 10 { + t.Errorf("Expected failureThreshold=10, got %d", cb.failureThreshold) + } + if cb.resetTimeout != 30*time.Second { + t.Errorf("Expected resetTimeout=30s, got %v", cb.resetTimeout) + } + if cb.maxRequests != 5 { + t.Errorf("Expected maxRequests=5, got %d", cb.maxRequests) + } + + // Test that circuit opens after configured threshold + testError := errors.New("test error") + for i := 0; i < 9; i++ { + err := cb.Execute(func() error { return testError }) + if err != testError { + t.Errorf("Expected test error, got %v", err) + } + if cb.GetState() != CircuitBreakerClosed { + t.Errorf("Circuit should still be closed after %d failures", i+1) + } + } + + // 10th failure should open the circuit + err := cb.Execute(func() error { return testError }) + if err != testError { + t.Errorf("Expected test error, got %v", err) + } + + if cb.GetState() != CircuitBreakerOpen { + t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState()) + } + }) +} diff --git a/hitless/config.go b/hitless/config.go new file mode 100644 index 00000000..6b9b7b37 --- /dev/null +++ b/hitless/config.go @@ -0,0 +1,472 @@ +package hitless + +import ( + "context" + "net" + "runtime" + "strings" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" +) + +// MaintNotificationsMode represents the maintenance notifications mode +type MaintNotificationsMode string + +// Constants for maintenance push notifications modes +const ( + MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command + MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error + MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error +) + +// IsValid returns true if the maintenance notifications mode is valid +func (m MaintNotificationsMode) IsValid() bool { + switch m { + case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto: + return true + default: + return false + } +} + +// String returns the string representation of the mode +func (m MaintNotificationsMode) String() string { + return string(m) +} + +// EndpointType represents the type of endpoint to request in MOVING notifications +type EndpointType string + +// Constants for endpoint types +const ( + EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection + EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address + EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN + EndpointTypeExternalIP EndpointType = "external-ip" // External IP address + EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN + EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config) +) + +// IsValid returns true if the endpoint type is valid +func (e EndpointType) IsValid() bool { + switch e { + case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone: + return true + default: + return false + } +} + +// String returns the string representation of the endpoint type +func (e EndpointType) String() string { + return string(e) +} + +// Config provides configuration options for hitless upgrades. +type Config struct { + // Mode controls how client maintenance notifications are handled. + // Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto + // Default: MaintNotificationsAuto + Mode MaintNotificationsMode + + // EndpointType specifies the type of endpoint to request in MOVING notifications. + // Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + // EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone + // Default: EndpointTypeAuto + EndpointType EndpointType + + // RelaxedTimeout is the concrete timeout value to use during + // MIGRATING/FAILING_OVER states to accommodate increased latency. + // This applies to both read and write timeouts. + // Default: 10 seconds + RelaxedTimeout time.Duration + + // HandoffTimeout is the maximum time to wait for connection handoff to complete. + // If handoff takes longer than this, the old connection will be forcibly closed. + // Default: 15 seconds (matches server-side eviction timeout) + HandoffTimeout time.Duration + + // MaxWorkers is the maximum number of worker goroutines for processing handoff requests. + // Workers are created on-demand and automatically cleaned up when idle. + // If zero, defaults to min(10, PoolSize/2) to handle bursts effectively. + // If explicitly set, enforces minimum of PoolSize/2 + // + // Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2 + MaxWorkers int + + // HandoffQueueSize is the size of the buffered channel used to queue handoff requests. + // If the queue is full, new handoff requests will be rejected. + // Scales with both worker count and pool size for better burst handling. + // + // Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize + // When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize + HandoffQueueSize int + + // PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection + // after a handoff completes. This provides additional resilience during cluster transitions. + // Default: 2 * RelaxedTimeout + PostHandoffRelaxedDuration time.Duration + + // LogLevel controls the verbosity of hitless upgrade logging. + // LogLevelError (0) = errors only, LogLevelWarn (1) = warnings, LogLevelInfo (2) = info, LogLevelDebug (3) = debug + // Default: logging.LogLevelError(0) + LogLevel logging.LogLevel + + // Circuit breaker configuration for endpoint failure handling + // CircuitBreakerFailureThreshold is the number of failures before opening the circuit. + // Default: 5 + CircuitBreakerFailureThreshold int + + // CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered. + // Default: 60 seconds + CircuitBreakerResetTimeout time.Duration + + // CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state. + // Default: 3 + CircuitBreakerMaxRequests int + + // MaxHandoffRetries is the maximum number of times to retry a failed handoff. + // After this many retries, the connection will be removed from the pool. + // Default: 3 + MaxHandoffRetries int +} + +func (c *Config) IsEnabled() bool { + return c != nil && c.Mode != MaintNotificationsDisabled +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() *Config { + return &Config{ + Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud + EndpointType: EndpointTypeAuto, // Auto-detect based on connection + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxWorkers: 0, // Auto-calculated based on pool size + HandoffQueueSize: 0, // Auto-calculated based on max workers + PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout + LogLevel: logging.LogLevelError, + + // Circuit breaker configuration + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 60 * time.Second, + CircuitBreakerMaxRequests: 3, + + // Connection Handoff Configuration + MaxHandoffRetries: 3, + } +} + +// Validate checks if the configuration is valid. +func (c *Config) Validate() error { + if c.RelaxedTimeout <= 0 { + return ErrInvalidRelaxedTimeout + } + if c.HandoffTimeout <= 0 { + return ErrInvalidHandoffTimeout + } + // Validate worker configuration + // Allow 0 for auto-calculation, but negative values are invalid + if c.MaxWorkers < 0 { + return ErrInvalidHandoffWorkers + } + // HandoffQueueSize validation - allow 0 for auto-calculation + if c.HandoffQueueSize < 0 { + return ErrInvalidHandoffQueueSize + } + if c.PostHandoffRelaxedDuration < 0 { + return ErrInvalidPostHandoffRelaxedDuration + } + if !c.LogLevel.IsValid() { + return ErrInvalidLogLevel + } + + // Circuit breaker validation + if c.CircuitBreakerFailureThreshold < 1 { + return ErrInvalidCircuitBreakerFailureThreshold + } + if c.CircuitBreakerResetTimeout < 0 { + return ErrInvalidCircuitBreakerResetTimeout + } + if c.CircuitBreakerMaxRequests < 1 { + return ErrInvalidCircuitBreakerMaxRequests + } + + // Validate Mode (maintenance notifications mode) + if !c.Mode.IsValid() { + return ErrInvalidMaintNotifications + } + + // Validate EndpointType + if !c.EndpointType.IsValid() { + return ErrInvalidEndpointType + } + + // Validate configuration fields + if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 { + return ErrInvalidHandoffRetries + } + + return nil +} + +// ApplyDefaults applies default values to any zero-value fields in the configuration. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaults() *Config { + return c.ApplyDefaultsWithPoolSize(0) +} + +// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration, +// using the provided pool size to calculate worker defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { + return c.ApplyDefaultsWithPoolConfig(poolSize, 0) +} + +// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration, +// using the provided pool size and max active connections to calculate worker and queue defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config { + if c == nil { + return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize) + } + + defaults := DefaultConfig() + result := &Config{} + + // Apply defaults for enum fields (empty/zero means not set) + result.Mode = defaults.Mode + if c.Mode != "" { + result.Mode = c.Mode + } + + result.EndpointType = defaults.EndpointType + if c.EndpointType != "" { + result.EndpointType = c.EndpointType + } + + // Apply defaults for duration fields (zero means not set) + result.RelaxedTimeout = defaults.RelaxedTimeout + if c.RelaxedTimeout > 0 { + result.RelaxedTimeout = c.RelaxedTimeout + } + + result.HandoffTimeout = defaults.HandoffTimeout + if c.HandoffTimeout > 0 { + result.HandoffTimeout = c.HandoffTimeout + } + + // Copy worker configuration + result.MaxWorkers = c.MaxWorkers + + // Apply worker defaults based on pool size + result.applyWorkerDefaults(poolSize) + + // Apply queue size defaults with new scaling approach + // Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size + workerBasedSize := result.MaxWorkers * 20 + poolBasedSize := poolSize + result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize) + if c.HandoffQueueSize > 0 { + // When explicitly set: enforce minimum of 200 + result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize) + } + + // Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size + var queueCap int + if maxActiveConns > 0 { + queueCap = maxActiveConns + 1 + // Ensure queue cap is at least 2 for very small maxActiveConns + if queueCap < 2 { + queueCap = 2 + } + } else { + queueCap = poolSize * 5 + } + result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap) + + // Ensure minimum queue size of 2 (fallback for very small pools) + if result.HandoffQueueSize < 2 { + result.HandoffQueueSize = 2 + } + + result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2 + if c.PostHandoffRelaxedDuration > 0 { + result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration + } + + // LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set + // We'll use the provided value as-is, since 0 is valid + result.LogLevel = c.LogLevel + + // Apply defaults for configuration fields + result.MaxHandoffRetries = defaults.MaxHandoffRetries + if c.MaxHandoffRetries > 0 { + result.MaxHandoffRetries = c.MaxHandoffRetries + } + + // Circuit breaker configuration + result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold + if c.CircuitBreakerFailureThreshold > 0 { + result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold + } + + result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout + if c.CircuitBreakerResetTimeout > 0 { + result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout + } + + result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests + if c.CircuitBreakerMaxRequests > 0 { + result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests + } + + if result.LogLevel.DebugOrAbove() { + internal.Logger.Printf(context.Background(), "hitless: debug logging enabled") + internal.Logger.Printf(context.Background(), "hitless: config: %+v", result) + } + return result +} + +// Clone creates a deep copy of the configuration. +func (c *Config) Clone() *Config { + if c == nil { + return DefaultConfig() + } + + return &Config{ + Mode: c.Mode, + EndpointType: c.EndpointType, + RelaxedTimeout: c.RelaxedTimeout, + HandoffTimeout: c.HandoffTimeout, + MaxWorkers: c.MaxWorkers, + HandoffQueueSize: c.HandoffQueueSize, + PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration, + LogLevel: c.LogLevel, + + // Circuit breaker configuration + CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold, + CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout, + CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests, + + // Configuration fields + MaxHandoffRetries: c.MaxHandoffRetries, + } +} + +// applyWorkerDefaults calculates and applies worker defaults based on pool size +func (c *Config) applyWorkerDefaults(poolSize int) { + // Calculate defaults based on pool size + if poolSize <= 0 { + poolSize = 10 * runtime.GOMAXPROCS(0) + } + + // When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach + originalMaxWorkers := c.MaxWorkers + c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3)) + if originalMaxWorkers != 0 { + // When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers + c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers) + } + + // Ensure minimum of 1 worker (fallback for very small pools) + if c.MaxWorkers < 1 { + c.MaxWorkers = 1 + } +} + +// DetectEndpointType automatically detects the appropriate endpoint type +// based on the connection address and TLS configuration. +// +// For IP addresses: +// - If TLS is enabled: requests FQDN for proper certificate validation +// - If TLS is disabled: requests IP for better performance +// +// For hostnames: +// - If TLS is enabled: always requests FQDN for proper certificate validation +// - If TLS is disabled: requests IP for better performance +// +// Internal vs External detection: +// - For IPs: uses private IP range detection +// - For hostnames: uses heuristics based on common internal naming patterns +func DetectEndpointType(addr string, tlsEnabled bool) EndpointType { + // Extract host from "host:port" format + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr // Assume no port + } + + // Check if the host is an IP address or hostname + ip := net.ParseIP(host) + isIPAddress := ip != nil + var endpointType EndpointType + + if isIPAddress { + // Address is an IP - determine if it's private or public + isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() + + if tlsEnabled { + // TLS with IP addresses - still prefer FQDN for certificate validation + if isPrivate { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } + } else { + // No TLS - can use IP addresses directly + if isPrivate { + endpointType = EndpointTypeInternalIP + } else { + endpointType = EndpointTypeExternalIP + } + } + } else { + // Address is a hostname + isInternalHostname := isInternalHostname(host) + if isInternalHostname { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } + } + + return endpointType +} + +// isInternalHostname determines if a hostname appears to be internal/private. +// This is a heuristic based on common naming patterns. +func isInternalHostname(hostname string) bool { + // Convert to lowercase for comparison + hostname = strings.ToLower(hostname) + + // Common internal hostname patterns + internalPatterns := []string{ + "localhost", + ".local", + ".internal", + ".corp", + ".lan", + ".intranet", + ".private", + } + + // Check for exact match or suffix match + for _, pattern := range internalPatterns { + if hostname == pattern || strings.HasSuffix(hostname, pattern) { + return true + } + } + + // Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.) + // If hostname doesn't contain dots, it's likely internal + if !strings.Contains(hostname, ".") { + return true + } + + // Default to external for fully qualified domain names + return false +} diff --git a/hitless/config_test.go b/hitless/config_test.go new file mode 100644 index 00000000..6c74823c --- /dev/null +++ b/hitless/config_test.go @@ -0,0 +1,490 @@ +package hitless + +import ( + "context" + "net" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" +) + +func TestConfig(t *testing.T) { + t.Run("DefaultConfig", func(t *testing.T) { + config := DefaultConfig() + + // MaxWorkers should be 0 in default config (auto-calculated) + if config.MaxWorkers != 0 { + t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers) + } + + // HandoffQueueSize should be 0 in default config (auto-calculated) + if config.HandoffQueueSize != 0 { + t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize) + } + + if config.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout) + } + + // Test configuration fields have proper defaults + if config.MaxHandoffRetries != 3 { + t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries) + } + + // Circuit breaker defaults + if config.CircuitBreakerFailureThreshold != 5 { + t.Errorf("Expected CircuitBreakerFailureThreshold=5, got %d", config.CircuitBreakerFailureThreshold) + } + if config.CircuitBreakerResetTimeout != 60*time.Second { + t.Errorf("Expected CircuitBreakerResetTimeout=60s, got %v", config.CircuitBreakerResetTimeout) + } + if config.CircuitBreakerMaxRequests != 3 { + t.Errorf("Expected CircuitBreakerMaxRequests=3, got %d", config.CircuitBreakerMaxRequests) + } + + if config.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout) + } + + if config.PostHandoffRelaxedDuration != 0 { + t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration) + } + + // Test that defaults are applied correctly + configWithDefaults := config.ApplyDefaultsWithPoolSize(100) + if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second { + t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration) + } + }) + + t.Run("ConfigValidation", func(t *testing.T) { + // Valid config with applied defaults + config := DefaultConfig().ApplyDefaults() + if err := config.Validate(); err != nil { + t.Errorf("Default config with applied defaults should be valid: %v", err) + } + + // Invalid worker configuration (negative MaxWorkers) + config = &Config{ + RelaxedTimeout: 30 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxWorkers: -1, // This should be invalid + HandoffQueueSize: 100, + PostHandoffRelaxedDuration: 10 * time.Second, + LogLevel: 1, + MaxHandoffRetries: 3, // Add required field + } + if err := config.Validate(); err != ErrInvalidHandoffWorkers { + t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err) + } + + // Invalid HandoffQueueSize + config = DefaultConfig().ApplyDefaults() + config.HandoffQueueSize = -1 + if err := config.Validate(); err != ErrInvalidHandoffQueueSize { + t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err) + } + + // Invalid PostHandoffRelaxedDuration + config = DefaultConfig().ApplyDefaults() + config.PostHandoffRelaxedDuration = -1 * time.Second + if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration { + t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err) + } + }) + + t.Run("ConfigClone", func(t *testing.T) { + original := DefaultConfig() + original.MaxWorkers = 20 + original.HandoffQueueSize = 200 + + cloned := original.Clone() + + if cloned.MaxWorkers != 20 { + t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers) + } + + if cloned.HandoffQueueSize != 200 { + t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize) + } + + // Modify original to ensure clone is independent + original.MaxWorkers = 2 + if cloned.MaxWorkers != 20 { + t.Error("Clone should be independent of original") + } + }) +} + +func TestApplyDefaults(t *testing.T) { + t.Run("NilConfig", func(t *testing.T) { + var config *Config + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // With nil config, should get default config with auto-calculated workers + if result.MaxWorkers <= 0 { + t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers) + } + + // HandoffQueueSize should be auto-calculated with hybrid scaling + workerBasedSize := result.MaxWorkers * 20 + poolSize := 100 // Default pool size used in ApplyDefaults + poolBasedSize := poolSize + expectedQueueSize := util.Max(workerBasedSize, poolBasedSize) + expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d", + expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize) + } + }) + + t.Run("PartialConfig", func(t *testing.T) { + config := &Config{ + MaxWorkers: 60, // Set this field explicitly (> poolSize/2 = 50) + // Leave other fields as zero values + } + + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Should keep the explicitly set values when > poolSize/2 + if result.MaxWorkers != 60 { + t.Errorf("Expected MaxWorkers to be 60 (explicitly set), got %d", result.MaxWorkers) + } + + // Should apply default for unset fields (auto-calculated queue size with hybrid scaling) + workerBasedSize := result.MaxWorkers * 20 + poolSize := 100 // Default pool size used in ApplyDefaults + poolBasedSize := poolSize + expectedQueueSize := util.Max(workerBasedSize, poolBasedSize) + expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d", + expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize) + } + + // Test explicit queue size capping by 5x pool size + configWithLargeQueue := &Config{ + MaxWorkers: 5, + HandoffQueueSize: 1000, // Much larger than 5x pool size + } + + resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size + expectedCap := 20 * 5 // 5x pool size = 100 + if resultCapped.HandoffQueueSize != expectedCap { + t.Errorf("Expected HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedCap, resultCapped.HandoffQueueSize) + } + + // Test explicit queue size minimum enforcement + configWithSmallQueue := &Config{ + MaxWorkers: 5, + HandoffQueueSize: 10, // Below minimum of 200 + } + + resultMinimum := configWithSmallQueue.ApplyDefaultsWithPoolSize(100) // Large pool size + if resultMinimum.HandoffQueueSize != 200 { + t.Errorf("Expected HandoffQueueSize to be enforced minimum (200), got %d", resultMinimum.HandoffQueueSize) + } + + // Test that large explicit values are capped by 5x pool size + configWithVeryLargeQueue := &Config{ + MaxWorkers: 5, + HandoffQueueSize: 1000, // Much larger than 5x pool size + } + + resultVeryLarge := configWithVeryLargeQueue.ApplyDefaultsWithPoolSize(100) // Pool size 100 + expectedVeryLargeCap := 100 * 5 // 5x pool size = 500 + if resultVeryLarge.HandoffQueueSize != expectedVeryLargeCap { + t.Errorf("Expected very large HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedVeryLargeCap, resultVeryLarge.HandoffQueueSize) + } + + if result.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout) + } + + if result.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout) + } + }) + + t.Run("ZeroValues", func(t *testing.T) { + config := &Config{ + MaxWorkers: 0, // Zero value should get auto-calculated defaults + HandoffQueueSize: 0, // Zero value should get default + RelaxedTimeout: 0, // Zero value should get default + LogLevel: 0, // Zero is valid for LogLevel (errors only) + } + + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Zero values should get auto-calculated defaults + if result.MaxWorkers <= 0 { + t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers) + } + + // HandoffQueueSize should be auto-calculated with hybrid scaling + workerBasedSize := result.MaxWorkers * 20 + poolSize := 100 // Default pool size used in ApplyDefaults + poolBasedSize := poolSize + expectedQueueSize := util.Max(workerBasedSize, poolBasedSize) + expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d", + expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize) + } + + if result.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout) + } + + // LogLevel 0 should be preserved (it's a valid value) + if result.LogLevel != 0 { + t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel) + } + }) +} + +func TestProcessorWithConfig(t *testing.T) { + t.Run("ProcessorUsesConfigValues", func(t *testing.T) { + config := &Config{ + MaxWorkers: 5, + HandoffQueueSize: 50, + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 5 * time.Second, + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // The processor should be created successfully with custom config + if processor == nil { + t.Error("Processor should be created with custom config") + } + }) + + t.Run("ProcessorWithPartialConfig", func(t *testing.T) { + config := &Config{ + MaxWorkers: 7, // Only set worker field + // Other fields will get defaults + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Should work with partial config (defaults applied) + if processor == nil { + t.Error("Processor should be created with partial config") + } + }) + + t.Run("ProcessorWithNilConfig", func(t *testing.T) { + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Should use default config when nil is passed + if processor == nil { + t.Error("Processor should be created with nil config (using defaults)") + } + }) +} + +func TestIntegrationWithApplyDefaults(t *testing.T) { + t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) { + // Create a partial config with only some fields set + partialConfig := &Config{ + MaxWorkers: 15, // Custom value (>= 10 to test preservation) + LogLevel: logging.LogLevelInfo, // Custom value + // Other fields left as zero values - should get defaults + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + // Create processor - should apply defaults to missing fields + processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil) + defer processor.Shutdown(context.Background()) + + // Processor should be created successfully + if processor == nil { + t.Error("Processor should be created with partial config") + } + + // Test that the ApplyDefaults method worked correctly by creating the same config + // and applying defaults manually + expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Should preserve custom values (when >= poolSize/2) + if expectedConfig.MaxWorkers != 50 { // max(poolSize/2, 15) = max(50, 15) = 50 + t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers) + } + + if expectedConfig.LogLevel != 2 { + t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel) + } + + // Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling) + workerBasedSize := expectedConfig.MaxWorkers * 20 + poolSize := 100 // Default pool size used in ApplyDefaults + poolBasedSize := poolSize + expectedQueueSize := util.Max(workerBasedSize, poolBasedSize) + expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size + if expectedConfig.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d", + expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, expectedConfig.HandoffQueueSize) + } + + // Test that queue size is always capped by 5x pool size + if expectedConfig.HandoffQueueSize > poolSize*5 { + t.Errorf("HandoffQueueSize (%d) should never exceed 5x pool size (%d)", + expectedConfig.HandoffQueueSize, poolSize*2) + } + + if expectedConfig.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout) + } + + if expectedConfig.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout) + } + + if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second { + t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration) + } + }) +} + +func TestEnhancedConfigValidation(t *testing.T) { + t.Run("ValidateFields", func(t *testing.T) { + config := DefaultConfig() + config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100 + + // Should pass validation with default values + if err := config.Validate(); err != nil { + t.Errorf("Default config should be valid, got error: %v", err) + } + + // Test invalid MaxHandoffRetries + config.MaxHandoffRetries = 0 + if err := config.Validate(); err == nil { + t.Error("Expected validation error for MaxHandoffRetries = 0") + } + config.MaxHandoffRetries = 11 + if err := config.Validate(); err == nil { + t.Error("Expected validation error for MaxHandoffRetries = 11") + } + config.MaxHandoffRetries = 3 // Reset to valid value + + // Test circuit breaker validation + config.CircuitBreakerFailureThreshold = 0 + if err := config.Validate(); err != ErrInvalidCircuitBreakerFailureThreshold { + t.Errorf("Expected ErrInvalidCircuitBreakerFailureThreshold, got %v", err) + } + config.CircuitBreakerFailureThreshold = 5 // Reset to valid value + + config.CircuitBreakerResetTimeout = -1 * time.Second + if err := config.Validate(); err != ErrInvalidCircuitBreakerResetTimeout { + t.Errorf("Expected ErrInvalidCircuitBreakerResetTimeout, got %v", err) + } + config.CircuitBreakerResetTimeout = 60 * time.Second // Reset to valid value + + config.CircuitBreakerMaxRequests = 0 + if err := config.Validate(); err != ErrInvalidCircuitBreakerMaxRequests { + t.Errorf("Expected ErrInvalidCircuitBreakerMaxRequests, got %v", err) + } + config.CircuitBreakerMaxRequests = 3 // Reset to valid value + + // Should pass validation again + if err := config.Validate(); err != nil { + t.Errorf("Config should be valid after reset, got error: %v", err) + } + }) +} + +func TestConfigClone(t *testing.T) { + original := DefaultConfig() + original.MaxHandoffRetries = 7 + original.HandoffTimeout = 8 * time.Second + + cloned := original.Clone() + + // Test that values are copied + if cloned.MaxHandoffRetries != 7 { + t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries) + } + if cloned.HandoffTimeout != 8*time.Second { + t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout) + } + + // Test that modifying clone doesn't affect original + cloned.MaxHandoffRetries = 10 + if original.MaxHandoffRetries != 7 { + t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries) + } +} + +func TestMaxWorkersLogic(t *testing.T) { + t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) { + testCases := []struct { + poolSize int + expectedWorkers int + description string + }{ + {6, 3, "Small pool: min(6/2, max(10, 6/3)) = min(3, max(10, 2)) = min(3, 10) = 3"}, + {15, 7, "Medium pool: min(15/2, max(10, 15/3)) = min(7, max(10, 5)) = min(7, 10) = 7"}, + {30, 10, "Large pool: min(30/2, max(10, 30/3)) = min(15, max(10, 10)) = min(15, 10) = 10"}, + {60, 20, "Very large pool: min(60/2, max(10, 60/3)) = min(30, max(10, 20)) = min(30, 20) = 20"}, + {120, 40, "Huge pool: min(120/2, max(10, 120/3)) = min(60, max(10, 40)) = min(60, 40) = 40"}, + } + + for _, tc := range testCases { + config := &Config{} // MaxWorkers = 0 (not set) + result := config.ApplyDefaultsWithPoolSize(tc.poolSize) + + if result.MaxWorkers != tc.expectedWorkers { + t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)", + tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description) + } + } + }) + + t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) { + testCases := []struct { + setValue int + expectedWorkers int + description string + }{ + {1, 50, "Set 1: max(poolSize/2, 1) = max(50, 1) = 50 (enforced minimum)"}, + {5, 50, "Set 5: max(poolSize/2, 5) = max(50, 5) = 50 (enforced minimum)"}, + {8, 50, "Set 8: max(poolSize/2, 8) = max(50, 8) = 50 (enforced minimum)"}, + {10, 50, "Set 10: max(poolSize/2, 10) = max(50, 10) = 50 (enforced minimum)"}, + {15, 50, "Set 15: max(poolSize/2, 15) = max(50, 15) = 50 (enforced minimum)"}, + {60, 60, "Set 60: max(poolSize/2, 60) = max(50, 60) = 60 (respects user choice)"}, + } + + for _, tc := range testCases { + config := &Config{ + MaxWorkers: tc.setValue, // Explicitly set + } + result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values + + if result.MaxWorkers != tc.expectedWorkers { + t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)", + tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description) + } + } + }) +} diff --git a/hitless/errors.go b/hitless/errors.go new file mode 100644 index 00000000..7f8ab4c7 --- /dev/null +++ b/hitless/errors.go @@ -0,0 +1,105 @@ +package hitless + +import ( + "errors" + "fmt" + "time" +) + +// Configuration errors +var ( + ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0") + ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0") + ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0") + ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0") + ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0") + ErrInvalidLogLevel = errors.New("hitless: log level must be LogLevelError (0), LogLevelWarn (1), LogLevelInfo (2), or LogLevelDebug (3)") + ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type") + ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')") + ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached") + + // Configuration validation errors + ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10") +) + +// Integration errors +var ( + ErrInvalidClient = errors.New("hitless: invalid client type") +) + +// Handoff errors +var ( + ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration") +) + +// Notification errors +var ( + ErrInvalidNotification = errors.New("hitless: invalid notification format") +) + +// connection handoff errors +var ( + // ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff + // and should not be used until the handoff is complete + ErrConnectionMarkedForHandoff = errors.New("hitless: connection marked for handoff") + // ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff + ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff") +) + +// general errors +var ( + ErrShutdown = errors.New("hitless: shutdown") +) + +// circuit breaker errors +var ( + ErrCircuitBreakerOpen = errors.New("hitless: circuit breaker is open, failing fast") +) + +// CircuitBreakerError provides detailed context for circuit breaker failures +type CircuitBreakerError struct { + Endpoint string + State string + Failures int64 + LastFailure time.Time + NextAttempt time.Time + Message string +} + +func (e *CircuitBreakerError) Error() string { + if e.NextAttempt.IsZero() { + return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v): %s", + e.State, e.Endpoint, e.Failures, e.LastFailure, e.Message) + } + return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v, next attempt: %v): %s", + e.State, e.Endpoint, e.Failures, e.LastFailure, e.NextAttempt, e.Message) +} + +// HandoffError provides detailed context for connection handoff failures +type HandoffError struct { + ConnectionID uint64 + SourceEndpoint string + TargetEndpoint string + Attempt int + MaxAttempts int + Duration time.Duration + FinalError error + Message string +} + +func (e *HandoffError) Error() string { + return fmt.Sprintf("hitless: handoff failed for conn[%d] %s→%s (attempt %d/%d, duration: %v): %s", + e.ConnectionID, e.SourceEndpoint, e.TargetEndpoint, + e.Attempt, e.MaxAttempts, e.Duration, e.Message) +} + +func (e *HandoffError) Unwrap() error { + return e.FinalError +} + +// circuit breaker configuration errors +var ( + ErrInvalidCircuitBreakerFailureThreshold = errors.New("hitless: circuit breaker failure threshold must be >= 1") + ErrInvalidCircuitBreakerResetTimeout = errors.New("hitless: circuit breaker reset timeout must be >= 0") + ErrInvalidCircuitBreakerMaxRequests = errors.New("hitless: circuit breaker max requests must be >= 1") +) diff --git a/hitless/example_hooks.go b/hitless/example_hooks.go new file mode 100644 index 00000000..54e28b3c --- /dev/null +++ b/hitless/example_hooks.go @@ -0,0 +1,100 @@ +package hitless + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + startTimeKey contextKey = "notif_hitless_start_time" +) + +// MetricsHook collects metrics about notification processing. +type MetricsHook struct { + NotificationCounts map[string]int64 + ProcessingTimes map[string]time.Duration + ErrorCounts map[string]int64 + HandoffCounts int64 // Total handoffs initiated + HandoffSuccesses int64 // Successful handoffs + HandoffFailures int64 // Failed handoffs +} + +// NewMetricsHook creates a new metrics collection hook. +func NewMetricsHook() *MetricsHook { + return &MetricsHook{ + NotificationCounts: make(map[string]int64), + ProcessingTimes: make(map[string]time.Duration), + ErrorCounts: make(map[string]int64), + } +} + +// PreHook records the start time for processing metrics. +func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + mh.NotificationCounts[notificationType]++ + + // Log connection information if available + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on conn[%d]", notificationType, conn.GetID()) + } + + // Store start time in context for duration calculation + startTime := time.Now() + _ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further + + return notification, true +} + +// PostHook records processing completion and any errors. +func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + // Calculate processing duration + if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok { + duration := time.Since(startTime) + mh.ProcessingTimes[notificationType] = duration + } + + // Record errors + if result != nil { + mh.ErrorCounts[notificationType]++ + + // Log error details with connection information + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on conn[%d]: %v", notificationType, conn.GetID(), result) + } + } +} + +// GetMetrics returns a summary of collected metrics. +func (mh *MetricsHook) GetMetrics() map[string]interface{} { + return map[string]interface{}{ + "notification_counts": mh.NotificationCounts, + "processing_times": mh.ProcessingTimes, + "error_counts": mh.ErrorCounts, + } +} + +// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status +func ExampleCircuitBreakerMonitor(poolHook *PoolHook) { + // Get circuit breaker statistics + stats := poolHook.GetCircuitBreakerStats() + + for _, stat := range stats { + fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint) + fmt.Printf(" State: %s\n", stat.State) + fmt.Printf(" Failures: %d\n", stat.Failures) + fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime) + fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime) + + // Alert if circuit breaker is open + if stat.State.String() == "open" { + fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint) + } + } +} diff --git a/hitless/handoff_worker.go b/hitless/handoff_worker.go new file mode 100644 index 00000000..ae22b684 --- /dev/null +++ b/hitless/handoff_worker.go @@ -0,0 +1,455 @@ +package hitless + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// handoffWorkerManager manages background workers and queue for connection handoffs +type handoffWorkerManager struct { + // Event-driven handoff support + handoffQueue chan HandoffRequest // Queue for handoff requests + shutdown chan struct{} // Shutdown signal + shutdownOnce sync.Once // Ensure clean shutdown + workerWg sync.WaitGroup // Track worker goroutines + + // On-demand worker management + maxWorkers int + activeWorkers atomic.Int32 + workerTimeout time.Duration // How long workers wait for work before exiting + workersScaling atomic.Bool + + // Simple state tracking + pending sync.Map // map[uint64]int64 (connID -> seqID) + + // Configuration for the hitless upgrade + config *Config + + // Pool hook reference for handoff processing + poolHook *PoolHook + + // Circuit breaker manager for endpoint failure handling + circuitBreakerManager *CircuitBreakerManager +} + +// newHandoffWorkerManager creates a new handoff worker manager +func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager { + return &handoffWorkerManager{ + handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize), + shutdown: make(chan struct{}), + maxWorkers: config.MaxWorkers, + activeWorkers: atomic.Int32{}, // Start with no workers - create on demand + workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity + config: config, + poolHook: poolHook, + circuitBreakerManager: newCircuitBreakerManager(config), + } +} + +// getCurrentWorkers returns the current number of active workers (for testing) +func (hwm *handoffWorkerManager) getCurrentWorkers() int { + return int(hwm.activeWorkers.Load()) +} + +// getPendingMap returns the pending map for testing purposes +func (hwm *handoffWorkerManager) getPendingMap() *sync.Map { + return &hwm.pending +} + +// getMaxWorkers returns the max workers for testing purposes +func (hwm *handoffWorkerManager) getMaxWorkers() int { + return hwm.maxWorkers +} + +// getHandoffQueue returns the handoff queue for testing purposes +func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest { + return hwm.handoffQueue +} + +// getCircuitBreakerStats returns circuit breaker statistics for monitoring +func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats { + return hwm.circuitBreakerManager.GetAllStats() +} + +// resetCircuitBreakers resets all circuit breakers (useful for testing) +func (hwm *handoffWorkerManager) resetCircuitBreakers() { + hwm.circuitBreakerManager.Reset() +} + +// isHandoffPending returns true if the given connection has a pending handoff +func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool { + _, pending := hwm.pending.Load(conn.GetID()) + return pending +} + +// ensureWorkerAvailable ensures at least one worker is available to process requests +// Creates a new worker if needed and under the max limit +func (hwm *handoffWorkerManager) ensureWorkerAvailable() { + select { + case <-hwm.shutdown: + return + default: + if hwm.workersScaling.CompareAndSwap(false, true) { + defer hwm.workersScaling.Store(false) + // Check if we need a new worker + currentWorkers := hwm.activeWorkers.Load() + workersWas := currentWorkers + for currentWorkers < int32(hwm.maxWorkers) { + hwm.workerWg.Add(1) + go hwm.onDemandWorker() + currentWorkers++ + } + // workersWas is always <= currentWorkers + // currentWorkers will be maxWorkers, but if we have a worker that was closed + // while we were creating new workers, just add the difference between + // the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created) + hwm.activeWorkers.Add(currentWorkers - workersWas) + } + } +} + +// onDemandWorker processes handoff requests and exits when idle +func (hwm *handoffWorkerManager) onDemandWorker() { + defer func() { + // Decrement active worker count when exiting + hwm.activeWorkers.Add(-1) + hwm.workerWg.Done() + }() + + // Create reusable timer to prevent timer leaks + timer := time.NewTimer(hwm.workerTimeout) + defer timer.Stop() + + for { + // Reset timer for next iteration + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(hwm.workerTimeout) + + select { + case <-hwm.shutdown: + return + case <-timer.C: + // Worker has been idle for too long, exit to save resources + if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout) + } + return + case request := <-hwm.handoffQueue: + // Check for shutdown before processing + select { + case <-hwm.shutdown: + // Clean up the request before exiting + hwm.pending.Delete(request.ConnID) + return + default: + // Process the request + hwm.processHandoffRequest(request) + } + } + } +} + +// processHandoffRequest processes a single handoff request +func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { + // Remove from pending map + defer hwm.pending.Delete(request.Conn.GetID()) + internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID()) + + // Create a context with handoff timeout from config + handoffTimeout := 15 * time.Second // Default timeout + if hwm.config != nil && hwm.config.HandoffTimeout > 0 { + handoffTimeout = hwm.config.HandoffTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout) + defer cancel() + + // Create a context that also respects the shutdown signal + shutdownCtx, shutdownCancel := context.WithCancel(ctx) + defer shutdownCancel() + + // Monitor shutdown signal in a separate goroutine + go func() { + select { + case <-hwm.shutdown: + shutdownCancel() + case <-shutdownCtx.Done(): + } + }() + + // Perform the handoff with cancellable context + shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn) + minRetryBackoff := 500 * time.Millisecond + if err != nil { + if shouldRetry { + now := time.Now() + deadline, ok := shutdownCtx.Deadline() + thirdOfTimeout := handoffTimeout / 3 + if !ok || deadline.Before(now) { + // wait half the timeout before retrying if no deadline or deadline has passed + deadline = now.Add(thirdOfTimeout) + } + afterTime := deadline.Sub(now) + if afterTime < minRetryBackoff { + afterTime = minRetryBackoff + } + + internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err) + time.AfterFunc(afterTime, func() { + if err := hwm.queueHandoff(request.Conn); err != nil { + internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err) + hwm.closeConnFromRequest(context.Background(), request, err) + } + }) + return + } else { + go hwm.closeConnFromRequest(ctx, request, err) + } + + // Clear handoff state if not returned for retry + seqID := request.Conn.GetMovingSeqID() + connID := request.Conn.GetID() + if hwm.poolHook.hitlessManager != nil { + hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID) + } + } +} + +// queueHandoff queues a handoff request for processing +// if err is returned, connection will be removed from pool +func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { + // Create handoff request + request := HandoffRequest{ + Conn: conn, + ConnID: conn.GetID(), + Endpoint: conn.GetHandoffEndpoint(), + SeqID: conn.GetMovingSeqID(), + Pool: hwm.poolHook.pool, // Include pool for connection removal on failure + } + + select { + // priority to shutdown + case <-hwm.shutdown: + return ErrShutdown + default: + select { + case <-hwm.shutdown: + return ErrShutdown + case hwm.handoffQueue <- request: + // Store in pending map + hwm.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + hwm.ensureWorkerAvailable() + return nil + default: + select { + case <-hwm.shutdown: + return ErrShutdown + case hwm.handoffQueue <- request: + // Store in pending map + hwm.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + hwm.ensureWorkerAvailable() + return nil + case <-time.After(100 * time.Millisecond): // give workers a chance to process + // Queue is full - log and attempt scaling + queueLen := len(hwm.handoffQueue) + queueCap := cap(hwm.handoffQueue) + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(context.Background(), + "hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", + queueLen, queueCap) + } + } + } + } + + // Ensure we have workers available to handle the load + hwm.ensureWorkerAvailable() + return ErrHandoffQueueFull +} + +// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete +func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error { + hwm.shutdownOnce.Do(func() { + close(hwm.shutdown) + // workers will exit when they finish their current request + + // Shutdown circuit breaker manager cleanup goroutine + if hwm.circuitBreakerManager != nil { + hwm.circuitBreakerManager.Shutdown() + } + }) + + // Wait for workers to complete + done := make(chan struct{}) + go func() { + hwm.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// performConnectionHandoff performs the actual connection handoff +// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached +func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) { + // Clear handoff state after successful handoff + connID := conn.GetID() + + newEndpoint := conn.GetHandoffEndpoint() + if newEndpoint == "" { + return false, ErrConnectionInvalidHandoffState + } + + // Use circuit breaker to protect against failing endpoints + circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint) + + // Check if circuit breaker is open before attempting handoff + if circuitBreaker.IsOpen() { + internal.Logger.Printf(ctx, "hitless: conn[%d] handoff to %s failed fast due to circuit breaker", connID, newEndpoint) + return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open + } + + // Perform the handoff + shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID) + + // Update circuit breaker based on result + if err != nil { + // Only track dial/network errors in circuit breaker, not initialization errors + if shouldRetry { + circuitBreaker.recordFailure() + } + return shouldRetry, err + } + + // Success - record in circuit breaker + circuitBreaker.recordSuccess() + return false, nil +} + +// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration) +func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) { + + retries := conn.IncrementAndGetHandoffRetries(1) + internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", connID, retries, newEndpoint, conn.RemoteAddr().String()) + maxRetries := 3 // Default fallback + if hwm.config != nil { + maxRetries = hwm.config.MaxHandoffRetries + } + + if retries > maxRetries { + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, + "hitless: reached max retries (%d) for handoff of conn[%d] to %s", + maxRetries, connID, newEndpoint) + } + // won't retry on ErrMaxHandoffRetriesReached + return false, ErrMaxHandoffRetriesReached + } + + // Create endpoint-specific dialer + endpointDialer := hwm.createEndpointDialer(newEndpoint) + + // Create new connection to the new endpoint + newNetConn, err := endpointDialer(ctx) + if err != nil { + internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", connID, newEndpoint, err) + // hitless: will retry + // Maybe a network error - retry after a delay + return true, err + } + + // Get the old connection + oldConn := conn.GetNetConn() + + // Apply relaxed timeout to the new connection for the configured post-handoff duration + // This gives the new connection more time to handle operations during cluster transition + // Setting this here (before initing the connection) ensures that the connection is going + // to use the relaxed timeout for the first operation (auth/ACL select) + if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 { + relaxedTimeout := hwm.config.RelaxedTimeout + // Set relaxed timeout with deadline - no background goroutine needed + deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration) + conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) + + if hwm.config.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), + "hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v", + connID, relaxedTimeout, deadline.Format("15:04:05.000")) + } + } + + // Replace the connection and execute initialization + err = conn.SetNetConnAndInitConn(ctx, newNetConn) + if err != nil { + // hitless: won't retry + // Initialization failed - remove the connection + return false, err + } + defer func() { + if oldConn != nil { + oldConn.Close() + } + }() + + conn.ClearHandoffState() + internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", connID, newEndpoint) + + return false, nil +} + +// createEndpointDialer creates a dialer function that connects to a specific endpoint +func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + // Parse endpoint to extract host and port + host, port, err := net.SplitHostPort(endpoint) + if err != nil { + // If no port specified, assume default Redis port + host = endpoint + if port == "" { + port = "6379" + } + } + + // Use the base dialer to connect to the new endpoint + return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port)) + } +} + +// closeConnFromRequest closes the connection and logs the reason +func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) { + pooler := request.Pool + conn := request.Conn + if pooler != nil { + pooler.Remove(ctx, conn, err) + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed conn[%d] from pool due: %v", + conn.GetID(), err) + } + } else { + conn.Close() + if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, + "hitless: no pool provided for conn[%d], cannot remove due to: %v", + conn.GetID(), err) + } + } +} diff --git a/hitless/hitless_manager.go b/hitless/hitless_manager.go new file mode 100644 index 00000000..bb0c35d8 --- /dev/null +++ b/hitless/hitless_manager.go @@ -0,0 +1,318 @@ +package hitless + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// Push notification type constants for hitless upgrades +const ( + NotificationMoving = "MOVING" + NotificationMigrating = "MIGRATING" + NotificationMigrated = "MIGRATED" + NotificationFailingOver = "FAILING_OVER" + NotificationFailedOver = "FAILED_OVER" +) + +// hitlessNotificationTypes contains all notification types that hitless upgrades handles +var hitlessNotificationTypes = []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, +} + +// NotificationHook is called before and after notification processing +// PreHook can modify the notification and return false to skip processing +// PostHook is called after successful processing +type NotificationHook interface { + PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) + PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) +} + +// MovingOperationKey provides a unique key for tracking MOVING operations +// that combines sequence ID with connection identifier to handle duplicate +// sequence IDs across multiple connections to the same node. +type MovingOperationKey struct { + SeqID int64 // Sequence ID from MOVING notification + ConnID uint64 // Unique connection identifier +} + +// String returns a string representation of the key for debugging +func (k MovingOperationKey) String() string { + return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID) +} + +// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state. +type HitlessManager struct { + client interfaces.ClientInterface + config *Config + options interfaces.OptionsInterface + pool pool.Pooler + + // MOVING operation tracking - using sync.Map for better concurrent performance + activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation + + // Atomic state tracking - no locks needed for state queries + activeOperationCount atomic.Int64 // Number of active operations + closed atomic.Bool // Manager closed state + + // Notification hooks for extensibility + hooks []NotificationHook + hooksMu sync.RWMutex // Protects hooks slice + poolHooksRef *PoolHook +} + +// MovingOperation tracks an active MOVING operation. +type MovingOperation struct { + SeqID int64 + NewEndpoint string + StartTime time.Time + Deadline time.Time +} + +// NewHitlessManager creates a new simplified hitless manager. +func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) { + if client == nil { + return nil, ErrInvalidClient + } + + hm := &HitlessManager{ + client: client, + pool: pool, + options: client.GetOptions(), + config: config.Clone(), + hooks: make([]NotificationHook, 0), + } + + // Set up push notification handling + if err := hm.setupPushNotifications(); err != nil { + return nil, err + } + + return hm, nil +} + +// GetPoolHook creates a pool hook with a custom dialer. +func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) { + poolHook := hm.createPoolHook(baseDialer) + hm.pool.AddPoolHook(poolHook) +} + +// setupPushNotifications sets up push notification handling by registering with the client's processor. +func (hm *HitlessManager) setupPushNotifications() error { + processor := hm.client.GetPushProcessor() + if processor == nil { + return ErrInvalidClient // Client doesn't support push notifications + } + + // Create our notification handler + handler := &NotificationHandler{manager: hm} + + // Register handlers for all hitless upgrade notifications with the client's processor + for _, notificationType := range hitlessNotificationTypes { + if err := processor.RegisterHandler(notificationType, handler, true); err != nil { + return fmt.Errorf("failed to register handler for %s: %w", notificationType, err) + } + } + + return nil +} + +// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID. +func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error { + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Create MOVING operation record + movingOp := &MovingOperation{ + SeqID: seqID, + NewEndpoint: newEndpoint, + StartTime: time.Now(), + Deadline: deadline, + } + + // Use LoadOrStore for atomic check-and-set operation + if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded { + // Duplicate MOVING notification, ignore + if hm.config.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Duplicate MOVING operation ignored: %s", connID, seqID, key.String()) + } + return nil + } + if hm.config.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Tracking MOVING operation: %s", connID, seqID, key.String()) + } + + // Increment active operation count atomically + hm.activeOperationCount.Add(1) + + return nil +} + +// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID. +func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) { + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Remove from active operations atomically + if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded { + if hm.config.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Untracking MOVING operation: %s", connID, seqID, key.String()) + } + // Decrement active operation count only if operation existed + hm.activeOperationCount.Add(-1) + } else { + if hm.config.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Operation not found for untracking: %s", connID, seqID, key.String()) + } + } +} + +// GetActiveMovingOperations returns active operations with composite keys. +// WARNING: This method creates a new map and copies all operations on every call. +// Use sparingly, especially in hot paths or high-frequency logging. +func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation { + result := make(map[MovingOperationKey]*MovingOperation) + + // Iterate over sync.Map to build result + hm.activeMovingOps.Range(func(key, value interface{}) bool { + k := key.(MovingOperationKey) + op := value.(*MovingOperation) + + // Create a copy to avoid sharing references + result[k] = &MovingOperation{ + SeqID: op.SeqID, + NewEndpoint: op.NewEndpoint, + StartTime: op.StartTime, + Deadline: op.Deadline, + } + return true // Continue iteration + }) + + return result +} + +// IsHandoffInProgress returns true if any handoff is in progress. +// Uses atomic counter for lock-free operation. +func (hm *HitlessManager) IsHandoffInProgress() bool { + return hm.activeOperationCount.Load() > 0 +} + +// GetActiveOperationCount returns the number of active operations. +// Uses atomic counter for lock-free operation. +func (hm *HitlessManager) GetActiveOperationCount() int64 { + return hm.activeOperationCount.Load() +} + +// Close closes the hitless manager. +func (hm *HitlessManager) Close() error { + // Use atomic operation for thread-safe close check + if !hm.closed.CompareAndSwap(false, true) { + return nil // Already closed + } + + // Shutdown the pool hook if it exists + if hm.poolHooksRef != nil { + // Use a timeout to prevent hanging indefinitely + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := hm.poolHooksRef.Shutdown(shutdownCtx) + if err != nil { + // was not able to close pool hook, keep closed state false + hm.closed.Store(false) + return err + } + // Remove the pool hook from the pool + if hm.pool != nil { + hm.pool.RemovePoolHook(hm.poolHooksRef) + } + } + + // Clear all active operations + hm.activeMovingOps.Range(func(key, value interface{}) bool { + hm.activeMovingOps.Delete(key) + return true + }) + + // Reset counter + hm.activeOperationCount.Store(0) + + return nil +} + +// GetState returns current state using atomic counter for lock-free operation. +func (hm *HitlessManager) GetState() State { + if hm.activeOperationCount.Load() > 0 { + return StateMoving + } + return StateIdle +} + +// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing. +func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + hm.hooksMu.RLock() + defer hm.hooksMu.RUnlock() + + currentNotification := notification + + for _, hook := range hm.hooks { + modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification) + if !shouldContinue { + return modifiedNotification, false + } + currentNotification = modifiedNotification + } + + return currentNotification, true +} + +// processPostHooks calls all post-hooks with the processing result. +func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + hm.hooksMu.RLock() + defer hm.hooksMu.RUnlock() + + for _, hook := range hm.hooks { + hook.PostHook(ctx, notificationCtx, notificationType, notification, result) + } +} + +// createPoolHook creates a pool hook with this manager already set. +func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook { + if hm.poolHooksRef != nil { + return hm.poolHooksRef + } + // Get pool size from client options for better worker defaults + poolSize := 0 + if hm.options != nil { + poolSize = hm.options.GetPoolSize() + } + + hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize) + hm.poolHooksRef.SetPool(hm.pool) + + return hm.poolHooksRef +} + +func (hm *HitlessManager) AddNotificationHook(notificationHook NotificationHook) { + hm.hooksMu.Lock() + defer hm.hooksMu.Unlock() + hm.hooks = append(hm.hooks, notificationHook) +} diff --git a/hitless/hitless_manager_test.go b/hitless/hitless_manager_test.go new file mode 100644 index 00000000..b1f55bf3 --- /dev/null +++ b/hitless/hitless_manager_test.go @@ -0,0 +1,260 @@ +package hitless + +import ( + "context" + "net" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" +) + +// MockClient implements interfaces.ClientInterface for testing +type MockClient struct { + options interfaces.OptionsInterface +} + +func (mc *MockClient) GetOptions() interfaces.OptionsInterface { + return mc.options +} + +func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor { + return &MockPushProcessor{} +} + +// MockPushProcessor implements interfaces.NotificationProcessor for testing +type MockPushProcessor struct{} + +func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error { + return nil +} + +func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error { + return nil +} + +func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} { + return nil +} + +// MockOptions implements interfaces.OptionsInterface for testing +type MockOptions struct{} + +func (mo *MockOptions) GetReadTimeout() time.Duration { + return 5 * time.Second +} + +func (mo *MockOptions) GetWriteTimeout() time.Duration { + return 5 * time.Second +} + +func (mo *MockOptions) GetAddr() string { + return "localhost:6379" +} + +func (mo *MockOptions) IsTLSEnabled() bool { + return false +} + +func (mo *MockOptions) GetProtocol() int { + return 3 // RESP3 +} + +func (mo *MockOptions) GetPoolSize() int { + return 10 +} + +func (mo *MockOptions) GetNetwork() string { + return "tcp" +} + +func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + return nil, nil + } +} + +func TestHitlessManagerRefactoring(t *testing.T) { + t.Run("AtomicStateTracking", func(t *testing.T) { + config := DefaultConfig() + client := &MockClient{options: &MockOptions{}} + + manager, err := NewHitlessManager(client, nil, config) + if err != nil { + t.Fatalf("Failed to create hitless manager: %v", err) + } + defer manager.Close() + + // Test initial state + if manager.IsHandoffInProgress() { + t.Error("Expected no handoff in progress initially") + } + + if manager.GetActiveOperationCount() != 0 { + t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount()) + } + + if manager.GetState() != StateIdle { + t.Errorf("Expected StateIdle, got %v", manager.GetState()) + } + + // Add an operation + ctx := context.Background() + deadline := time.Now().Add(30 * time.Second) + err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1) + if err != nil { + t.Fatalf("Failed to track operation: %v", err) + } + + // Test state after adding operation + if !manager.IsHandoffInProgress() { + t.Error("Expected handoff in progress after adding operation") + } + + if manager.GetActiveOperationCount() != 1 { + t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount()) + } + + if manager.GetState() != StateMoving { + t.Errorf("Expected StateMoving, got %v", manager.GetState()) + } + + // Remove the operation + manager.UntrackOperationWithConnID(12345, 1) + + // Test state after removing operation + if manager.IsHandoffInProgress() { + t.Error("Expected no handoff in progress after removing operation") + } + + if manager.GetActiveOperationCount() != 0 { + t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount()) + } + + if manager.GetState() != StateIdle { + t.Errorf("Expected StateIdle, got %v", manager.GetState()) + } + }) + + t.Run("SyncMapPerformance", func(t *testing.T) { + config := DefaultConfig() + client := &MockClient{options: &MockOptions{}} + + manager, err := NewHitlessManager(client, nil, config) + if err != nil { + t.Fatalf("Failed to create hitless manager: %v", err) + } + defer manager.Close() + + ctx := context.Background() + deadline := time.Now().Add(30 * time.Second) + + // Test concurrent operations + const numOps = 100 + for i := 0; i < numOps; i++ { + err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i)) + if err != nil { + t.Fatalf("Failed to track operation %d: %v", i, err) + } + } + + if manager.GetActiveOperationCount() != numOps { + t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount()) + } + + // Test GetActiveMovingOperations + operations := manager.GetActiveMovingOperations() + if len(operations) != numOps { + t.Errorf("Expected %d operations in map, got %d", numOps, len(operations)) + } + + // Remove all operations + for i := 0; i < numOps; i++ { + manager.UntrackOperationWithConnID(int64(i), uint64(i)) + } + + if manager.GetActiveOperationCount() != 0 { + t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount()) + } + }) + + t.Run("DuplicateOperationHandling", func(t *testing.T) { + config := DefaultConfig() + client := &MockClient{options: &MockOptions{}} + + manager, err := NewHitlessManager(client, nil, config) + if err != nil { + t.Fatalf("Failed to create hitless manager: %v", err) + } + defer manager.Close() + + ctx := context.Background() + deadline := time.Now().Add(30 * time.Second) + + // Add operation + err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1) + if err != nil { + t.Fatalf("Failed to track operation: %v", err) + } + + // Try to add duplicate operation + err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1) + if err != nil { + t.Fatalf("Duplicate operation should not return error: %v", err) + } + + // Should still have only 1 operation + if manager.GetActiveOperationCount() != 1 { + t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount()) + } + }) + + t.Run("NotificationTypeConstants", func(t *testing.T) { + // Test that constants are properly defined + expectedTypes := []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, + } + + if len(hitlessNotificationTypes) != len(expectedTypes) { + t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes)) + } + + // Test that all expected types are present + typeMap := make(map[string]bool) + for _, t := range hitlessNotificationTypes { + typeMap[t] = true + } + + for _, expected := range expectedTypes { + if !typeMap[expected] { + t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected) + } + } + + // Test that hitlessNotificationTypes contains all expected constants + expectedConstants := []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, + } + + for _, expected := range expectedConstants { + found := false + for _, actual := range hitlessNotificationTypes { + if actual == expected { + found = true + break + } + } + if !found { + t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected) + } + } + }) +} diff --git a/hitless/hooks.go b/hitless/hooks.go new file mode 100644 index 00000000..24d4fc34 --- /dev/null +++ b/hitless/hooks.go @@ -0,0 +1,47 @@ +package hitless + +import ( + "context" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" + "github.com/redis/go-redis/v9/push" +) + +// LoggingHook is an example hook implementation that logs all notifications. +type LoggingHook struct { + LogLevel logging.LogLevel +} + +// PreHook logs the notification before processing and allows modification. +func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + if lh.LogLevel.InfoOrAbove() { // Info level + // Log the notification type and content + connID := uint64(0) + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID = conn.GetID() + } + internal.Logger.Printf(ctx, "hitless: conn[%d] processing %s notification: %v", connID, notificationType, notification) + } + return notification, true // Continue processing with unmodified notification +} + +// PostHook logs the result after processing. +func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + connID := uint64(0) + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID = conn.GetID() + } + if result != nil && lh.LogLevel.WarnOrAbove() { // Warning level + internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processing failed: %v - %v", connID, notificationType, result, notification) + } else if lh.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processed successfully", connID, notificationType) + } +} + +// NewLoggingHook creates a new logging hook with the specified log level. +// Log levels: LogLevelError=errors, LogLevelWarn=warnings, LogLevelInfo=info, LogLevelDebug=debug +func NewLoggingHook(logLevel logging.LogLevel) *LoggingHook { + return &LoggingHook{LogLevel: logLevel} +} diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go new file mode 100644 index 00000000..b530dce0 --- /dev/null +++ b/hitless/pool_hook.go @@ -0,0 +1,179 @@ +package hitless + +import ( + "context" + "net" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// HitlessManagerInterface defines the interface for completing handoff operations +type HitlessManagerInterface interface { + TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error + UntrackOperationWithConnID(seqID int64, connID uint64) +} + +// HandoffRequest represents a request to handoff a connection to a new endpoint +type HandoffRequest struct { + Conn *pool.Conn + ConnID uint64 // Unique connection identifier + Endpoint string + SeqID int64 + Pool pool.Pooler // Pool to remove connection from on failure +} + +// PoolHook implements pool.PoolHook for Redis-specific connection handling +// with hitless upgrade support. +type PoolHook struct { + // Base dialer for creating connections to new endpoints during handoffs + // args are network and address + baseDialer func(context.Context, string, string) (net.Conn, error) + + // Network type (e.g., "tcp", "unix") + network string + + // Worker manager for background handoff processing + workerManager *handoffWorkerManager + + // Configuration for the hitless upgrade + config *Config + + // Hitless manager for operation completion tracking + hitlessManager HitlessManagerInterface + + // Pool interface for removing connections on handoff failure + pool pool.Pooler +} + +// NewPoolHook creates a new pool hook +func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook { + return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0) +} + +// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults +func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook { + // Apply defaults if config is nil or has zero values + if config == nil { + config = config.ApplyDefaultsWithPoolSize(poolSize) + } + + ph := &PoolHook{ + // baseDialer is used to create connections to new endpoints during handoffs + baseDialer: baseDialer, + network: network, + config: config, + // Hitless manager for operation completion tracking + hitlessManager: hitlessManager, + } + + // Create worker manager + ph.workerManager = newHandoffWorkerManager(config, ph) + + return ph +} + +// SetPool sets the pool interface for removing connections on handoff failure +func (ph *PoolHook) SetPool(pooler pool.Pooler) { + ph.pool = pooler +} + +// GetCurrentWorkers returns the current number of active workers (for testing) +func (ph *PoolHook) GetCurrentWorkers() int { + return ph.workerManager.getCurrentWorkers() +} + +// IsHandoffPending returns true if the given connection has a pending handoff +func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool { + return ph.workerManager.isHandoffPending(conn) +} + +// GetPendingMap returns the pending map for testing purposes +func (ph *PoolHook) GetPendingMap() *sync.Map { + return ph.workerManager.getPendingMap() +} + +// GetMaxWorkers returns the max workers for testing purposes +func (ph *PoolHook) GetMaxWorkers() int { + return ph.workerManager.getMaxWorkers() +} + +// GetHandoffQueue returns the handoff queue for testing purposes +func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest { + return ph.workerManager.getHandoffQueue() +} + +// GetCircuitBreakerStats returns circuit breaker statistics for monitoring +func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats { + return ph.workerManager.getCircuitBreakerStats() +} + +// ResetCircuitBreakers resets all circuit breakers (useful for testing) +func (ph *PoolHook) ResetCircuitBreakers() { + ph.workerManager.resetCircuitBreakers() +} + +// OnGet is called when a connection is retrieved from the pool +func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error { + // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is + // in a handoff state at the moment. + + // Check if connection is usable (not in a handoff state) + // Should not happen since the pool will not return a connection that is not usable. + if !conn.IsUsable() { + return ErrConnectionMarkedForHandoff + } + + // Check if connection is marked for handoff, which means it will be queued for handoff on put. + if conn.ShouldHandoff() { + return ErrConnectionMarkedForHandoff + } + + return nil +} + +// OnPut is called when a connection is returned to the pool +func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) { + // first check if we should handoff for faster rejection + if !conn.ShouldHandoff() { + // Default behavior (no handoff): pool the connection + return true, false, nil + } + + // check pending handoff to not queue the same connection twice + if ph.workerManager.isHandoffPending(conn) { + // Default behavior (pending handoff): pool the connection + return true, false, nil + } + + if err := ph.workerManager.queueHandoff(conn); err != nil { + // Failed to queue handoff, remove the connection + internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err) + // Don't pool, remove connection, no error to caller + return false, true, nil + } + + // Check if handoff was already processed by a worker before we can mark it as queued + if !conn.ShouldHandoff() { + // Handoff was already processed - this is normal and the connection should be pooled + return true, false, nil + } + + if err := conn.MarkQueuedForHandoff(); err != nil { + // If marking fails, check if handoff was processed in the meantime + if !conn.ShouldHandoff() { + // Handoff was processed - this is normal, pool the connection + return true, false, nil + } + // Other error - remove the connection + return false, true, nil + } + return true, false, nil +} + +// Shutdown gracefully shuts down the processor, waiting for workers to complete +func (ph *PoolHook) Shutdown(ctx context.Context) error { + return ph.workerManager.shutdownWorkers(ctx) +} diff --git a/hitless/pool_hook_test.go b/hitless/pool_hook_test.go new file mode 100644 index 00000000..6f84002e --- /dev/null +++ b/hitless/pool_hook_test.go @@ -0,0 +1,964 @@ +package hitless + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// mockNetConn implements net.Conn for testing +type mockNetConn struct { + addr string + shouldFailInit bool +} + +func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockAddr struct { + addr string +} + +func (m *mockAddr) Network() string { return "tcp" } +func (m *mockAddr) String() string { return m.addr } + +// createMockPoolConnection creates a mock pool connection for testing +func createMockPoolConnection() *pool.Conn { + mockNetConn := &mockNetConn{addr: "test:6379"} + conn := pool.NewConn(mockNetConn) + conn.SetUsable(true) // Make connection usable for testing + return conn +} + +// mockPool implements pool.Pooler for testing +type mockPool struct { + removedConnections map[uint64]bool + mu sync.Mutex +} + +func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) { + return nil, errors.New("not implemented") +} + +func (mp *mockPool) CloseConn(conn *pool.Conn) error { + return nil +} + +func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) { + return nil, errors.New("not implemented") +} + +func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) { + // Not implemented for testing +} + +func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) { + mp.mu.Lock() + defer mp.mu.Unlock() + + // Use pool.Conn directly - no adapter needed + mp.removedConnections[conn.GetID()] = true +} + +// WasRemoved safely checks if a connection was removed from the pool +func (mp *mockPool) WasRemoved(connID uint64) bool { + mp.mu.Lock() + defer mp.mu.Unlock() + return mp.removedConnections[connID] +} + +func (mp *mockPool) Len() int { + return 0 +} + +func (mp *mockPool) IdleLen() int { + return 0 +} + +func (mp *mockPool) Stats() *pool.Stats { + return &pool.Stats{} +} + +func (mp *mockPool) AddPoolHook(hook pool.PoolHook) { + // Mock implementation - do nothing +} + +func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) { + // Mock implementation - do nothing +} + +func (mp *mockPool) Close() error { + return nil +} + +// TestConnectionHook tests the Redis connection processor functionality +func TestConnectionHook(t *testing.T) { + // Create a base dialer for testing + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) { + config := &Config{ + Mode: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MaxWorkers: 1, // Use only 1 worker to ensure synchronization + HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue + MaxHandoffRetries: 3, + LogLevel: 2, + } + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Verify connection is marked for handoff + if !conn.ShouldHandoff() { + t.Fatal("Connection should be marked for handoff") + } + // Set a mock initialization function with synchronization + initConnCalled := make(chan bool, 1) + proceedWithInit := make(chan bool, 1) + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + select { + case initConnCalled <- true: + default: + } + // Wait for test to proceed + <-proceedWithInit + return nil + } + conn.SetInitConnFunc(initConnFunc) + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error: %v", err) + } + + // Should pool the connection immediately (handoff queued) + if !shouldPool { + t.Error("Connection should be pooled immediately with event-driven handoff") + } + if shouldRemove { + t.Error("Connection should not be removed when queuing handoff") + } + + // Wait for initialization to be called (indicates handoff started) + select { + case <-initConnCalled: + // Good, initialization was called + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for initialization function to be called") + } + + // Connection should be in pending map while initialization is blocked + if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { + t.Error("Connection should be in pending handoffs map") + } + + // Allow initialization to proceed + proceedWithInit <- true + + // Wait for handoff to complete with proper timeout and polling + timeout := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.GetPendingMap().Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify handoff completed (removed from pending map) + if _, pending := processor.GetPendingMap().Load(conn); pending { + t.Error("Connection should be removed from pending map after handoff") + } + + // Verify connection is usable again + if !conn.IsUsable() { + t.Error("Connection should be usable after successful handoff") + } + + // Verify handoff state is cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after completion") + } + }) + + t.Run("HandoffNotNeeded", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + // Don't mark for handoff + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error when handoff not needed: %v", err) + } + + // Should pool the connection normally + if !shouldPool { + t.Error("Connection should be pooled when no handoff needed") + } + if shouldRemove { + t.Error("Connection should not be removed when no handoff needed") + } + }) + + t.Run("EmptyEndpoint", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error with empty endpoint: %v", err) + } + + // Should pool the connection (empty endpoint clears state) + if !shouldPool { + t.Error("Connection should be pooled after clearing empty endpoint") + } + if shouldRemove { + t.Error("Connection should not be removed after clearing empty endpoint") + } + + // State should be cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after clearing empty endpoint") + } + }) + + t.Run("EventDrivenHandoffDialerError", func(t *testing.T) { + // Create a failing base dialer + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("dial failed") + } + + config := &Config{ + Mode: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 2, // Reduced retries for faster test + HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test + LogLevel: 2, + } + processor := NewPoolHook(failingDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not return error to caller: %v", err) + } + + // Should pool the connection initially (handoff queued) + if !shouldPool { + t.Error("Connection should be pooled initially with event-driven handoff") + } + if shouldRemove { + t.Error("Connection should not be removed when queuing handoff") + } + + // Wait for handoff to complete and fail with proper timeout and polling + timeout := time.After(3 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + // wait for handoff to start + time.Sleep(50 * time.Millisecond) + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for failed handoff to complete") + case <-ticker.C: + if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { + handoffCompleted = true + } + } + } + + // Connection should be removed from pending map after failed handoff + if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending { + t.Error("Connection should be removed from pending map after failed handoff") + } + + // Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up) + // Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete + time.Sleep(800 * time.Millisecond) + + // After max retries are reached, the connection should be removed from pool + // and handoff state should be cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after max retries reached") + } + + t.Logf("EventDrivenHandoffDialerError test completed successfully") + }) + + t.Run("BufferedDataRESP2", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + + // For this test, we'll just verify the logic works for connections without buffered data + // The actual buffered data detection is handled by the pool's connection health check + // which is outside the scope of the Redis connection processor + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error: %v", err) + } + + // Should pool the connection normally (no buffered data in mock) + if !shouldPool { + t.Error("Connection should be pooled when no buffered data") + } + if shouldRemove { + t.Error("Connection should not be removed when no buffered data") + } + }) + + t.Run("OnGet", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + + ctx := context.Background() + err := processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("OnGet should not error for normal connection: %v", err) + } + }) + + t.Run("OnGetWithPendingHandoff", func(t *testing.T) { + config := &Config{ + Mode: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue + LogLevel: 2, + } + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + + // Simulate a pending handoff by marking for handoff and queuing + conn.MarkForHandoff("new-endpoint:6379", 12345) + processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + + ctx := context.Background() + err := processor.OnGet(ctx, conn, false) + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) + } + + // Clean up + processor.GetPendingMap().Delete(conn) + }) + + t.Run("EventDrivenStateManagement", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + + // Test initial state - no pending handoffs + if _, pending := processor.GetPendingMap().Load(conn); pending { + t.Error("New connection should not have pending handoffs") + } + + // Test adding to pending map + conn.MarkForHandoff("new-endpoint:6379", 12345) + processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + + if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending { + t.Error("Connection should be in pending map") + } + + // Test OnGet with pending handoff + ctx := context.Background() + err := processor.OnGet(ctx, conn, false) + if err != ErrConnectionMarkedForHandoff { + t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") + } + + // Test removing from pending map and clearing handoff state + processor.GetPendingMap().Delete(conn) + if _, pending := processor.GetPendingMap().Load(conn); pending { + t.Error("Connection should be removed from pending map") + } + + // Clear handoff state to simulate completed handoff + conn.ClearHandoffState() + conn.SetUsable(true) // Make connection usable again + + // Test OnGet without pending handoff + err = processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("Should not return error for non-pending connection: %v", err) + } + }) + + t.Run("EventDrivenQueueOptimization", func(t *testing.T) { + // Create processor with small queue to test optimization features + config := &Config{ + MaxWorkers: 3, + HandoffQueueSize: 2, + MaxHandoffRetries: 3, // Small queue to trigger optimizations + LogLevel: 3, // Debug level to see optimization logs + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Add small delay to simulate network latency + time.Sleep(10 * time.Millisecond) + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Create multiple connections that need handoff to fill the queue + connections := make([]*pool.Conn, 5) + for i := 0; i < 5; i++ { + connections[i] = createMockPoolConnection() + if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil { + t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err) + } + // Set a mock initialization function + connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + } + + ctx := context.Background() + successCount := 0 + + // Process connections - should trigger scaling and timeout logic + for _, conn := range connections { + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Logf("OnPut returned error (expected with timeout): %v", err) + } + + if shouldPool && !shouldRemove { + successCount++ + } + } + + // With timeout and scaling, most handoffs should eventually succeed + if successCount == 0 { + t.Error("Should have queued some handoffs with timeout and scaling") + } + + t.Logf("Successfully queued %d handoffs with optimization features", successCount) + + // Give time for workers to process and scaling to occur + time.Sleep(100 * time.Millisecond) + }) + + t.Run("WorkerScalingBehavior", func(t *testing.T) { + // Create processor with small queue to test scaling behavior + config := &Config{ + MaxWorkers: 15, // Set to >= 10 to test explicit value preservation + HandoffQueueSize: 1, + MaxHandoffRetries: 3, // Very small queue to force scaling + LogLevel: 2, // Info level to see scaling logs + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Verify initial worker count (should be 0 with on-demand workers) + if processor.GetCurrentWorkers() != 0 { + t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers()) + } + if processor.GetMaxWorkers() != 15 { + t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers()) + } + + // The on-demand worker behavior creates workers only when needed + // This test just verifies the basic configuration is correct + t.Logf("On-demand worker configuration verified - Max: %d, Current: %d", + processor.GetMaxWorkers(), processor.GetCurrentWorkers()) + }) + + t.Run("PassiveTimeoutRestoration", func(t *testing.T) { + // Create processor with fast post-handoff duration for testing + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, // Allow retries for successful handoff + PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing + RelaxedTimeout: 5 * time.Second, + LogLevel: 2, + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + ctx := context.Background() + + // Create a connection and trigger handoff + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + // Process the connection to trigger handoff + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("Handoff should succeed: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after handoff") + } + + // Wait for handoff to complete with proper timeout and polling + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.GetPendingMap().Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify relaxed timeout is set with deadline + if !conn.HasRelaxedTimeout() { + t.Error("Connection should have relaxed timeout after handoff") + } + + // Test that timeout is still active before deadline + // We'll use HasRelaxedTimeout which internally checks the deadline + if !conn.HasRelaxedTimeout() { + t.Error("Connection should still have active relaxed timeout before deadline") + } + + // Wait for deadline to pass + time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer + + // Test that timeout is automatically restored after deadline + // HasRelaxedTimeout should return false after deadline passes + if conn.HasRelaxedTimeout() { + t.Error("Connection should not have active relaxed timeout after deadline") + } + + // Additional verification: calling HasRelaxedTimeout again should still return false + // and should have cleared the internal timeout values + if conn.HasRelaxedTimeout() { + t.Error("Connection should not have relaxed timeout after deadline (second check)") + } + + t.Logf("Passive timeout restoration test completed successfully") + }) + + t.Run("UsableFlagBehavior", func(t *testing.T) { + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, + LogLevel: 2, + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + ctx := context.Background() + + // Create a new connection without setting it usable + mockNetConn := &mockNetConn{addr: "test:6379"} + conn := pool.NewConn(mockNetConn) + + // Initially, connection should not be usable (not initialized) + if conn.IsUsable() { + t.Error("New connection should not be usable before initialization") + } + + // Simulate initialization by setting usable to true + conn.SetUsable(true) + if !conn.IsUsable() { + t.Error("Connection should be usable after initialization") + } + + // OnGet should succeed for usable connection + err := processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("OnGet should succeed for usable connection: %v", err) + } + + // Mark connection for handoff + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + // Connection should still be usable until queued, but marked for handoff + if !conn.IsUsable() { + t.Error("Connection should still be usable after being marked for handoff (until queued)") + } + if !conn.ShouldHandoff() { + t.Error("Connection should be marked for handoff") + } + + // OnGet should fail for connection marked for handoff + err = processor.OnGet(ctx, conn, false) + if err == nil { + t.Error("OnGet should fail for connection marked for handoff") + } + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) + } + + // Process the connection to trigger handoff + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should succeed: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after handoff") + } + + // Wait for handoff to complete + time.Sleep(50 * time.Millisecond) + + // After handoff completion, connection should be usable again + if !conn.IsUsable() { + t.Error("Connection should be usable after handoff completion") + } + + // OnGet should succeed again + err = processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("OnGet should succeed after handoff completion: %v", err) + } + + t.Logf("Usable flag behavior test completed successfully") + }) + + t.Run("StaticQueueBehavior", func(t *testing.T) { + config := &Config{ + MaxWorkers: 3, + HandoffQueueSize: 50, + MaxHandoffRetries: 3, // Explicit static queue size + LogLevel: 2, + } + + processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100 + defer processor.Shutdown(context.Background()) + + // Verify queue capacity matches configured size + queueCapacity := cap(processor.GetHandoffQueue()) + if queueCapacity != 50 { + t.Errorf("Expected queue capacity 50, got %d", queueCapacity) + } + + // Test that queue size is static regardless of pool size + // (No dynamic resizing should occur) + + ctx := context.Background() + + // Fill part of the queue + for i := 0; i < 10; i++ { + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil { + t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err) + } + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("Failed to queue handoff %d: %v", i, err) + } + + if !shouldPool || shouldRemove { + t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)", + i, shouldPool, shouldRemove) + } + } + + // Verify queue capacity remains static (the main purpose of this test) + finalCapacity := cap(processor.GetHandoffQueue()) + + if finalCapacity != 50 { + t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity) + } + + // Note: We don't check queue size here because workers process items quickly + // The important thing is that the capacity remains static regardless of pool size + }) + + t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) { + // Create a failing dialer that will cause handoff initialization to fail + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Return a connection that will fail during initialization + return &mockNetConn{addr: addr, shouldFailInit: true}, nil + } + + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, + LogLevel: 2, + } + + processor := NewPoolHook(failingDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Create a mock pool that tracks removals + mockPool := &mockPool{removedConnections: make(map[uint64]bool)} + processor.SetPool(mockPool) + + ctx := context.Background() + + // Create a connection and mark it for handoff + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a failing initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return fmt.Errorf("initialization failed") + }) + + // Process the connection - handoff should fail and connection should be removed + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after failed handoff attempt") + } + + // Wait for handoff to be attempted and fail + time.Sleep(100 * time.Millisecond) + + // Verify that the connection was removed from the pool + if !mockPool.WasRemoved(conn.GetID()) { + t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID()) + } + + t.Logf("Connection removal on handoff failure test completed successfully") + }) + + t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) { + // Create config with short post-handoff duration for testing + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, // Allow retries for successful handoff + RelaxedTimeout: 5 * time.Second, + PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + + if err != nil { + t.Fatalf("OnPut failed: %v", err) + } + + if !shouldPool { + t.Error("Connection should be pooled after successful handoff") + } + + if shouldRemove { + t.Error("Connection should not be removed after successful handoff") + } + + // Wait for the handoff to complete (it happens asynchronously) + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.GetPendingMap().Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify that relaxed timeout was applied to the new connection + if !conn.HasRelaxedTimeout() { + t.Error("New connection should have relaxed timeout applied after handoff") + } + + // Wait for the post-handoff duration to expire + time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration + + // Verify that relaxed timeout was automatically cleared + if conn.HasRelaxedTimeout() { + t.Error("Relaxed timeout should be automatically cleared after post-handoff duration") + } + }) + + t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) { + conn := createMockPoolConnection() + + // First mark should succeed + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("First MarkForHandoff should succeed: %v", err) + } + + // Second mark should fail + if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil { + t.Fatal("Second MarkForHandoff should return error") + } else if err.Error() != "connection is already marked for handoff" { + t.Fatalf("Expected specific error message, got: %v", err) + } + + // Verify original handoff data is preserved + if !conn.ShouldHandoff() { + t.Fatal("Connection should still be marked for handoff") + } + if conn.GetHandoffEndpoint() != "new-endpoint:6379" { + t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint()) + } + if conn.GetMovingSeqID() != 1 { + t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID()) + } + }) + + t.Run("HandoffTimeoutConfiguration", func(t *testing.T) { + // Test that HandoffTimeout from config is actually used + customTimeout := 2 * time.Second + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + HandoffTimeout: customTimeout, // Custom timeout + MaxHandoffRetries: 1, // Single retry to speed up test + LogLevel: 2, + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Create a connection that will test the timeout + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a dialer that will check the context timeout + var timeoutVerified int32 // Use atomic for thread safety + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + // Check that the context has the expected timeout + deadline, ok := ctx.Deadline() + if !ok { + t.Error("Context should have a deadline") + return errors.New("no deadline") + } + + // The deadline should be approximately customTimeout from now + expectedDeadline := time.Now().Add(customTimeout) + timeDiff := deadline.Sub(expectedDeadline) + if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond { + t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)", + expectedDeadline, deadline, timeDiff) + } else { + atomic.StoreInt32(&timeoutVerified, 1) + } + + return nil // Successful handoff + }) + + // Trigger handoff + shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn) + if err != nil { + t.Errorf("OnPut should not return error: %v", err) + } + + // Connection should be queued for handoff + if !shouldPool || shouldRemove { + t.Errorf("Connection should be pooled for handoff processing") + } + + // Wait for handoff to complete + time.Sleep(500 * time.Millisecond) + + if atomic.LoadInt32(&timeoutVerified) == 0 { + t.Error("HandoffTimeout was not properly applied to context") + } + + t.Logf("HandoffTimeout configuration test completed successfully") + }) +} diff --git a/hitless/push_notification_handler.go b/hitless/push_notification_handler.go new file mode 100644 index 00000000..33a4fd3e --- /dev/null +++ b/hitless/push_notification_handler.go @@ -0,0 +1,276 @@ +package hitless + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// NotificationHandler handles push notifications for the simplified manager. +type NotificationHandler struct { + manager *HitlessManager +} + +// HandlePushNotification processes push notifications with hook support. +func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) == 0 { + internal.Logger.Printf(ctx, "hitless: invalid notification format: %v", notification) + return ErrInvalidNotification + } + + notificationType, ok := notification[0].(string) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid notification type format: %v", notification[0]) + return ErrInvalidNotification + } + + // Process pre-hooks - they can modify the notification or skip processing + modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification) + if !shouldContinue { + return nil // Hooks decided to skip processing + } + + var err error + switch notificationType { + case NotificationMoving: + err = snh.handleMoving(ctx, handlerCtx, modifiedNotification) + case NotificationMigrating: + err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification) + case NotificationMigrated: + err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification) + case NotificationFailingOver: + err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification) + case NotificationFailedOver: + err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification) + default: + // Ignore other notification types (e.g., pub/sub messages) + err = nil + } + + // Process post-hooks with the result + snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err) + + return err +} + +// handleMoving processes MOVING notifications. +// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff +func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + internal.Logger.Printf(ctx, "hitless: invalid MOVING notification: %v", notification) + return ErrInvalidNotification + } + seqID, ok := notification[1].(int64) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid seqID in MOVING notification: %v", notification[1]) + return ErrInvalidNotification + } + + // Extract timeS + timeS, ok := notification[2].(int64) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid timeS in MOVING notification: %v", notification[2]) + return ErrInvalidNotification + } + + newEndpoint := "" + if len(notification) > 3 { + // Extract new endpoint + newEndpoint, ok = notification[3].(string) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid newEndpoint in MOVING notification: %v", notification[3]) + return ErrInvalidNotification + } + } + + // Get the connection that received this notification + conn := handlerCtx.Conn + if conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for MOVING notification") + return ErrInvalidNotification + } + + // Type assert to get the underlying pool connection + var poolConn *pool.Conn + if pc, ok := conn.(*pool.Conn); ok { + poolConn = pc + } else { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx) + return ErrInvalidNotification + } + + // If the connection is closed or not pooled, we can ignore the notification + // this connection won't be remembered by the pool and will be garbage collected + // Keep pubsub connections around since they are not pooled but are long-lived + // and should be allowed to handoff (the pubsub instance will reconnect and change + // the underlying *pool.Conn) + if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() { + return nil + } + + deadline := time.Now().Add(time.Duration(timeS) * time.Second) + // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds + if newEndpoint == "" || newEndpoint == internal.RedisNull { + if snh.manager.config.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(ctx, "hitless: conn[%d] scheduling handoff to current endpoint in %v seconds", + poolConn.GetID(), timeS/2) + } + // same as current endpoint + newEndpoint = snh.manager.options.GetAddr() + // delay the handoff for timeS/2 seconds to the same endpoint + // do this in a goroutine to avoid blocking the notification handler + // NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff + // and there should be no possibility of a race condition or double handoff. + time.AfterFunc(time.Duration(timeS/2)*time.Second, func() { + if poolConn == nil || poolConn.IsClosed() { + return + } + if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { + // Log error but don't fail the goroutine - use background context since original may be cancelled + internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err) + } + }) + return nil + } + + return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline) +} + +func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { + if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err) + // Connection is already marked for handoff, which is acceptable + // This can happen if multiple MOVING notifications are received for the same connection + return nil + } + // Optionally track in hitless manager for monitoring/debugging + if snh.manager != nil { + connID := conn.GetID() + // Track the operation (ignore errors since this is optional) + _ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID) + } else { + return fmt.Errorf("hitless: manager not initialized") + } + return nil +} + +// handleMigrating processes MIGRATING notifications. +func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // MIGRATING notifications indicate that a connection is about to be migrated + // Apply relaxed timeouts to the specific connection that received this notification + if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid MIGRATING notification: %v", notification) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification") + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification") + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level + internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for MIGRATING notification", + conn.GetID(), + snh.manager.config.RelaxedTimeout) + } + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleMigrated processes MIGRATED notifications. +func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // MIGRATED notifications indicate that a connection migration has completed + // Restore normal timeouts for the specific connection that received this notification + if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid MIGRATED notification: %v", notification) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification") + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification") + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level + connID := conn.GetID() + internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for MIGRATED notification", connID) + } + conn.ClearRelaxedTimeout() + return nil +} + +// handleFailingOver processes FAILING_OVER notifications. +func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // FAILING_OVER notifications indicate that a connection is about to failover + // Apply relaxed timeouts to the specific connection that received this notification + if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid FAILING_OVER notification: %v", notification) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification") + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification") + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level + connID := conn.GetID() + internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for FAILING_OVER notification", connID, snh.manager.config.RelaxedTimeout) + } + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleFailedOver processes FAILED_OVER notifications. +func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // FAILED_OVER notifications indicate that a connection failover has completed + // Restore normal timeouts for the specific connection that received this notification + if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid FAILED_OVER notification: %v", notification) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification") + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification") + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level + connID := conn.GetID() + internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for FAILED_OVER notification", connID) + } + conn.ClearRelaxedTimeout() + return nil +} diff --git a/hitless/state.go b/hitless/state.go new file mode 100644 index 00000000..109d939f --- /dev/null +++ b/hitless/state.go @@ -0,0 +1,24 @@ +package hitless + +// State represents the current state of a hitless upgrade operation. +type State int + +const ( + // StateIdle indicates no upgrade is in progress + StateIdle State = iota + + // StateHandoff indicates a connection handoff is in progress + StateMoving +) + +// String returns a string representation of the state. +func (s State) String() string { + switch s { + case StateIdle: + return "idle" + case StateMoving: + return "moving" + default: + return "unknown" + } +} diff --git a/internal/interfaces/interfaces.go b/internal/interfaces/interfaces.go new file mode 100644 index 00000000..5352436f --- /dev/null +++ b/internal/interfaces/interfaces.go @@ -0,0 +1,54 @@ +// Package interfaces provides shared interfaces used by both the main redis package +// and the hitless upgrade package to avoid circular dependencies. +package interfaces + +import ( + "context" + "net" + "time" +) + +// NotificationProcessor is (most probably) a push.NotificationProcessor +// forward declaration to avoid circular imports +type NotificationProcessor interface { + RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error + UnregisterHandler(pushNotificationName string) error + GetHandler(pushNotificationName string) interface{} +} + +// ClientInterface defines the interface that clients must implement for hitless upgrades. +type ClientInterface interface { + // GetOptions returns the client options. + GetOptions() OptionsInterface + + // GetPushProcessor returns the client's push notification processor. + GetPushProcessor() NotificationProcessor +} + +// OptionsInterface defines the interface for client options. +// Uses an adapter pattern to avoid circular dependencies. +type OptionsInterface interface { + // GetReadTimeout returns the read timeout. + GetReadTimeout() time.Duration + + // GetWriteTimeout returns the write timeout. + GetWriteTimeout() time.Duration + + // GetNetwork returns the network type. + GetNetwork() string + + // GetAddr returns the connection address. + GetAddr() string + + // IsTLSEnabled returns true if TLS is enabled. + IsTLSEnabled() bool + + // GetProtocol returns the protocol version. + GetProtocol() int + + // GetPoolSize returns the connection pool size. + GetPoolSize() int + + // NewDialer returns a new dialer function for the connection. + NewDialer() func(context.Context) (net.Conn, error) +} diff --git a/internal/log.go b/internal/log.go index 4fe3d7db..eef9c0a3 100644 --- a/internal/log.go +++ b/internal/log.go @@ -14,26 +14,20 @@ type Logging interface { Printf(ctx context.Context, format string, v ...interface{}) } -type logger struct { +type DefaultLogger struct { log *log.Logger } -func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { +func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) { _ = l.log.Output(2, fmt.Sprintf(format, v...)) } +func NewDefaultLogger() Logging { + return &DefaultLogger{ + log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), + } +} + // Logger calls Output to print to the stderr. // Arguments are handled in the manner of fmt.Print. -var Logger Logging = &logger{ - log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), -} - -// VoidLogger is a logger that does nothing. -// Used to disable logging and thus speed up the library. -type VoidLogger struct{} - -func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) { - // do nothing -} - -var _ Logging = (*VoidLogger)(nil) +var Logger Logging = NewDefaultLogger() diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 72308e12..fc37b821 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -2,6 +2,7 @@ package pool_test import ( "context" + "errors" "fmt" "testing" "time" @@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: bm.poolSize, + PoolSize: int32(bm.poolSize), PoolTimeout: time.Second, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Hour, @@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: bm.poolSize, + PoolSize: int32(bm.poolSize), PoolTimeout: time.Second, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Hour, @@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { if err != nil { b.Fatal(err) } - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("Bench test remove")) } }) }) diff --git a/internal/pool/buffer_size_test.go b/internal/pool/buffer_size_test.go index 7f4bd37e..71223d70 100644 --- a/internal/pool/buffer_size_test.go +++ b/internal/pool/buffer_size_test.go @@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() { It("should use default buffer sizes when not specified", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, }) @@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, ReadBufferSize: customReadSize, WriteBufferSize: customWriteSize, @@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() { It("should handle zero buffer sizes by using defaults", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, ReadBufferSize: 0, // Should use default WriteBufferSize: 0, // Should use default @@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() { // without setting ReadBufferSize and WriteBufferSize connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, // ReadBufferSize and WriteBufferSize are not set (will be 0) }) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 8fcdfa67..239b86dc 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -3,7 +3,10 @@ package pool import ( "bufio" "context" + "errors" + "fmt" "net" + "sync" "sync/atomic" "time" @@ -12,17 +15,65 @@ import ( var noDeadline = time.Time{} +// Global atomic counter for connection IDs +var connIDCounter uint64 + +// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value +type atomicNetConn struct { + conn net.Conn +} + +// generateConnID generates a fast unique identifier for a connection with zero allocations +func generateConnID() uint64 { + return atomic.AddUint64(&connIDCounter, 1) +} + type Conn struct { - usedAt int64 // atomic - netConn net.Conn + usedAt int64 // atomic + + // Lock-free netConn access using atomic.Value + // Contains *atomicNetConn wrapper, accessed atomically for better performance + netConnAtomic atomic.Value // stores *atomicNetConn rd *proto.Reader bw *bufio.Writer wr *proto.Writer - Inited bool + // Lightweight mutex to protect reader operations during handoff + // Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe + readerMu sync.RWMutex + + Inited atomic.Bool pooled bool + pubsub bool + closed atomic.Bool createdAt time.Time + expiresAt time.Time + + // Hitless upgrade support: relaxed timeouts during migrations/failovers + // Using atomic operations for lock-free access to avoid mutex contention + relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch + + // Counter to track multiple relaxed timeout setters if we have nested calls + // will be decremented when ClearRelaxedTimeout is called or deadline is reached + // if counter reaches 0, we clear the relaxed timeouts + relaxedCounter atomic.Int32 + + // Connection initialization function for reconnections + initConnFunc func(context.Context, *Conn) error + + // Connection identifier for unique tracking across handoffs + id uint64 // Unique numeric identifier for this connection + + // Handoff state - using atomic operations for lock-free access + usableAtomic atomic.Bool // Connection usability state + shouldHandoffAtomic atomic.Bool // Whether connection should be handed off + movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification + handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts + // newEndpointAtomic needs special handling as it's a string + newEndpointAtomic atomic.Value // stores string onClose func() error } @@ -33,8 +84,8 @@ func NewConn(netConn net.Conn) *Conn { func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn { cn := &Conn{ - netConn: netConn, createdAt: time.Now(), + id: generateConnID(), // Generate unique ID for this connection } // Use specified buffer sizes, or fall back to 32KiB defaults if 0 @@ -50,6 +101,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize) } + // Store netConn atomically for lock-free access using wrapper + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) + + // Initialize atomic handoff state + cn.usableAtomic.Store(false) // false initially, set to true after initialization + cn.shouldHandoffAtomic.Store(false) // false initially + cn.movingSeqIDAtomic.Store(0) // 0 initially + cn.handoffRetriesAtomic.Store(0) // 0 initially + cn.newEndpointAtomic.Store("") // empty string initially + cn.wr = proto.NewWriter(cn.bw) cn.SetUsedAt(time.Now()) return cn @@ -64,23 +125,366 @@ func (cn *Conn) SetUsedAt(tm time.Time) { atomic.StoreInt64(&cn.usedAt, tm.Unix()) } +// getNetConn returns the current network connection using atomic load (lock-free). +// This is the fast path for accessing netConn without mutex overhead. +func (cn *Conn) getNetConn() net.Conn { + if v := cn.netConnAtomic.Load(); v != nil { + if wrapper, ok := v.(*atomicNetConn); ok { + return wrapper.conn + } + } + return nil +} + +// setNetConn stores the network connection atomically (lock-free). +// This is used for the fast path of connection replacement. +func (cn *Conn) setNetConn(netConn net.Conn) { + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) +} + +// Lock-free helper methods for handoff state management + +// isUsable returns true if the connection is safe to use (lock-free). +func (cn *Conn) isUsable() bool { + return cn.usableAtomic.Load() +} + +// setUsable sets the usable flag atomically (lock-free). +func (cn *Conn) setUsable(usable bool) { + cn.usableAtomic.Store(usable) +} + +// shouldHandoff returns true if connection needs handoff (lock-free). +func (cn *Conn) shouldHandoff() bool { + return cn.shouldHandoffAtomic.Load() +} + +// setShouldHandoff sets the handoff flag atomically (lock-free). +func (cn *Conn) setShouldHandoff(should bool) { + cn.shouldHandoffAtomic.Store(should) +} + +// getMovingSeqID returns the sequence ID atomically (lock-free). +func (cn *Conn) getMovingSeqID() int64 { + return cn.movingSeqIDAtomic.Load() +} + +// setMovingSeqID sets the sequence ID atomically (lock-free). +func (cn *Conn) setMovingSeqID(seqID int64) { + cn.movingSeqIDAtomic.Store(seqID) +} + +// getNewEndpoint returns the new endpoint atomically (lock-free). +func (cn *Conn) getNewEndpoint() string { + if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil { + return endpoint.(string) + } + return "" +} + +// setNewEndpoint sets the new endpoint atomically (lock-free). +func (cn *Conn) setNewEndpoint(endpoint string) { + cn.newEndpointAtomic.Store(endpoint) +} + +// setHandoffRetries sets the retry count atomically (lock-free). +func (cn *Conn) setHandoffRetries(retries int) { + cn.handoffRetriesAtomic.Store(uint32(retries)) +} + +// incrementHandoffRetries atomically increments and returns the new retry count (lock-free). +func (cn *Conn) incrementHandoffRetries(delta int) int { + return int(cn.handoffRetriesAtomic.Add(uint32(delta))) +} + +// IsUsable returns true if the connection is safe to use for new commands (lock-free). +func (cn *Conn) IsUsable() bool { + return cn.isUsable() +} + +// IsPooled returns true if the connection is managed by a pool and will be pooled on Put. +func (cn *Conn) IsPooled() bool { + return cn.pooled +} + +// IsPubSub returns true if the connection is used for PubSub. +func (cn *Conn) IsPubSub() bool { + return cn.pubsub +} + +func (cn *Conn) IsInited() bool { + return cn.Inited.Load() +} + +// SetUsable sets the usable flag for the connection (lock-free). +func (cn *Conn) SetUsable(usable bool) { + cn.setUsable(usable) +} + +// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. +// These timeouts will be used for all subsequent commands until the deadline expires. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + cn.relaxedCounter.Add(1) + cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) + cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) +} + +// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. +// After the deadline, timeouts automatically revert to normal values. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + cn.SetRelaxedTimeout(readTimeout, writeTimeout) + cn.relaxedDeadlineNs.Store(deadline.UnixNano()) +} + +// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior. +// Uses atomic operations for lock-free access. +func (cn *Conn) ClearRelaxedTimeout() { + // Atomically decrement counter and check if we should clear + newCount := cn.relaxedCounter.Add(-1) + if newCount <= 0 { + // Use atomic load to get current value for CAS to avoid stale value race + current := cn.relaxedCounter.Load() + if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) { + cn.clearRelaxedTimeout() + } + } +} + +func (cn *Conn) clearRelaxedTimeout() { + cn.relaxedReadTimeoutNs.Store(0) + cn.relaxedWriteTimeoutNs.Store(0) + cn.relaxedDeadlineNs.Store(0) + cn.relaxedCounter.Store(0) +} + +// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection. +// This checks both the timeout values and the deadline (if set). +// Uses atomic operations for lock-free access. +func (cn *Conn) HasRelaxedTimeout() bool { + // Fast path: no relaxed timeouts are set + if cn.relaxedCounter.Load() <= 0 { + return false + } + + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // If no relaxed timeouts are set, return false + if readTimeoutNs <= 0 && writeTimeoutNs <= 0 { + return false + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, relaxed timeouts are active + if deadlineNs == 0 { + return true + } + + // If deadline is set, check if it's still in the future + return time.Now().UnixNano() < deadlineNs +} + +// getEffectiveReadTimeout returns the timeout to use for read operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration { + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if readTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(readTimeoutNs) + } + + nowNs := time.Now().UnixNano() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(readTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + cn.relaxedCounter.Add(-1) + if cn.relaxedCounter.Load() <= 0 { + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + +// getEffectiveWriteTimeout returns the timeout to use for write operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration { + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if writeTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(writeTimeoutNs) + } + + nowNs := time.Now().UnixNano() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(writeTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + cn.relaxedCounter.Add(-1) + if cn.relaxedCounter.Load() <= 0 { + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + func (cn *Conn) SetOnClose(fn func() error) { cn.onClose = fn } +// SetInitConnFunc sets the connection initialization function to be called on reconnections. +func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) { + cn.initConnFunc = fn +} + +// ExecuteInitConn runs the stored connection initialization function if available. +func (cn *Conn) ExecuteInitConn(ctx context.Context) error { + if cn.initConnFunc != nil { + return cn.initConnFunc(ctx, cn) + } + return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID()) +} + func (cn *Conn) SetNetConn(netConn net.Conn) { - cn.netConn = netConn + // Store the new connection atomically first (lock-free) + cn.setNetConn(netConn) + // Protect reader reset operations to avoid data races + // Use write lock since we're modifying the reader state + cn.readerMu.Lock() cn.rd.Reset(netConn) + cn.readerMu.Unlock() + cn.bw.Reset(netConn) } +// GetNetConn safely returns the current network connection using atomic load (lock-free). +// This method is used by the pool for health checks and provides better performance. +func (cn *Conn) GetNetConn() net.Conn { + return cn.getNetConn() +} + +// SetNetConnAndInitConn replaces the underlying connection and executes the initialization. +func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error { + // New connection is not initialized yet + cn.Inited.Store(false) + // Replace the underlying connection + cn.SetNetConn(netConn) + return cn.ExecuteInitConn(ctx) +} + +// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free). +// Returns an error if the connection is already marked for handoff. +func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error { + // Use single atomic CAS operation for state transition + if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) { + return errors.New("connection is already marked for handoff") + } + + cn.setNewEndpoint(newEndpoint) + cn.setMovingSeqID(seqID) + return nil +} + +func (cn *Conn) MarkQueuedForHandoff() error { + // Use single atomic CAS operation for state transition + if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) { + return errors.New("connection was not marked for handoff") + } + cn.setUsable(false) + return nil +} + +// ShouldHandoff returns true if the connection needs to be handed off (lock-free). +func (cn *Conn) ShouldHandoff() bool { + return cn.shouldHandoff() +} + +// GetHandoffEndpoint returns the new endpoint for handoff (lock-free). +func (cn *Conn) GetHandoffEndpoint() string { + return cn.getNewEndpoint() +} + +// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free). +func (cn *Conn) GetMovingSeqID() int64 { + return cn.getMovingSeqID() +} + +// GetID returns the unique identifier for this connection. +func (cn *Conn) GetID() uint64 { + return cn.id +} + +// ClearHandoffState clears the handoff state after successful handoff (lock-free). +func (cn *Conn) ClearHandoffState() { + // clear handoff state + cn.setShouldHandoff(false) + cn.setNewEndpoint("") + cn.setMovingSeqID(0) + cn.setHandoffRetries(0) + cn.setUsable(true) // Connection is safe to use again after handoff completes +} + +// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free). +func (cn *Conn) IncrementAndGetHandoffRetries(n int) int { + return cn.incrementHandoffRetries(n) +} + +// HasBufferedData safely checks if the connection has buffered data. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) HasBufferedData() bool { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + return cn.rd.Buffered() > 0 +} + +// PeekReplyTypeSafe safely peeks at the reply type. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) PeekReplyTypeSafe() (byte, error) { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + + if cn.rd.Buffered() <= 0 { + return 0, fmt.Errorf("redis: can't peek reply type, no data available") + } + return cn.rd.PeekReplyType() +} + func (cn *Conn) Write(b []byte) (int, error) { - return cn.netConn.Write(b) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Write(b) + } + return 0, net.ErrClosed } func (cn *Conn) RemoteAddr() net.Addr { - if cn.netConn != nil { - return cn.netConn.RemoteAddr() + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.RemoteAddr() } return nil } @@ -89,7 +493,16 @@ func (cn *Conn) WithReader( ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveReadTimeout(timeout) + + // Get the connection directly from atomic storage + netConn := cn.getNetConn() + if netConn == nil { + return fmt.Errorf("redis: connection not available") + } + + if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { return err } } @@ -100,13 +513,26 @@ func (cn *Conn) WithWriter( ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { - return err + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveWriteTimeout(timeout) + + // Always set write deadline, even if getNetConn() returns nil + // This prevents write operations from hanging indefinitely + if netConn := cn.getNetConn(); netConn != nil { + if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { + return err + } + } else { + // If getNetConn() returns nil, we still need to respect the timeout + // Return an error to prevent indefinite blocking + return fmt.Errorf("redis: connection not available for write operation") } } if cn.bw.Buffered() > 0 { - cn.bw.Reset(cn.netConn) + if netConn := cn.getNetConn(); netConn != nil { + cn.bw.Reset(netConn) + } } if err := fn(cn.wr); err != nil { @@ -116,19 +542,33 @@ func (cn *Conn) WithWriter( return cn.bw.Flush() } +func (cn *Conn) IsClosed() bool { + return cn.closed.Load() +} + func (cn *Conn) Close() error { + cn.closed.Store(true) if cn.onClose != nil { // ignore error _ = cn.onClose() } - return cn.netConn.Close() + + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Close() + } + return nil } // MaybeHasData tries to peek at the next byte in the socket without consuming it // This is used to check if there are push notifications available // Important: This will work on Linux, but not on Windows func (cn *Conn) MaybeHasData() bool { - return maybeHasData(cn.netConn) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return maybeHasData(netConn) + } + return false } func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { diff --git a/internal/pool/conn_relaxed_timeout_test.go b/internal/pool/conn_relaxed_timeout_test.go new file mode 100644 index 00000000..503107ab --- /dev/null +++ b/internal/pool/conn_relaxed_timeout_test.go @@ -0,0 +1,92 @@ +package pool + +import ( + "net" + "sync" + "testing" + "time" +) + +// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout +func TestConcurrentRelaxedTimeoutClearing(t *testing.T) { + // Create a dummy connection for testing + netConn := &net.TCPConn{} + cn := NewConn(netConn) + defer cn.Close() + + // Set relaxed timeout multiple times to increase counter + cn.SetRelaxedTimeout(time.Second, time.Second) + cn.SetRelaxedTimeout(time.Second, time.Second) + cn.SetRelaxedTimeout(time.Second, time.Second) + + // Verify counter is 3 + if count := cn.relaxedCounter.Load(); count != 3 { + t.Errorf("Expected relaxed counter to be 3, got %d", count) + } + + // Clear timeouts concurrently to test race condition fix + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cn.ClearRelaxedTimeout() + }() + } + wg.Wait() + + // Verify counter is 0 and timeouts are cleared + if count := cn.relaxedCounter.Load(); count != 0 { + t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count) + } + if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 { + t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout) + } + if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 { + t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout) + } +} + +// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario +func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) { + netConn := &net.TCPConn{} + cn := NewConn(netConn) + defer cn.Close() + + // Set relaxed timeout once + cn.SetRelaxedTimeout(time.Second, time.Second) + + // Verify counter is 1 + if count := cn.relaxedCounter.Load(); count != 1 { + t.Errorf("Expected relaxed counter to be 1, got %d", count) + } + + // Test concurrent clearing with race condition scenario + var wg sync.WaitGroup + + // Multiple goroutines try to clear simultaneously + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cn.ClearRelaxedTimeout() + }() + } + wg.Wait() + + // Verify final state is consistent + if count := cn.relaxedCounter.Load(); count != 0 { + t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count) + } + + // Verify timeouts are actually cleared + if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 { + t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout) + } + if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 { + t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout) + } + if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 { + t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline) + } +} diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index 40e387c9..20456b81 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) { } func (cn *Conn) NetConn() net.Conn { - return cn.netConn + return cn.getNetConn() } func (p *ConnPool) CheckMinIdleConns() { diff --git a/internal/pool/hooks.go b/internal/pool/hooks.go new file mode 100644 index 00000000..adbcfbbf --- /dev/null +++ b/internal/pool/hooks.go @@ -0,0 +1,114 @@ +package pool + +import ( + "context" + "sync" +) + +// PoolHook defines the interface for connection lifecycle hooks. +type PoolHook interface { + // OnGet is called when a connection is retrieved from the pool. + // It can modify the connection or return an error to prevent its use. + // It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool) + // The flag can be used for gathering metrics on pool hit/miss ratio. + OnGet(ctx context.Context, conn *Conn, isNewConn bool) error + + // OnPut is called when a connection is returned to the pool. + // It returns whether the connection should be pooled and whether it should be removed. + OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) +} + +// PoolHookManager manages multiple pool hooks. +type PoolHookManager struct { + hooks []PoolHook + hooksMu sync.RWMutex +} + +// NewPoolHookManager creates a new pool hook manager. +func NewPoolHookManager() *PoolHookManager { + return &PoolHookManager{ + hooks: make([]PoolHook, 0), + } +} + +// AddHook adds a pool hook to the manager. +// Hooks are called in the order they were added. +func (phm *PoolHookManager) AddHook(hook PoolHook) { + phm.hooksMu.Lock() + defer phm.hooksMu.Unlock() + phm.hooks = append(phm.hooks, hook) +} + +// RemoveHook removes a pool hook from the manager. +func (phm *PoolHookManager) RemoveHook(hook PoolHook) { + phm.hooksMu.Lock() + defer phm.hooksMu.Unlock() + + for i, h := range phm.hooks { + if h == hook { + // Remove hook by swapping with last element and truncating + phm.hooks[i] = phm.hooks[len(phm.hooks)-1] + phm.hooks = phm.hooks[:len(phm.hooks)-1] + break + } + } +} + +// ProcessOnGet calls all OnGet hooks in order. +// If any hook returns an error, processing stops and the error is returned. +func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + for _, hook := range phm.hooks { + if err := hook.OnGet(ctx, conn, isNewConn); err != nil { + return err + } + } + return nil +} + +// ProcessOnPut calls all OnPut hooks in order. +// The first hook that returns shouldRemove=true or shouldPool=false will stop processing. +func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + shouldPool = true // Default to pooling the connection + + for _, hook := range phm.hooks { + hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn) + + if hookErr != nil { + return false, true, hookErr + } + + // If any hook says to remove or not pool, respect that decision + if hookShouldRemove { + return false, true, nil + } + + if !hookShouldPool { + shouldPool = false + } + } + + return shouldPool, false, nil +} + +// GetHookCount returns the number of registered hooks (for testing). +func (phm *PoolHookManager) GetHookCount() int { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + return len(phm.hooks) +} + +// GetHooks returns a copy of all registered hooks. +func (phm *PoolHookManager) GetHooks() []PoolHook { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + hooks := make([]PoolHook, len(phm.hooks)) + copy(hooks, phm.hooks) + return hooks +} diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go new file mode 100644 index 00000000..e6100115 --- /dev/null +++ b/internal/pool/hooks_test.go @@ -0,0 +1,213 @@ +package pool + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +// TestHook for testing hook functionality +type TestHook struct { + OnGetCalled int + OnPutCalled int + GetError error + PutError error + ShouldPool bool + ShouldRemove bool +} + +func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error { + th.OnGetCalled++ + return th.GetError +} + +func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) { + th.OnPutCalled++ + return th.ShouldPool, th.ShouldRemove, th.PutError +} + +func TestPoolHookManager(t *testing.T) { + manager := NewPoolHookManager() + + // Test initial state + if manager.GetHookCount() != 0 { + t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount()) + } + + // Add hooks + hook1 := &TestHook{ShouldPool: true} + hook2 := &TestHook{ShouldPool: true} + + manager.AddHook(hook1) + manager.AddHook(hook2) + + if manager.GetHookCount() != 2 { + t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount()) + } + + // Test ProcessOnGet + ctx := context.Background() + conn := &Conn{} // Mock connection + + err := manager.ProcessOnGet(ctx, conn, false) + if err != nil { + t.Errorf("ProcessOnGet should not error: %v", err) + } + + if hook1.OnGetCalled != 1 { + t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled) + } + + if hook2.OnGetCalled != 1 { + t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled) + } + + // Test ProcessOnPut + shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessOnPut should not error: %v", err) + } + + if !shouldPool { + t.Error("Expected shouldPool to be true") + } + + if shouldRemove { + t.Error("Expected shouldRemove to be false") + } + + if hook1.OnPutCalled != 1 { + t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled) + } + + if hook2.OnPutCalled != 1 { + t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled) + } + + // Remove a hook + manager.RemoveHook(hook1) + + if manager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount()) + } +} + +func TestHookErrorHandling(t *testing.T) { + manager := NewPoolHookManager() + + // Hook that returns error on Get + errorHook := &TestHook{ + GetError: errors.New("test error"), + ShouldPool: true, + } + + normalHook := &TestHook{ShouldPool: true} + + manager.AddHook(errorHook) + manager.AddHook(normalHook) + + ctx := context.Background() + conn := &Conn{} + + // Test that error stops processing + err := manager.ProcessOnGet(ctx, conn, false) + if err == nil { + t.Error("Expected error from ProcessOnGet") + } + + if errorHook.OnGetCalled != 1 { + t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled) + } + + // normalHook should not be called due to error + if normalHook.OnGetCalled != 0 { + t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled) + } +} + +func TestHookShouldRemove(t *testing.T) { + manager := NewPoolHookManager() + + // Hook that says to remove connection + removeHook := &TestHook{ + ShouldPool: false, + ShouldRemove: true, + } + + normalHook := &TestHook{ShouldPool: true} + + manager.AddHook(removeHook) + manager.AddHook(normalHook) + + ctx := context.Background() + conn := &Conn{} + + shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessOnPut should not error: %v", err) + } + + if shouldPool { + t.Error("Expected shouldPool to be false") + } + + if !shouldRemove { + t.Error("Expected shouldRemove to be true") + } + + if removeHook.OnPutCalled != 1 { + t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled) + } + + // normalHook should not be called due to early return + if normalHook.OnPutCalled != 0 { + t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled) + } +} + +func TestPoolWithHooks(t *testing.T) { + // Create a pool with hooks + hookManager := NewPoolHookManager() + testHook := &TestHook{ShouldPool: true} + hookManager.AddHook(testHook) + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil // Mock connection + }, + PoolSize: 1, + DialTimeout: time.Second, + } + + pool := NewConnPool(opt) + defer pool.Close() + + // Add hook to pool after creation + pool.AddPoolHook(testHook) + + // Verify hooks are initialized + if pool.hookManager == nil { + t.Error("Expected hookManager to be initialized") + } + + if pool.hookManager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount()) + } + + // Test adding hook to pool + additionalHook := &TestHook{ShouldPool: true} + pool.AddPoolHook(additionalHook) + + if pool.hookManager.GetHookCount() != 2 { + t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount()) + } + + // Test removing hook from pool + pool.RemovePoolHook(additionalHook) + + if pool.hookManager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount()) + } +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index fa0306c3..b2cdbef5 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -10,6 +10,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/util" ) var ( @@ -22,6 +23,23 @@ var ( // ErrPoolTimeout timed out waiting to get a connection from the connection pool. ErrPoolTimeout = errors.New("redis: connection pool timeout") + + // popAttempts is the maximum number of attempts to find a usable connection + // when popping from the idle connection pool. This handles cases where connections + // are temporarily marked as unusable (e.g., during hitless upgrades or network issues). + // Value of 50 provides sufficient resilience without excessive overhead. + // This is capped by the idle connection count, so we won't loop excessively. + popAttempts = 50 + + // getAttempts is the maximum number of attempts to get a connection that passes + // hook validation (e.g., hitless upgrade hooks). This protects against race conditions + // where hooks might temporarily reject connections during cluster transitions. + // Value of 3 balances resilience with performance - most hook rejections resolve quickly. + getAttempts = 3 + + minTime = time.Unix(-2208988800, 0) // Jan 1, 1900 + maxTime = minTime.Add(1<<63 - 1) + noExpiration = maxTime ) var timers = sync.Pool{ @@ -38,11 +56,14 @@ type Stats struct { Misses uint32 // number of times free connection was NOT found in the pool Timeouts uint32 // number of times a wait timeout occurred WaitCount uint32 // number of times a connection was waited + Unusable uint32 // number of times a connection was found to be unusable WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds TotalConns uint32 // number of total connections in the pool IdleConns uint32 // number of idle connections in the pool StaleConns uint32 // number of stale connections removed from the pool + + PubSubStats PubSubStats } type Pooler interface { @@ -57,29 +78,35 @@ type Pooler interface { IdleLen() int Stats() *Stats + AddPoolHook(hook PoolHook) + RemovePoolHook(hook PoolHook) + Close() error } type Options struct { - Dialer func(context.Context) (net.Conn, error) - - PoolFIFO bool - PoolSize int - DialTimeout time.Duration - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration - - - // Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without) - Protocol int - + Dialer func(context.Context) (net.Conn, error) ReadBufferSize int WriteBufferSize int + PoolFIFO bool + PoolSize int32 + DialTimeout time.Duration + PoolTimeout time.Duration + MinIdleConns int32 + MaxIdleConns int32 + MaxActiveConns int32 + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + PushNotificationsEnabled bool + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // Default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // Default: 100ms + DialerRetryTimeout time.Duration } type lastDialErrorWrap struct { @@ -95,16 +122,21 @@ type ConnPool struct { queue chan struct{} connsMu sync.Mutex - conns []*Conn + conns map[uint64]*Conn idleConns []*Conn - poolSize int - idleConnsLen int + poolSize atomic.Int32 + idleConnsLen atomic.Int32 + idleCheckInProgress atomic.Bool stats Stats waitDurationNs atomic.Int64 _closed uint32 // atomic + + // Pool hooks manager for flexible connection processing + hookManagerMu sync.RWMutex + hookManager *PoolHookManager } var _ Pooler = (*ConnPool)(nil) @@ -114,34 +146,69 @@ func NewConnPool(opt *Options) *ConnPool { cfg: opt, queue: make(chan struct{}, opt.PoolSize), - conns: make([]*Conn, 0, opt.PoolSize), + conns: make(map[uint64]*Conn), idleConns: make([]*Conn, 0, opt.PoolSize), } - p.connsMu.Lock() - p.checkMinIdleConns() - p.connsMu.Unlock() + // Only create MinIdleConns if explicitly requested (> 0) + // This avoids creating connections during pool initialization for tests + if opt.MinIdleConns > 0 { + p.connsMu.Lock() + p.checkMinIdleConns() + p.connsMu.Unlock() + } return p } +// initializeHooks sets up the pool hooks system. +func (p *ConnPool) initializeHooks() { + p.hookManager = NewPoolHookManager() +} + +// AddPoolHook adds a pool hook to the pool. +func (p *ConnPool) AddPoolHook(hook PoolHook) { + p.hookManagerMu.Lock() + defer p.hookManagerMu.Unlock() + + if p.hookManager == nil { + p.initializeHooks() + } + p.hookManager.AddHook(hook) +} + +// RemovePoolHook removes a pool hook from the pool. +func (p *ConnPool) RemovePoolHook(hook PoolHook) { + p.hookManagerMu.Lock() + defer p.hookManagerMu.Unlock() + + if p.hookManager != nil { + p.hookManager.RemoveHook(hook) + } +} + func (p *ConnPool) checkMinIdleConns() { + if !p.idleCheckInProgress.CompareAndSwap(false, true) { + return + } + defer p.idleCheckInProgress.Store(false) + if p.cfg.MinIdleConns == 0 { return } - for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns { + + // Only create idle connections if we haven't reached the total pool size limit + // MinIdleConns should be a subset of PoolSize, not additional connections + for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns { select { case p.queue <- struct{}{}: - p.poolSize++ - p.idleConnsLen++ - + p.poolSize.Add(1) + p.idleConnsLen.Add(1) go func() { defer func() { if err := recover(); err != nil { - p.connsMu.Lock() - p.poolSize-- - p.idleConnsLen-- - p.connsMu.Unlock() + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) p.freeTurn() internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) @@ -150,12 +217,9 @@ func (p *ConnPool) checkMinIdleConns() { err := p.addIdleConn() if err != nil && err != ErrClosed { - p.connsMu.Lock() - p.poolSize-- - p.idleConnsLen-- - p.connsMu.Unlock() + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) } - p.freeTurn() }() default: @@ -172,6 +236,9 @@ func (p *ConnPool) addIdleConn() error { if err != nil { return err } + // Mark connection as usable after successful creation + // This is essential for normal pool operations + cn.SetUsable(true) p.connsMu.Lock() defer p.connsMu.Unlock() @@ -182,11 +249,15 @@ func (p *ConnPool) addIdleConn() error { return ErrClosed } - p.conns = append(p.conns, cn) + p.conns[cn.GetID()] = cn p.idleConns = append(p.idleConns, cn) return nil } +// NewConn creates a new connection and returns it to the user. +// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size. +// +// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades. func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.newConn(ctx, false) } @@ -196,33 +267,44 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, ErrClosed } - p.connsMu.Lock() - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { - p.connsMu.Unlock() + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) { return nil, ErrPoolExhausted } - p.connsMu.Unlock() - cn, err := p.dialConn(ctx, pooled) + dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout) + defer cancel() + cn, err := p.dialConn(dialCtx, pooled) if err != nil { return nil, err } + // Mark connection as usable after successful creation + // This is essential for normal pool operations + cn.SetUsable(true) - p.connsMu.Lock() - defer p.connsMu.Unlock() - - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { _ = cn.Close() return nil, ErrPoolExhausted } - p.conns = append(p.conns, cn) + p.connsMu.Lock() + defer p.connsMu.Unlock() + if p.closed() { + _ = cn.Close() + return nil, ErrClosed + } + // Check if pool was closed while we were waiting for the lock + if p.conns == nil { + p.conns = make(map[uint64]*Conn) + } + p.conns[cn.GetID()] = cn + if pooled { // If pool is full remove the cn on next Put. - if p.poolSize >= p.cfg.PoolSize { + currentPoolSize := p.poolSize.Load() + if currentPoolSize >= p.cfg.PoolSize { cn.pooled = false } else { - p.poolSize++ + p.poolSize.Add(1) } } @@ -238,18 +320,57 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, p.getLastDialError() } - netConn, err := p.cfg.Dialer(ctx) - if err != nil { - p.setLastDialError(err) - if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { - go p.tryDial() - } - return nil, err + // Retry dialing with backoff + // the context timeout is already handled by the context passed in + // so we may never reach the max retries, higher values don't hurt + maxRetries := p.cfg.DialerRetries + if maxRetries <= 0 { + maxRetries = 5 // Default value + } + backoffDuration := p.cfg.DialerRetryTimeout + if backoffDuration <= 0 { + backoffDuration = 100 * time.Millisecond // Default value } - cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) - cn.pooled = pooled - return cn, nil + var lastErr error + shouldLoop := true + // when the timeout is reached, we should stop retrying + // but keep the lastErr to return to the caller + // instead of a generic context deadline exceeded error + for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ { + netConn, err := p.cfg.Dialer(ctx) + if err != nil { + lastErr = err + // Add backoff delay for retry attempts + // (not for the first attempt, do at least one) + select { + case <-ctx.Done(): + shouldLoop = false + case <-time.After(backoffDuration): + // Continue with retry + } + continue + } + + // Success - create connection + cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) + cn.pooled = pooled + if p.cfg.ConnMaxLifetime > 0 { + cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime) + } else { + cn.expiresAt = noExpiration + } + + return cn, nil + } + + internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr) + // All retries failed - handle error tracking + p.setLastDialError(lastErr) + if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { + go p.tryDial() + } + return nil, lastErr } func (p *ConnPool) tryDial() { @@ -289,6 +410,14 @@ func (p *ConnPool) getLastDialError() error { // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { + return p.getConn(ctx) +} + +// getConn returns a connection from the pool. +func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { + var cn *Conn + var err error + if p.closed() { return nil, ErrClosed } @@ -297,9 +426,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return nil, err } + now := time.Now() + attempts := 0 for { + if attempts >= getAttempts { + internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) + break + } + attempts++ + p.connsMu.Lock() - cn, err := p.popIdle() + cn, err = p.popIdle() p.connsMu.Unlock() if err != nil { @@ -311,11 +448,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { break } - if !p.isHealthyConn(cn) { + if !p.isHealthyConn(cn, now) { _ = p.CloseConn(cn) continue } + // Process connection using the hooks system + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil { + internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) + // Failed to process connection, discard it + _ = p.CloseConn(cn) + continue + } + } + atomic.AddUint32(&p.stats.Hits, 1) return cn, nil } @@ -328,6 +479,19 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return nil, err } + // Process connection using the hooks system + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil { + // Failed to process connection, discard it + internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err) + _ = p.CloseConn(newcn) + return nil, err + } + } return newcn, nil } @@ -356,7 +520,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error { } return ctx.Err() case p.queue <- struct{}{}: - p.waitDurationNs.Add(time.Since(start).Nanoseconds()) + p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) atomic.AddUint32(&p.stats.WaitCount, 1) if !timer.Stop() { <-timer.C @@ -376,68 +540,130 @@ func (p *ConnPool) popIdle() (*Conn, error) { if p.closed() { return nil, ErrClosed } + defer p.checkMinIdleConns() + n := len(p.idleConns) if n == 0 { return nil, nil } var cn *Conn - if p.cfg.PoolFIFO { - cn = p.idleConns[0] - copy(p.idleConns, p.idleConns[1:]) - p.idleConns = p.idleConns[:n-1] - } else { - idx := n - 1 - cn = p.idleConns[idx] - p.idleConns = p.idleConns[:idx] + attempts := 0 + + maxAttempts := util.Min(popAttempts, n) + for attempts < maxAttempts { + if len(p.idleConns) == 0 { + return nil, nil + } + + if p.cfg.PoolFIFO { + cn = p.idleConns[0] + copy(p.idleConns, p.idleConns[1:]) + p.idleConns = p.idleConns[:len(p.idleConns)-1] + } else { + idx := len(p.idleConns) - 1 + cn = p.idleConns[idx] + p.idleConns = p.idleConns[:idx] + } + attempts++ + + if cn.IsUsable() { + p.idleConnsLen.Add(-1) + break + } + + // Connection is not usable, put it back in the pool + if p.cfg.PoolFIFO { + // FIFO: put at end (will be picked up last since we pop from front) + p.idleConns = append(p.idleConns, cn) + } else { + // LIFO: put at beginning (will be picked up last since we pop from end) + p.idleConns = append([]*Conn{cn}, p.idleConns...) + } + cn = nil } - p.idleConnsLen-- - p.checkMinIdleConns() + + // If we exhausted all attempts without finding a usable connection, return nil + if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() { + internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) + return nil, nil + } + return cn, nil } func (p *ConnPool) Put(ctx context.Context, cn *Conn) { + // Process connection using the hooks system + shouldPool := true shouldRemove := false - if cn.rd.Buffered() > 0 { - // Check if this might be push notification data - if p.cfg.Protocol == 3 { - // we know that there is something in the buffer, so peek at the next reply type without - // the potential to block and check if it's a push notification - if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush { - shouldRemove = true - } - } else { - // not a push notification since protocol 2 doesn't support them - shouldRemove = true - } + var err error - if shouldRemove { - // For non-RESP3 or data that is not a push notification, buffered data is unexpected - internal.Logger.Printf(ctx, "Conn has unread data, closing it") - p.Remove(ctx, cn, BadConnError{}) + if cn.HasBufferedData() { + // Peek at the reply type to check if it's a push notification + if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { + // Not a push notification or error peeking, remove connection + internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + p.Remove(ctx, cn, err) + } + // It's a push notification, allow pooling (client will handle it) + } + + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) + if err != nil { + internal.Logger.Printf(ctx, "Connection hook error: %v", err) + p.Remove(ctx, cn, err) return } } + // If hooks say to remove the connection, do so + if shouldRemove { + p.Remove(ctx, cn, errors.New("hook requested removal")) + return + } + + // If processor says not to pool the connection, remove it + if !shouldPool { + p.Remove(ctx, cn, errors.New("hook requested no pooling")) + return + } + if !cn.pooled { - p.Remove(ctx, cn, nil) + p.Remove(ctx, cn, errors.New("connection not pooled")) return } var shouldCloseConn bool - p.connsMu.Lock() - - if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns { - p.idleConns = append(p.idleConns, cn) - p.idleConnsLen++ + if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { + // unusable conns are expected to become usable at some point (background process is reconnecting them) + // put them at the opposite end of the queue + if !cn.IsUsable() { + if p.cfg.PoolFIFO { + p.connsMu.Lock() + p.idleConns = append(p.idleConns, cn) + p.connsMu.Unlock() + } else { + p.connsMu.Lock() + p.idleConns = append([]*Conn{cn}, p.idleConns...) + p.connsMu.Unlock() + } + } else { + p.connsMu.Lock() + p.idleConns = append(p.idleConns, cn) + p.connsMu.Unlock() + } + p.idleConnsLen.Add(1) } else { - p.removeConn(cn) + p.removeConnWithLock(cn) shouldCloseConn = true } - p.connsMu.Unlock() - p.freeTurn() if shouldCloseConn { @@ -447,8 +673,13 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) { p.removeConnWithLock(cn) + p.freeTurn() + _ = p.closeConn(cn) + + // Check if we need to create new idle connections to maintain MinIdleConns + p.checkMinIdleConns() } func (p *ConnPool) CloseConn(cn *Conn) error { @@ -463,17 +694,23 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) { } func (p *ConnPool) removeConn(cn *Conn) { - for i, c := range p.conns { - if c == cn { - p.conns = append(p.conns[:i], p.conns[i+1:]...) - if cn.pooled { - p.poolSize-- - p.checkMinIdleConns() + cid := cn.GetID() + delete(p.conns, cid) + atomic.AddUint32(&p.stats.StaleConns, 1) + + // Decrement pool size counter when removing a connection + if cn.pooled { + p.poolSize.Add(-1) + // this can be idle conn + for idx, ic := range p.idleConns { + if ic.GetID() == cid { + internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) + p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) + p.idleConnsLen.Add(-1) + break } - break } } - atomic.AddUint32(&p.stats.StaleConns, 1) } func (p *ConnPool) closeConn(cn *Conn) error { @@ -491,9 +728,9 @@ func (p *ConnPool) Len() int { // IdleLen returns number of idle connections. func (p *ConnPool) IdleLen() int { p.connsMu.Lock() - n := p.idleConnsLen + n := p.idleConnsLen.Load() p.connsMu.Unlock() - return n + return int(n) } func (p *ConnPool) Stats() *Stats { @@ -502,6 +739,7 @@ func (p *ConnPool) Stats() *Stats { Misses: atomic.LoadUint32(&p.stats.Misses), Timeouts: atomic.LoadUint32(&p.stats.Timeouts), WaitCount: atomic.LoadUint32(&p.stats.WaitCount), + Unusable: atomic.LoadUint32(&p.stats.Unusable), WaitDurationNs: p.waitDurationNs.Load(), TotalConns: uint32(p.Len()), @@ -542,30 +780,33 @@ func (p *ConnPool) Close() error { } } p.conns = nil - p.poolSize = 0 + p.poolSize.Store(0) p.idleConns = nil - p.idleConnsLen = 0 + p.idleConnsLen.Store(0) p.connsMu.Unlock() return firstErr } -func (p *ConnPool) isHealthyConn(cn *Conn) bool { - now := time.Now() - - if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime { +func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { + // slight optimization, check expiresAt first. + if cn.expiresAt.Before(now) { return false } + + // Check if connection has exceeded idle timeout if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { return false } - // Check connection health, but be aware of push notifications - if err := connCheck(cn.netConn); err != nil { + cn.SetUsedAt(now) + // Check basic connection health + // Use GetNetConn() to safely access netConn and avoid data races + if err := connCheck(cn.getNetConn()); err != nil { // If there's unexpected data, it might be push notifications (RESP3) // However, push notification processing is now handled by the client // before WithReader to ensure proper context is available to handlers - if err == errUnexpectedRead && p.cfg.Protocol == 3 { + if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead { // we know that there is something in the buffer, so peek at the next reply type without // the potential to block if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { @@ -579,7 +820,5 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool { return false } } - - cn.SetUsedAt(now) return true } diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 5a3fde19..136d6f2d 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -1,6 +1,8 @@ package pool -import "context" +import ( + "context" +) type SingleConnPool struct { pool Pooler @@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int { func (p *SingleConnPool) Stats() *Stats { return &Stats{} } + +func (p *SingleConnPool) AddPoolHook(hook PoolHook) {} + +func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 3adb99bc..dc4266a4 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int { func (p *StickyConnPool) Stats() *Stats { return &Stats{} } + +func (p *StickyConnPool) AddPoolHook(hook PoolHook) {} + +func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 736323d9..6a7870b5 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -2,15 +2,17 @@ package pool_test import ( "context" + "errors" "net" "sync" + "sync/atomic" "testing" "time" . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" - "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) var _ = Describe("ConnPool", func() { @@ -20,7 +22,7 @@ var _ = Describe("ConnPool", func() { BeforeEach(func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 10, + PoolSize: int32(10), PoolTimeout: time.Hour, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Millisecond, @@ -45,11 +47,11 @@ var _ = Describe("ConnPool", func() { <-closedChan return &net.TCPConn{}, nil }, - PoolSize: 10, + PoolSize: int32(10), PoolTimeout: time.Hour, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Millisecond, - MinIdleConns: minIdleConns, + MinIdleConns: int32(minIdleConns), }) wg.Wait() Expect(connPool.Close()).NotTo(HaveOccurred()) @@ -105,7 +107,7 @@ var _ = Describe("ConnPool", func() { // ok } - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("test")) // Check that Get is unblocked. select { @@ -130,8 +132,8 @@ var _ = Describe("MinIdleConns", func() { newConnPool := func() *pool.ConnPool { connPool := pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: poolSize, - MinIdleConns: minIdleConns, + PoolSize: int32(poolSize), + MinIdleConns: int32(minIdleConns), PoolTimeout: 100 * time.Millisecond, DialTimeout: 1 * time.Second, ConnMaxIdleTime: -1, @@ -168,7 +170,7 @@ var _ = Describe("MinIdleConns", func() { Context("after Remove", func() { BeforeEach(func() { - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("test")) }) It("has idle connections", func() { @@ -245,7 +247,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Remove(ctx, cns[i], nil) + connPool.Remove(ctx, cns[i], errors.New("test")) mu.RUnlock() }) @@ -309,7 +311,7 @@ var _ = Describe("race", func() { It("does not happen on Get, Put, and Remove", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 10, + PoolSize: int32(10), PoolTimeout: time.Minute, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Millisecond, @@ -328,7 +330,7 @@ var _ = Describe("race", func() { cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("test")) } } }) @@ -339,15 +341,15 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: 1000, - MinIdleConns: 50, + PoolSize: int32(1000), + MinIdleConns: int32(50), PoolTimeout: 3 * time.Second, DialTimeout: 1 * time.Second, } p := pool.NewConnPool(opt) var wg sync.WaitGroup - for i := 0; i < opt.PoolSize; i++ { + for i := int32(0); i < opt.PoolSize; i++ { wg.Add(1) go func() { defer wg.Done() @@ -366,8 +368,8 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { panic("test panic") }, - PoolSize: 100, - MinIdleConns: 30, + PoolSize: int32(100), + MinIdleConns: int32(30), } p := pool.NewConnPool(opt) @@ -377,14 +379,14 @@ var _ = Describe("race", func() { state := p.Stats() return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0 }, "3s", "50ms").Should(BeTrue()) - }) - + }) + It("wait", func() { opt := &pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 3 * time.Second, } p := pool.NewConnPool(opt) @@ -415,7 +417,7 @@ var _ = Describe("race", func() { return &net.TCPConn{}, nil }, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: testPoolTimeout, } p := pool.NewConnPool(opt) @@ -435,3 +437,73 @@ var _ = Describe("race", func() { Expect(stats.Timeouts).To(Equal(uint32(1))) }) }) + +// TestDialerRetryConfiguration tests the new DialerRetries and DialerRetryTimeout options +func TestDialerRetryConfiguration(t *testing.T) { + ctx := context.Background() + + t.Run("CustomDialerRetries", func(t *testing.T) { + var attempts int64 + failingDialer := func(ctx context.Context) (net.Conn, error) { + atomic.AddInt64(&attempts, 1) + return nil, errors.New("dial failed") + } + + connPool := pool.NewConnPool(&pool.Options{ + Dialer: failingDialer, + PoolSize: 1, + PoolTimeout: time.Second, + DialTimeout: time.Second, + DialerRetries: 3, // Custom retry count + DialerRetryTimeout: 10 * time.Millisecond, // Fast retries for testing + }) + defer connPool.Close() + + _, err := connPool.Get(ctx) + if err == nil { + t.Error("Expected error from failing dialer") + } + + // Should have attempted at least 3 times (DialerRetries = 3) + // There might be additional attempts due to pool logic + finalAttempts := atomic.LoadInt64(&attempts) + if finalAttempts < 3 { + t.Errorf("Expected at least 3 dial attempts, got %d", finalAttempts) + } + if finalAttempts > 6 { + t.Errorf("Expected around 3 dial attempts, got %d (too many)", finalAttempts) + } + }) + + t.Run("DefaultDialerRetries", func(t *testing.T) { + var attempts int64 + failingDialer := func(ctx context.Context) (net.Conn, error) { + atomic.AddInt64(&attempts, 1) + return nil, errors.New("dial failed") + } + + connPool := pool.NewConnPool(&pool.Options{ + Dialer: failingDialer, + PoolSize: 1, + PoolTimeout: time.Second, + DialTimeout: time.Second, + // DialerRetries and DialerRetryTimeout not set - should use defaults + }) + defer connPool.Close() + + _, err := connPool.Get(ctx) + if err == nil { + t.Error("Expected error from failing dialer") + } + + // Should have attempted 5 times (default DialerRetries = 5) + finalAttempts := atomic.LoadInt64(&attempts) + if finalAttempts != 5 { + t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts) + } + }) +} + +func init() { + logging.Disable() +} diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go new file mode 100644 index 00000000..73ee4b3e --- /dev/null +++ b/internal/pool/pubsub.go @@ -0,0 +1,78 @@ +package pool + +import ( + "context" + "net" + "sync" + "sync/atomic" +) + +type PubSubStats struct { + Created uint32 + Untracked uint32 + Active uint32 +} + +// PubSubPool manages a pool of PubSub connections. +type PubSubPool struct { + opt *Options + netDialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Map to track active PubSub connections + activeConns sync.Map // map[uint64]*Conn (connID -> conn) + closed atomic.Bool + stats PubSubStats +} + +func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { + return &PubSubPool{ + opt: opt, + netDialer: netDialer, + } +} + +func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) { + if p.closed.Load() { + return nil, ErrClosed + } + + netConn, err := p.netDialer(ctx, network, addr) + if err != nil { + return nil, err + } + cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize) + cn.pubsub = true + atomic.AddUint32(&p.stats.Created, 1) + return cn, nil + +} + +func (p *PubSubPool) TrackConn(cn *Conn) { + atomic.AddUint32(&p.stats.Active, 1) + p.activeConns.Store(cn.GetID(), cn) +} + +func (p *PubSubPool) UntrackConn(cn *Conn) { + atomic.AddUint32(&p.stats.Active, ^uint32(0)) + atomic.AddUint32(&p.stats.Untracked, 1) + p.activeConns.Delete(cn.GetID()) +} + +func (p *PubSubPool) Close() error { + p.closed.Store(true) + p.activeConns.Range(func(key, value interface{}) bool { + cn := value.(*Conn) + _ = cn.Close() + return true + }) + return nil +} + +func (p *PubSubPool) Stats() *PubSubStats { + // load stats atomically + return &PubSubStats{ + Created: atomic.LoadUint32(&p.stats.Created), + Untracked: atomic.LoadUint32(&p.stats.Untracked), + Active: atomic.LoadUint32(&p.stats.Active), + } +} diff --git a/internal/redis.go b/internal/redis.go new file mode 100644 index 00000000..0459e42b --- /dev/null +++ b/internal/redis.go @@ -0,0 +1,3 @@ +package internal + +const RedisNull = "null" diff --git a/internal/util/convert.go b/internal/util/convert.go index d326d50d..b743a4f0 100644 --- a/internal/util/convert.go +++ b/internal/util/convert.go @@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 { } return f } + +// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur. +func SafeIntToInt32(value int, fieldName string) (int32, error) { + if value > math.MaxInt32 { + return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32) + } + if value < math.MinInt32 { + return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32) + } + return int32(value), nil +} diff --git a/internal/util/math.go b/internal/util/math.go new file mode 100644 index 00000000..e707c47a --- /dev/null +++ b/internal/util/math.go @@ -0,0 +1,17 @@ +package util + +// Max returns the maximum of two integers +func Max(a, b int) int { + if a > b { + return a + } + return b +} + +// Min returns the minimum of two integers +func Min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/logging/logging.go b/logging/logging.go new file mode 100644 index 00000000..e2759284 --- /dev/null +++ b/logging/logging.go @@ -0,0 +1,121 @@ +// Package logging provides logging level constants and utilities for the go-redis library. +// This package centralizes logging configuration to ensure consistency across all components. +package logging + +import ( + "context" + "fmt" + "strings" + + "github.com/redis/go-redis/v9/internal" +) + +// LogLevel represents the logging level +type LogLevel int + +// Log level constants for the entire go-redis library +const ( + LogLevelError LogLevel = iota // 0 - errors only + LogLevelWarn // 1 - warnings and errors + LogLevelInfo // 2 - info, warnings, and errors + LogLevelDebug // 3 - debug, info, warnings, and errors +) + +// String returns the string representation of the log level +func (l LogLevel) String() string { + switch l { + case LogLevelError: + return "ERROR" + case LogLevelWarn: + return "WARN" + case LogLevelInfo: + return "INFO" + case LogLevelDebug: + return "DEBUG" + default: + return "UNKNOWN" + } +} + +// IsValid returns true if the log level is valid +func (l LogLevel) IsValid() bool { + return l >= LogLevelError && l <= LogLevelDebug +} + +func (l LogLevel) WarnOrAbove() bool { + return l >= LogLevelWarn +} + +func (l LogLevel) InfoOrAbove() bool { + return l >= LogLevelInfo +} + +func (l LogLevel) DebugOrAbove() bool { + return l >= LogLevelDebug +} + +// VoidLogger is a logger that does nothing. +// Used to disable logging and thus speed up the library. +type VoidLogger struct{} + +func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) { + // do nothing +} + +// Disable disables logging by setting the internal logger to a void logger. +// This can be used to speed up the library if logging is not needed. +// It will override any custom logger that was set before and set the VoidLogger. +func Disable() { + internal.Logger = &VoidLogger{} +} + +// Enable enables logging by setting the internal logger to the default logger. +// This is the default behavior. +// You can use redis.SetLogger to set a custom logger. +// +// NOTE: This function is not thread-safe. +// It will override any custom logger that was set before and set the DefaultLogger. +func Enable() { + internal.Logger = internal.NewDefaultLogger() +} + +// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings. +// This can be used to filter out messages containing sensitive information. +func NewBlacklistLogger(substr []string) internal.Logging { + l := internal.NewDefaultLogger() + return &filterLogger{logger: l, substr: substr, blacklist: true} +} + +// NewWhitelistLogger returns a new logger that only logs messages containing any of the substrings. +// This can be used to only log messages related to specific commands or patterns. +func NewWhitelistLogger(substr []string) internal.Logging { + l := internal.NewDefaultLogger() + return &filterLogger{logger: l, substr: substr, blacklist: false} +} + +type filterLogger struct { + logger internal.Logging + blacklist bool + substr []string +} + +func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface{}) { + msg := fmt.Sprintf(format, v...) + found := false + for _, substr := range l.substr { + if strings.Contains(msg, substr) { + found = true + if l.blacklist { + return + } + } + } + // whitelist, only log if one of the substrings is present + if !l.blacklist && !found { + return + } + if l.logger != nil { + l.logger.Printf(ctx, format, v...) + return + } +} diff --git a/logging/logging_test.go b/logging/logging_test.go new file mode 100644 index 00000000..9f26d222 --- /dev/null +++ b/logging/logging_test.go @@ -0,0 +1,59 @@ +package logging + +import "testing" + +func TestLogLevel_String(t *testing.T) { + tests := []struct { + level LogLevel + expected string + }{ + {LogLevelError, "ERROR"}, + {LogLevelWarn, "WARN"}, + {LogLevelInfo, "INFO"}, + {LogLevelDebug, "DEBUG"}, + {LogLevel(99), "UNKNOWN"}, + } + + for _, test := range tests { + if got := test.level.String(); got != test.expected { + t.Errorf("LogLevel(%d).String() = %q, want %q", test.level, got, test.expected) + } + } +} + +func TestLogLevel_IsValid(t *testing.T) { + tests := []struct { + level LogLevel + expected bool + }{ + {LogLevelError, true}, + {LogLevelWarn, true}, + {LogLevelInfo, true}, + {LogLevelDebug, true}, + {LogLevel(-1), false}, + {LogLevel(4), false}, + {LogLevel(99), false}, + } + + for _, test := range tests { + if got := test.level.IsValid(); got != test.expected { + t.Errorf("LogLevel(%d).IsValid() = %v, want %v", test.level, got, test.expected) + } + } +} + +func TestLogLevelConstants(t *testing.T) { + // Test that constants have expected values + if LogLevelError != 0 { + t.Errorf("LogLevelError = %d, want 0", LogLevelError) + } + if LogLevelWarn != 1 { + t.Errorf("LogLevelWarn = %d, want 1", LogLevelWarn) + } + if LogLevelInfo != 2 { + t.Errorf("LogLevelInfo = %d, want 2", LogLevelInfo) + } + if LogLevelDebug != 3 { + t.Errorf("LogLevelDebug = %d, want 3", LogLevelDebug) + } +} diff --git a/main_test.go b/main_test.go index 29e6014b..a192aa3a 100644 --- a/main_test.go +++ b/main_test.go @@ -13,6 +13,7 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/logging" ) const ( @@ -102,6 +103,7 @@ var _ = BeforeSuite(func() { fmt.Printf("RCEDocker: %v\n", RCEDocker) fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion) fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE")) + logging.Disable() if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") diff --git a/options.go b/options.go index 237be6be..0e154ac0 100644 --- a/options.go +++ b/options.go @@ -14,9 +14,11 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/hitless" "github.com/redis/go-redis/v9/internal/pool" - "github.com/redis/go-redis/v9/push" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/push" ) // Limiter is the interface of a rate limiter or a circuit breaker. @@ -107,9 +109,19 @@ type Options struct { // DialTimeout for establishing new connections. // - // default: 5 seconds + // default: 10 seconds DialTimeout time.Duration + // DialerRetries is the maximum number of retry attempts when dialing fails. + // + // default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // + // default: 100 milliseconds + DialerRetryTimeout time.Duration + // ReadTimeout for socket reads. If reached, commands will fail // with a timeout instead of blocking. Supported values: // @@ -153,6 +165,7 @@ type Options struct { // // Note that FIFO has slightly higher overhead compared to LIFO, // but it helps closing idle connections faster reducing the pool size. + // default: false PoolFIFO bool // PoolSize is the base number of socket connections. @@ -244,8 +257,19 @@ type Options struct { // When a node is marked as failing, it will be avoided for this duration. // Default is 15 seconds. FailingTimeoutSeconds int + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Mode is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it. + HitlessUpgradeConfig *HitlessUpgradeConfig } +// HitlessUpgradeConfig provides configuration options for hitless upgrades. +// This is an alias to hitless.Config for convenience. +type HitlessUpgradeConfig = hitless.Config + func (opt *Options) init() { if opt.Addr == "" { opt.Addr = "localhost:6379" @@ -261,7 +285,13 @@ func (opt *Options) init() { opt.Protocol = 3 } if opt.DialTimeout == 0 { - opt.DialTimeout = 5 * time.Second + opt.DialTimeout = 10 * time.Second + } + if opt.DialerRetries == 0 { + opt.DialerRetries = 5 + } + if opt.DialerRetryTimeout == 0 { + opt.DialerRetryTimeout = 100 * time.Millisecond } if opt.Dialer == nil { opt.Dialer = NewDialer(opt) @@ -320,13 +350,36 @@ func (opt *Options) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + + opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns) + + // auto-detect endpoint type if not specified + endpointType := opt.HitlessUpgradeConfig.EndpointType + if endpointType == "" || endpointType == hitless.EndpointTypeAuto { + // Auto-detect endpoint type if not specified + endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil) + } + opt.HitlessUpgradeConfig.EndpointType = endpointType } func (opt *Options) clone() *Options { clone := *opt + + // Deep clone HitlessUpgradeConfig to avoid sharing between clients + if opt.HitlessUpgradeConfig != nil { + configClone := *opt.HitlessUpgradeConfig + clone.HitlessUpgradeConfig = &configClone + } + return &clone } +// NewDialer returns a function that will be used as the default dialer +// when none is specified in Options.Dialer. +func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) { + return NewDialer(opt) +} + // NewDialer returns a function that will be used as the default dialer // when none is specified in Options.Dialer. func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) { @@ -612,23 +665,84 @@ func getUserPassword(u *url.URL) (string, string) { func newConnPool( opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), -) *pool.ConnPool { +) (*pool.ConnPool, error) { + poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize") + if err != nil { + return nil, err + } + + minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns") + if err != nil { + return nil, err + } + + maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + if err != nil { + return nil, err + } + + maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + if err != nil { + return nil, err + } + return pool.NewConnPool(&pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { return dialer(ctx, opt.Network, opt.Addr) }, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - DialTimeout: opt.DialTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, - // Pass protocol version for push notification optimization - Protocol: opt.Protocol, - ReadBufferSize: opt.ReadBufferSize, - WriteBufferSize: opt.WriteBufferSize, - }) + PoolFIFO: opt.PoolFIFO, + PoolSize: poolSize, + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + MinIdleConns: minIdleConns, + MaxIdleConns: maxIdleConns, + MaxActiveConns: maxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, + PushNotificationsEnabled: opt.Protocol == 3, + }), nil +} + +func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), +) (*pool.PubSubPool, error) { + poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize") + if err != nil { + return nil, err + } + + minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns") + if err != nil { + return nil, err + } + + maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + if err != nil { + return nil, err + } + + maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + if err != nil { + return nil, err + } + + return pool.NewPubSubPool(&pool.Options{ + PoolFIFO: opt.PoolFIFO, + PoolSize: poolSize, + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + MinIdleConns: minIdleConns, + MaxIdleConns: maxIdleConns, + MaxActiveConns: maxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ReadBufferSize: 32 * 1024, + WriteBufferSize: 32 * 1024, + PushNotificationsEnabled: opt.Protocol == 3, + }, dialer), nil } diff --git a/osscluster.go b/osscluster.go index ec77a95c..5bae4555 100644 --- a/osscluster.go +++ b/osscluster.go @@ -20,6 +20,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/push" ) const ( @@ -38,6 +39,7 @@ type ClusterOptions struct { ClientName string // NewClient creates a cluster node client with provided name and options. + // If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications. NewClient func(opt *Options) *Client // The maximum number of retries before giving up. Command is retried @@ -125,10 +127,22 @@ type ClusterOptions struct { // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. UnstableResp3 bool + // PushNotificationProcessor is the processor for handling push notifications. + // If nil, a default processor will be created for RESP3 connections. + PushNotificationProcessor push.NotificationProcessor + // FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing. // When a node is marked as failing, it will be avoided for this duration. // Default is 15 seconds. FailingTimeoutSeconds int + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Mode is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it. + // The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless. + HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *ClusterOptions) init() { @@ -319,6 +333,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er } func (opt *ClusterOptions) clientOptions() *Options { + // Clone HitlessUpgradeConfig to avoid sharing between cluster node clients + var hitlessConfig *HitlessUpgradeConfig + if opt.HitlessUpgradeConfig != nil { + configClone := *opt.HitlessUpgradeConfig + hitlessConfig = &configClone + } + return &Options{ ClientName: opt.ClientName, Dialer: opt.Dialer, @@ -360,8 +381,10 @@ func (opt *ClusterOptions) clientOptions() *Options { // much use for ClusterSlots config). This means we cannot execute the // READONLY command against that node -- setting readOnly to false in such // situations in the options below will prevent that from happening. - readOnly: opt.ReadOnly && opt.ClusterSlots == nil, - UnstableResp3: opt.UnstableResp3, + readOnly: opt.ReadOnly && opt.ClusterSlots == nil, + UnstableResp3: opt.UnstableResp3, + HitlessUpgradeConfig: hitlessConfig, + PushNotificationProcessor: opt.PushNotificationProcessor, } } @@ -1830,12 +1853,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s return err } +// hitless won't work here for now func (c *ClusterClient) pubSub() *PubSub { var node *clusterNode pubsub := &PubSub{ opt: c.opt.clientOptions(), - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { if node != nil { panic("node != nil") } @@ -1850,18 +1873,25 @@ func (c *ClusterClient) pubSub() *PubSub { if err != nil { return nil, err } - - cn, err := node.Client.newConn(context.TODO()) + cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels) if err != nil { node = nil - return nil, err } - + // will return nil if already initialized + err = node.Client.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + node = nil + return nil, err + } + node.Client.pubSubPool.TrackConn(cn) return cn, nil }, closeConn: func(cn *pool.Conn) error { - err := node.Client.connPool.CloseConn(cn) + // Untrack connection from PubSubPool + node.Client.pubSubPool.UntrackConn(cn) + err := cn.Close() node = nil return err }, diff --git a/pool_pubsub_bench_test.go b/pool_pubsub_bench_test.go new file mode 100644 index 00000000..0db8ec55 --- /dev/null +++ b/pool_pubsub_bench_test.go @@ -0,0 +1,375 @@ +// Pool and PubSub Benchmark Suite +// +// This file contains comprehensive benchmarks for both pool operations and PubSub initialization. +// It's designed to be run against different branches to compare performance. +// +// Usage Examples: +// # Run all benchmarks +// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go +// +// # Run only pool benchmarks +// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go +// +// # Run only PubSub benchmarks +// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go +// +// # Compare between branches +// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt +// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt +// benchcmp branch1.txt branch2.txt +// +// # Run with memory profiling +// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go +// +// # Run with CPU profiling +// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go + +package redis_test + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/internal/pool" +) + +// dummyDialer creates a mock connection for benchmarking +func dummyDialer(ctx context.Context) (net.Conn, error) { + return &dummyConn{}, nil +} + +// dummyConn implements net.Conn for benchmarking +type dummyConn struct{} + +func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil } +func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (c *dummyConn) Close() error { return nil } +func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} } +func (c *dummyConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} +} +func (c *dummyConn) SetDeadline(t time.Time) error { return nil } +func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil } + +// ============================================================================= +// POOL BENCHMARKS +// ============================================================================= + +// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations +func BenchmarkPoolGetPut(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: int32(poolSize), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), // Start with no idle connections + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns +func BenchmarkPoolGetPutWithMinIdle(b *testing.B) { + ctx := context.Background() + + configs := []struct { + poolSize int + minIdleConns int + }{ + {8, 2}, + {16, 4}, + {32, 8}, + {64, 16}, + } + + for _, config := range configs { + b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: int32(config.poolSize), + MinIdleConns: int32(config.minIdleConns), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency +func BenchmarkPoolConcurrentGetPut(b *testing.B) { + ctx := context.Background() + + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: int32(32), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + // Test with different levels of concurrency + concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) { + b.SetParallelism(concurrency) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// ============================================================================= +// PUBSUB BENCHMARKS +// ============================================================================= + +// benchmarkClient creates a Redis client for benchmarking with mock dialer +func benchmarkClient(poolSize int) *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: "localhost:6379", // Mock address + DialTimeout: time.Second, + ReadTimeout: time.Second, + WriteTimeout: time.Second, + PoolSize: poolSize, + MinIdleConns: 0, // Start with no idle connections for consistent benchmarks + }) +} + +// BenchmarkPubSubCreation benchmarks PubSub creation and subscription +func BenchmarkPubSubCreation(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 4, 8, 16, 32} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + client := benchmarkClient(poolSize) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription +func BenchmarkPubSubPatternCreation(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 4, 8, 16, 32} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + client := benchmarkClient(poolSize) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.PSubscribe(ctx, "test-*") + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation +func BenchmarkPubSubConcurrentCreation(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + concurrencyLevels := []int{1, 2, 4, 8, 16} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + var wg sync.WaitGroup + semaphore := make(chan struct{}, concurrency) + + for i := 0; i < b.N; i++ { + wg.Add(1) + semaphore <- struct{}{} + + go func() { + defer wg.Done() + defer func() { <-semaphore }() + + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + }() + } + + wg.Wait() + }) + } +} + +// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels +func BenchmarkPubSubMultipleChannels(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(16) + defer client.Close() + + channelCounts := []int{1, 5, 10, 25, 50, 100} + + for _, channelCount := range channelCounts { + b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) { + // Prepare channel names + channels := make([]string, channelCount) + for i := 0; i < channelCount; i++ { + channels[i] = fmt.Sprintf("channel-%d", i) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.Subscribe(ctx, channels...) + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubReuse benchmarks reusing PubSub connections +func BenchmarkPubSubReuse(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(16) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Benchmark just the creation and closing of PubSub connections + // This simulates reuse patterns without requiring actual Redis operations + pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i)) + pubsub.Close() + } +} + +// ============================================================================= +// COMBINED BENCHMARKS +// ============================================================================= + +// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations +func BenchmarkPoolAndPubSubMixed(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Mix of pool stats collection and PubSub creation + if pb.Next() { + // Pool stats operation + stats := client.PoolStats() + _ = stats.Hits + stats.Misses // Use the stats to prevent optimization + } + + if pb.Next() { + // PubSub operation + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + } + }) +} + +// BenchmarkPoolStatsCollection benchmarks pool statistics collection +func BenchmarkPoolStatsCollection(b *testing.B) { + client := benchmarkClient(16) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + stats := client.PoolStats() + _ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization + } +} + +// BenchmarkPoolHighContention tests pool performance under high contention +func BenchmarkPoolHighContention(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // High contention Get/Put operations + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + }) +} diff --git a/pubsub.go b/pubsub.go index 75327dd2..0f535a03 100644 --- a/pubsub.go +++ b/pubsub.go @@ -22,7 +22,7 @@ import ( type PubSub struct { opt *Options - newConn func(ctx context.Context, channels []string) (*pool.Conn, error) + newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) closeConn func(*pool.Conn) error mu sync.Mutex @@ -42,6 +42,9 @@ type PubSub struct { // Push notification processor for handling generic push notifications pushProcessor push.NotificationProcessor + + // Cleanup callback for hitless upgrade tracking + onClose func() } func (c *PubSub) init() { @@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er return c.cn, nil } + if c.opt.Addr == "" { + // TODO(hitless): + // this is probably cluster client + // c.newConn will ignore the addr argument + // will be changed when we have hitless upgrades for cluster clients + c.opt.Addr = internal.RedisNull + } + channels := mapKeys(c.channels) channels = append(channels, newChannels...) - cn, err := c.newConn(ctx, channels) + cn, err := c.newConn(ctx, c.opt.Addr, channels) if err != nil { return nil, err } @@ -157,12 +168,31 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo if c.cn != cn { return } + + if !cn.IsUsable() || cn.ShouldHandoff() { + c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable")) + } + if isBadConn(err, allowTimeout, c.opt.Addr) { c.reconnect(ctx, err) } } func (c *PubSub) reconnect(ctx context.Context, reason error) { + if c.cn != nil && c.cn.ShouldHandoff() { + newEndpoint := c.cn.GetHandoffEndpoint() + // If new endpoint is NULL, use the original address + if newEndpoint == internal.RedisNull { + newEndpoint = c.opt.Addr + } + + if newEndpoint != "" { + // Update the address in the options + oldAddr := c.cn.RemoteAddr().String() + c.opt.Addr = newEndpoint + internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) + } + } _ = c.closeTheCn(reason) _, _ = c.conn(ctx, nil) } @@ -171,9 +201,6 @@ func (c *PubSub) closeTheCn(reason error) error { if c.cn == nil { return nil } - if !c.closed { - internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) - } err := c.closeConn(c.cn) c.cn = nil return err @@ -189,6 +216,11 @@ func (c *PubSub) Close() error { c.closed = true close(c.exit) + // Call cleanup callback if set + if c.onClose != nil { + c.onClose() + } + return c.closeTheCn(pool.ErrClosed) } @@ -444,11 +476,10 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) } return c.cmd.readReply(rd) }) - c.releaseConnWithLock(ctx, cn, err, timeout > 0) if err != nil { @@ -461,6 +492,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int // Receive returns a message as a Subscription, Message, Pong or error. // See PubSub example for details. This is low-level API and in most cases // Channel should be used instead. +// Receive returns a message as a Subscription, Message, Pong, or an error. +// See PubSub example for details. This is a low-level API and in most cases +// Channel should be used instead. +// This method blocks until a message is received or an error occurs. +// It may return early with an error if the context is canceled, the connection fails, +// or other internal errors occur. func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { return c.ReceiveTimeout(ctx, 0) } @@ -543,7 +580,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac } func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error { - if c.pushProcessor == nil { + // Only process push notifications for RESP3 connections with a processor + if c.opt.Protocol != 3 || c.pushProcessor == nil { return nil } diff --git a/pubsub_test.go b/pubsub_test.go index 2f3f4604..585433eb 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -113,6 +113,9 @@ var _ = Describe("PubSub", func() { pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2") defer pubsub.Close() + // sleep a bit to make sure redis knows about the subscriptions + time.Sleep(10 * time.Millisecond) + channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result() Expect(err).NotTo(HaveOccurred()) Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"})) diff --git a/push/handler_context.go b/push/handler_context.go index 3bcf128f..c39e186b 100644 --- a/push/handler_context.go +++ b/push/handler_context.go @@ -1,8 +1,6 @@ package push -import ( - "github.com/redis/go-redis/v9/internal/pool" -) +// No imports needed for this file // NotificationHandlerContext provides context information about where a push notification was received. // This struct allows handlers to make informed decisions based on the source of the notification @@ -35,7 +33,11 @@ type NotificationHandlerContext struct { PubSub interface{} // Conn is the specific connection on which the notification was received. - Conn *pool.Conn + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *pool.Conn + Conn interface{} // IsBlocking indicates if the notification was received on a blocking connection. IsBlocking bool diff --git a/push/processor_unit_test.go b/push/processor_unit_test.go new file mode 100644 index 00000000..ce799048 --- /dev/null +++ b/push/processor_unit_test.go @@ -0,0 +1,315 @@ +package push + +import ( + "context" + "testing" +) + +// TestProcessorCreation tests processor creation and initialization +func TestProcessorCreation(t *testing.T) { + t.Run("NewProcessor", func(t *testing.T) { + processor := NewProcessor() + if processor == nil { + t.Fatal("NewProcessor should not return nil") + } + if processor.registry == nil { + t.Error("Processor should have a registry") + } + }) + + t.Run("NewVoidProcessor", func(t *testing.T) { + voidProcessor := NewVoidProcessor() + if voidProcessor == nil { + t.Fatal("NewVoidProcessor should not return nil") + } + }) +} + +// TestProcessorHandlerManagement tests handler registration and retrieval +func TestProcessorHandlerManagement(t *testing.T) { + processor := NewProcessor() + handler := &UnitTestHandler{name: "test-handler"} + + t.Run("RegisterHandler", func(t *testing.T) { + err := processor.RegisterHandler("TEST", handler, false) + if err != nil { + t.Errorf("RegisterHandler should not error: %v", err) + } + + // Verify handler is registered + retrievedHandler := processor.GetHandler("TEST") + if retrievedHandler != handler { + t.Error("GetHandler should return the registered handler") + } + }) + + t.Run("RegisterProtectedHandler", func(t *testing.T) { + protectedHandler := &UnitTestHandler{name: "protected-handler"} + err := processor.RegisterHandler("PROTECTED", protectedHandler, true) + if err != nil { + t.Errorf("RegisterHandler should not error for protected handler: %v", err) + } + + // Verify handler is registered + retrievedHandler := processor.GetHandler("PROTECTED") + if retrievedHandler != protectedHandler { + t.Error("GetHandler should return the protected handler") + } + }) + + t.Run("GetNonExistentHandler", func(t *testing.T) { + handler := processor.GetHandler("NONEXISTENT") + if handler != nil { + t.Error("GetHandler should return nil for non-existent handler") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + err := processor.UnregisterHandler("TEST") + if err != nil { + t.Errorf("UnregisterHandler should not error: %v", err) + } + + // Verify handler is removed + retrievedHandler := processor.GetHandler("TEST") + if retrievedHandler != nil { + t.Error("GetHandler should return nil after unregistering") + } + }) + + t.Run("UnregisterProtectedHandler", func(t *testing.T) { + err := processor.UnregisterHandler("PROTECTED") + if err == nil { + t.Error("UnregisterHandler should error for protected handler") + } + + // Verify handler is still there + retrievedHandler := processor.GetHandler("PROTECTED") + if retrievedHandler == nil { + t.Error("Protected handler should not be removed") + } + }) +} + +// TestVoidProcessorBehavior tests void processor behavior +func TestVoidProcessorBehavior(t *testing.T) { + voidProcessor := NewVoidProcessor() + handler := &UnitTestHandler{name: "test-handler"} + + t.Run("GetHandler", func(t *testing.T) { + retrievedHandler := voidProcessor.GetHandler("ANY") + if retrievedHandler != nil { + t.Error("VoidProcessor GetHandler should always return nil") + } + }) + + t.Run("RegisterHandler", func(t *testing.T) { + err := voidProcessor.RegisterHandler("TEST", handler, false) + if err == nil { + t.Error("VoidProcessor RegisterHandler should return error") + } + + // Check error type + if !IsVoidProcessorError(err) { + t.Error("Error should be a VoidProcessorError") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + err := voidProcessor.UnregisterHandler("TEST") + if err == nil { + t.Error("VoidProcessor UnregisterHandler should return error") + } + + // Check error type + if !IsVoidProcessorError(err) { + t.Error("Error should be a VoidProcessorError") + } + }) +} + +// TestProcessPendingNotificationsNilReader tests handling of nil reader +func TestProcessPendingNotificationsNilReader(t *testing.T) { + t.Run("ProcessorWithNilReader", func(t *testing.T) { + processor := NewProcessor() + ctx := context.Background() + handlerCtx := NotificationHandlerContext{} + + err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil) + if err != nil { + t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err) + } + }) + + t.Run("VoidProcessorWithNilReader", func(t *testing.T) { + voidProcessor := NewVoidProcessor() + ctx := context.Background() + handlerCtx := NotificationHandlerContext{} + + err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil) + if err != nil { + t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err) + } + }) +} + +// TestWillHandleNotificationInClient tests the notification filtering logic +func TestWillHandleNotificationInClient(t *testing.T) { + testCases := []struct { + name string + notificationType string + shouldHandle bool + }{ + // Pub/Sub notifications (should be handled in client) + {"message", "message", true}, + {"pmessage", "pmessage", true}, + {"subscribe", "subscribe", true}, + {"unsubscribe", "unsubscribe", true}, + {"psubscribe", "psubscribe", true}, + {"punsubscribe", "punsubscribe", true}, + {"smessage", "smessage", true}, + {"ssubscribe", "ssubscribe", true}, + {"sunsubscribe", "sunsubscribe", true}, + + // Push notifications (should be handled by processor) + {"MOVING", "MOVING", false}, + {"MIGRATING", "MIGRATING", false}, + {"MIGRATED", "MIGRATED", false}, + {"FAILING_OVER", "FAILING_OVER", false}, + {"FAILED_OVER", "FAILED_OVER", false}, + {"custom", "custom", false}, + {"unknown", "unknown", false}, + {"empty", "", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := willHandleNotificationInClient(tc.notificationType) + if result != tc.shouldHandle { + t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle) + } + }) + } +} + +// TestProcessorErrorHandlingUnit tests error handling scenarios +func TestProcessorErrorHandlingUnit(t *testing.T) { + processor := NewProcessor() + + t.Run("RegisterNilHandler", func(t *testing.T) { + err := processor.RegisterHandler("TEST", nil, false) + if err == nil { + t.Error("RegisterHandler should error with nil handler") + } + + // Check error type + if !IsHandlerNilError(err) { + t.Error("Error should be a HandlerNilError") + } + }) + + t.Run("RegisterDuplicateHandler", func(t *testing.T) { + handler1 := &UnitTestHandler{name: "handler1"} + handler2 := &UnitTestHandler{name: "handler2"} + + // Register first handler + err := processor.RegisterHandler("DUPLICATE", handler1, false) + if err != nil { + t.Errorf("First RegisterHandler should not error: %v", err) + } + + // Try to register second handler with same name + err = processor.RegisterHandler("DUPLICATE", handler2, false) + if err == nil { + t.Error("RegisterHandler should error when registering duplicate handler") + } + + // Verify original handler is still there + retrievedHandler := processor.GetHandler("DUPLICATE") + if retrievedHandler != handler1 { + t.Error("Original handler should remain after failed duplicate registration") + } + }) + + t.Run("UnregisterNonExistentHandler", func(t *testing.T) { + err := processor.UnregisterHandler("NONEXISTENT") + if err != nil { + t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err) + } + }) +} + +// TestProcessorConcurrentAccess tests concurrent access to processor +func TestProcessorConcurrentAccess(t *testing.T) { + processor := NewProcessor() + + t.Run("ConcurrentRegisterAndGet", func(t *testing.T) { + done := make(chan bool, 2) + + // Goroutine 1: Register handlers + go func() { + defer func() { done <- true }() + for i := 0; i < 100; i++ { + handler := &UnitTestHandler{name: "concurrent-handler"} + processor.RegisterHandler("CONCURRENT", handler, false) + processor.UnregisterHandler("CONCURRENT") + } + }() + + // Goroutine 2: Get handlers + go func() { + defer func() { done <- true }() + for i := 0; i < 100; i++ { + processor.GetHandler("CONCURRENT") + } + }() + + // Wait for both goroutines to complete + <-done + <-done + }) +} + +// TestProcessorInterfaceCompliance tests interface compliance +func TestProcessorInterfaceCompliance(t *testing.T) { + t.Run("ProcessorImplementsInterface", func(t *testing.T) { + var _ NotificationProcessor = (*Processor)(nil) + }) + + t.Run("VoidProcessorImplementsInterface", func(t *testing.T) { + var _ NotificationProcessor = (*VoidProcessor)(nil) + }) +} + +// UnitTestHandler is a test implementation of NotificationHandler +type UnitTestHandler struct { + name string + lastNotification []interface{} + errorToReturn error + callCount int +} + +func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error { + h.callCount++ + h.lastNotification = notification + return h.errorToReturn +} + +// Helper methods for UnitTestHandler +func (h *UnitTestHandler) GetCallCount() int { + return h.callCount +} + +func (h *UnitTestHandler) GetLastNotification() []interface{} { + return h.lastNotification +} + +func (h *UnitTestHandler) SetErrorToReturn(err error) { + h.errorToReturn = err +} + +func (h *UnitTestHandler) Reset() { + h.callCount = 0 + h.lastNotification = nil + h.errorToReturn = nil +} diff --git a/push_notifications.go b/push_notifications.go index ceffe04a..572955fe 100644 --- a/push_notifications.go +++ b/push_notifications.go @@ -4,24 +4,6 @@ import ( "github.com/redis/go-redis/v9/push" ) -// Push notification constants for cluster operations -const ( - // MOVING indicates a slot is being moved to a different node - PushNotificationMoving = "MOVING" - - // MIGRATING indicates a slot is being migrated from this node - PushNotificationMigrating = "MIGRATING" - - // MIGRATED indicates a slot has been migrated to this node - PushNotificationMigrated = "MIGRATED" - - // FAILING_OVER indicates a failover is starting - PushNotificationFailingOver = "FAILING_OVER" - - // FAILED_OVER indicates a failover has completed - PushNotificationFailedOver = "FAILED_OVER" -) - // NewPushNotificationProcessor creates a new push notification processor // This processor maintains a registry of handlers and processes push notifications // It is used for RESP3 connections where push notifications are available diff --git a/redis.go b/redis.go index b3608c5f..f2b80cf8 100644 --- a/redis.go +++ b/redis.go @@ -10,6 +10,7 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/hitless" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" @@ -204,19 +205,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e //------------------------------------------------------------------------------ type baseClient struct { - opt *Options - connPool pool.Pooler + opt *Options + optLock sync.RWMutex + connPool pool.Pooler + pubSubPool *pool.PubSubPool hooksMixin onClose func() error // hook called when client is closed // Push notification processing pushProcessor push.NotificationProcessor + + // Hitless upgrade manager + hitlessManager *hitless.HitlessManager + hitlessManagerLock sync.RWMutex } func (c *baseClient) clone() *baseClient { - clone := *c - return &clone + c.hitlessManagerLock.RLock() + hitlessManager := c.hitlessManager + c.hitlessManagerLock.RUnlock() + + clone := &baseClient{ + opt: c.opt, + connPool: c.connPool, + onClose: c.onClose, + pushProcessor: c.pushProcessor, + hitlessManager: hitlessManager, + } + return clone } func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { @@ -234,21 +251,6 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { - cn, err := c.connPool.NewConn(ctx) - if err != nil { - return nil, err - } - - err = c.initConn(ctx, cn) - if err != nil { - _ = c.connPool.CloseConn(cn) - return nil, err - } - - return cn, nil -} - func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { if c.opt.Limiter != nil { err := c.opt.Limiter.Allow() @@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return nil, err } - if cn.Inited { + if cn.IsInited() { return cn, nil } @@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { } func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { - if cn.Inited { + if !cn.Inited.CompareAndSwap(false, true) { return nil } - var err error - cn.Inited = true connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(c.opt, connPool, &c.hooksMixin) @@ -430,6 +430,51 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return fmt.Errorf("failed to initialize connection options: %w", err) } + // Enable maintenance notifications if hitless upgrades are configured + c.optLock.RLock() + hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled + protocol := c.opt.Protocol + endpointType := c.opt.HitlessUpgradeConfig.EndpointType + c.optLock.RUnlock() + var hitlessHandshakeErr error + if hitlessEnabled && protocol == 3 { + hitlessHandshakeErr = conn.ClientMaintNotifications( + ctx, + true, + endpointType.String(), + ).Err() + if hitlessHandshakeErr != nil { + if !isRedisError(hitlessHandshakeErr) { + // if not redis error, fail the connection + return hitlessHandshakeErr + } + c.optLock.Lock() + // handshake failed - check and modify config atomically + switch c.opt.HitlessUpgradeConfig.Mode { + case hitless.MaintNotificationsEnabled: + // enabled mode, fail the connection + c.optLock.Unlock() + return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr) + default: // will handle auto and any other + internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake error: %v", hitlessHandshakeErr) + c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled + c.optLock.Unlock() + // auto mode, disable hitless upgrades and continue + if err := c.disableHitlessUpgrades(); err != nil { + // Log error but continue - auto mode should be resilient + internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err) + } + } + } else { + // handshake was executed successfully + // to make sure that the handshake will be executed on other connections as well if it was successfully + // executed on this connection, we will force the handshake to be executed on all connections + c.optLock.Lock() + c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled + c.optLock.Unlock() + } + } + if !c.opt.DisableIdentity && !c.opt.DisableIndentity { libName := "" libVer := Version() @@ -446,6 +491,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + cn.SetUsable(true) + cn.Inited.Store(true) + + // Set the connection initialization function for potential reconnections + cn.SetInitConnFunc(c.createInitConnFunc()) + if c.opt.OnConnect != nil { return c.opt.OnConnect(ctx, conn) } @@ -512,6 +563,8 @@ func (c *baseClient) assertUnstableCommand(cmd Cmder) bool { if c.opt.UnstableResp3 { return true } else { + // TODO: find the best way to remove the panic and return error here + // The client should not panic when executing a command, only when initializing. panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.") } default: @@ -593,19 +646,76 @@ func (c *baseClient) context(ctx context.Context) context.Context { return context.Background() } +// createInitConnFunc creates a connection initialization function that can be used for reconnections. +func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { + return func(ctx context.Context, cn *pool.Conn) error { + return c.initConn(ctx, cn) + } +} + +// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook. +// This function is called during client initialization. +// will register push notification handlers for all hitless upgrade events. +// will start background workers for handoff processing in the pool hook. +func (c *baseClient) enableHitlessUpgrades() error { + // Create client adapter + clientAdapterInstance := newClientAdapter(c) + + // Create hitless manager directly + manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig) + if err != nil { + return err + } + // Set the manager reference and initialize pool hook + c.hitlessManagerLock.Lock() + c.hitlessManager = manager + c.hitlessManagerLock.Unlock() + + // Initialize pool hook (safe to call without lock since manager is now set) + manager.InitPoolHook(c.dialHook) + return nil +} + +func (c *baseClient) disableHitlessUpgrades() error { + c.hitlessManagerLock.Lock() + defer c.hitlessManagerLock.Unlock() + + // Close the hitless manager + if c.hitlessManager != nil { + // Closing the manager will also shutdown the pool hook + // and remove it from the pool + c.hitlessManager.Close() + c.hitlessManager = nil + } + return nil +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be // long-lived and shared between many goroutines. func (c *baseClient) Close() error { var firstErr error + + // Close hitless manager first + if err := c.disableHitlessUpgrades(); err != nil { + firstErr = err + } + if c.onClose != nil { - if err := c.onClose(); err != nil { + if err := c.onClose(); err != nil && firstErr == nil { firstErr = err } } - if err := c.connPool.Close(); err != nil && firstErr == nil { - firstErr = err + if c.connPool != nil { + if err := c.connPool.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + if c.pubSubPool != nil { + if err := c.pubSubPool.Close(); err != nil && firstErr == nil { + firstErr = err + } } return firstErr } @@ -796,6 +906,8 @@ func NewClient(opt *Options) *Client { if opt == nil { panic("redis: NewClient nil options") } + // clone to not share options with the caller + opt = opt.clone() opt.init() // Push notifications are always enabled for RESP3 (cannot be disabled) @@ -810,11 +922,40 @@ func NewClient(opt *Options) *Client { // Initialize push notification processor using shared helper // Use void processor for RESP2 connections (push notifications not available) c.pushProcessor = initializePushProcessor(opt) + // set opt push processor for child clients + c.opt.PushNotificationProcessor = c.pushProcessor - // Update options with the initialized push processor for connection pool - opt.PushNotificationProcessor = c.pushProcessor + // Create connection pools + var err error + c.connPool, err = newConnPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + c.pubSubPool, err = newPubSubPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } - c.connPool = newConnPool(opt, c.dialHook) + // Initialize hitless upgrades first if enabled and protocol is RESP3 + if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 { + err := c.enableHitlessUpgrades() + if err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err) + if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled { + /* + Design decision: panic here to fail fast if hitless upgrades cannot be enabled when explicitly requested. + We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect + an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced + immediately, rather than allowing the client to continue in a partially initialized or inconsistent state. + Clients relying on hitless upgrades should be aware that initialization errors will cause a panic, and should + handle this accordingly (e.g., via recover or by validating configuration before calling NewClient). + This approach is only used when HitlessUpgradeConfig.Mode is MaintNotificationsEnabled, indicating that hitless + upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic. + */ + panic(fmt.Errorf("failed to enable hitless upgrades: %w", err)) + } + } + } return &c } @@ -851,6 +992,14 @@ func (c *Client) Options() *Options { return c.opt } +// GetHitlessManager returns the hitless manager instance for monitoring and control. +// Returns nil if hitless upgrades are not enabled. +func (c *Client) GetHitlessManager() *hitless.HitlessManager { + c.hitlessManagerLock.RLock() + defer c.hitlessManagerLock.RUnlock() + return c.hitlessManager +} + // initializePushProcessor initializes the push notification processor for any client type. // This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient. func initializePushProcessor(opt *Options) push.NotificationProcessor { @@ -887,6 +1036,7 @@ type PoolStats pool.Stats // PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { stats := c.connPool.Stats() + stats.PubSubStats = *(c.pubSubPool.Stats()) return (*PoolStats)(stats) } @@ -921,11 +1071,27 @@ func (c *Client) TxPipeline() Pipeliner { func (c *Client) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { + cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels) + if err != nil { + return nil, err + } + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + return nil, err + } + // Track connection in PubSubPool + c.pubSubPool.TrackConn(cn) + return cn, nil + }, + closeConn: func(cn *pool.Conn) error { + // Untrack connection from PubSubPool + c.pubSubPool.UntrackConn(cn) + _ = cn.Close() + return nil }, - closeConn: c.connPool.CloseConn, pushProcessor: c.pushProcessor, } pubsub.init() @@ -1113,6 +1279,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica return push.NotificationHandlerContext{ Client: c, ConnPool: c.connPool, - Conn: cn, + Conn: cn, // Wrap in adapter for easier interface access } } diff --git a/redis_test.go b/redis_test.go index 6aaa0a75..27b69ed1 100644 --- a/redis_test.go +++ b/redis_test.go @@ -12,7 +12,6 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" - "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/auth" ) diff --git a/sentinel.go b/sentinel.go index 2509d70f..e52e8407 100644 --- a/sentinel.go +++ b/sentinel.go @@ -16,8 +16,8 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/rand" - "github.com/redis/go-redis/v9/push" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/push" ) //------------------------------------------------------------------------------ @@ -139,6 +139,14 @@ type FailoverOptions struct { FailingTimeoutSeconds int UnstableResp3 bool + + // Hitless is not supported for FailoverClients at the moment + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Mode is not "disabled", the client will handle + // upgrade notifications gracefully and manage connection/pool state transitions + // seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are disabled. + //HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *FailoverOptions) clientOptions() *Options { @@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt.Dialer = masterReplicaDialer(failover) opt.init() - var connPool *pool.ConnPool - rdb := &Client{ baseClient: &baseClient{ opt: opt, @@ -469,15 +475,25 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { // Use void processor by default for RESP2 connections rdb.pushProcessor = initializePushProcessor(opt) - connPool = newConnPool(opt, rdb.dialHook) - rdb.connPool = connPool + var err error + rdb.connPool, err = newConnPool(opt, rdb.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } + rdb.onClose = rdb.wrappedOnClose(failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { - _ = connPool.Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != addr - }) + if connPool, ok := rdb.connPool.(*pool.ConnPool); ok { + _ = connPool.Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr + }) + } } failover.mu.Unlock() @@ -543,7 +559,15 @@ func NewSentinelClient(opt *Options) *SentinelClient { dial: c.baseClient.dial, process: c.baseClient.process, }) - c.connPool = newConnPool(opt, c.dialHook) + var err error + c.connPool, err = newConnPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + c.pubSubPool, err = newPubSubPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } return c } @@ -570,13 +594,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { func (c *SentinelClient) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { + cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels) + if err != nil { + return nil, err + } + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + return nil, err + } + // Track connection in PubSubPool + c.pubSubPool.TrackConn(cn) + return cn, nil }, - closeConn: c.connPool.CloseConn, + closeConn: func(cn *pool.Conn) error { + // Untrack connection from PubSubPool + c.pubSubPool.UntrackConn(cn) + _ = cn.Close() + return nil + }, + pushProcessor: c.pushProcessor, } pubsub.init() + return pubsub } diff --git a/tx.go b/tx.go index 67689f57..40bc1d66 100644 --- a/tx.go +++ b/tx.go @@ -24,7 +24,7 @@ type Tx struct { func (c *Client) newTx() *Tx { tx := Tx{ baseClient: baseClient{ - opt: c.opt, + opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client connPool: pool.NewStickyConnPool(c.connPool), hooksMixin: c.hooksMixin.clone(), pushProcessor: c.pushProcessor, // Copy push processor from parent client diff --git a/universal.go b/universal.go index 02da3be8..2f4b4a53 100644 --- a/universal.go +++ b/universal.go @@ -122,6 +122,9 @@ type UniversalOptions struct { // IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint). IsClusterMode bool + + // HitlessUpgradeConfig provides configuration for hitless upgrades. + HitlessUpgradeConfig *HitlessUpgradeConfig } // Cluster returns cluster options created from the universal options. @@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions { IdentitySuffix: o.IdentitySuffix, FailingTimeoutSeconds: o.FailingTimeoutSeconds, UnstableResp3: o.UnstableResp3, + HitlessUpgradeConfig: o.HitlessUpgradeConfig, } } @@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions { DisableIndentity: o.DisableIndentity, IdentitySuffix: o.IdentitySuffix, UnstableResp3: o.UnstableResp3, + // Note: HitlessUpgradeConfig not supported for FailoverOptions } } @@ -284,10 +289,11 @@ func (o *UniversalOptions) Simple() *Options { TLSConfig: o.TLSConfig, - DisableIdentity: o.DisableIdentity, - DisableIndentity: o.DisableIndentity, - IdentitySuffix: o.IdentitySuffix, - UnstableResp3: o.UnstableResp3, + DisableIdentity: o.DisableIdentity, + DisableIndentity: o.DisableIndentity, + IdentitySuffix: o.IdentitySuffix, + UnstableResp3: o.UnstableResp3, + HitlessUpgradeConfig: o.HitlessUpgradeConfig, } }