1
0
mirror of https://github.com/redis/go-redis.git synced 2025-10-18 22:08:50 +03:00

feat: RESP3 notifications support & Hitless notifications handling [CAE-1088] & [CAE-1072] (#3418)

- Adds support for handling push notifications with RESP3. 
- Using this support adds handlers for hitless upgrades.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Hristo Temelski <hristo.temelski@redis.com>
This commit is contained in:
Nedyalko Dyakov
2025-09-10 22:18:01 +03:00
committed by GitHub
parent 2da6ca07c0
commit 0ef6d0727d
70 changed files with 11668 additions and 596 deletions

3
.gitignore vendored
View File

@@ -9,3 +9,6 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
# Hitless upgrade documentation (temporary)
hitless/docs/

111
adapters.go Normal file
View File

@@ -0,0 +1,111 @@
package redis
import (
"context"
"errors"
"net"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/push"
)
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
var ErrInvalidCommand = errors.New("invalid command type")
// ErrInvalidPool is returned when the pool type is not supported.
var ErrInvalidPool = errors.New("invalid pool type")
// newClientAdapter creates a new client adapter for regular Redis clients.
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
return &clientAdapter{client: client}
}
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
type clientAdapter struct {
client *baseClient
}
// GetOptions returns the client options.
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
return &optionsAdapter{options: ca.client.opt}
}
// GetPushProcessor returns the client's push notification processor.
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
}
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
type optionsAdapter struct {
options *Options
}
// GetReadTimeout returns the read timeout.
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
return oa.options.ReadTimeout
}
// GetWriteTimeout returns the write timeout.
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
return oa.options.WriteTimeout
}
// GetNetwork returns the network type.
func (oa *optionsAdapter) GetNetwork() string {
return oa.options.Network
}
// GetAddr returns the connection address.
func (oa *optionsAdapter) GetAddr() string {
return oa.options.Addr
}
// IsTLSEnabled returns true if TLS is enabled.
func (oa *optionsAdapter) IsTLSEnabled() bool {
return oa.options.TLSConfig != nil
}
// GetProtocol returns the protocol version.
func (oa *optionsAdapter) GetProtocol() int {
return oa.options.Protocol
}
// GetPoolSize returns the connection pool size.
func (oa *optionsAdapter) GetPoolSize() int {
return oa.options.PoolSize
}
// NewDialer returns a new dialer function for the connection.
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
baseDialer := oa.options.NewDialer()
return func(ctx context.Context) (net.Conn, error) {
// Extract network and address from the options
network := oa.options.Network
addr := oa.options.Addr
return baseDialer(ctx, network, addr)
}
}
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
type pushProcessorAdapter struct {
processor push.NotificationProcessor
}
// RegisterHandler registers a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
if pushHandler, ok := handler.(push.NotificationHandler); ok {
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
}
return errors.New("handler must implement push.NotificationHandler")
}
// UnregisterHandler removes a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
return ppa.processor.UnregisterHandler(pushNotificationName)
}
// GetHandler returns the handler for a specific push notification name.
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
return ppa.processor.GetHandler(pushNotificationName)
}

View File

@@ -0,0 +1,353 @@
package redis
import (
"context"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
func TestEventDrivenHandoffIntegration(t *testing.T) {
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor with event-driven handoff support
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create a test pool with hooks
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(5),
PoolTimeout: time.Second,
})
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
defer testPool.Close()
// Set the pool reference in the processor for connection removal on handoff failure
processor.SetPool(testPool)
ctx := context.Background()
// Get a connection and mark it for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
// Set initialization function with a small delay to ensure handoff is pending
initConnCalled := false
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
initConnCalled = true
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark connection for handoff
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
if err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Return connection to pool - this should queue handoff
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start processing
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Try to get the same connection - should be skipped due to pending handoff
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get second connection: %v", err)
}
// Should get a different connection (the pending one should be skipped)
if conn == conn2 {
t.Error("Should have gotten a different connection while handoff is pending")
}
// Return the second connection
testPool.Put(ctx, conn2)
// Wait for handoff to complete
time.Sleep(200 * time.Millisecond)
// Verify handoff completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map")
}
if !initConnCalled {
t.Error("InitConn should have been called during handoff")
}
// Now the original connection should be available again
conn3, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get third connection: %v", err)
}
// Could be the original connection (now handed off) or a new one
testPool.Put(ctx, conn3)
})
t.Run("ConcurrentHandoffs", func(t *testing.T) {
// Create a base dialer that simulates slow handoffs
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(50 * time.Millisecond) // Simulate network delay
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(10),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
var wg sync.WaitGroup
// Start multiple concurrent handoffs
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Get connection
conn, err := testPool.Get(ctx)
if err != nil {
t.Errorf("Failed to get conn[%d]: %v", id, err)
return
}
// Set initialization function
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark for handoff
conn.MarkForHandoff("new-endpoint:6379", int64(id))
// Return to pool (starts async handoff)
testPool.Put(ctx, conn)
}(i)
}
wg.Wait()
// Wait for all handoffs to complete
time.Sleep(300 * time.Millisecond)
// Verify pool is still functional
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
}
testPool.Put(ctx, conn)
})
t.Run("HandoffFailureRecovery", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
}
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(3),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Get connection and mark for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
// Return to pool (starts async handoff that will fail)
testPool.Put(ctx, conn)
// Wait for handoff to fail
time.Sleep(200 * time.Millisecond)
// Connection should be removed from pending map after failed handoff
if processor.IsHandoffPending(conn) {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Pool should still be functional
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional: %v", err)
}
// In event-driven approach, the original connection remains in pool
// even after failed handoff (it's still a valid connection)
// We might get the same connection or a different one
testPool.Put(ctx, conn2)
})
t.Run("GracefulShutdown", func(t *testing.T) {
// Create a slow base dialer
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(100 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(2),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Start a handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function with delay to ensure handoff is pending
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
return nil
})
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start and begin processing
// The handoff should be pending because the slowDialer takes 100ms
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued and is being processed
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Give the handoff a moment to start processing
time.Sleep(50 * time.Millisecond)
// Shutdown processor gracefully
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = processor.Shutdown(shutdownCtx)
if err != nil {
t.Errorf("Graceful shutdown should succeed: %v", err)
}
// Handoff should have completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map after shutdown")
}
})
}
func init() {
logging.Disable()
}

View File

@@ -1,316 +0,0 @@
package redis
import (
"context"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/proto"
)
var ctx = context.TODO()
type ClientStub struct {
Cmdable
resp []byte
}
var initHello = []byte("%1\r\n+proto\r\n:3\r\n")
func NewClientStub(resp []byte) *ClientStub {
stub := &ClientStub{
resp: resp,
}
stub.Cmdable = NewClient(&Options{
PoolSize: 128,
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return stub.stubConn(initHello), nil
},
DisableIdentity: true,
})
return stub
}
func NewClusterClientStub(resp []byte) *ClientStub {
stub := &ClientStub{
resp: resp,
}
client := NewClusterClient(&ClusterOptions{
PoolSize: 128,
Addrs: []string{":6379"},
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return stub.stubConn(initHello), nil
},
DisableIdentity: true,
ClusterSlots: func(_ context.Context) ([]ClusterSlot, error) {
return []ClusterSlot{
{
Start: 0,
End: 16383,
Nodes: []ClusterNode{{Addr: "127.0.0.1:6379"}},
},
}, nil
},
})
stub.Cmdable = client
return stub
}
func (c *ClientStub) stubConn(init []byte) *ConnStub {
return &ConnStub{
init: init,
resp: c.resp,
}
}
type ConnStub struct {
init []byte
resp []byte
pos int
}
func (c *ConnStub) Read(b []byte) (n int, err error) {
// Return conn.init()
if len(c.init) > 0 {
n = copy(b, c.init)
c.init = c.init[n:]
return n, nil
}
if len(c.resp) == 0 {
return 0, io.EOF
}
if c.pos >= len(c.resp) {
c.pos = 0
}
n = copy(b, c.resp[c.pos:])
c.pos += n
return n, nil
}
func (c *ConnStub) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *ConnStub) Close() error { return nil }
func (c *ConnStub) LocalAddr() net.Addr { return nil }
func (c *ConnStub) RemoteAddr() net.Addr { return nil }
func (c *ConnStub) SetDeadline(_ time.Time) error { return nil }
func (c *ConnStub) SetReadDeadline(_ time.Time) error { return nil }
func (c *ConnStub) SetWriteDeadline(_ time.Time) error { return nil }
type ClientStubFunc func([]byte) *ClientStub
func BenchmarkDecode(b *testing.B) {
type Benchmark struct {
name string
stub ClientStubFunc
}
benchmarks := []Benchmark{
{"server", NewClientStub},
{"cluster", NewClusterClientStub},
}
for _, bench := range benchmarks {
b.Run(fmt.Sprintf("RespError-%s", bench.name), func(b *testing.B) {
respError(b, bench.stub)
})
b.Run(fmt.Sprintf("RespStatus-%s", bench.name), func(b *testing.B) {
respStatus(b, bench.stub)
})
b.Run(fmt.Sprintf("RespInt-%s", bench.name), func(b *testing.B) {
respInt(b, bench.stub)
})
b.Run(fmt.Sprintf("RespString-%s", bench.name), func(b *testing.B) {
respString(b, bench.stub)
})
b.Run(fmt.Sprintf("RespArray-%s", bench.name), func(b *testing.B) {
respArray(b, bench.stub)
})
b.Run(fmt.Sprintf("RespPipeline-%s", bench.name), func(b *testing.B) {
respPipeline(b, bench.stub)
})
b.Run(fmt.Sprintf("RespTxPipeline-%s", bench.name), func(b *testing.B) {
respTxPipeline(b, bench.stub)
})
// goroutine
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 5)
})
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 20)
})
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 50)
})
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 100)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 5)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 20)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 50)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 100)
})
}
}
func respError(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("-ERR test error\r\n"))
respErr := proto.RedisError("ERR test error")
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := rdb.Get(ctx, "key").Err(); err != respErr {
b.Fatalf("response error, got %q, want %q", err, respErr)
}
}
}
func respStatus(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("+OK\r\n"))
var val string
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.Set(ctx, "key", "value", 0).Val(); val != "OK" {
b.Fatalf("response error, got %q, want OK", val)
}
}
}
func respInt(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte(":10\r\n"))
var val int64
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.Incr(ctx, "key").Val(); val != 10 {
b.Fatalf("response error, got %q, want 10", val)
}
}
}
func respString(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("$5\r\nhello\r\n"))
var val string
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.Get(ctx, "key").Val(); val != "hello" {
b.Fatalf("response error, got %q, want hello", val)
}
}
}
func respArray(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("*3\r\n$5\r\nhello\r\n:10\r\n+OK\r\n"))
var val []interface{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.MGet(ctx, "key").Val(); len(val) != 3 {
b.Fatalf("response error, got len(%d), want len(3)", len(val))
}
}
}
func respPipeline(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("+OK\r\n$5\r\nhello\r\n:1\r\n"))
var pipe Pipeliner
b.ResetTimer()
for i := 0; i < b.N; i++ {
pipe = rdb.Pipeline()
set := pipe.Set(ctx, "key", "value", 0)
get := pipe.Get(ctx, "key")
del := pipe.Del(ctx, "key")
_, err := pipe.Exec(ctx)
if err != nil {
b.Fatalf("response error, got %q, want nil", err)
}
if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
b.Fatal("response error")
}
}
}
func respTxPipeline(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("+OK\r\n+QUEUED\r\n+QUEUED\r\n+QUEUED\r\n*3\r\n+OK\r\n$5\r\nhello\r\n:1\r\n"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
var set *StatusCmd
var get *StringCmd
var del *IntCmd
_, err := rdb.TxPipelined(ctx, func(pipe Pipeliner) error {
set = pipe.Set(ctx, "key", "value", 0)
get = pipe.Get(ctx, "key")
del = pipe.Del(ctx, "key")
return nil
})
if err != nil {
b.Fatalf("response error, got %q, want nil", err)
}
if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
b.Fatal("response error")
}
}
}
func dynamicGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
rdb := stub([]byte("$5\r\nhello\r\n"))
c := make(chan struct{}, concurrency)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c <- struct{}{}
go func() {
if val := rdb.Get(ctx, "key").Val(); val != "hello" {
panic(fmt.Sprintf("response error, got %q, want hello", val))
}
<-c
}()
}
// Here no longer wait for all goroutines to complete, it will not affect the test results.
close(c)
}
func staticGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
rdb := stub([]byte("$5\r\nhello\r\n"))
c := make(chan struct{}, concurrency)
b.ResetTimer()
for i := 0; i < concurrency; i++ {
go func() {
for {
_, ok := <-c
if !ok {
return
}
if val := rdb.Get(ctx, "key").Val(); val != "hello" {
panic(fmt.Sprintf("response error, got %q, want hello", val))
}
}
}()
}
for i := 0; i < b.N; i++ {
c <- struct{}{}
}
close(c)
}

View File

@@ -193,6 +193,7 @@ type Cmdable interface {
ClientID(ctx context.Context) *IntCmd
ClientUnblock(ctx context.Context, id int64) *IntCmd
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
ConfigResetStat(ctx context.Context) *StatusCmd
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
@@ -519,6 +520,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}
if enabled {
if endpointType == "" {
endpointType = "none"
}
args = append(args, "on", "moving-endpoint-type", endpointType)
} else {
args = append(args, "off")
}
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// ------------------------------------------------------------------------------------------------
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {

View File

@@ -3019,7 +3019,8 @@ var _ = Describe("Commands", func() {
res, err = client.HPTTL(ctx, "myhash", "key1", "key2", "key200").Result()
Expect(err).NotTo(HaveOccurred())
Expect(res[0]).To(BeNumerically("~", 10*time.Second.Milliseconds(), 1))
// overhead of the push notification check is about 1-2ms for 100 commands
Expect(res[0]).To(BeNumerically("~", 10*time.Second.Milliseconds(), 2))
})
It("should HGETDEL", Label("hash", "HGETDEL"), func() {

12
example/pubsub/go.mod Normal file
View File

@@ -0,0 +1,12 @@
module github.com/redis/go-redis/example/pubsub
go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.11.0
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

6
example/pubsub/go.sum Normal file
View File

@@ -0,0 +1,6 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

171
example/pubsub/main.go Normal file
View File

@@ -0,0 +1,171 @@
package main
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/logging"
)
var ctx = context.Background()
var cntErrors atomic.Int64
var cntSuccess atomic.Int64
var startTime = time.Now()
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
// It was used to find regressions in pool management in hitless mode.
// Please don't use it as a reference for how to use pubsub.
func main() {
startTime = time.Now()
wg := &sync.WaitGroup{}
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled,
},
})
_ = rdb.FlushDB(ctx).Err()
hitlessManager := rdb.GetHitlessManager()
if hitlessManager == nil {
panic("hitless manager is nil")
}
loggingHook := hitless.NewLoggingHook(logging.LogLevelDebug)
hitlessManager.AddNotificationHook(loggingHook)
go func() {
for {
time.Sleep(2 * time.Second)
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
}
}()
err := rdb.Ping(ctx).Err()
if err != nil {
panic(err)
}
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
panic(err)
}
fmt.Println("published", rdb.Get(ctx, "published").Val())
fmt.Println("received", rdb.Get(ctx, "received").Val())
subCtx, cancelSubCtx := context.WithCancel(ctx)
pubCtx, cancelPublishers := context.WithCancel(ctx)
for i := 0; i < 10; i++ {
wg.Add(1)
go subscribe(subCtx, rdb, "test", i, wg)
}
time.Sleep(time.Second)
cancelSubCtx()
time.Sleep(time.Second)
subCtx, cancelSubCtx = context.WithCancel(ctx)
for i := 0; i < 10; i++ {
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
fmt.Println("incr error:", err)
cntErrors.Add(1)
}
wg.Add(1)
go floodThePool(pubCtx, rdb, wg)
}
for i := 0; i < 500; i++ {
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
fmt.Println("incr error:", err)
cntErrors.Add(1)
}
wg.Add(1)
go subscribe(subCtx, rdb, "test2", i, wg)
}
time.Sleep(120 * time.Second)
fmt.Println("canceling publishers")
cancelPublishers()
time.Sleep(10 * time.Second)
fmt.Println("canceling subscribers")
cancelSubCtx()
wg.Wait()
published, err := rdb.Get(ctx, "published").Result()
received, err := rdb.Get(ctx, "received").Result()
publishers, err := rdb.Get(ctx, "publishers").Result()
subscribers, err := rdb.Get(ctx, "subscribers").Result()
fmt.Printf("publishers: %s\n", publishers)
fmt.Printf("published: %s\n", published)
fmt.Printf("subscribers: %s\n", subscribers)
fmt.Printf("received: %s\n", received)
publishedInt, err := rdb.Get(ctx, "published").Int()
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
time.Sleep(2 * time.Second)
fmt.Println("errors:", cntErrors.Load())
fmt.Println("success:", cntSuccess.Load())
fmt.Println("time:", time.Since(startTime))
}
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
err := rdb.Publish(ctx, "test2", "hello").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Println("publish error:", err)
cntErrors.Add(1)
}
}
err = rdb.Incr(ctx, "published").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Println("incr error:", err)
cntErrors.Add(1)
}
}
time.Sleep(10 * time.Nanosecond)
}
}
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
defer wg.Done()
rec := rdb.Subscribe(ctx, topic)
recChan := rec.Channel()
for {
select {
case <-ctx.Done():
rec.Close()
return
default:
select {
case <-ctx.Done():
rec.Close()
return
case msg := <-recChan:
err := rdb.Incr(ctx, "received").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Printf("%s\n", err.Error())
cntErrors.Add(1)
}
}
_ = msg // Use the message to avoid unused variable warning
}
}
}
}

View File

@@ -57,6 +57,8 @@ func Example_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[ping]>
}
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// pipeline finished processing: [[ping] [ping]]
}
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[watch foo]>
// starting processing: <[ping]>
// finished processing: <[ping]>

98
hitless/README.md Normal file
View File

@@ -0,0 +1,98 @@
# Hitless Upgrades
Seamless Redis connection handoffs during cluster changes without dropping connections.
## Quick Start
```go
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
HitlessUpgrades: &hitless.Config{
Mode: hitless.MaintNotificationsEnabled,
},
})
```
## Modes
- **`MaintNotificationsDisabled`** - Hitless upgrades disabled
- **`MaintNotificationsEnabled`** - Forcefully enabled (fails if server doesn't support)
- **`MaintNotificationsAuto`** - Auto-detect server support (default)
## Configuration
```go
&hitless.Config{
Mode: hitless.MaintNotificationsAuto,
EndpointType: hitless.EndpointTypeAuto,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxHandoffRetries: 3,
MaxWorkers: 0, // Auto-calculated
HandoffQueueSize: 0, // Auto-calculated
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
LogLevel: logging.LogLevelError,
}
```
### Endpoint Types
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
- **`EndpointTypeInternalIP`** - Internal IP address
- **`EndpointTypeInternalFQDN`** - Internal FQDN
- **`EndpointTypeExternalIP`** - External IP address
- **`EndpointTypeExternalFQDN`** - External FQDN
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
### Auto-Scaling
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
**Examples:**
- Pool 100: 33 workers, 660 queue (capped at 500)
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
## How It Works
1. Redis sends push notifications about cluster changes
2. Client creates new connections to updated endpoints
3. Active operations transfer to new connections
4. Old connections close gracefully
## Supported Notifications
- `MOVING` - Slot moving to new node
- `MIGRATING` - Slot in migration state
- `MIGRATED` - Migration completed
- `FAILING_OVER` - Node failing over
- `FAILED_OVER` - Failover completed
## Hooks (Optional)
Monitor and customize hitless operations:
```go
type NotificationHook interface {
PreHook(ctx, notificationCtx, notificationType, notification) ([]interface{}, bool)
PostHook(ctx, notificationCtx, notificationType, notification, result)
}
// Add custom hook
manager.AddNotificationHook(&MyHook{})
```
### Metrics Hook Example
```go
// Create metrics hook
metricsHook := hitless.NewMetricsHook()
manager.AddNotificationHook(metricsHook)
// Access collected metrics
metrics := metricsHook.GetMetrics()
fmt.Printf("Notification counts: %v\n", metrics["notification_counts"])
fmt.Printf("Processing times: %v\n", metrics["processing_times"])
fmt.Printf("Error counts: %v\n", metrics["error_counts"])
```

360
hitless/circuit_breaker.go Normal file
View File

@@ -0,0 +1,360 @@
package hitless
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
)
// CircuitBreakerState represents the state of a circuit breaker
type CircuitBreakerState int32
const (
// CircuitBreakerClosed - normal operation, requests allowed
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen - failing fast, requests rejected
CircuitBreakerOpen
// CircuitBreakerHalfOpen - testing if service recovered
CircuitBreakerHalfOpen
)
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
type CircuitBreaker struct {
// Configuration
failureThreshold int // Number of failures before opening
resetTimeout time.Duration // How long to stay open before testing
maxRequests int // Max requests allowed in half-open state
// State tracking (atomic for lock-free access)
state atomic.Int32 // CircuitBreakerState
failures atomic.Int64 // Current failure count
successes atomic.Int64 // Success count in half-open state
requests atomic.Int64 // Request count in half-open state
lastFailureTime atomic.Int64 // Unix timestamp of last failure
lastSuccessTime atomic.Int64 // Unix timestamp of last success
// Endpoint identification
endpoint string
config *Config
}
// newCircuitBreaker creates a new circuit breaker for an endpoint
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
// Use configuration values with sensible defaults
failureThreshold := 5
resetTimeout := 60 * time.Second
maxRequests := 3
if config != nil {
failureThreshold = config.CircuitBreakerFailureThreshold
resetTimeout = config.CircuitBreakerResetTimeout
maxRequests = config.CircuitBreakerMaxRequests
}
return &CircuitBreaker{
failureThreshold: failureThreshold,
resetTimeout: resetTimeout,
maxRequests: maxRequests,
endpoint: endpoint,
config: config,
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
}
}
// IsOpen returns true if the circuit breaker is open (rejecting requests)
func (cb *CircuitBreaker) IsOpen() bool {
state := CircuitBreakerState(cb.state.Load())
return state == CircuitBreakerOpen
}
// shouldAttemptReset checks if enough time has passed to attempt reset
func (cb *CircuitBreaker) shouldAttemptReset() bool {
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
return time.Since(lastFailure) >= cb.resetTimeout
}
// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
// Single atomic state load for consistency
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerOpen:
if cb.shouldAttemptReset() {
// Attempt transition to half-open
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.requests.Store(0)
cb.successes.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker for %s transitioning to half-open", cb.endpoint)
}
// Fall through to half-open logic
} else {
return ErrCircuitBreakerOpen
}
} else {
return ErrCircuitBreakerOpen
}
fallthrough
case CircuitBreakerHalfOpen:
requests := cb.requests.Add(1)
if requests > int64(cb.maxRequests) {
cb.requests.Add(-1) // Revert the increment
return ErrCircuitBreakerOpen
}
}
// Execute the function with consistent state
err := fn()
if err != nil {
cb.recordFailure()
return err
}
cb.recordSuccess()
return nil
}
// recordFailure records a failure and potentially opens the circuit
func (cb *CircuitBreaker) recordFailure() {
cb.lastFailureTime.Store(time.Now().Unix())
failures := cb.failures.Add(1)
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
if failures >= int64(cb.failureThreshold) {
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker opened for endpoint %s after %d failures",
cb.endpoint, failures)
}
}
}
case CircuitBreakerHalfOpen:
// Any failure in half-open state immediately opens the circuit
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker reopened for endpoint %s due to failure in half-open state",
cb.endpoint)
}
}
}
}
// recordSuccess records a success and potentially closes the circuit
func (cb *CircuitBreaker) recordSuccess() {
cb.lastSuccessTime.Store(time.Now().Unix())
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
// Reset failure count on success in closed state
cb.failures.Store(0)
case CircuitBreakerHalfOpen:
successes := cb.successes.Add(1)
// If we've had enough successful requests, close the circuit
if successes >= int64(cb.maxRequests) {
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.failures.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker closed for endpoint %s after %d successful requests",
cb.endpoint, successes)
}
}
}
}
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(cb.state.Load())
}
// GetStats returns current statistics for monitoring
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
return CircuitBreakerStats{
Endpoint: cb.endpoint,
State: cb.GetState(),
Failures: cb.failures.Load(),
Successes: cb.successes.Load(),
Requests: cb.requests.Load(),
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
}
}
// CircuitBreakerStats provides statistics about a circuit breaker
type CircuitBreakerStats struct {
Endpoint string
State CircuitBreakerState
Failures int64
Successes int64
Requests int64
LastFailureTime time.Time
LastSuccessTime time.Time
}
// CircuitBreakerEntry wraps a circuit breaker with access tracking
type CircuitBreakerEntry struct {
breaker *CircuitBreaker
lastAccess atomic.Int64 // Unix timestamp
created time.Time
}
// CircuitBreakerManager manages circuit breakers for multiple endpoints
type CircuitBreakerManager struct {
breakers sync.Map // map[string]*CircuitBreakerEntry
config *Config
cleanupStop chan struct{}
cleanupMu sync.Mutex
lastCleanup atomic.Int64 // Unix timestamp
}
// newCircuitBreakerManager creates a new circuit breaker manager
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
cbm := &CircuitBreakerManager{
config: config,
cleanupStop: make(chan struct{}),
}
cbm.lastCleanup.Store(time.Now().Unix())
// Start background cleanup goroutine
go cbm.cleanupLoop()
return cbm
}
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
now := time.Now().Unix()
if entry, ok := cbm.breakers.Load(endpoint); ok {
cbEntry := entry.(*CircuitBreakerEntry)
cbEntry.lastAccess.Store(now)
return cbEntry.breaker
}
// Create new circuit breaker with metadata
newBreaker := newCircuitBreaker(endpoint, cbm.config)
newEntry := &CircuitBreakerEntry{
breaker: newBreaker,
created: time.Now(),
}
newEntry.lastAccess.Store(now)
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
return actual.(*CircuitBreakerEntry).breaker
}
// GetAllStats returns statistics for all circuit breakers
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
var stats []CircuitBreakerStats
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
stats = append(stats, entry.breaker.GetStats())
return true
})
return stats
}
// cleanupLoop runs background cleanup of unused circuit breakers
func (cbm *CircuitBreakerManager) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-ticker.C:
cbm.cleanup()
case <-cbm.cleanupStop:
return
}
}
}
// cleanup removes circuit breakers that haven't been accessed recently
func (cbm *CircuitBreakerManager) cleanup() {
// Prevent concurrent cleanups
if !cbm.cleanupMu.TryLock() {
return
}
defer cbm.cleanupMu.Unlock()
now := time.Now()
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
var toDelete []string
count := 0
cbm.breakers.Range(func(key, value interface{}) bool {
endpoint := key.(string)
entry := value.(*CircuitBreakerEntry)
count++
// Remove if not accessed recently
if entry.lastAccess.Load() < cutoff {
toDelete = append(toDelete, endpoint)
}
return true
})
// Delete expired entries
for _, endpoint := range toDelete {
cbm.breakers.Delete(endpoint)
}
// Log cleanup results
if len(toDelete) > 0 && cbm.config != nil && cbm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker cleanup removed %d/%d entries", len(toDelete), count)
}
cbm.lastCleanup.Store(now.Unix())
}
// Shutdown stops the cleanup goroutine
func (cbm *CircuitBreakerManager) Shutdown() {
close(cbm.cleanupStop)
}
// Reset resets all circuit breakers (useful for testing)
func (cbm *CircuitBreakerManager) Reset() {
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
breaker := entry.breaker
breaker.state.Store(int32(CircuitBreakerClosed))
breaker.failures.Store(0)
breaker.successes.Store(0)
breaker.requests.Store(0)
breaker.lastFailureTime.Store(0)
breaker.lastSuccessTime.Store(0)
return true
})
}

View File

@@ -0,0 +1,356 @@
package hitless
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9/logging"
)
func TestCircuitBreaker(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError, // Reduce noise in tests
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
}
t.Run("InitialState", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
if cb.IsOpen() {
t.Error("Circuit breaker should start in closed state")
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("SuccessfulExecution", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("FailureThreshold", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Fail 4 times (below threshold of 5)
for i := 0; i < 4; i++ {
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 5th failure should open the circuit
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("OpenCircuitFailsFast", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Now it should fail fast
err := cb.Execute(func() error {
t.Error("Function should not be called when circuit is open")
return nil
})
if err != ErrCircuitBreakerOpen {
t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err)
}
})
t.Run("HalfOpenTransition", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Wait for reset timeout
time.Sleep(150 * time.Millisecond)
// Next call should transition to half-open
executed := false
err := cb.Execute(func() error {
executed = true
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !executed {
t.Error("Function should have been executed in half-open state")
}
})
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// Execute successful requests in half-open state
for i := 0; i < 3; i++ {
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error on attempt %d, got %v", i+1, err)
}
}
// Circuit should now be closed
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// First request in half-open state fails
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
// Circuit should be open again
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("Stats", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Execute some operations
cb.Execute(func() error { return testError }) // Failure
cb.Execute(func() error { return testError }) // Failure
stats := cb.GetStats()
if stats.Endpoint != "test-endpoint:6379" {
t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint)
}
if stats.Failures != 2 {
t.Errorf("Expected 2 failures, got %d", stats.Failures)
}
if stats.State != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State)
}
// Test that success resets failure count
cb.Execute(func() error { return nil }) // Success
stats = cb.GetStats()
if stats.Failures != 0 {
t.Errorf("Expected 0 failures after success, got %d", stats.Failures)
}
})
}
func TestCircuitBreakerManager(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
}
t.Run("GetCircuitBreaker", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1
if cb1 == cb2 {
t.Error("Different endpoints should have different circuit breakers")
}
if cb1 != cb3 {
t.Error("Same endpoint should return the same circuit breaker")
}
})
t.Run("GetAllStats", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
// Create circuit breakers for different endpoints
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
// Execute some operations
cb1.Execute(func() error { return nil })
cb2.Execute(func() error { return errors.New("test error") })
stats := manager.GetAllStats()
if len(stats) != 2 {
t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats))
}
// Check that we have stats for both endpoints
endpoints := make(map[string]bool)
for _, stat := range stats {
endpoints[stat.Endpoint] = true
}
if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] {
t.Error("Missing stats for expected endpoints")
}
})
t.Run("Reset", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
testError := errors.New("test error")
cb := manager.GetCircuitBreaker("test-endpoint:6379")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Reset all circuit breakers
manager.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Error("Circuit should be closed after reset")
}
if cb.failures.Load() != 0 {
t.Error("Failure count should be reset to 0")
}
})
t.Run("ConfigurableParameters", func(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 10,
CircuitBreakerResetTimeout: 30 * time.Second,
CircuitBreakerMaxRequests: 5,
}
cb := newCircuitBreaker("test-endpoint:6379", config)
// Test that configuration values are used
if cb.failureThreshold != 10 {
t.Errorf("Expected failureThreshold=10, got %d", cb.failureThreshold)
}
if cb.resetTimeout != 30*time.Second {
t.Errorf("Expected resetTimeout=30s, got %v", cb.resetTimeout)
}
if cb.maxRequests != 5 {
t.Errorf("Expected maxRequests=5, got %d", cb.maxRequests)
}
// Test that circuit opens after configured threshold
testError := errors.New("test error")
for i := 0; i < 9; i++ {
err := cb.Execute(func() error { return testError })
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 10th failure should open the circuit
err := cb.Execute(func() error { return testError })
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
}

472
hitless/config.go Normal file
View File

@@ -0,0 +1,472 @@
package hitless
import (
"context"
"net"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/logging"
)
// MaintNotificationsMode represents the maintenance notifications mode
type MaintNotificationsMode string
// Constants for maintenance push notifications modes
const (
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m MaintNotificationsMode) IsValid() bool {
switch m {
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
return true
default:
return false
}
}
// String returns the string representation of the mode
func (m MaintNotificationsMode) String() string {
return string(m)
}
// EndpointType represents the type of endpoint to request in MOVING notifications
type EndpointType string
// Constants for endpoint types
const (
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
)
// IsValid returns true if the endpoint type is valid
func (e EndpointType) IsValid() bool {
switch e {
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
return true
default:
return false
}
}
// String returns the string representation of the endpoint type
func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for hitless upgrades.
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
// Default: MaintNotificationsAuto
Mode MaintNotificationsMode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
// Default: EndpointTypeAuto
EndpointType EndpointType
// RelaxedTimeout is the concrete timeout value to use during
// MIGRATING/FAILING_OVER states to accommodate increased latency.
// This applies to both read and write timeouts.
// Default: 10 seconds
RelaxedTimeout time.Duration
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
// If handoff takes longer than this, the old connection will be forcibly closed.
// Default: 15 seconds (matches server-side eviction timeout)
HandoffTimeout time.Duration
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
// Workers are created on-demand and automatically cleaned up when idle.
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
// If explicitly set, enforces minimum of PoolSize/2
//
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
MaxWorkers int
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
// If the queue is full, new handoff requests will be rejected.
// Scales with both worker count and pool size for better burst handling.
//
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
HandoffQueueSize int
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
// after a handoff completes. This provides additional resilience during cluster transitions.
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// LogLevel controls the verbosity of hitless upgrade logging.
// LogLevelError (0) = errors only, LogLevelWarn (1) = warnings, LogLevelInfo (2) = info, LogLevelDebug (3) = debug
// Default: logging.LogLevelError(0)
LogLevel logging.LogLevel
// Circuit breaker configuration for endpoint failure handling
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
// Default: 5
CircuitBreakerFailureThreshold int
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
// Default: 60 seconds
CircuitBreakerResetTimeout time.Duration
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
// Default: 3
CircuitBreakerMaxRequests int
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
// After this many retries, the connection will be removed from the pool.
// Default: 3
MaxHandoffRetries int
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != MaintNotificationsDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
LogLevel: logging.LogLevelError,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
// Connection Handoff Configuration
MaxHandoffRetries: 3,
}
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.RelaxedTimeout <= 0 {
return ErrInvalidRelaxedTimeout
}
if c.HandoffTimeout <= 0 {
return ErrInvalidHandoffTimeout
}
// Validate worker configuration
// Allow 0 for auto-calculation, but negative values are invalid
if c.MaxWorkers < 0 {
return ErrInvalidHandoffWorkers
}
// HandoffQueueSize validation - allow 0 for auto-calculation
if c.HandoffQueueSize < 0 {
return ErrInvalidHandoffQueueSize
}
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
if !c.LogLevel.IsValid() {
return ErrInvalidLogLevel
}
// Circuit breaker validation
if c.CircuitBreakerFailureThreshold < 1 {
return ErrInvalidCircuitBreakerFailureThreshold
}
if c.CircuitBreakerResetTimeout < 0 {
return ErrInvalidCircuitBreakerResetTimeout
}
if c.CircuitBreakerMaxRequests < 1 {
return ErrInvalidCircuitBreakerMaxRequests
}
// Validate Mode (maintenance notifications mode)
if !c.Mode.IsValid() {
return ErrInvalidMaintNotifications
}
// Validate EndpointType
if !c.EndpointType.IsValid() {
return ErrInvalidEndpointType
}
// Validate configuration fields
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
return ErrInvalidHandoffRetries
}
return nil
}
// ApplyDefaults applies default values to any zero-value fields in the configuration.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaults() *Config {
return c.ApplyDefaultsWithPoolSize(0)
}
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
// using the provided pool size to calculate worker defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
}
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
// using the provided pool size and max active connections to calculate worker and queue defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
if c == nil {
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
}
defaults := DefaultConfig()
result := &Config{}
// Apply defaults for enum fields (empty/zero means not set)
result.Mode = defaults.Mode
if c.Mode != "" {
result.Mode = c.Mode
}
result.EndpointType = defaults.EndpointType
if c.EndpointType != "" {
result.EndpointType = c.EndpointType
}
// Apply defaults for duration fields (zero means not set)
result.RelaxedTimeout = defaults.RelaxedTimeout
if c.RelaxedTimeout > 0 {
result.RelaxedTimeout = c.RelaxedTimeout
}
result.HandoffTimeout = defaults.HandoffTimeout
if c.HandoffTimeout > 0 {
result.HandoffTimeout = c.HandoffTimeout
}
// Copy worker configuration
result.MaxWorkers = c.MaxWorkers
// Apply worker defaults based on pool size
result.applyWorkerDefaults(poolSize)
// Apply queue size defaults with new scaling approach
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
workerBasedSize := result.MaxWorkers * 20
poolBasedSize := poolSize
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
if c.HandoffQueueSize > 0 {
// When explicitly set: enforce minimum of 200
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
}
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
var queueCap int
if maxActiveConns > 0 {
queueCap = maxActiveConns + 1
// Ensure queue cap is at least 2 for very small maxActiveConns
if queueCap < 2 {
queueCap = 2
}
} else {
queueCap = poolSize * 5
}
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
// Ensure minimum queue size of 2 (fallback for very small pools)
if result.HandoffQueueSize < 2 {
result.HandoffQueueSize = 2
}
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
if c.PostHandoffRelaxedDuration > 0 {
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
// We'll use the provided value as-is, since 0 is valid
result.LogLevel = c.LogLevel
// Apply defaults for configuration fields
result.MaxHandoffRetries = defaults.MaxHandoffRetries
if c.MaxHandoffRetries > 0 {
result.MaxHandoffRetries = c.MaxHandoffRetries
}
// Circuit breaker configuration
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
if c.CircuitBreakerFailureThreshold > 0 {
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
}
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
if c.CircuitBreakerResetTimeout > 0 {
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
}
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
if c.CircuitBreakerMaxRequests > 0 {
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
}
if result.LogLevel.DebugOrAbove() {
internal.Logger.Printf(context.Background(), "hitless: debug logging enabled")
internal.Logger.Printf(context.Background(), "hitless: config: %+v", result)
}
return result
}
// Clone creates a deep copy of the configuration.
func (c *Config) Clone() *Config {
if c == nil {
return DefaultConfig()
}
return &Config{
Mode: c.Mode,
EndpointType: c.EndpointType,
RelaxedTimeout: c.RelaxedTimeout,
HandoffTimeout: c.HandoffTimeout,
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
LogLevel: c.LogLevel,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
// Configuration fields
MaxHandoffRetries: c.MaxHandoffRetries,
}
}
// applyWorkerDefaults calculates and applies worker defaults based on pool size
func (c *Config) applyWorkerDefaults(poolSize int) {
// Calculate defaults based on pool size
if poolSize <= 0 {
poolSize = 10 * runtime.GOMAXPROCS(0)
}
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
originalMaxWorkers := c.MaxWorkers
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
if originalMaxWorkers != 0 {
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
}
// Ensure minimum of 1 worker (fallback for very small pools)
if c.MaxWorkers < 1 {
c.MaxWorkers = 1
}
}
// DetectEndpointType automatically detects the appropriate endpoint type
// based on the connection address and TLS configuration.
//
// For IP addresses:
// - If TLS is enabled: requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// For hostnames:
// - If TLS is enabled: always requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// Internal vs External detection:
// - For IPs: uses private IP range detection
// - For hostnames: uses heuristics based on common internal naming patterns
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
// Extract host from "host:port" format
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // Assume no port
}
// Check if the host is an IP address or hostname
ip := net.ParseIP(host)
isIPAddress := ip != nil
var endpointType EndpointType
if isIPAddress {
// Address is an IP - determine if it's private or public
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
if tlsEnabled {
// TLS with IP addresses - still prefer FQDN for certificate validation
if isPrivate {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
} else {
// No TLS - can use IP addresses directly
if isPrivate {
endpointType = EndpointTypeInternalIP
} else {
endpointType = EndpointTypeExternalIP
}
}
} else {
// Address is a hostname
isInternalHostname := isInternalHostname(host)
if isInternalHostname {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
}
return endpointType
}
// isInternalHostname determines if a hostname appears to be internal/private.
// This is a heuristic based on common naming patterns.
func isInternalHostname(hostname string) bool {
// Convert to lowercase for comparison
hostname = strings.ToLower(hostname)
// Common internal hostname patterns
internalPatterns := []string{
"localhost",
".local",
".internal",
".corp",
".lan",
".intranet",
".private",
}
// Check for exact match or suffix match
for _, pattern := range internalPatterns {
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
return true
}
}
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
// If hostname doesn't contain dots, it's likely internal
if !strings.Contains(hostname, ".") {
return true
}
// Default to external for fully qualified domain names
return false
}

490
hitless/config_test.go Normal file
View File

@@ -0,0 +1,490 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/logging"
)
func TestConfig(t *testing.T) {
t.Run("DefaultConfig", func(t *testing.T) {
config := DefaultConfig()
// MaxWorkers should be 0 in default config (auto-calculated)
if config.MaxWorkers != 0 {
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
}
// HandoffQueueSize should be 0 in default config (auto-calculated)
if config.HandoffQueueSize != 0 {
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
}
if config.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
}
// Test configuration fields have proper defaults
if config.MaxHandoffRetries != 3 {
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
}
// Circuit breaker defaults
if config.CircuitBreakerFailureThreshold != 5 {
t.Errorf("Expected CircuitBreakerFailureThreshold=5, got %d", config.CircuitBreakerFailureThreshold)
}
if config.CircuitBreakerResetTimeout != 60*time.Second {
t.Errorf("Expected CircuitBreakerResetTimeout=60s, got %v", config.CircuitBreakerResetTimeout)
}
if config.CircuitBreakerMaxRequests != 3 {
t.Errorf("Expected CircuitBreakerMaxRequests=3, got %d", config.CircuitBreakerMaxRequests)
}
if config.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
}
if config.PostHandoffRelaxedDuration != 0 {
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
}
// Test that defaults are applied correctly
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
}
})
t.Run("ConfigValidation", func(t *testing.T) {
// Valid config with applied defaults
config := DefaultConfig().ApplyDefaults()
if err := config.Validate(); err != nil {
t.Errorf("Default config with applied defaults should be valid: %v", err)
}
// Invalid worker configuration (negative MaxWorkers)
config = &Config{
RelaxedTimeout: 30 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: -1, // This should be invalid
HandoffQueueSize: 100,
PostHandoffRelaxedDuration: 10 * time.Second,
LogLevel: 1,
MaxHandoffRetries: 3, // Add required field
}
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
}
// Invalid HandoffQueueSize
config = DefaultConfig().ApplyDefaults()
config.HandoffQueueSize = -1
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
}
// Invalid PostHandoffRelaxedDuration
config = DefaultConfig().ApplyDefaults()
config.PostHandoffRelaxedDuration = -1 * time.Second
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
}
})
t.Run("ConfigClone", func(t *testing.T) {
original := DefaultConfig()
original.MaxWorkers = 20
original.HandoffQueueSize = 200
cloned := original.Clone()
if cloned.MaxWorkers != 20 {
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
}
if cloned.HandoffQueueSize != 200 {
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
}
// Modify original to ensure clone is independent
original.MaxWorkers = 2
if cloned.MaxWorkers != 20 {
t.Error("Clone should be independent of original")
}
})
}
func TestApplyDefaults(t *testing.T) {
t.Run("NilConfig", func(t *testing.T) {
var config *Config
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// With nil config, should get default config with auto-calculated workers
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated with hybrid scaling
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
})
t.Run("PartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 60, // Set this field explicitly (> poolSize/2 = 50)
// Leave other fields as zero values
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should keep the explicitly set values when > poolSize/2
if result.MaxWorkers != 60 {
t.Errorf("Expected MaxWorkers to be 60 (explicitly set), got %d", result.MaxWorkers)
}
// Should apply default for unset fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
// Test explicit queue size capping by 5x pool size
configWithLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than 5x pool size
}
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
expectedCap := 20 * 5 // 5x pool size = 100
if resultCapped.HandoffQueueSize != expectedCap {
t.Errorf("Expected HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedCap, resultCapped.HandoffQueueSize)
}
// Test explicit queue size minimum enforcement
configWithSmallQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 10, // Below minimum of 200
}
resultMinimum := configWithSmallQueue.ApplyDefaultsWithPoolSize(100) // Large pool size
if resultMinimum.HandoffQueueSize != 200 {
t.Errorf("Expected HandoffQueueSize to be enforced minimum (200), got %d", resultMinimum.HandoffQueueSize)
}
// Test that large explicit values are capped by 5x pool size
configWithVeryLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than 5x pool size
}
resultVeryLarge := configWithVeryLargeQueue.ApplyDefaultsWithPoolSize(100) // Pool size 100
expectedVeryLargeCap := 100 * 5 // 5x pool size = 500
if resultVeryLarge.HandoffQueueSize != expectedVeryLargeCap {
t.Errorf("Expected very large HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedVeryLargeCap, resultVeryLarge.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
if result.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
}
})
t.Run("ZeroValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 0, // Zero value should get auto-calculated defaults
HandoffQueueSize: 0, // Zero value should get default
RelaxedTimeout: 0, // Zero value should get default
LogLevel: 0, // Zero is valid for LogLevel (errors only)
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Zero values should get auto-calculated defaults
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated with hybrid scaling
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
// LogLevel 0 should be preserved (it's a valid value)
if result.LogLevel != 0 {
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
}
})
}
func TestProcessorWithConfig(t *testing.T) {
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 5,
HandoffQueueSize: 50,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 5 * time.Second,
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// The processor should be created successfully with custom config
if processor == nil {
t.Error("Processor should be created with custom config")
}
})
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 7, // Only set worker field
// Other fields will get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Should work with partial config (defaults applied)
if processor == nil {
t.Error("Processor should be created with partial config")
}
})
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Should use default config when nil is passed
if processor == nil {
t.Error("Processor should be created with nil config (using defaults)")
}
})
}
func TestIntegrationWithApplyDefaults(t *testing.T) {
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
// Create a partial config with only some fields set
partialConfig := &Config{
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
LogLevel: logging.LogLevelInfo, // Custom value
// Other fields left as zero values - should get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor - should apply defaults to missing fields
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
defer processor.Shutdown(context.Background())
// Processor should be created successfully
if processor == nil {
t.Error("Processor should be created with partial config")
}
// Test that the ApplyDefaults method worked correctly by creating the same config
// and applying defaults manually
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should preserve custom values (when >= poolSize/2)
if expectedConfig.MaxWorkers != 50 { // max(poolSize/2, 15) = max(50, 15) = 50
t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers)
}
if expectedConfig.LogLevel != 2 {
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
}
// Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := expectedConfig.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if expectedConfig.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, expectedConfig.HandoffQueueSize)
}
// Test that queue size is always capped by 5x pool size
if expectedConfig.HandoffQueueSize > poolSize*5 {
t.Errorf("HandoffQueueSize (%d) should never exceed 5x pool size (%d)",
expectedConfig.HandoffQueueSize, poolSize*2)
}
if expectedConfig.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
}
if expectedConfig.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
}
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
}
})
}
func TestEnhancedConfigValidation(t *testing.T) {
t.Run("ValidateFields", func(t *testing.T) {
config := DefaultConfig()
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
// Should pass validation with default values
if err := config.Validate(); err != nil {
t.Errorf("Default config should be valid, got error: %v", err)
}
// Test invalid MaxHandoffRetries
config.MaxHandoffRetries = 0
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 0")
}
config.MaxHandoffRetries = 11
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 11")
}
config.MaxHandoffRetries = 3 // Reset to valid value
// Test circuit breaker validation
config.CircuitBreakerFailureThreshold = 0
if err := config.Validate(); err != ErrInvalidCircuitBreakerFailureThreshold {
t.Errorf("Expected ErrInvalidCircuitBreakerFailureThreshold, got %v", err)
}
config.CircuitBreakerFailureThreshold = 5 // Reset to valid value
config.CircuitBreakerResetTimeout = -1 * time.Second
if err := config.Validate(); err != ErrInvalidCircuitBreakerResetTimeout {
t.Errorf("Expected ErrInvalidCircuitBreakerResetTimeout, got %v", err)
}
config.CircuitBreakerResetTimeout = 60 * time.Second // Reset to valid value
config.CircuitBreakerMaxRequests = 0
if err := config.Validate(); err != ErrInvalidCircuitBreakerMaxRequests {
t.Errorf("Expected ErrInvalidCircuitBreakerMaxRequests, got %v", err)
}
config.CircuitBreakerMaxRequests = 3 // Reset to valid value
// Should pass validation again
if err := config.Validate(); err != nil {
t.Errorf("Config should be valid after reset, got error: %v", err)
}
})
}
func TestConfigClone(t *testing.T) {
original := DefaultConfig()
original.MaxHandoffRetries = 7
original.HandoffTimeout = 8 * time.Second
cloned := original.Clone()
// Test that values are copied
if cloned.MaxHandoffRetries != 7 {
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
}
if cloned.HandoffTimeout != 8*time.Second {
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
}
// Test that modifying clone doesn't affect original
cloned.MaxHandoffRetries = 10
if original.MaxHandoffRetries != 7 {
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
}
}
func TestMaxWorkersLogic(t *testing.T) {
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
testCases := []struct {
poolSize int
expectedWorkers int
description string
}{
{6, 3, "Small pool: min(6/2, max(10, 6/3)) = min(3, max(10, 2)) = min(3, 10) = 3"},
{15, 7, "Medium pool: min(15/2, max(10, 15/3)) = min(7, max(10, 5)) = min(7, 10) = 7"},
{30, 10, "Large pool: min(30/2, max(10, 30/3)) = min(15, max(10, 10)) = min(15, 10) = 10"},
{60, 20, "Very large pool: min(60/2, max(10, 60/3)) = min(30, max(10, 20)) = min(30, 20) = 20"},
{120, 40, "Huge pool: min(120/2, max(10, 120/3)) = min(60, max(10, 40)) = min(60, 40) = 40"},
}
for _, tc := range testCases {
config := &Config{} // MaxWorkers = 0 (not set)
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
testCases := []struct {
setValue int
expectedWorkers int
description string
}{
{1, 50, "Set 1: max(poolSize/2, 1) = max(50, 1) = 50 (enforced minimum)"},
{5, 50, "Set 5: max(poolSize/2, 5) = max(50, 5) = 50 (enforced minimum)"},
{8, 50, "Set 8: max(poolSize/2, 8) = max(50, 8) = 50 (enforced minimum)"},
{10, 50, "Set 10: max(poolSize/2, 10) = max(50, 10) = 50 (enforced minimum)"},
{15, 50, "Set 15: max(poolSize/2, 15) = max(50, 15) = 50 (enforced minimum)"},
{60, 60, "Set 60: max(poolSize/2, 60) = max(50, 60) = 60 (respects user choice)"},
}
for _, tc := range testCases {
config := &Config{
MaxWorkers: tc.setValue, // Explicitly set
}
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
}

105
hitless/errors.go Normal file
View File

@@ -0,0 +1,105 @@
package hitless
import (
"errors"
"fmt"
"time"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
ErrInvalidLogLevel = errors.New("hitless: log level must be LogLevelError (0), LogLevelWarn (1), LogLevelInfo (2), or LogLevelDebug (3)")
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
)
// Integration errors
var (
ErrInvalidClient = errors.New("hitless: invalid client type")
)
// Handoff errors
var (
ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration")
)
// Notification errors
var (
ErrInvalidNotification = errors.New("hitless: invalid notification format")
)
// connection handoff errors
var (
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoff = errors.New("hitless: connection marked for handoff")
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff")
)
// general errors
var (
ErrShutdown = errors.New("hitless: shutdown")
)
// circuit breaker errors
var (
ErrCircuitBreakerOpen = errors.New("hitless: circuit breaker is open, failing fast")
)
// CircuitBreakerError provides detailed context for circuit breaker failures
type CircuitBreakerError struct {
Endpoint string
State string
Failures int64
LastFailure time.Time
NextAttempt time.Time
Message string
}
func (e *CircuitBreakerError) Error() string {
if e.NextAttempt.IsZero() {
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v): %s",
e.State, e.Endpoint, e.Failures, e.LastFailure, e.Message)
}
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v, next attempt: %v): %s",
e.State, e.Endpoint, e.Failures, e.LastFailure, e.NextAttempt, e.Message)
}
// HandoffError provides detailed context for connection handoff failures
type HandoffError struct {
ConnectionID uint64
SourceEndpoint string
TargetEndpoint string
Attempt int
MaxAttempts int
Duration time.Duration
FinalError error
Message string
}
func (e *HandoffError) Error() string {
return fmt.Sprintf("hitless: handoff failed for conn[%d] %s→%s (attempt %d/%d, duration: %v): %s",
e.ConnectionID, e.SourceEndpoint, e.TargetEndpoint,
e.Attempt, e.MaxAttempts, e.Duration, e.Message)
}
func (e *HandoffError) Unwrap() error {
return e.FinalError
}
// circuit breaker configuration errors
var (
ErrInvalidCircuitBreakerFailureThreshold = errors.New("hitless: circuit breaker failure threshold must be >= 1")
ErrInvalidCircuitBreakerResetTimeout = errors.New("hitless: circuit breaker reset timeout must be >= 0")
ErrInvalidCircuitBreakerMaxRequests = errors.New("hitless: circuit breaker max requests must be >= 1")
)

100
hitless/example_hooks.go Normal file
View File

@@ -0,0 +1,100 @@
package hitless
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
startTimeKey contextKey = "notif_hitless_start_time"
)
// MetricsHook collects metrics about notification processing.
type MetricsHook struct {
NotificationCounts map[string]int64
ProcessingTimes map[string]time.Duration
ErrorCounts map[string]int64
HandoffCounts int64 // Total handoffs initiated
HandoffSuccesses int64 // Successful handoffs
HandoffFailures int64 // Failed handoffs
}
// NewMetricsHook creates a new metrics collection hook.
func NewMetricsHook() *MetricsHook {
return &MetricsHook{
NotificationCounts: make(map[string]int64),
ProcessingTimes: make(map[string]time.Duration),
ErrorCounts: make(map[string]int64),
}
}
// PreHook records the start time for processing metrics.
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
mh.NotificationCounts[notificationType]++
// Log connection information if available
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on conn[%d]", notificationType, conn.GetID())
}
// Store start time in context for duration calculation
startTime := time.Now()
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
return notification, true
}
// PostHook records processing completion and any errors.
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
// Calculate processing duration
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
duration := time.Since(startTime)
mh.ProcessingTimes[notificationType] = duration
}
// Record errors
if result != nil {
mh.ErrorCounts[notificationType]++
// Log error details with connection information
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on conn[%d]: %v", notificationType, conn.GetID(), result)
}
}
}
// GetMetrics returns a summary of collected metrics.
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"notification_counts": mh.NotificationCounts,
"processing_times": mh.ProcessingTimes,
"error_counts": mh.ErrorCounts,
}
}
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
// Get circuit breaker statistics
stats := poolHook.GetCircuitBreakerStats()
for _, stat := range stats {
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
fmt.Printf(" State: %s\n", stat.State)
fmt.Printf(" Failures: %d\n", stat.Failures)
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
// Alert if circuit breaker is open
if stat.State.String() == "open" {
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
}
}
}

468
hitless/handoff_worker.go Normal file
View File

@@ -0,0 +1,468 @@
package hitless
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// handoffWorkerManager manages background workers and queue for connection handoffs
type handoffWorkerManager struct {
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers atomic.Int32
workerTimeout time.Duration // How long workers wait for work before exiting
workersScaling atomic.Bool
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the hitless upgrade
config *Config
// Pool hook reference for handoff processing
poolHook *PoolHook
// Circuit breaker manager for endpoint failure handling
circuitBreakerManager *CircuitBreakerManager
}
// newHandoffWorkerManager creates a new handoff worker manager
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
return &handoffWorkerManager{
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
circuitBreakerManager: newCircuitBreakerManager(config),
}
}
// getCurrentWorkers returns the current number of active workers (for testing)
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
return int(hwm.activeWorkers.Load())
}
// getPendingMap returns the pending map for testing purposes
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
return &hwm.pending
}
// getMaxWorkers returns the max workers for testing purposes
func (hwm *handoffWorkerManager) getMaxWorkers() int {
return hwm.maxWorkers
}
// getHandoffQueue returns the handoff queue for testing purposes
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
return hwm.handoffQueue
}
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
return hwm.circuitBreakerManager.GetAllStats()
}
// resetCircuitBreakers resets all circuit breakers (useful for testing)
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
hwm.circuitBreakerManager.Reset()
}
// isHandoffPending returns true if the given connection has a pending handoff
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
_, pending := hwm.pending.Load(conn.GetID())
return pending
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
select {
case <-hwm.shutdown:
return
default:
if hwm.workersScaling.CompareAndSwap(false, true) {
defer hwm.workersScaling.Store(false)
// Check if we need a new worker
currentWorkers := hwm.activeWorkers.Load()
workersWas := currentWorkers
for currentWorkers < int32(hwm.maxWorkers) {
hwm.workerWg.Add(1)
go hwm.onDemandWorker()
currentWorkers++
}
// workersWas is always <= currentWorkers
// currentWorkers will be maxWorkers, but if we have a worker that was closed
// while we were creating new workers, just add the difference between
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
hwm.activeWorkers.Add(currentWorkers - workersWas)
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (hwm *handoffWorkerManager) onDemandWorker() {
defer func() {
// Handle panics to ensure proper cleanup
if r := recover(); r != nil {
internal.Logger.Printf(context.Background(),
"hitless: worker panic recovered: %v", r)
}
// Decrement active worker count when exiting
hwm.activeWorkers.Add(-1)
hwm.workerWg.Done()
}()
// Create reusable timer to prevent timer leaks
timer := time.NewTimer(hwm.workerTimeout)
defer timer.Stop()
for {
// Reset timer for next iteration
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(hwm.workerTimeout)
select {
case <-hwm.shutdown:
return
case <-timer.C:
// Worker has been idle for too long, exit to save resources
if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout)
}
return
case request := <-hwm.handoffQueue:
// Check for shutdown before processing
select {
case <-hwm.shutdown:
// Clean up the request before exiting
hwm.pending.Delete(request.ConnID)
return
default:
// Process the request
hwm.processHandoffRequest(request)
}
}
}
}
// processHandoffRequest processes a single handoff request
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer hwm.pending.Delete(request.Conn.GetID())
internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID())
// Create a context with handoff timeout from config
handoffTimeout := 15 * time.Second // Default timeout
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
handoffTimeout = hwm.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-hwm.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
minRetryBackoff := 500 * time.Millisecond
if err != nil {
if shouldRetry {
now := time.Now()
deadline, ok := shutdownCtx.Deadline()
thirdOfTimeout := handoffTimeout / 3
if !ok || deadline.Before(now) {
// wait half the timeout before retrying if no deadline or deadline has passed
deadline = now.Add(thirdOfTimeout)
}
afterTime := deadline.Sub(now)
if afterTime < minRetryBackoff {
afterTime = minRetryBackoff
}
internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err)
time.AfterFunc(afterTime, func() {
if err := hwm.queueHandoff(request.Conn); err != nil {
internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err)
hwm.closeConnFromRequest(context.Background(), request, err)
}
})
return
} else {
go hwm.closeConnFromRequest(ctx, request, err)
}
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if hwm.poolHook.hitlessManager != nil {
hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID)
}
}
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Get handoff info atomically to prevent race conditions
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
if !shouldHandoff {
return errors.New("connection is not marked for handoff")
}
// Create handoff request with atomically retrieved data
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: endpoint,
SeqID: seqID,
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-hwm.shutdown:
return ErrShutdown
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
case <-time.After(100 * time.Millisecond): // give workers a chance to process
// Queue is full - log and attempt scaling
queueLen := len(hwm.handoffQueue)
queueCap := cap(hwm.handoffQueue)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration",
queueLen, queueCap)
}
}
}
}
// Ensure we have workers available to handle the load
hwm.ensureWorkerAvailable()
return ErrHandoffQueueFull
}
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
hwm.shutdownOnce.Do(func() {
close(hwm.shutdown)
// workers will exit when they finish their current request
// Shutdown circuit breaker manager cleanup goroutine
if hwm.circuitBreakerManager != nil {
hwm.circuitBreakerManager.Shutdown()
}
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
hwm.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// performConnectionHandoff performs the actual connection handoff
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
// Clear handoff state after successful handoff
connID := conn.GetID()
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
return false, ErrConnectionInvalidHandoffState
}
// Use circuit breaker to protect against failing endpoints
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
// Check if circuit breaker is open before attempting handoff
if circuitBreaker.IsOpen() {
internal.Logger.Printf(ctx, "hitless: conn[%d] handoff to %s failed fast due to circuit breaker", connID, newEndpoint)
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
}
// Perform the handoff
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
// Update circuit breaker based on result
if err != nil {
// Only track dial/network errors in circuit breaker, not initialization errors
if shouldRetry {
circuitBreaker.recordFailure()
}
return shouldRetry, err
}
// Success - record in circuit breaker
circuitBreaker.recordSuccess()
return false, nil
}
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", connID, retries, newEndpoint, conn.RemoteAddr().String())
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
if retries > maxRetries {
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
maxRetries, connID, newEndpoint)
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
}
// Create endpoint-specific dialer
endpointDialer := hwm.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", connID, newEndpoint, err)
// hitless: will retry
// Maybe a network error - retry after a delay
return true, err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
// Setting this here (before initing the connection) ensures that the connection is going
// to use the relaxed timeout for the first operation (auth/ACL select)
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := hwm.config.RelaxedTimeout
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v",
connID, relaxedTimeout, deadline.Format("15:04:05.000"))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// hitless: won't retry
// Initialization failed - remove the connection
return false, err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", connID, newEndpoint)
return false, nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
}
}
// closeConnFromRequest closes the connection and logs the reason
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
pooler := request.Pool
conn := request.Conn
if pooler != nil {
pooler.Remove(ctx, conn, err)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed conn[%d] from pool due: %v",
conn.GetID(), err)
}
} else {
conn.Close()
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for conn[%d], cannot remove due to: %v",
conn.GetID(), err)
}
}
}

318
hitless/hitless_manager.go Normal file
View File

@@ -0,0 +1,318 @@
package hitless
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// Push notification type constants for hitless upgrades
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
NotificationMigrated = "MIGRATED"
NotificationFailingOver = "FAILING_OVER"
NotificationFailedOver = "FAILED_OVER"
)
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
var hitlessNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
// NotificationHook is called before and after notification processing
// PreHook can modify the notification and return false to skip processing
// PostHook is called after successful processing
type NotificationHook interface {
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
}
// MovingOperationKey provides a unique key for tracking MOVING operations
// that combines sequence ID with connection identifier to handle duplicate
// sequence IDs across multiple connections to the same node.
type MovingOperationKey struct {
SeqID int64 // Sequence ID from MOVING notification
ConnID uint64 // Unique connection identifier
}
// String returns a string representation of the key for debugging
func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
type HitlessManager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
pool pool.Pooler
// MOVING operation tracking - using sync.Map for better concurrent performance
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
// Atomic state tracking - no locks needed for state queries
activeOperationCount atomic.Int64 // Number of active operations
closed atomic.Bool // Manager closed state
// Notification hooks for extensibility
hooks []NotificationHook
hooksMu sync.RWMutex // Protects hooks slice
poolHooksRef *PoolHook
}
// MovingOperation tracks an active MOVING operation.
type MovingOperation struct {
SeqID int64
NewEndpoint string
StartTime time.Time
Deadline time.Time
}
// NewHitlessManager creates a new simplified hitless manager.
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &HitlessManager{
client: client,
pool: pool,
options: client.GetOptions(),
config: config.Clone(),
hooks: make([]NotificationHook, 0),
}
// Set up push notification handling
if err := hm.setupPushNotifications(); err != nil {
return nil, err
}
return hm, nil
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *HitlessManager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm}
// Register handlers for all hitless upgrade notifications with the client's processor
for _, notificationType := range hitlessNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
}
}
return nil
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Create MOVING operation record
movingOp := &MovingOperation{
SeqID: seqID,
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Deadline: deadline,
}
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Duplicate MOVING operation ignored: %s", connID, seqID, key.String())
}
return nil
}
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Tracking MOVING operation: %s", connID, seqID, key.String())
}
// Increment active operation count atomically
hm.activeOperationCount.Add(1)
return nil
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Untracking MOVING operation: %s", connID, seqID, key.String())
}
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
} else {
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Operation not found for untracking: %s", connID, seqID, key.String())
}
}
}
// GetActiveMovingOperations returns active operations with composite keys.
// WARNING: This method creates a new map and copies all operations on every call.
// Use sparingly, especially in hot paths or high-frequency logging.
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
hm.activeMovingOps.Range(func(key, value interface{}) bool {
k := key.(MovingOperationKey)
op := value.(*MovingOperation)
// Create a copy to avoid sharing references
result[k] = &MovingOperation{
SeqID: op.SeqID,
NewEndpoint: op.NewEndpoint,
StartTime: op.StartTime,
Deadline: op.Deadline,
}
return true // Continue iteration
})
return result
}
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the hitless manager.
func (hm *HitlessManager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
// Shutdown the pool hook if it exists
if hm.poolHooksRef != nil {
// Use a timeout to prevent hanging indefinitely
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := hm.poolHooksRef.Shutdown(shutdownCtx)
if err != nil {
// was not able to close pool hook, keep closed state false
hm.closed.Store(false)
return err
}
// Remove the pool hook from the pool
if hm.pool != nil {
hm.pool.RemovePoolHook(hm.poolHooksRef)
}
}
// Clear all active operations
hm.activeMovingOps.Range(func(key, value interface{}) bool {
hm.activeMovingOps.Delete(key)
return true
})
// Reset counter
hm.activeOperationCount.Store(0)
return nil
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *HitlessManager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
return StateIdle
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
currentNotification := notification
for _, hook := range hm.hooks {
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
if !shouldContinue {
return modifiedNotification, false
}
currentNotification = modifiedNotification
}
return currentNotification, true
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
for _, hook := range hm.hooks {
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
}
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
// Get pool size from client options for better worker defaults
poolSize := 0
if hm.options != nil {
poolSize = hm.options.GetPoolSize()
}
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
hm.poolHooksRef.SetPool(hm.pool)
return hm.poolHooksRef
}
func (hm *HitlessManager) AddNotificationHook(notificationHook NotificationHook) {
hm.hooksMu.Lock()
defer hm.hooksMu.Unlock()
hm.hooks = append(hm.hooks, notificationHook)
}

View File

@@ -0,0 +1,260 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
)
// MockClient implements interfaces.ClientInterface for testing
type MockClient struct {
options interfaces.OptionsInterface
}
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
return mc.options
}
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
return &MockPushProcessor{}
}
// MockPushProcessor implements interfaces.NotificationProcessor for testing
type MockPushProcessor struct{}
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
return nil
}
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
return nil
}
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
return nil
}
// MockOptions implements interfaces.OptionsInterface for testing
type MockOptions struct{}
func (mo *MockOptions) GetReadTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetWriteTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetAddr() string {
return "localhost:6379"
}
func (mo *MockOptions) IsTLSEnabled() bool {
return false
}
func (mo *MockOptions) GetProtocol() int {
return 3 // RESP3
}
func (mo *MockOptions) GetPoolSize() int {
return 10
}
func (mo *MockOptions) GetNetwork() string {
return "tcp"
}
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
return nil, nil
}
}
func TestHitlessManagerRefactoring(t *testing.T) {
t.Run("AtomicStateTracking", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
// Test initial state
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress initially")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
// Add an operation
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Test state after adding operation
if !manager.IsHandoffInProgress() {
t.Error("Expected handoff in progress after adding operation")
}
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateMoving {
t.Errorf("Expected StateMoving, got %v", manager.GetState())
}
// Remove the operation
manager.UntrackOperationWithConnID(12345, 1)
// Test state after removing operation
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress after removing operation")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
})
t.Run("SyncMapPerformance", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Test concurrent operations
const numOps = 100
for i := 0; i < numOps; i++ {
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
if err != nil {
t.Fatalf("Failed to track operation %d: %v", i, err)
}
}
if manager.GetActiveOperationCount() != numOps {
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
}
// Test GetActiveMovingOperations
operations := manager.GetActiveMovingOperations()
if len(operations) != numOps {
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
}
// Remove all operations
for i := 0; i < numOps; i++ {
manager.UntrackOperationWithConnID(int64(i), uint64(i))
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
}
})
t.Run("DuplicateOperationHandling", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Add operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Try to add duplicate operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Duplicate operation should not return error: %v", err)
}
// Should still have only 1 operation
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
}
})
t.Run("NotificationTypeConstants", func(t *testing.T) {
// Test that constants are properly defined
expectedTypes := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
if len(hitlessNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
}
// Test that all expected types are present
typeMap := make(map[string]bool)
for _, t := range hitlessNotificationTypes {
typeMap[t] = true
}
for _, expected := range expectedTypes {
if !typeMap[expected] {
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
}
}
// Test that hitlessNotificationTypes contains all expected constants
expectedConstants := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
for _, expected := range expectedConstants {
found := false
for _, actual := range hitlessNotificationTypes {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
}
}
})
}

47
hitless/hooks.go Normal file
View File

@@ -0,0 +1,47 @@
package hitless
import (
"context"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/push"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel logging.LogLevel
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel.InfoOrAbove() { // Info level
// Log the notification type and content
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
internal.Logger.Printf(ctx, "hitless: conn[%d] processing %s notification: %v", connID, notificationType, notification)
}
return notification, true // Continue processing with unmodified notification
}
// PostHook logs the result after processing.
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
if result != nil && lh.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processing failed: %v - %v", connID, notificationType, result, notification)
} else if lh.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processed successfully", connID, notificationType)
}
}
// NewLoggingHook creates a new logging hook with the specified log level.
// Log levels: LogLevelError=errors, LogLevelWarn=warnings, LogLevelInfo=info, LogLevelDebug=debug
func NewLoggingHook(logLevel logging.LogLevel) *LoggingHook {
return &LoggingHook{LogLevel: logLevel}
}

179
hitless/pool_hook.go Normal file
View File

@@ -0,0 +1,179 @@
package hitless
import (
"context"
"net"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// HitlessManagerInterface defines the interface for completing handoff operations
type HitlessManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
// HandoffRequest represents a request to handoff a connection to a new endpoint
type HandoffRequest struct {
Conn *pool.Conn
ConnID uint64 // Unique connection identifier
Endpoint string
SeqID int64
Pool pool.Pooler // Pool to remove connection from on failure
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with hitless upgrade support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
baseDialer func(context.Context, string, string) (net.Conn, error)
// Network type (e.g., "tcp", "unix")
network string
// Worker manager for background handoff processing
workerManager *handoffWorkerManager
// Configuration for the hitless upgrade
config *Config
// Hitless manager for operation completion tracking
hitlessManager HitlessManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
// Apply defaults if config is nil or has zero values
if config == nil {
config = config.ApplyDefaultsWithPoolSize(poolSize)
}
ph := &PoolHook{
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
config: config,
// Hitless manager for operation completion tracking
hitlessManager: hitlessManager,
}
// Create worker manager
ph.workerManager = newHandoffWorkerManager(config, ph)
return ph
}
// SetPool sets the pool interface for removing connections on handoff failure
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
ph.pool = pooler
}
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return ph.workerManager.getCurrentWorkers()
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
return ph.workerManager.isHandoffPending(conn)
}
// GetPendingMap returns the pending map for testing purposes
func (ph *PoolHook) GetPendingMap() *sync.Map {
return ph.workerManager.getPendingMap()
}
// GetMaxWorkers returns the max workers for testing purposes
func (ph *PoolHook) GetMaxWorkers() int {
return ph.workerManager.getMaxWorkers()
}
// GetHandoffQueue returns the handoff queue for testing purposes
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
return ph.workerManager.getHandoffQueue()
}
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
return ph.workerManager.getCircuitBreakerStats()
}
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
func (ph *PoolHook) ResetCircuitBreakers() {
ph.workerManager.resetCircuitBreakers()
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
// in a handoff state at the moment.
// Check if connection is usable (not in a handoff state)
// Should not happen since the pool will not return a connection that is not usable.
if !conn.IsUsable() {
return ErrConnectionMarkedForHandoff
}
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
if conn.ShouldHandoff() {
return ErrConnectionMarkedForHandoff
}
return nil
}
// OnPut is called when a connection is returned to the pool
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
// first check if we should handoff for faster rejection
if !conn.ShouldHandoff() {
// Default behavior (no handoff): pool the connection
return true, false, nil
}
// check pending handoff to not queue the same connection twice
if ph.workerManager.isHandoffPending(conn) {
// Default behavior (pending handoff): pool the connection
return true, false, nil
}
if err := ph.workerManager.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
// Don't pool, remove connection, no error to caller
return false, true, nil
}
// Check if handoff was already processed by a worker before we can mark it as queued
if !conn.ShouldHandoff() {
// Handoff was already processed - this is normal and the connection should be pooled
return true, false, nil
}
if err := conn.MarkQueuedForHandoff(); err != nil {
// If marking fails, check if handoff was processed in the meantime
if !conn.ShouldHandoff() {
// Handoff was processed - this is normal, pool the connection
return true, false, nil
}
// Other error - remove the connection
return false, true, nil
}
return true, false, nil
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
return ph.workerManager.shutdownWorkers(ctx)
}

964
hitless/pool_hook_test.go Normal file
View File

@@ -0,0 +1,964 @@
package hitless
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
shouldFailInit bool
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// createMockPoolConnection creates a mock pool connection for testing
func createMockPoolConnection() *pool.Conn {
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
conn.SetUsable(true) // Make connection usable for testing
return conn
}
// mockPool implements pool.Pooler for testing
type mockPool struct {
removedConnections map[uint64]bool
mu sync.Mutex
}
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
return nil
}
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
// Not implemented for testing
}
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Use pool.Conn directly - no adapter needed
mp.removedConnections[conn.GetID()] = true
}
// WasRemoved safely checks if a connection was removed from the pool
func (mp *mockPool) WasRemoved(connID uint64) bool {
mp.mu.Lock()
defer mp.mu.Unlock()
return mp.removedConnections[connID]
}
func (mp *mockPool) Len() int {
return 0
}
func (mp *mockPool) IdleLen() int {
return 0
}
func (mp *mockPool) Stats() *pool.Stats {
return &pool.Stats{}
}
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) Close() error {
return nil
}
// TestConnectionHook tests the Redis connection processor functionality
func TestConnectionHook(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Verify connection is marked for handoff
if !conn.ShouldHandoff() {
t.Fatal("Connection should be marked for handoff")
}
// Set a mock initialization function with synchronization
initConnCalled := make(chan bool, 1)
proceedWithInit := make(chan bool, 1)
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
select {
case initConnCalled <- true:
default:
}
// Wait for test to proceed
<-proceedWithInit
return nil
}
conn.SetInitConnFunc(initConnFunc)
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection immediately (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled immediately with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for initialization to be called (indicates handoff started)
select {
case <-initConnCalled:
// Good, initialization was called
case <-time.After(1 * time.Second):
t.Fatal("Timeout waiting for initialization function to be called")
}
// Connection should be in pending map while initialization is blocked
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending handoffs map")
}
// Allow initialization to proceed
proceedWithInit <- true
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(2 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify handoff completed (removed from pending map)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map after handoff")
}
// Verify connection is usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after successful handoff")
}
// Verify handoff state is cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after completion")
}
})
t.Run("HandoffNotNeeded", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// Don't mark for handoff
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error when handoff not needed: %v", err)
}
// Should pool the connection normally
if !shouldPool {
t.Error("Connection should be pooled when no handoff needed")
}
if shouldRemove {
t.Error("Connection should not be removed when no handoff needed")
}
})
t.Run("EmptyEndpoint", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error with empty endpoint: %v", err)
}
// Should pool the connection (empty endpoint clears state)
if !shouldPool {
t.Error("Connection should be pooled after clearing empty endpoint")
}
if shouldRemove {
t.Error("Connection should not be removed after clearing empty endpoint")
}
// State should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
}
})
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("dial failed")
}
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 2, // Reduced retries for faster test
HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not return error to caller: %v", err)
}
// Should pool the connection initially (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled initially with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for handoff to complete and fail with proper timeout and polling
timeout := time.After(3 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
// wait for handoff to start
time.Sleep(50 * time.Millisecond)
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for failed handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
handoffCompleted = true
}
}
}
// Connection should be removed from pending map after failed handoff
if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up)
// Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete
time.Sleep(800 * time.Millisecond)
// After max retries are reached, the connection should be removed from pool
// and handoff state should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after max retries reached")
}
t.Logf("EventDrivenHandoffDialerError test completed successfully")
})
t.Run("BufferedDataRESP2", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// For this test, we'll just verify the logic works for connections without buffered data
// The actual buffered data detection is handled by the pool's connection health check
// which is outside the scope of the Redis connection processor
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection normally (no buffered data in mock)
if !shouldPool {
t.Error("Connection should be pooled when no buffered data")
}
if shouldRemove {
t.Error("Connection should not be removed when no buffered data")
}
})
t.Run("OnGet", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should not error for normal connection: %v", err)
}
})
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Simulate a pending handoff by marking for handoff and queuing
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Clean up
processor.GetPendingMap().Delete(conn)
})
t.Run("EventDrivenStateManagement", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Test initial state - no pending handoffs
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("New connection should not have pending handoffs")
}
// Test adding to pending map
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending map")
}
// Test OnGet with pending handoff
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
}
// Test removing from pending map and clearing handoff state
processor.GetPendingMap().Delete(conn)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map")
}
// Clear handoff state to simulate completed handoff
conn.ClearHandoffState()
conn.SetUsable(true) // Make connection usable again
// Test OnGet without pending handoff
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("Should not return error for non-pending connection: %v", err)
}
})
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
// Create processor with small queue to test optimization features
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 2,
MaxHandoffRetries: 3, // Small queue to trigger optimizations
LogLevel: 3, // Debug level to see optimization logs
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Add small delay to simulate network latency
time.Sleep(10 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create multiple connections that need handoff to fill the queue
connections := make([]*pool.Conn, 5)
for i := 0; i < 5; i++ {
connections[i] = createMockPoolConnection()
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
}
// Set a mock initialization function
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
}
ctx := context.Background()
successCount := 0
// Process connections - should trigger scaling and timeout logic
for _, conn := range connections {
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Logf("OnPut returned error (expected with timeout): %v", err)
}
if shouldPool && !shouldRemove {
successCount++
}
}
// With timeout and scaling, most handoffs should eventually succeed
if successCount == 0 {
t.Error("Should have queued some handoffs with timeout and scaling")
}
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
// Give time for workers to process and scaling to occur
time.Sleep(100 * time.Millisecond)
})
t.Run("WorkerScalingBehavior", func(t *testing.T) {
// Create processor with small queue to test scaling behavior
config := &Config{
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
HandoffQueueSize: 1,
MaxHandoffRetries: 3, // Very small queue to force scaling
LogLevel: 2, // Info level to see scaling logs
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Verify initial worker count (should be 0 with on-demand workers)
if processor.GetCurrentWorkers() != 0 {
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
}
if processor.GetMaxWorkers() != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers())
}
// The on-demand worker behavior creates workers only when needed
// This test just verifies the basic configuration is correct
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
processor.GetMaxWorkers(), processor.GetCurrentWorkers())
})
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
// Create processor with fast post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Allow retries for successful handoff
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
RelaxedTimeout: 5 * time.Second,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a connection and trigger handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Handoff should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify relaxed timeout is set with deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should have relaxed timeout after handoff")
}
// Test that timeout is still active before deadline
// We'll use HasRelaxedTimeout which internally checks the deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should still have active relaxed timeout before deadline")
}
// Wait for deadline to pass
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
// Test that timeout is automatically restored after deadline
// HasRelaxedTimeout should return false after deadline passes
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have active relaxed timeout after deadline")
}
// Additional verification: calling HasRelaxedTimeout again should still return false
// and should have cleared the internal timeout values
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have relaxed timeout after deadline (second check)")
}
t.Logf("Passive timeout restoration test completed successfully")
})
t.Run("UsableFlagBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a new connection without setting it usable
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
// Initially, connection should not be usable (not initialized)
if conn.IsUsable() {
t.Error("New connection should not be usable before initialization")
}
// Simulate initialization by setting usable to true
conn.SetUsable(true)
if !conn.IsUsable() {
t.Error("Connection should be usable after initialization")
}
// OnGet should succeed for usable connection
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed for usable connection: %v", err)
}
// Mark connection for handoff
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Connection should still be usable until queued, but marked for handoff
if !conn.IsUsable() {
t.Error("Connection should still be usable after being marked for handoff (until queued)")
}
if !conn.ShouldHandoff() {
t.Error("Connection should be marked for handoff")
}
// OnGet should fail for connection marked for handoff
err = processor.OnGet(ctx, conn, false)
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete
time.Sleep(50 * time.Millisecond)
// After handoff completion, connection should be usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after handoff completion")
}
// OnGet should succeed again
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed after handoff completion: %v", err)
}
t.Logf("Usable flag behavior test completed successfully")
})
t.Run("StaticQueueBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 50,
MaxHandoffRetries: 3, // Explicit static queue size
LogLevel: 2,
}
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
defer processor.Shutdown(context.Background())
// Verify queue capacity matches configured size
queueCapacity := cap(processor.GetHandoffQueue())
if queueCapacity != 50 {
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
}
// Test that queue size is static regardless of pool size
// (No dynamic resizing should occur)
ctx := context.Background()
// Fill part of the queue
for i := 0; i < 10; i++ {
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Failed to queue handoff %d: %v", i, err)
}
if !shouldPool || shouldRemove {
t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
i, shouldPool, shouldRemove)
}
}
// Verify queue capacity remains static (the main purpose of this test)
finalCapacity := cap(processor.GetHandoffQueue())
if finalCapacity != 50 {
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
}
// Note: We don't check queue size here because workers process items quickly
// The important thing is that the capacity remains static regardless of pool size
})
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
// Create a failing dialer that will cause handoff initialization to fail
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Return a connection that will fail during initialization
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
}
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a mock pool that tracks removals
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
processor.SetPool(mockPool)
ctx := context.Background()
// Create a connection and mark it for handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a failing initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("initialization failed")
})
// Process the connection - handoff should fail and connection should be removed
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after failed handoff attempt")
}
// Wait for handoff to be attempted and fail
time.Sleep(100 * time.Millisecond)
// Verify that the connection was removed from the pool
if !mockPool.WasRemoved(conn.GetID()) {
t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID())
}
t.Logf("Connection removal on handoff failure test completed successfully")
})
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
// Create config with short post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Allow retries for successful handoff
RelaxedTimeout: 5 * time.Second,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Fatalf("OnPut failed: %v", err)
}
if !shouldPool {
t.Error("Connection should be pooled after successful handoff")
}
if shouldRemove {
t.Error("Connection should not be removed after successful handoff")
}
// Wait for the handoff to complete (it happens asynchronously)
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify that relaxed timeout was applied to the new connection
if !conn.HasRelaxedTimeout() {
t.Error("New connection should have relaxed timeout applied after handoff")
}
// Wait for the post-handoff duration to expire
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
// Verify that relaxed timeout was automatically cleared
if conn.HasRelaxedTimeout() {
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
}
})
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
conn := createMockPoolConnection()
// First mark should succeed
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("First MarkForHandoff should succeed: %v", err)
}
// Second mark should fail
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
t.Fatal("Second MarkForHandoff should return error")
} else if err.Error() != "connection is already marked for handoff" {
t.Fatalf("Expected specific error message, got: %v", err)
}
// Verify original handoff data is preserved
if !conn.ShouldHandoff() {
t.Fatal("Connection should still be marked for handoff")
}
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
}
if conn.GetMovingSeqID() != 1 {
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
}
})
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
// Test that HandoffTimeout from config is actually used
customTimeout := 2 * time.Second
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
HandoffTimeout: customTimeout, // Custom timeout
MaxHandoffRetries: 1, // Single retry to speed up test
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a connection that will test the timeout
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a dialer that will check the context timeout
var timeoutVerified int32 // Use atomic for thread safety
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
// Check that the context has the expected timeout
deadline, ok := ctx.Deadline()
if !ok {
t.Error("Context should have a deadline")
return errors.New("no deadline")
}
// The deadline should be approximately customTimeout from now
expectedDeadline := time.Now().Add(customTimeout)
timeDiff := deadline.Sub(expectedDeadline)
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
expectedDeadline, deadline, timeDiff)
} else {
atomic.StoreInt32(&timeoutVerified, 1)
}
return nil // Successful handoff
})
// Trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
if err != nil {
t.Errorf("OnPut should not return error: %v", err)
}
// Connection should be queued for handoff
if !shouldPool || shouldRemove {
t.Errorf("Connection should be pooled for handoff processing")
}
// Wait for handoff to complete
time.Sleep(500 * time.Millisecond)
if atomic.LoadInt32(&timeoutVerified) == 0 {
t.Error("HandoffTimeout was not properly applied to context")
}
t.Logf("HandoffTimeout configuration test completed successfully")
})
}

View File

@@ -0,0 +1,276 @@
package hitless
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *HitlessManager
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
internal.Logger.Printf(ctx, "hitless: invalid notification format: %v", notification)
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid notification type format: %v", notification[0])
return ErrInvalidNotification
}
// Process pre-hooks - they can modify the notification or skip processing
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
if !shouldContinue {
return nil // Hooks decided to skip processing
}
var err error
switch notificationType {
case NotificationMoving:
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
case NotificationMigrating:
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
case NotificationMigrated:
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
case NotificationFailingOver:
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
case NotificationFailedOver:
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
default:
// Ignore other notification types (e.g., pub/sub messages)
err = nil
}
// Process post-hooks with the result
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
return err
}
// handleMoving processes MOVING notifications.
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
internal.Logger.Printf(ctx, "hitless: invalid MOVING notification: %v", notification)
return ErrInvalidNotification
}
seqID, ok := notification[1].(int64)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid seqID in MOVING notification: %v", notification[1])
return ErrInvalidNotification
}
// Extract timeS
timeS, ok := notification[2].(int64)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid timeS in MOVING notification: %v", notification[2])
return ErrInvalidNotification
}
newEndpoint := ""
if len(notification) > 3 {
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid newEndpoint in MOVING notification: %v", notification[3])
return ErrInvalidNotification
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MOVING notification")
return ErrInvalidNotification
}
// Type assert to get the underlying pool connection
var poolConn *pool.Conn
if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx)
return ErrInvalidNotification
}
// If the connection is closed or not pooled, we can ignore the notification
// this connection won't be remembered by the pool and will be garbage collected
// Keep pubsub connections around since they are not pooled but are long-lived
// and should be allowed to handoff (the pubsub instance will reconnect and change
// the underlying *pool.Conn)
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
return nil
}
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
if snh.manager.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] scheduling handoff to current endpoint in %v seconds",
poolConn.GetID(), timeS/2)
}
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
// delay the handoff for timeS/2 seconds to the same endpoint
// do this in a goroutine to avoid blocking the notification handler
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
// and there should be no possibility of a race condition or double handoff.
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
if poolConn == nil || poolConn.IsClosed() {
return
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine - use background context since original may be cancelled
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
}
})
return nil
}
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
}
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in hitless manager for monitoring/debugging
if snh.manager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return fmt.Errorf("hitless: manager not initialized")
}
return nil
}
// handleMigrating processes MIGRATING notifications.
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid MIGRATING notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification")
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for MIGRATING notification",
conn.GetID(),
snh.manager.config.RelaxedTimeout)
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleMigrated processes MIGRATED notifications.
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid MIGRATED notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification")
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for MIGRATED notification", connID)
}
conn.ClearRelaxedTimeout()
return nil
}
// handleFailingOver processes FAILING_OVER notifications.
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid FAILING_OVER notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification")
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for FAILING_OVER notification", connID, snh.manager.config.RelaxedTimeout)
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleFailedOver processes FAILED_OVER notifications.
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid FAILED_OVER notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification")
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for FAILED_OVER notification", connID)
}
conn.ClearRelaxedTimeout()
return nil
}

24
hitless/state.go Normal file
View File

@@ -0,0 +1,24 @@
package hitless
// State represents the current state of a hitless upgrade operation.
type State int
const (
// StateIdle indicates no upgrade is in progress
StateIdle State = iota
// StateHandoff indicates a connection handoff is in progress
StateMoving
)
// String returns a string representation of the state.
func (s State) String() string {
switch s {
case StateIdle:
return "idle"
case StateMoving:
return "moving"
default:
return "unknown"
}
}

245
hset_benchmark_test.go Normal file
View File

@@ -0,0 +1,245 @@
package redis_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// HSET Benchmark Tests
//
// This file contains benchmark tests for Redis HSET operations with different scales:
// 1, 10, 100, 1000, 10000, 100000 operations
//
// Prerequisites:
// - Redis server running on localhost:6379
// - No authentication required
//
// Usage:
// go test -bench=BenchmarkHSET -v ./hset_benchmark_test.go
// go test -bench=BenchmarkHSETPipelined -v ./hset_benchmark_test.go
// go test -bench=. -v ./hset_benchmark_test.go # Run all benchmarks
//
// Example output:
// BenchmarkHSET/HSET_1_operations-8 5000 250000 ns/op 1000000.00 ops/sec
// BenchmarkHSET/HSET_100_operations-8 100 10000000 ns/op 100000.00 ops/sec
//
// The benchmarks test three different approaches:
// 1. Individual HSET commands (BenchmarkHSET)
// 2. Pipelined HSET commands (BenchmarkHSETPipelined)
// BenchmarkHSET benchmarks HSET operations with different scales
func BenchmarkHSET(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 0,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_%d_operations", scale), func(b *testing.B) {
benchmarkHSETOperations(b, rdb, ctx, scale)
})
}
}
// benchmarkHSETOperations performs the actual HSET benchmark for a given scale
func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
hashKey := fmt.Sprintf("benchmark_hash_%d", operations)
b.ResetTimer()
b.StartTimer()
totalTimes := []time.Duration{}
for i := 0; i < b.N; i++ {
b.StopTimer()
// Clean up the hash before each iteration
rdb.Del(ctx, hashKey)
b.StartTimer()
startTime := time.Now()
// Perform the specified number of HSET operations
for j := 0; j < operations; j++ {
field := fmt.Sprintf("field_%d", j)
value := fmt.Sprintf("value_%d", j)
err := rdb.HSet(ctx, hashKey, field, value).Err()
if err != nil {
b.Fatalf("HSET operation failed: %v", err)
}
}
totalTimes = append(totalTimes, time.Now().Sub(startTime))
}
// Stop the timer to calculate metrics
b.StopTimer()
// Report operations per second
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
b.ReportMetric(opsPerSec, "ops/sec")
// Report average time per operation
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms")
}
// BenchmarkHSETPipelined benchmarks HSET operations using pipelining for better performance
func BenchmarkHSETPipelined(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 0,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_Pipelined_%d_operations", scale), func(b *testing.B) {
benchmarkHSETPipelined(b, rdb, ctx, scale)
})
}
}
// benchmarkHSETPipelined performs HSET benchmark using pipelining
func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations)
b.ResetTimer()
b.StartTimer()
totalTimes := []time.Duration{}
for i := 0; i < b.N; i++ {
b.StopTimer()
// Clean up the hash before each iteration
rdb.Del(ctx, hashKey)
b.StartTimer()
startTime := time.Now()
// Use pipelining for better performance
pipe := rdb.Pipeline()
// Add all HSET operations to the pipeline
for j := 0; j < operations; j++ {
field := fmt.Sprintf("field_%d", j)
value := fmt.Sprintf("value_%d", j)
pipe.HSet(ctx, hashKey, field, value)
}
// Execute all operations at once
_, err := pipe.Exec(ctx)
if err != nil {
b.Fatalf("Pipeline execution failed: %v", err)
}
totalTimes = append(totalTimes, time.Now().Sub(startTime))
}
b.StopTimer()
// Report operations per second
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
b.ReportMetric(opsPerSec, "ops/sec")
// Report average time per operation
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms")
}
// add same tests but with RESP2
func BenchmarkHSET_RESP2(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password docs
DB: 0, // use default DB
Protocol: 2,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_RESP2_%d_operations", scale), func(b *testing.B) {
benchmarkHSETOperations(b, rdb, ctx, scale)
})
}
}
func BenchmarkHSETPipelined_RESP2(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password docs
DB: 0, // use default DB
Protocol: 2,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_Pipelined_RESP2_%d_operations", scale), func(b *testing.B) {
benchmarkHSETPipelined(b, rdb, ctx, scale)
})
}
}

View File

@@ -0,0 +1,54 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the hitless upgrade package to avoid circular dependencies.
package interfaces
import (
"context"
"net"
"time"
)
// NotificationProcessor is (most probably) a push.NotificationProcessor
// forward declaration to avoid circular imports
type NotificationProcessor interface {
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
UnregisterHandler(pushNotificationName string) error
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for hitless upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface
// GetPushProcessor returns the client's push notification processor.
GetPushProcessor() NotificationProcessor
}
// OptionsInterface defines the interface for client options.
// Uses an adapter pattern to avoid circular dependencies.
type OptionsInterface interface {
// GetReadTimeout returns the read timeout.
GetReadTimeout() time.Duration
// GetWriteTimeout returns the write timeout.
GetWriteTimeout() time.Duration
// GetNetwork returns the network type.
GetNetwork() string
// GetAddr returns the connection address.
GetAddr() string
// IsTLSEnabled returns true if TLS is enabled.
IsTLSEnabled() bool
// GetProtocol returns the protocol version.
GetProtocol() int
// GetPoolSize returns the connection pool size.
GetPoolSize() int
// NewDialer returns a new dialer function for the connection.
NewDialer() func(context.Context) (net.Conn, error)
}

View File

@@ -7,20 +7,27 @@ import (
"os"
)
// TODO (ned): Revisit logging
// Add more standardized approach with log levels and configurability
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
type DefaultLogger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
func NewDefaultLogger() Logging {
return &DefaultLogger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
}
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
var Logger Logging = NewDefaultLogger()

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
if err != nil {
b.Fatal(err)
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
}
})
})

View File

@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should use default buffer sizes when not specified", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
})
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should handle zero buffer sizes by using defaults", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
// without setting ReadBufferSize and WriteBufferSize
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
// ReadBufferSize and WriteBufferSize are not set (will be 0)
})

View File

@@ -3,7 +3,10 @@ package pool
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
@@ -12,17 +15,74 @@ import (
var noDeadline = time.Time{}
// Global atomic counter for connection IDs
var connIDCounter uint64
// HandoffState represents the atomic state for connection handoffs
// This struct is stored atomically to prevent race conditions between
// checking handoff status and reading handoff parameters
type HandoffState struct {
ShouldHandoff bool // Whether connection should be handed off
Endpoint string // New endpoint for handoff
SeqID int64 // Sequence ID from MOVING notification
}
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
type atomicNetConn struct {
conn net.Conn
}
// generateConnID generates a fast unique identifier for a connection with zero allocations
func generateConnID() uint64 {
return atomic.AddUint64(&connIDCounter, 1)
}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
// Lock-free netConn access using atomic.Value
// Contains *atomicNetConn wrapper, accessed atomically for better performance
netConnAtomic atomic.Value // stores *atomicNetConn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
// Lightweight mutex to protect reader operations during handoff
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
pooled bool
pubsub bool
closed atomic.Bool
createdAt time.Time
expiresAt time.Time
// Hitless upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
// Counter to track multiple relaxed timeout setters if we have nested calls
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
// if counter reaches 0, we clear the relaxed timeouts
relaxedCounter atomic.Int32
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking across handoffs
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// Atomic handoff state to prevent race conditions
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
handoffStateAtomic atomic.Value // stores *HandoffState
onClose func() error
}
@@ -33,8 +93,8 @@ func NewConn(netConn net.Conn) *Conn {
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
id: generateConnID(), // Generate unique ID for this connection
}
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
@@ -50,6 +110,21 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
}
// Store netConn atomically for lock-free access using wrapper
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.handoffRetriesAtomic.Store(0) // 0 initially
// Initialize handoff state atomically
initialHandoffState := &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
cn.handoffStateAtomic.Store(initialHandoffState)
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
@@ -64,23 +139,430 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
if v := cn.netConnAtomic.Load(); v != nil {
if wrapper, ok := v.(*atomicNetConn); ok {
return wrapper.conn
}
}
return nil
}
// setNetConn stores the network connection atomically (lock-free).
// This is used for the fast path of connection replacement.
func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
}
// getHandoffState returns the current handoff state atomically (lock-free).
func (cn *Conn) getHandoffState() *HandoffState {
state := cn.handoffStateAtomic.Load()
if state == nil {
// Return default state if not initialized
return &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
}
return state.(*HandoffState)
}
// setHandoffState sets the handoff state atomically (lock-free).
func (cn *Conn) setHandoffState(state *HandoffState) {
cn.handoffStateAtomic.Store(state)
}
// shouldHandoff returns true if connection needs handoff (lock-free).
func (cn *Conn) shouldHandoff() bool {
return cn.getHandoffState().ShouldHandoff
}
// getMovingSeqID returns the sequence ID atomically (lock-free).
func (cn *Conn) getMovingSeqID() int64 {
return cn.getHandoffState().SeqID
}
// getNewEndpoint returns the new endpoint atomically (lock-free).
func (cn *Conn) getNewEndpoint() string {
return cn.getHandoffState().Endpoint
}
// setHandoffRetries sets the retry count atomically (lock-free).
func (cn *Conn) setHandoffRetries(retries int) {
cn.handoffRetriesAtomic.Store(uint32(retries))
}
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
func (cn *Conn) IsPooled() bool {
return cn.pooled
}
// IsPubSub returns true if the connection is used for PubSub.
func (cn *Conn) IsPubSub() bool {
return cn.pubsub
}
func (cn *Conn) IsInited() bool {
return cn.Inited.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
}
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
// Uses atomic operations for lock-free access.
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
// Use atomic load to get current value for CAS to avoid stale value race
current := cn.relaxedCounter.Load()
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
cn.clearRelaxedTimeout()
}
}
}
func (cn *Conn) clearRelaxedTimeout() {
cn.relaxedReadTimeoutNs.Store(0)
cn.relaxedWriteTimeoutNs.Store(0)
cn.relaxedDeadlineNs.Store(0)
cn.relaxedCounter.Store(0)
}
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
// This checks both the timeout values and the deadline (if set).
// Uses atomic operations for lock-free access.
func (cn *Conn) HasRelaxedTimeout() bool {
// Fast path: no relaxed timeouts are set
if cn.relaxedCounter.Load() <= 0 {
return false
}
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// If no relaxed timeouts are set, return false
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
return false
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, relaxed timeouts are active
if deadlineNs == 0 {
return true
}
// If deadline is set, check if it's still in the future
return time.Now().UnixNano() < deadlineNs
}
// getEffectiveReadTimeout returns the timeout to use for read operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
// Fast path: no relaxed timeout set
if readTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(readTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
// getEffectiveWriteTimeout returns the timeout to use for write operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// Fast path: no relaxed timeout set
if writeTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(writeTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
cn.initConnFunc = fn
}
// ExecuteInitConn runs the stored connection initialization function if available.
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
if cn.initConnFunc != nil {
return cn.initConnFunc(ctx, cn)
}
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
// Store the new connection atomically first (lock-free)
cn.setNetConn(netConn)
// Protect reader reset operations to avoid data races
// Use write lock since we're modifying the reader state
cn.readerMu.Lock()
cn.rd.Reset(netConn)
cn.readerMu.Unlock()
cn.bw.Reset(netConn)
}
// GetNetConn safely returns the current network connection using atomic load (lock-free).
// This method is used by the pool for health checks and provides better performance.
func (cn *Conn) GetNetConn() net.Conn {
return cn.getNetConn()
}
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
// New connection is not initialized yet
cn.Inited.Store(false)
// Replace the underlying connection
cn.SetNetConn(netConn)
return cn.ExecuteInitConn(ctx)
}
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
// Returns an error if the connection is already marked for handoff.
// This method uses atomic compare-and-swap to ensure all handoff state is updated atomically.
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
const maxRetries = 50
const baseDelay = time.Microsecond
for attempt := 0; attempt < maxRetries; attempt++ {
currentState := cn.getHandoffState()
// Check if already marked for handoff
if currentState.ShouldHandoff {
return errors.New("connection is already marked for handoff")
}
// Create new state with handoff enabled
newState := &HandoffState{
ShouldHandoff: true,
Endpoint: newEndpoint,
SeqID: seqID,
}
// Atomic compare-and-swap to update entire state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
return nil
}
// If CAS failed, add exponential backoff to reduce contention
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
}
return fmt.Errorf("failed to mark connection for handoff after %d attempts due to high contention", maxRetries)
}
func (cn *Conn) MarkQueuedForHandoff() error {
const maxRetries = 50
const baseDelay = time.Microsecond
for attempt := 0; attempt < maxRetries; attempt++ {
currentState := cn.getHandoffState()
// Check if marked for handoff
if !currentState.ShouldHandoff {
return errors.New("connection was not marked for handoff")
}
// Create new state with handoff disabled (queued)
newState := &HandoffState{
ShouldHandoff: false,
Endpoint: currentState.Endpoint, // Preserve endpoint for handoff processing
SeqID: currentState.SeqID, // Preserve seqID for handoff processing
}
// Atomic compare-and-swap to update state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
cn.setUsable(false)
return nil
}
// If CAS failed, add exponential backoff to reduce contention
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
}
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
}
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
func (cn *Conn) ShouldHandoff() bool {
return cn.shouldHandoff()
}
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
func (cn *Conn) GetHandoffEndpoint() string {
return cn.getNewEndpoint()
}
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
func (cn *Conn) GetMovingSeqID() int64 {
return cn.getMovingSeqID()
}
// GetHandoffInfo returns all handoff information atomically (lock-free).
// This method prevents race conditions by returning all handoff state in a single atomic operation.
// Returns (shouldHandoff, endpoint, seqID).
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
state := cn.getHandoffState()
return state.ShouldHandoff, state.Endpoint, state.SeqID
}
// GetID returns the unique identifier for this connection.
func (cn *Conn) GetID() uint64 {
return cn.id
}
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
func (cn *Conn) ClearHandoffState() {
// Create clean state
cleanState := &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
// Atomically set clean state
cn.setHandoffState(cleanState)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return cn.incrementHandoffRetries(n)
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
return cn.rd.Buffered() > 0
}
// PeekReplyTypeSafe safely peeks at the reply type.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
if cn.rd.Buffered() <= 0 {
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
}
return cn.rd.PeekReplyType()
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Write(b)
}
return 0, net.ErrClosed
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.RemoteAddr()
}
return nil
}
@@ -89,7 +571,16 @@ func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
// Get the connection directly from atomic storage
netConn := cn.getNetConn()
if netConn == nil {
return fmt.Errorf("redis: connection not available")
}
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
}
@@ -100,13 +591,26 @@ func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
// Always set write deadline, even if getNetConn() returns nil
// This prevents write operations from hanging indefinitely
if netConn := cn.getNetConn(); netConn != nil {
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
} else {
// If getNetConn() returns nil, we still need to respect the timeout
// Return an error to prevent indefinite blocking
return fmt.Errorf("redis: connection not available for write operation")
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
if netConn := cn.getNetConn(); netConn != nil {
cn.bw.Reset(netConn)
}
}
if err := fn(cn.wr); err != nil {
@@ -116,12 +620,33 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}
func (cn *Conn) IsClosed() bool {
return cn.closed.Load()
}
func (cn *Conn) Close() error {
cn.closed.Store(true)
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Close()
}
return nil
}
// MaybeHasData tries to peek at the next byte in the socket without consuming it
// This is used to check if there are push notifications available
// Important: This will work on Linux, but not on Windows
func (cn *Conn) MaybeHasData() bool {
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return maybeHasData(netConn)
}
return false
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {

View File

@@ -12,6 +12,9 @@ import (
var errUnexpectedRead = errors.New("unexpected read from socket")
// connCheck checks if the connection is still alive and if there is data in the socket
// it will try to peek at the next byte without consuming it since we may want to work with it
// later on (e.g. push notifications)
func connCheck(conn net.Conn) error {
// Reset previous timeout.
_ = conn.SetDeadline(time.Time{})
@@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error {
if err := rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
// Use MSG_PEEK to peek at data without consuming it
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
switch {
case n == 0 && err == nil:
sysErr = io.EOF
@@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error {
return sysErr
}
// maybeHasData checks if there is data in the socket without consuming it
func maybeHasData(conn net.Conn) bool {
return connCheck(conn) == errUnexpectedRead
}

View File

@@ -7,3 +7,8 @@ import "net"
func connCheck(conn net.Conn) error {
return nil
}
// since we can't check for data on the socket, we just assume there is some
func maybeHasData(conn net.Conn) bool {
return true
}

View File

@@ -0,0 +1,92 @@
package pool
import (
"net"
"sync"
"testing"
"time"
)
// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout
func TestConcurrentRelaxedTimeoutClearing(t *testing.T) {
// Create a dummy connection for testing
netConn := &net.TCPConn{}
cn := NewConn(netConn)
defer cn.Close()
// Set relaxed timeout multiple times to increase counter
cn.SetRelaxedTimeout(time.Second, time.Second)
cn.SetRelaxedTimeout(time.Second, time.Second)
cn.SetRelaxedTimeout(time.Second, time.Second)
// Verify counter is 3
if count := cn.relaxedCounter.Load(); count != 3 {
t.Errorf("Expected relaxed counter to be 3, got %d", count)
}
// Clear timeouts concurrently to test race condition fix
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cn.ClearRelaxedTimeout()
}()
}
wg.Wait()
// Verify counter is 0 and timeouts are cleared
if count := cn.relaxedCounter.Load(); count != 0 {
t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count)
}
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout)
}
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout)
}
}
// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario
func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) {
netConn := &net.TCPConn{}
cn := NewConn(netConn)
defer cn.Close()
// Set relaxed timeout once
cn.SetRelaxedTimeout(time.Second, time.Second)
// Verify counter is 1
if count := cn.relaxedCounter.Load(); count != 1 {
t.Errorf("Expected relaxed counter to be 1, got %d", count)
}
// Test concurrent clearing with race condition scenario
var wg sync.WaitGroup
// Multiple goroutines try to clear simultaneously
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cn.ClearRelaxedTimeout()
}()
}
wg.Wait()
// Verify final state is consistent
if count := cn.relaxedCounter.Load(); count != 0 {
t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count)
}
// Verify timeouts are actually cleared
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout)
}
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout)
}
if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 {
t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline)
}
}

View File

@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
}
func (cn *Conn) NetConn() net.Conn {
return cn.netConn
return cn.getNetConn()
}
func (p *ConnPool) CheckMinIdleConns() {

114
internal/pool/hooks.go Normal file
View File

@@ -0,0 +1,114 @@
package pool
import (
"context"
"sync"
)
// PoolHook defines the interface for connection lifecycle hooks.
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
}
// PoolHookManager manages multiple pool hooks.
type PoolHookManager struct {
hooks []PoolHook
hooksMu sync.RWMutex
}
// NewPoolHookManager creates a new pool hook manager.
func NewPoolHookManager() *PoolHookManager {
return &PoolHookManager{
hooks: make([]PoolHook, 0),
}
}
// AddHook adds a pool hook to the manager.
// Hooks are called in the order they were added.
func (phm *PoolHookManager) AddHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
phm.hooks = append(phm.hooks, hook)
}
// RemoveHook removes a pool hook from the manager.
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
for i, h := range phm.hooks {
if h == hook {
// Remove hook by swapping with last element and truncating
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
phm.hooks = phm.hooks[:len(phm.hooks)-1]
break
}
}
}
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
return err
}
}
return nil
}
// ProcessOnPut calls all OnPut hooks in order.
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
shouldPool = true // Default to pooling the connection
for _, hook := range phm.hooks {
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
if hookErr != nil {
return false, true, hookErr
}
// If any hook says to remove or not pool, respect that decision
if hookShouldRemove {
return false, true, nil
}
if !hookShouldPool {
shouldPool = false
}
}
return shouldPool, false, nil
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
return len(phm.hooks)
}
// GetHooks returns a copy of all registered hooks.
func (phm *PoolHookManager) GetHooks() []PoolHook {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
hooks := make([]PoolHook, len(phm.hooks))
copy(hooks, phm.hooks)
return hooks
}

213
internal/pool/hooks_test.go Normal file
View File

@@ -0,0 +1,213 @@
package pool
import (
"context"
"errors"
"net"
"testing"
"time"
)
// TestHook for testing hook functionality
type TestHook struct {
OnGetCalled int
OnPutCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
}
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
th.OnGetCalled++
return th.GetError
}
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
th.OnPutCalled++
return th.ShouldPool, th.ShouldRemove, th.PutError
}
func TestPoolHookManager(t *testing.T) {
manager := NewPoolHookManager()
// Test initial state
if manager.GetHookCount() != 0 {
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
}
// Add hooks
hook1 := &TestHook{ShouldPool: true}
hook2 := &TestHook{ShouldPool: true}
manager.AddHook(hook1)
manager.AddHook(hook2)
if manager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
}
// Test ProcessOnGet
ctx := context.Background()
conn := &Conn{} // Mock connection
err := manager.ProcessOnGet(ctx, conn, false)
if err != nil {
t.Errorf("ProcessOnGet should not error: %v", err)
}
if hook1.OnGetCalled != 1 {
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
}
if hook2.OnGetCalled != 1 {
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
}
// Test ProcessOnPut
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if !shouldPool {
t.Error("Expected shouldPool to be true")
}
if shouldRemove {
t.Error("Expected shouldRemove to be false")
}
if hook1.OnPutCalled != 1 {
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
}
if hook2.OnPutCalled != 1 {
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
}
// Remove a hook
manager.RemoveHook(hook1)
if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
}
}
func TestHookErrorHandling(t *testing.T) {
manager := NewPoolHookManager()
// Hook that returns error on Get
errorHook := &TestHook{
GetError: errors.New("test error"),
ShouldPool: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(errorHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
// Test that error stops processing
err := manager.ProcessOnGet(ctx, conn, false)
if err == nil {
t.Error("Expected error from ProcessOnGet")
}
if errorHook.OnGetCalled != 1 {
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
}
// normalHook should not be called due to error
if normalHook.OnGetCalled != 0 {
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
}
}
func TestHookShouldRemove(t *testing.T) {
manager := NewPoolHookManager()
// Hook that says to remove connection
removeHook := &TestHook{
ShouldPool: false,
ShouldRemove: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(removeHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if shouldPool {
t.Error("Expected shouldPool to be false")
}
if !shouldRemove {
t.Error("Expected shouldRemove to be true")
}
if removeHook.OnPutCalled != 1 {
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
}
// normalHook should not be called due to early return
if normalHook.OnPutCalled != 0 {
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
}
}
func TestPoolWithHooks(t *testing.T) {
// Create a pool with hooks
hookManager := NewPoolHookManager()
testHook := &TestHook{ShouldPool: true}
hookManager.AddHook(testHook)
opt := &Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // Mock connection
},
PoolSize: 1,
DialTimeout: time.Second,
}
pool := NewConnPool(opt)
defer pool.Close()
// Add hook to pool after creation
pool.AddPoolHook(testHook)
// Verify hooks are initialized
if pool.hookManager == nil {
t.Error("Expected hookManager to be initialized")
}
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
}
// Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true}
pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
}
// Test removing hook from pool
pool.RemovePoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
}
}

View File

@@ -9,6 +9,8 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
)
var (
@@ -21,6 +23,23 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
// popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during hitless upgrades or network issues).
// Value of 50 provides sufficient resilience without excessive overhead.
// This is capped by the idle connection count, so we won't loop excessively.
popAttempts = 50
// getAttempts is the maximum number of attempts to get a connection that passes
// hook validation (e.g., hitless upgrade hooks). This protects against race conditions
// where hooks might temporarily reject connections during cluster transitions.
// Value of 3 balances resilience with performance - most hook rejections resolve quickly.
getAttempts = 3
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
maxTime = minTime.Add(1<<63 - 1)
noExpiration = maxTime
)
var timers = sync.Pool{
@@ -37,11 +56,14 @@ type Stats struct {
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
WaitCount uint32 // number of times a connection was waited
Unusable uint32 // number of times a connection was found to be unusable
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
PubSubStats PubSubStats
}
type Pooler interface {
@@ -56,24 +78,35 @@ type Pooler interface {
IdleLen() int
Stats() *Stats
AddPoolHook(hook PoolHook)
RemovePoolHook(hook PoolHook)
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
PoolFIFO bool
PoolSize int
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxActiveConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
ReadBufferSize int
WriteBufferSize int
PoolFIFO bool
PoolSize int32
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int32
MaxIdleConns int32
MaxActiveConns int32
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
PushNotificationsEnabled bool
// DialerRetries is the maximum number of retry attempts when dialing fails.
// Default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
// Default: 100ms
DialerRetryTimeout time.Duration
}
type lastDialErrorWrap struct {
@@ -89,16 +122,21 @@ type ConnPool struct {
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
conns map[uint64]*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
poolSize atomic.Int32
idleConnsLen atomic.Int32
idleCheckInProgress atomic.Bool
stats Stats
waitDurationNs atomic.Int64
_closed uint32 // atomic
// Pool hooks manager for flexible connection processing
hookManagerMu sync.RWMutex
hookManager *PoolHookManager
}
var _ Pooler = (*ConnPool)(nil)
@@ -108,34 +146,69 @@ func NewConnPool(opt *Options) *ConnPool {
cfg: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize),
}
// Only create MinIdleConns if explicitly requested (> 0)
// This avoids creating connections during pool initialization for tests
if opt.MinIdleConns > 0 {
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
}
return p
}
// initializeHooks sets up the pool hooks system.
func (p *ConnPool) initializeHooks() {
p.hookManager = NewPoolHookManager()
}
// AddPoolHook adds a pool hook to the pool.
func (p *ConnPool) AddPoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager == nil {
p.initializeHooks()
}
p.hookManager.AddHook(hook)
}
// RemovePoolHook removes a pool hook from the pool.
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager != nil {
p.hookManager.RemoveHook(hook)
}
}
func (p *ConnPool) checkMinIdleConns() {
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
return
}
defer p.idleCheckInProgress.Store(false)
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
// Only create idle connections if we haven't reached the total pool size limit
// MinIdleConns should be a subset of PoolSize, not additional connections
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
select {
case p.queue <- struct{}{}:
p.poolSize++
p.idleConnsLen++
p.poolSize.Add(1)
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
@@ -144,12 +217,9 @@ func (p *ConnPool) checkMinIdleConns() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
default:
@@ -166,6 +236,9 @@ func (p *ConnPool) addIdleConn() error {
if err != nil {
return err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
@@ -176,11 +249,15 @@ func (p *ConnPool) addIdleConn() error {
return ErrClosed
}
p.conns = append(p.conns, cn)
p.conns[cn.GetID()] = cn
p.idleConns = append(p.idleConns, cn)
return nil
}
// NewConn creates a new connection and returns it to the user.
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
//
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
@@ -190,33 +267,44 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, ErrClosed
}
p.connsMu.Lock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
return nil, ErrPoolExhausted
}
p.connsMu.Unlock()
cn, err := p.dialConn(ctx, pooled)
dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout)
defer cancel()
cn, err := p.dialConn(dialCtx, pooled)
if err != nil {
return nil, err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
_ = cn.Close()
return nil, ErrPoolExhausted
}
p.conns = append(p.conns, cn)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.closed() {
_ = cn.Close()
return nil, ErrClosed
}
// Check if pool was closed while we were waiting for the lock
if p.conns == nil {
p.conns = make(map[uint64]*Conn)
}
p.conns[cn.GetID()] = cn
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
currentPoolSize := p.poolSize.Load()
if currentPoolSize >= p.cfg.PoolSize {
cn.pooled = false
} else {
p.poolSize++
p.poolSize.Add(1)
}
}
@@ -232,18 +320,57 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}
// Retry dialing with backoff
// the context timeout is already handled by the context passed in
// so we may never reach the max retries, higher values don't hurt
maxRetries := p.cfg.DialerRetries
if maxRetries <= 0 {
maxRetries = 5 // Default value
}
backoffDuration := p.cfg.DialerRetryTimeout
if backoffDuration <= 0 {
backoffDuration = 100 * time.Millisecond // Default value
}
var lastErr error
shouldLoop := true
// when the timeout is reached, we should stop retrying
// but keep the lastErr to return to the caller
// instead of a generic context deadline exceeded error
for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ {
netConn, err := p.cfg.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
lastErr = err
// Add backoff delay for retry attempts
// (not for the first attempt, do at least one)
select {
case <-ctx.Done():
shouldLoop = false
case <-time.After(backoffDuration):
// Continue with retry
}
continue
}
// Success - create connection
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
if p.cfg.ConnMaxLifetime > 0 {
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
} else {
cn.expiresAt = noExpiration
}
return cn, nil
}
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr)
// All retries failed - handle error tracking
p.setLastDialError(lastErr)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
return cn, nil
return nil, lastErr
}
func (p *ConnPool) tryDial() {
@@ -283,6 +410,14 @@ func (p *ConnPool) getLastDialError() error {
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return p.getConn(ctx)
}
// getConn returns a connection from the pool.
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
var cn *Conn
var err error
if p.closed() {
return nil, ErrClosed
}
@@ -291,9 +426,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
now := time.Now()
attempts := 0
for {
if attempts >= getAttempts {
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
break
}
attempts++
p.connsMu.Lock()
cn, err := p.popIdle()
cn, err = p.popIdle()
p.connsMu.Unlock()
if err != nil {
@@ -305,11 +448,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
break
}
if !p.isHealthyConn(cn) {
if !p.isHealthyConn(cn, now) {
_ = p.CloseConn(cn)
continue
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
// Failed to process connection, discard it
_ = p.CloseConn(cn)
continue
}
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
@@ -322,6 +479,19 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
// Failed to process connection, discard it
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
_ = p.CloseConn(newcn)
return nil, err
}
}
return newcn, nil
}
@@ -350,7 +520,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
}
return ctx.Err()
case p.queue <- struct{}{}:
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
atomic.AddUint32(&p.stats.WaitCount, 1)
if !timer.Stop() {
<-timer.C
@@ -370,52 +540,130 @@ func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
defer p.checkMinIdleConns()
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
attempts := 0
maxAttempts := util.Min(popAttempts, n)
for attempts < maxAttempts {
if len(p.idleConns) == 0 {
return nil, nil
}
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
p.idleConns = p.idleConns[:len(p.idleConns)-1]
} else {
idx := n - 1
idx := len(p.idleConns) - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
p.idleConnsLen--
p.checkMinIdleConns()
attempts++
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
}
// Connection is not usable, put it back in the pool
if p.cfg.PoolFIFO {
// FIFO: put at end (will be picked up last since we pop from front)
p.idleConns = append(p.idleConns, cn)
} else {
// LIFO: put at beginning (will be picked up last since we pop from end)
p.idleConns = append([]*Conn{cn}, p.idleConns...)
}
cn = nil
}
// If we exhausted all attempts without finding a usable connection, return nil
if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() {
internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts)
return nil, nil
}
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf(ctx, "Conn has unread data")
p.Remove(ctx, cn, BadConnError{})
// Process connection using the hooks system
shouldPool := true
shouldRemove := false
var err error
if cn.HasBufferedData() {
// Peek at the reply type to check if it's a push notification
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
// Not a push notification or error peeking, remove connection
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
p.Remove(ctx, cn, err)
}
// It's a push notification, allow pooling (client will handle it)
}
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
if err != nil {
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
p.Remove(ctx, cn, err)
return
}
}
// If hooks say to remove the connection, do so
if shouldRemove {
p.Remove(ctx, cn, errors.New("hook requested removal"))
return
}
// If processor says not to pool the connection, remove it
if !shouldPool {
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
p.Remove(ctx, cn, errors.New("connection not pooled"))
return
}
var shouldCloseConn bool
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue
if !cn.IsUsable() {
if p.cfg.PoolFIFO {
p.connsMu.Lock()
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
p.connsMu.Unlock()
} else {
p.removeConn(cn)
p.connsMu.Lock()
p.idleConns = append([]*Conn{cn}, p.idleConns...)
p.connsMu.Unlock()
}
} else {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
}
p.idleConnsLen.Add(1)
} else {
p.removeConnWithLock(cn)
shouldCloseConn = true
}
p.connsMu.Unlock()
p.freeTurn()
if shouldCloseConn {
@@ -425,8 +673,13 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
// Check if we need to create new idle connections to maintain MinIdleConns
p.checkMinIdleConns()
}
func (p *ConnPool) CloseConn(cn *Conn) error {
@@ -441,17 +694,23 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
cid := cn.GetID()
delete(p.conns, cid)
atomic.AddUint32(&p.stats.StaleConns, 1)
// Decrement pool size counter when removing a connection
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
}
p.poolSize.Add(-1)
// this can be idle conn
for idx, ic := range p.idleConns {
if ic.GetID() == cid {
internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid)
p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...)
p.idleConnsLen.Add(-1)
break
}
}
atomic.AddUint32(&p.stats.StaleConns, 1)
}
}
func (p *ConnPool) closeConn(cn *Conn) error {
@@ -469,9 +728,9 @@ func (p *ConnPool) Len() int {
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
n := p.idleConnsLen.Load()
p.connsMu.Unlock()
return n
return int(n)
}
func (p *ConnPool) Stats() *Stats {
@@ -480,6 +739,7 @@ func (p *ConnPool) Stats() *Stats {
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
Unusable: atomic.LoadUint32(&p.stats.Unusable),
WaitDurationNs: p.waitDurationNs.Load(),
TotalConns: uint32(p.Len()),
@@ -520,28 +780,45 @@ func (p *ConnPool) Close() error {
}
}
p.conns = nil
p.poolSize = 0
p.poolSize.Store(0)
p.idleConns = nil
p.idleConnsLen = 0
p.idleConnsLen.Store(0)
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
// slight optimization, check expiresAt first.
if cn.expiresAt.Before(now) {
return false
}
// Check if connection has exceeded idle timeout
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false
}
if connCheck(cn.netConn) != nil {
cn.SetUsedAt(now)
// Check basic connection health
// Use GetNetConn() to safely access netConn and avoid data races
if err := connCheck(cn.getNetConn()); err != nil {
// If there's unexpected data, it might be push notifications (RESP3)
// However, push notification processing is now handled by the client
// before WithReader to ensure proper context is available to handlers
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
// For RESP3 connections with push notifications, we allow some buffered data
// The client will process these notifications before using the connection
internal.Logger.Printf(context.Background(), "push: connection has buffered data, likely push notifications - will be processed by client")
return true // Connection is healthy, client will handle notifications
}
return false // Unexpected data, not push notifications, connection is unhealthy
} else {
return false
}
cn.SetUsedAt(now)
}
return true
}

View File

@@ -1,6 +1,8 @@
package pool
import "context"
import (
"context"
)
type SingleConnPool struct {
pool Pooler
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -2,15 +2,17 @@ package pool_test
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
var _ = Describe("ConnPool", func() {
@@ -20,7 +22,7 @@ var _ = Describe("ConnPool", func() {
BeforeEach(func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -45,11 +47,11 @@ var _ = Describe("ConnPool", func() {
<-closedChan
return &net.TCPConn{}, nil
},
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
MinIdleConns: minIdleConns,
MinIdleConns: int32(minIdleConns),
})
wg.Wait()
Expect(connPool.Close()).NotTo(HaveOccurred())
@@ -105,7 +107,7 @@ var _ = Describe("ConnPool", func() {
// ok
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
// Check that Get is unblocked.
select {
@@ -130,8 +132,8 @@ var _ = Describe("MinIdleConns", func() {
newConnPool := func() *pool.ConnPool {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: poolSize,
MinIdleConns: minIdleConns,
PoolSize: int32(poolSize),
MinIdleConns: int32(minIdleConns),
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: -1,
@@ -168,7 +170,7 @@ var _ = Describe("MinIdleConns", func() {
Context("after Remove", func() {
BeforeEach(func() {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
})
It("has idle connections", func() {
@@ -245,7 +247,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Remove(ctx, cns[i], nil)
connPool.Remove(ctx, cns[i], errors.New("test"))
mu.RUnlock()
})
@@ -309,7 +311,7 @@ var _ = Describe("race", func() {
It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Minute,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -328,7 +330,7 @@ var _ = Describe("race", func() {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
}
}
})
@@ -339,15 +341,15 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1000,
MinIdleConns: 50,
PoolSize: int32(1000),
MinIdleConns: int32(50),
PoolTimeout: 3 * time.Second,
DialTimeout: 1 * time.Second,
}
p := pool.NewConnPool(opt)
var wg sync.WaitGroup
for i := 0; i < opt.PoolSize; i++ {
for i := int32(0); i < opt.PoolSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
@@ -366,8 +368,8 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
panic("test panic")
},
PoolSize: 100,
MinIdleConns: 30,
PoolSize: int32(100),
MinIdleConns: int32(30),
}
p := pool.NewConnPool(opt)
@@ -384,7 +386,7 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 3 * time.Second,
}
p := pool.NewConnPool(opt)
@@ -415,7 +417,7 @@ var _ = Describe("race", func() {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: testPoolTimeout,
}
p := pool.NewConnPool(opt)
@@ -435,3 +437,73 @@ var _ = Describe("race", func() {
Expect(stats.Timeouts).To(Equal(uint32(1)))
})
})
// TestDialerRetryConfiguration tests the new DialerRetries and DialerRetryTimeout options
func TestDialerRetryConfiguration(t *testing.T) {
ctx := context.Background()
t.Run("CustomDialerRetries", func(t *testing.T) {
var attempts int64
failingDialer := func(ctx context.Context) (net.Conn, error) {
atomic.AddInt64(&attempts, 1)
return nil, errors.New("dial failed")
}
connPool := pool.NewConnPool(&pool.Options{
Dialer: failingDialer,
PoolSize: 1,
PoolTimeout: time.Second,
DialTimeout: time.Second,
DialerRetries: 3, // Custom retry count
DialerRetryTimeout: 10 * time.Millisecond, // Fast retries for testing
})
defer connPool.Close()
_, err := connPool.Get(ctx)
if err == nil {
t.Error("Expected error from failing dialer")
}
// Should have attempted at least 3 times (DialerRetries = 3)
// There might be additional attempts due to pool logic
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts < 3 {
t.Errorf("Expected at least 3 dial attempts, got %d", finalAttempts)
}
if finalAttempts > 6 {
t.Errorf("Expected around 3 dial attempts, got %d (too many)", finalAttempts)
}
})
t.Run("DefaultDialerRetries", func(t *testing.T) {
var attempts int64
failingDialer := func(ctx context.Context) (net.Conn, error) {
atomic.AddInt64(&attempts, 1)
return nil, errors.New("dial failed")
}
connPool := pool.NewConnPool(&pool.Options{
Dialer: failingDialer,
PoolSize: 1,
PoolTimeout: time.Second,
DialTimeout: time.Second,
// DialerRetries and DialerRetryTimeout not set - should use defaults
})
defer connPool.Close()
_, err := connPool.Get(ctx)
if err == nil {
t.Error("Expected error from failing dialer")
}
// Should have attempted 5 times (default DialerRetries = 5)
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts != 5 {
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
}
})
}
func init() {
logging.Disable()
}

78
internal/pool/pubsub.go Normal file
View File

@@ -0,0 +1,78 @@
package pool
import (
"context"
"net"
"sync"
"sync/atomic"
)
type PubSubStats struct {
Created uint32
Untracked uint32
Active uint32
}
// PubSubPool manages a pool of PubSub connections.
type PubSubPool struct {
opt *Options
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Map to track active PubSub connections
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
closed atomic.Bool
stats PubSubStats
}
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,
netDialer: netDialer,
}
}
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
if p.closed.Load() {
return nil, ErrClosed
}
netConn, err := p.netDialer(ctx, network, addr)
if err != nil {
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
cn.pubsub = true
atomic.AddUint32(&p.stats.Created, 1)
return cn, nil
}
func (p *PubSubPool) TrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, 1)
p.activeConns.Store(cn.GetID(), cn)
}
func (p *PubSubPool) UntrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, ^uint32(0))
atomic.AddUint32(&p.stats.Untracked, 1)
p.activeConns.Delete(cn.GetID())
}
func (p *PubSubPool) Close() error {
p.closed.Store(true)
p.activeConns.Range(func(key, value interface{}) bool {
cn := value.(*Conn)
_ = cn.Close()
return true
})
return nil
}
func (p *PubSubPool) Stats() *PubSubStats {
// load stats atomically
return &PubSubStats{
Created: atomic.LoadUint32(&p.stats.Created),
Untracked: atomic.LoadUint32(&p.stats.Untracked),
Active: atomic.LoadUint32(&p.stats.Active),
}
}

View File

@@ -0,0 +1,614 @@
package proto
import (
"bytes"
"fmt"
"math/rand"
"strings"
"testing"
)
// TestPeekPushNotificationName tests the updated PeekPushNotificationName method
func TestPeekPushNotificationName(t *testing.T) {
t.Run("ValidPushNotifications", func(t *testing.T) {
testCases := []struct {
name string
notification string
expected string
}{
{"MOVING", "MOVING", "MOVING"},
{"MIGRATING", "MIGRATING", "MIGRATING"},
{"MIGRATED", "MIGRATED", "MIGRATED"},
{"FAILING_OVER", "FAILING_OVER", "FAILING_OVER"},
{"FAILED_OVER", "FAILED_OVER", "FAILED_OVER"},
{"message", "message", "message"},
{"pmessage", "pmessage", "pmessage"},
{"subscribe", "subscribe", "subscribe"},
{"unsubscribe", "unsubscribe", "unsubscribe"},
{"psubscribe", "psubscribe", "psubscribe"},
{"punsubscribe", "punsubscribe", "punsubscribe"},
{"smessage", "smessage", "smessage"},
{"ssubscribe", "ssubscribe", "ssubscribe"},
{"sunsubscribe", "sunsubscribe", "sunsubscribe"},
{"custom", "custom", "custom"},
{"short", "a", "a"},
{"empty", "", ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := createValidPushNotification(tc.notification, "data")
reader := NewReader(buf)
// Prime the buffer by peeking first
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for valid notification: %v", err)
}
if name != tc.expected {
t.Errorf("Expected notification name '%s', got '%s'", tc.expected, name)
}
})
}
})
t.Run("NotificationWithMultipleArguments", func(t *testing.T) {
// Create push notification with multiple arguments
buf := createPushNotificationWithArgs("MOVING", "slot", "123", "from", "node1", "to", "node2")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "MOVING" {
t.Errorf("Expected 'MOVING', got '%s'", name)
}
})
t.Run("SingleElementNotification", func(t *testing.T) {
// Create push notification with single element
buf := createSingleElementPushNotification("TEST")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
})
t.Run("ErrorDetection", func(t *testing.T) {
t.Run("NotPushNotification", func(t *testing.T) {
// Test with regular array instead of push notification
buf := &bytes.Buffer{}
buf.WriteString("*2\r\n$6\r\nMOVING\r\n$4\r\ndata\r\n")
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for non-push notification")
}
// The error might be "no data available" or "can't parse push notification"
if !strings.Contains(err.Error(), "can't peek push notification name") {
t.Errorf("Error should mention push notification parsing, got: %v", err)
}
})
t.Run("InsufficientData", func(t *testing.T) {
// Test with buffer smaller than peek size - this might panic due to bounds checking
buf := &bytes.Buffer{}
buf.WriteString(">")
reader := NewReader(buf)
func() {
defer func() {
if r := recover(); r != nil {
t.Logf("PeekPushNotificationName panicked as expected for insufficient data: %v", r)
}
}()
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for insufficient data")
}
}()
})
t.Run("EmptyBuffer", func(t *testing.T) {
buf := &bytes.Buffer{}
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for empty buffer")
}
})
t.Run("DifferentRESPTypes", func(t *testing.T) {
// Test with different RESP types that should be rejected
respTypes := []byte{'+', '-', ':', '$', '*', '%', '~', '|', '('}
for _, respType := range respTypes {
t.Run(fmt.Sprintf("Type_%c", respType), func(t *testing.T) {
buf := &bytes.Buffer{}
buf.WriteByte(respType)
buf.WriteString("test data that fills the buffer completely")
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Errorf("PeekPushNotificationName should error for RESP type '%c'", respType)
}
// The error might be "no data available" or "can't parse push notification"
if !strings.Contains(err.Error(), "can't peek push notification name") {
t.Errorf("Error should mention push notification parsing, got: %v", err)
}
})
}
})
})
t.Run("EdgeCases", func(t *testing.T) {
t.Run("ZeroLengthArray", func(t *testing.T) {
// Create push notification with zero elements: >0\r\n
buf := &bytes.Buffer{}
buf.WriteString(">0\r\npadding_data_to_fill_buffer_completely")
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for zero-length array")
}
})
t.Run("EmptyNotificationName", func(t *testing.T) {
// Create push notification with empty name: >1\r\n$0\r\n\r\n
buf := createValidPushNotification("", "data")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for empty name: %v", err)
}
if name != "" {
t.Errorf("Expected empty notification name, got '%s'", name)
}
})
t.Run("CorruptedData", func(t *testing.T) {
corruptedCases := []struct {
name string
data string
}{
{"CorruptedLength", ">abc\r\n$6\r\nMOVING\r\n"},
{"MissingCRLF", ">2$6\r\nMOVING\r\n$4\r\ndata\r\n"},
{"InvalidStringLength", ">2\r\n$abc\r\nMOVING\r\n$4\r\ndata\r\n"},
{"NegativeStringLength", ">2\r\n$-1\r\n$4\r\ndata\r\n"},
{"IncompleteString", ">1\r\n$6\r\nMOV"},
}
for _, tc := range corruptedCases {
t.Run(tc.name, func(t *testing.T) {
buf := &bytes.Buffer{}
buf.WriteString(tc.data)
reader := NewReader(buf)
// Some corrupted data might not error but return unexpected results
// This is acceptable behavior for malformed input
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data)
} else {
t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data)
}
})
}
})
})
t.Run("BoundaryConditions", func(t *testing.T) {
t.Run("ExactlyPeekSize", func(t *testing.T) {
// Create buffer that is exactly 36 bytes (the peek window size)
buf := &bytes.Buffer{}
// ">1\r\n$4\r\nTEST\r\n" = 14 bytes, need 22 more
buf.WriteString(">1\r\n$4\r\nTEST\r\n1234567890123456789012")
if buf.Len() != 36 {
t.Errorf("Expected buffer length 36, got %d", buf.Len())
}
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should work for exact peek size: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
})
t.Run("LessThanPeekSize", func(t *testing.T) {
// Create buffer smaller than 36 bytes but with complete notification
buf := createValidPushNotification("TEST", "")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should work for complete notification: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
})
t.Run("LongNotificationName", func(t *testing.T) {
// Test with notification name that might exceed peek window
longName := strings.Repeat("A", 20) // 20 character name (safe size)
buf := createValidPushNotification(longName, "data")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should work for long name: %v", err)
}
if name != longName {
t.Errorf("Expected '%s', got '%s'", longName, name)
}
})
})
}
// Helper functions to create test data
// createValidPushNotification creates a valid RESP3 push notification
func createValidPushNotification(notificationName, data string) *bytes.Buffer {
buf := &bytes.Buffer{}
simpleOrString := rand.Intn(2) == 0
if data == "" {
// Single element notification
buf.WriteString(">1\r\n")
if simpleOrString {
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
} else {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
}
} else {
// Two element notification
buf.WriteString(">2\r\n")
if simpleOrString {
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
buf.WriteString(fmt.Sprintf("+%s\r\n", data))
} else {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
}
}
return buf
}
// createReaderWithPrimedBuffer creates a reader and primes the buffer
func createReaderWithPrimedBuffer(buf *bytes.Buffer) *Reader {
reader := NewReader(buf)
// Prime the buffer by peeking first
_, _ = reader.rd.Peek(1)
return reader
}
// createPushNotificationWithArgs creates a push notification with multiple arguments
func createPushNotificationWithArgs(notificationName string, args ...string) *bytes.Buffer {
buf := &bytes.Buffer{}
totalElements := 1 + len(args)
buf.WriteString(fmt.Sprintf(">%d\r\n", totalElements))
// Write notification name
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
// Write arguments
for _, arg := range args {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(arg), arg))
}
return buf
}
// createSingleElementPushNotification creates a push notification with single element
func createSingleElementPushNotification(notificationName string) *bytes.Buffer {
buf := &bytes.Buffer{}
buf.WriteString(">1\r\n")
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
return buf
}
// BenchmarkPeekPushNotificationName benchmarks the method performance
func BenchmarkPeekPushNotificationName(b *testing.B) {
testCases := []struct {
name string
notification string
}{
{"Short", "TEST"},
{"Medium", "MOVING_NOTIFICATION"},
{"Long", "VERY_LONG_NOTIFICATION_NAME_FOR_TESTING"},
}
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
buf := createValidPushNotification(tc.notification, "data")
data := buf.Bytes()
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader := NewReader(bytes.NewReader(data))
_, err := reader.PeekPushNotificationName()
if err != nil {
b.Errorf("PeekPushNotificationName should not error: %v", err)
}
}
})
}
}
// TestPeekPushNotificationNameSpecialCases tests special cases and realistic scenarios
func TestPeekPushNotificationNameSpecialCases(t *testing.T) {
t.Run("RealisticNotifications", func(t *testing.T) {
// Test realistic Redis push notifications
realisticCases := []struct {
name string
notification []string
expected string
}{
{"MovingSlot", []string{"MOVING", "slot", "123", "from", "127.0.0.1:7000", "to", "127.0.0.1:7001"}, "MOVING"},
{"MigratingSlot", []string{"MIGRATING", "slot", "456", "from", "127.0.0.1:7001", "to", "127.0.0.1:7002"}, "MIGRATING"},
{"MigratedSlot", []string{"MIGRATED", "slot", "789", "from", "127.0.0.1:7002", "to", "127.0.0.1:7000"}, "MIGRATED"},
{"FailingOver", []string{"FAILING_OVER", "node", "127.0.0.1:7000"}, "FAILING_OVER"},
{"FailedOver", []string{"FAILED_OVER", "node", "127.0.0.1:7000"}, "FAILED_OVER"},
{"PubSubMessage", []string{"message", "mychannel", "hello world"}, "message"},
{"PubSubPMessage", []string{"pmessage", "pattern*", "mychannel", "hello world"}, "pmessage"},
{"Subscribe", []string{"subscribe", "mychannel", "1"}, "subscribe"},
{"Unsubscribe", []string{"unsubscribe", "mychannel", "0"}, "unsubscribe"},
}
for _, tc := range realisticCases {
t.Run(tc.name, func(t *testing.T) {
buf := createPushNotificationWithArgs(tc.notification[0], tc.notification[1:]...)
reader := createReaderWithPrimedBuffer(buf)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for %s: %v", tc.name, err)
}
if name != tc.expected {
t.Errorf("Expected '%s', got '%s'", tc.expected, name)
}
})
}
})
t.Run("SpecialCharactersInName", func(t *testing.T) {
specialCases := []struct {
name string
notification string
}{
{"WithUnderscore", "test_notification"},
{"WithDash", "test-notification"},
{"WithNumbers", "test123"},
{"WithDots", "test.notification"},
{"WithColon", "test:notification"},
{"WithSlash", "test/notification"},
{"MixedCase", "TestNotification"},
{"AllCaps", "TESTNOTIFICATION"},
{"AllLower", "testnotification"},
{"Unicode", "tëst"},
}
for _, tc := range specialCases {
t.Run(tc.name, func(t *testing.T) {
buf := createValidPushNotification(tc.notification, "data")
reader := createReaderWithPrimedBuffer(buf)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for '%s': %v", tc.notification, err)
}
if name != tc.notification {
t.Errorf("Expected '%s', got '%s'", tc.notification, name)
}
})
}
})
t.Run("IdempotentPeek", func(t *testing.T) {
// Test that multiple peeks return the same result
buf := createValidPushNotification("MOVING", "data")
reader := createReaderWithPrimedBuffer(buf)
// First peek
name1, err1 := reader.PeekPushNotificationName()
if err1 != nil {
t.Errorf("First PeekPushNotificationName should not error: %v", err1)
}
// Second peek should return the same result
name2, err2 := reader.PeekPushNotificationName()
if err2 != nil {
t.Errorf("Second PeekPushNotificationName should not error: %v", err2)
}
if name1 != name2 {
t.Errorf("Peek should be idempotent: first='%s', second='%s'", name1, name2)
}
if name1 != "MOVING" {
t.Errorf("Expected 'MOVING', got '%s'", name1)
}
})
}
// TestPeekPushNotificationNamePerformance tests performance characteristics
func TestPeekPushNotificationNamePerformance(t *testing.T) {
t.Run("RepeatedCalls", func(t *testing.T) {
// Test that repeated calls work correctly
buf := createValidPushNotification("TEST", "data")
reader := createReaderWithPrimedBuffer(buf)
// Call multiple times
for i := 0; i < 10; i++ {
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error on call %d: %v", i, err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST' on call %d, got '%s'", i, name)
}
}
})
t.Run("LargeNotifications", func(t *testing.T) {
// Test with large notification data
largeData := strings.Repeat("x", 1000)
buf := createValidPushNotification("LARGE", largeData)
reader := createReaderWithPrimedBuffer(buf)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for large notification: %v", err)
}
if name != "LARGE" {
t.Errorf("Expected 'LARGE', got '%s'", name)
}
})
}
// TestPeekPushNotificationNameBehavior documents the method's behavior
func TestPeekPushNotificationNameBehavior(t *testing.T) {
t.Run("MethodBehavior", func(t *testing.T) {
// Test that the method works as intended:
// 1. Peek at the buffer without consuming it
// 2. Detect push notifications (RESP type '>')
// 3. Extract the notification name from the first element
// 4. Return the name for filtering decisions
buf := createValidPushNotification("MOVING", "slot_data")
reader := createReaderWithPrimedBuffer(buf)
// Peek should not consume the buffer
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "MOVING" {
t.Errorf("Expected 'MOVING', got '%s'", name)
}
// Buffer should still be available for normal reading
replyType, err := reader.PeekReplyType()
if err != nil {
t.Errorf("PeekReplyType should work after PeekPushNotificationName: %v", err)
}
if replyType != RespPush {
t.Errorf("Expected RespPush, got %v", replyType)
}
})
t.Run("BufferNotConsumed", func(t *testing.T) {
// Verify that peeking doesn't consume the buffer
buf := createValidPushNotification("TEST", "data")
originalData := buf.Bytes()
reader := createReaderWithPrimedBuffer(buf)
// Peek the notification name
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
// Read the actual notification
reply, err := reader.ReadReply()
if err != nil {
t.Errorf("ReadReply should work after peek: %v", err)
}
// Verify we got the complete notification
if replySlice, ok := reply.([]interface{}); ok {
if len(replySlice) != 2 {
t.Errorf("Expected 2 elements, got %d", len(replySlice))
}
if replySlice[0] != "TEST" {
t.Errorf("Expected 'TEST', got %v", replySlice[0])
}
} else {
t.Errorf("Expected slice reply, got %T", reply)
}
// Verify buffer was properly consumed
if buf.Len() != 0 {
t.Errorf("Buffer should be empty after reading, but has %d bytes: %q", buf.Len(), buf.Bytes())
}
t.Logf("Original buffer size: %d bytes", len(originalData))
t.Logf("Successfully peeked and then read complete notification")
})
t.Run("ImplementationSuccess", func(t *testing.T) {
// Document that the implementation is now working correctly
t.Log("PeekPushNotificationName implementation status:")
t.Log("1. ✅ Correctly parses RESP3 push notifications")
t.Log("2. ✅ Extracts notification names properly")
t.Log("3. ✅ Handles buffer peeking without consumption")
t.Log("4. ✅ Works with various notification types")
t.Log("5. ✅ Supports empty notification names")
t.Log("")
t.Log("RESP3 format parsing:")
t.Log(">2\\r\\n$6\\r\\nMOVING\\r\\n$4\\r\\ndata\\r\\n")
t.Log("✅ Correctly identifies push notification marker (>)")
t.Log("✅ Skips array length (2)")
t.Log("✅ Parses string marker ($) and length (6)")
t.Log("✅ Extracts notification name (MOVING)")
t.Log("✅ Returns name without consuming buffer")
t.Log("")
t.Log("Note: Buffer must be primed with a peek operation first")
})
}

View File

@@ -99,6 +99,92 @@ func (r *Reader) PeekReplyType() (byte, error) {
return b[0], nil
}
func (r *Reader) PeekPushNotificationName() (string, error) {
// "prime" the buffer by peeking at the next byte
c, err := r.Peek(1)
if err != nil {
return "", err
}
if c[0] != RespPush {
return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification")
}
// peek 36 bytes at most, should be enough to read the push notification name
toPeek := 36
buffered := r.Buffered()
if buffered == 0 {
return "", fmt.Errorf("redis: can't peek push notification name, no data available")
}
if buffered < toPeek {
toPeek = buffered
}
buf, err := r.rd.Peek(toPeek)
if err != nil {
return "", err
}
if buf[0] != RespPush {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
if len(buf) < 3 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// remove push notification type
buf = buf[1:]
// remove first line - e.g. >2\r\n
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
}
}
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// next line should be $<length><string>\r\n or +<length><string>\r\n
// should have the type of the push notification name and it's length
if buf[0] != RespString && buf[0] != RespStatus {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
typeOfName := buf[0]
// remove the type of the push notification name
buf = buf[1:]
if typeOfName == RespString {
// remove the length of the string
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
}
}
}
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
// keep only the notification name
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[:i]
break
}
}
return util.BytesToString(buf), nil
}
// ReadLine Return a valid reply, it will check the protocol or redis error,
// and discard the attribute type.
func (r *Reader) ReadLine() ([]byte, error) {

3
internal/redis.go Normal file
View File

@@ -0,0 +1,3 @@
package internal
const RedisNull = "null"

View File

@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
}
return f
}
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
func SafeIntToInt32(value int, fieldName string) (int32, error) {
if value > math.MaxInt32 {
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
}
if value < math.MinInt32 {
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
}
return int32(value), nil
}

17
internal/util/math.go Normal file
View File

@@ -0,0 +1,17 @@
package util
// Max returns the maximum of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// Min returns the minimum of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -16,6 +16,8 @@ import (
. "github.com/bsm/gomega"
)
var ctx = context.TODO()
var _ = Describe("newClusterState", func() {
var state *clusterState

121
logging/logging.go Normal file
View File

@@ -0,0 +1,121 @@
// Package logging provides logging level constants and utilities for the go-redis library.
// This package centralizes logging configuration to ensure consistency across all components.
package logging
import (
"context"
"fmt"
"strings"
"github.com/redis/go-redis/v9/internal"
)
// LogLevel represents the logging level
type LogLevel int
// Log level constants for the entire go-redis library
const (
LogLevelError LogLevel = iota // 0 - errors only
LogLevelWarn // 1 - warnings and errors
LogLevelInfo // 2 - info, warnings, and errors
LogLevelDebug // 3 - debug, info, warnings, and errors
)
// String returns the string representation of the log level
func (l LogLevel) String() string {
switch l {
case LogLevelError:
return "ERROR"
case LogLevelWarn:
return "WARN"
case LogLevelInfo:
return "INFO"
case LogLevelDebug:
return "DEBUG"
default:
return "UNKNOWN"
}
}
// IsValid returns true if the log level is valid
func (l LogLevel) IsValid() bool {
return l >= LogLevelError && l <= LogLevelDebug
}
func (l LogLevel) WarnOrAbove() bool {
return l >= LogLevelWarn
}
func (l LogLevel) InfoOrAbove() bool {
return l >= LogLevelInfo
}
func (l LogLevel) DebugOrAbove() bool {
return l >= LogLevelDebug
}
// VoidLogger is a logger that does nothing.
// Used to disable logging and thus speed up the library.
type VoidLogger struct{}
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
// do nothing
}
// Disable disables logging by setting the internal logger to a void logger.
// This can be used to speed up the library if logging is not needed.
// It will override any custom logger that was set before and set the VoidLogger.
func Disable() {
internal.Logger = &VoidLogger{}
}
// Enable enables logging by setting the internal logger to the default logger.
// This is the default behavior.
// You can use redis.SetLogger to set a custom logger.
//
// NOTE: This function is not thread-safe.
// It will override any custom logger that was set before and set the DefaultLogger.
func Enable() {
internal.Logger = internal.NewDefaultLogger()
}
// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings.
// This can be used to filter out messages containing sensitive information.
func NewBlacklistLogger(substr []string) internal.Logging {
l := internal.NewDefaultLogger()
return &filterLogger{logger: l, substr: substr, blacklist: true}
}
// NewWhitelistLogger returns a new logger that only logs messages containing any of the substrings.
// This can be used to only log messages related to specific commands or patterns.
func NewWhitelistLogger(substr []string) internal.Logging {
l := internal.NewDefaultLogger()
return &filterLogger{logger: l, substr: substr, blacklist: false}
}
type filterLogger struct {
logger internal.Logging
blacklist bool
substr []string
}
func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface{}) {
msg := fmt.Sprintf(format, v...)
found := false
for _, substr := range l.substr {
if strings.Contains(msg, substr) {
found = true
if l.blacklist {
return
}
}
}
// whitelist, only log if one of the substrings is present
if !l.blacklist && !found {
return
}
if l.logger != nil {
l.logger.Printf(ctx, format, v...)
return
}
}

59
logging/logging_test.go Normal file
View File

@@ -0,0 +1,59 @@
package logging
import "testing"
func TestLogLevel_String(t *testing.T) {
tests := []struct {
level LogLevel
expected string
}{
{LogLevelError, "ERROR"},
{LogLevelWarn, "WARN"},
{LogLevelInfo, "INFO"},
{LogLevelDebug, "DEBUG"},
{LogLevel(99), "UNKNOWN"},
}
for _, test := range tests {
if got := test.level.String(); got != test.expected {
t.Errorf("LogLevel(%d).String() = %q, want %q", test.level, got, test.expected)
}
}
}
func TestLogLevel_IsValid(t *testing.T) {
tests := []struct {
level LogLevel
expected bool
}{
{LogLevelError, true},
{LogLevelWarn, true},
{LogLevelInfo, true},
{LogLevelDebug, true},
{LogLevel(-1), false},
{LogLevel(4), false},
{LogLevel(99), false},
}
for _, test := range tests {
if got := test.level.IsValid(); got != test.expected {
t.Errorf("LogLevel(%d).IsValid() = %v, want %v", test.level, got, test.expected)
}
}
}
func TestLogLevelConstants(t *testing.T) {
// Test that constants have expected values
if LogLevelError != 0 {
t.Errorf("LogLevelError = %d, want 0", LogLevelError)
}
if LogLevelWarn != 1 {
t.Errorf("LogLevelWarn = %d, want 1", LogLevelWarn)
}
if LogLevelInfo != 2 {
t.Errorf("LogLevelInfo = %d, want 2", LogLevelInfo)
}
if LogLevelDebug != 3 {
t.Errorf("LogLevelDebug = %d, want 3", LogLevelDebug)
}
}

View File

@@ -13,6 +13,7 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
)
const (
@@ -102,6 +103,7 @@ var _ = BeforeSuite(func() {
fmt.Printf("RCEDocker: %v\n", RCEDocker)
fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion)
fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE"))
logging.Disable()
if RedisVersion < 7.0 || RedisVersion > 9 {
panic("incorrect or not supported redis version")

View File

@@ -14,8 +14,11 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/push"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
@@ -109,6 +112,16 @@ type Options struct {
// default: 5 seconds
DialTimeout time.Duration
// DialerRetries is the maximum number of retry attempts when dialing fails.
//
// default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
//
// default: 100 milliseconds
DialerRetryTimeout time.Duration
// ReadTimeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
//
@@ -152,6 +165,7 @@ type Options struct {
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
// default: false
PoolFIFO bool
// PoolSize is the base number of socket connections.
@@ -232,12 +246,30 @@ type Options struct {
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
UnstableResp3 bool
// Push notifications are always enabled for RESP3 connections (Protocol: 3)
// and are not available for RESP2 connections. No configuration option is needed.
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
// This is an alias to hitless.Config for convenience.
type HitlessUpgradeConfig = hitless.Config
func (opt *Options) init() {
if opt.Addr == "" {
opt.Addr = "localhost:6379"
@@ -255,6 +287,12 @@ func (opt *Options) init() {
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
}
if opt.DialerRetries == 0 {
opt.DialerRetries = 5
}
if opt.DialerRetryTimeout == 0 {
opt.DialerRetryTimeout = 100 * time.Millisecond
}
if opt.Dialer == nil {
opt.Dialer = NewDialer(opt)
}
@@ -312,13 +350,36 @@ func (opt *Options) init() {
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
// auto-detect endpoint type if not specified
endpointType := opt.HitlessUpgradeConfig.EndpointType
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.HitlessUpgradeConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
clone.HitlessUpgradeConfig = &configClone
}
return &clone
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
return NewDialer(opt)
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
@@ -604,21 +665,84 @@ func getUserPassword(u *url.URL) (string, string) {
func newConnPool(
opt *Options,
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) *pool.ConnPool {
) (*pool.ConnPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolSize: poolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
})
PushNotificationsEnabled: opt.Protocol == 3,
}), nil
}
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) (*pool.PubSubPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewPubSubPool(&pool.Options{
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
PushNotificationsEnabled: opt.Protocol == 3,
}, dialer), nil
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/push"
)
const (
@@ -38,6 +39,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -125,10 +127,22 @@ type ClusterOptions struct {
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
UnstableResp3 bool
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *ClusterOptions) init() {
@@ -319,6 +333,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
var hitlessConfig *HitlessUpgradeConfig
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
hitlessConfig = &configClone
}
return &Options{
ClientName: opt.ClientName,
Dialer: opt.Dialer,
@@ -362,6 +383,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
HitlessUpgradeConfig: hitlessConfig,
PushNotificationProcessor: opt.PushNotificationProcessor,
}
}
@@ -1664,7 +1687,7 @@ func (c *ClusterClient) processTxPipelineNode(
}
func (c *ClusterClient) processTxPipelineNodeConn(
ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
@@ -1682,7 +1705,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
trimmedCmds := cmds[1 : len(cmds)-1]
if err := c.txPipelineReadQueued(
ctx, rd, statusCmd, trimmedCmds, failedCmds,
ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds,
); err != nil {
setCmdsErr(cmds, err)
@@ -1694,23 +1717,37 @@ func (c *ClusterClient) processTxPipelineNodeConn(
return err
}
return pipelineReadCmds(rd, trimmedCmds)
return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
})
}
func (c *ClusterClient) txPipelineReadQueued(
ctx context.Context,
node *clusterNode,
cn *pool.Conn,
rd *proto.Reader,
statusCmd *StatusCmd,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
// Parse queued replies.
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
if err := statusCmd.readReply(rd); err != nil {
return err
}
for _, cmd := range cmds {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
err := statusCmd.readReply(rd)
if err != nil {
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
@@ -1724,6 +1761,12 @@ func (c *ClusterClient) txPipelineReadQueued(
}
}
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse number of replies.
line, err := rd.ReadLine()
if err != nil {
@@ -1829,12 +1872,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// hitless won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@@ -1868,18 +1911,25 @@ func (c *ClusterClient) pubSub() *PubSub {
return nil, err
}
}
cn, err := node.Client.newConn(context.TODO())
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
if err != nil {
node = nil
return nil, err
}
// will return nil if already initialized
err = node.Client.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
node = nil
return nil, err
}
node.Client.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
err := node.Client.connPool.CloseConn(cn)
// Untrack connection from PubSubPool
node.Client.pubSubPool.UntrackConn(cn)
err := cn.Close()
node = nil
return err
},

375
pool_pubsub_bench_test.go Normal file
View File

@@ -0,0 +1,375 @@
// Pool and PubSub Benchmark Suite
//
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
// It's designed to be run against different branches to compare performance.
//
// Usage Examples:
// # Run all benchmarks
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
//
// # Run only pool benchmarks
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
//
// # Run only PubSub benchmarks
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
//
// # Compare between branches
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
// benchcmp branch1.txt branch2.txt
//
// # Run with memory profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
//
// # Run with CPU profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
package redis_test
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
)
// dummyDialer creates a mock connection for benchmarking
func dummyDialer(ctx context.Context) (net.Conn, error) {
return &dummyConn{}, nil
}
// dummyConn implements net.Conn for benchmarking
type dummyConn struct{}
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Close() error { return nil }
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
func (c *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
}
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
// =============================================================================
// POOL BENCHMARKS
// =============================================================================
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(poolSize),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0), // Start with no idle connections
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
ctx := context.Background()
configs := []struct {
poolSize int
minIdleConns int
}{
{8, 2},
{16, 4},
{32, 8},
{64, 16},
}
for _, config := range configs {
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(config.poolSize),
MinIdleConns: int32(config.minIdleConns),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
ctx := context.Background()
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(32),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0),
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
// Test with different levels of concurrency
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.SetParallelism(concurrency)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// =============================================================================
// PUBSUB BENCHMARKS
// =============================================================================
// benchmarkClient creates a Redis client for benchmarking with mock dialer
func benchmarkClient(poolSize int) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Mock address
DialTimeout: time.Second,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
PoolSize: poolSize,
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
})
}
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
func BenchmarkPubSubCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
func BenchmarkPubSubPatternCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.PSubscribe(ctx, "test-*")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
concurrencyLevels := []int{1, 2, 4, 8, 16}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for i := 0; i < b.N; i++ {
wg.Add(1)
semaphore <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-semaphore }()
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}()
}
wg.Wait()
})
}
}
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
func BenchmarkPubSubMultipleChannels(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
channelCounts := []int{1, 5, 10, 25, 50, 100}
for _, channelCount := range channelCounts {
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
// Prepare channel names
channels := make([]string, channelCount)
for i := 0; i < channelCount; i++ {
channels[i] = fmt.Sprintf("channel-%d", i)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, channels...)
pubsub.Close()
}
})
}
}
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
func BenchmarkPubSubReuse(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Benchmark just the creation and closing of PubSub connections
// This simulates reuse patterns without requiring actual Redis operations
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
pubsub.Close()
}
}
// =============================================================================
// COMBINED BENCHMARKS
// =============================================================================
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// Mix of pool stats collection and PubSub creation
if pb.Next() {
// Pool stats operation
stats := client.PoolStats()
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
}
if pb.Next() {
// PubSub operation
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
}
})
}
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
func BenchmarkPoolStatsCollection(b *testing.B) {
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
stats := client.PoolStats()
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
}
}
// BenchmarkPoolHighContention tests pool performance under high contention
func BenchmarkPoolHighContention(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// High contention Get/Put operations
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/push"
)
// PubSub implements Pub/Sub commands as described in
@@ -21,7 +22,7 @@ import (
type PubSub struct {
opt *Options
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
@@ -38,6 +39,12 @@ type PubSub struct {
chOnce sync.Once
msgCh *channel
allCh *channel
// Push notification processor for handling generic push notifications
pushProcessor push.NotificationProcessor
// Cleanup callback for hitless upgrade tracking
onClose func()
}
func (c *PubSub) init() {
@@ -69,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
return c.cn, nil
}
if c.opt.Addr == "" {
// TODO(hitless):
// this is probably cluster client
// c.newConn will ignore the addr argument
// will be changed when we have hitless upgrades for cluster clients
c.opt.Addr = internal.RedisNull
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(ctx, channels)
cn, err := c.newConn(ctx, c.opt.Addr, channels)
if err != nil {
return nil, err
}
@@ -153,12 +168,31 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
if c.cn != cn {
return
}
if !cn.IsUsable() || cn.ShouldHandoff() {
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
}
if isBadConn(err, allowTimeout, c.opt.Addr) {
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(ctx context.Context, reason error) {
if c.cn != nil && c.cn.ShouldHandoff() {
newEndpoint := c.cn.GetHandoffEndpoint()
// If new endpoint is NULL, use the original address
if newEndpoint == internal.RedisNull {
newEndpoint = c.opt.Addr
}
if newEndpoint != "" {
// Update the address in the options
oldAddr := c.cn.RemoteAddr().String()
c.opt.Addr = newEndpoint
internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr)
}
}
_ = c.closeTheCn(reason)
_, _ = c.conn(ctx, nil)
}
@@ -167,9 +201,6 @@ func (c *PubSub) closeTheCn(reason error) error {
if c.cn == nil {
return nil
}
if !c.closed {
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
}
err := c.closeConn(c.cn)
c.cn = nil
return err
@@ -185,6 +216,11 @@ func (c *PubSub) Close() error {
c.closed = true
close(c.exit)
// Call cleanup callback if set
if c.onClose != nil {
c.onClose()
}
return c.closeTheCn(pool.ErrClosed)
}
@@ -436,9 +472,14 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
}
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err)
}
return c.cmd.readReply(rd)
})
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
if err != nil {
@@ -451,6 +492,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
// Receive returns a message as a Subscription, Message, Pong, or an error.
// See PubSub example for details. This is a low-level API and in most cases
// Channel should be used instead.
// This method blocks until a message is received or an error occurs.
// It may return early with an error if the context is canceled, the connection fails,
// or other internal errors occur.
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
@@ -532,6 +579,27 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
return c.allCh.allCh
}
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
// Only process push notifications for RESP3 connections with a processor
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}
// Create handler context with client, connection pool, and connection information
handlerCtx := c.pushNotificationHandlerContext(cn)
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
}
func (c *PubSub) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext {
// PubSub doesn't have a client or connection pool, so we pass nil for those
// PubSub connections are blocking
return push.NotificationHandlerContext{
PubSub: c,
Conn: cn,
IsBlocking: true,
}
}
type ChannelOption func(c *channel)
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.

170
push/errors.go Normal file
View File

@@ -0,0 +1,170 @@
package push
import (
"errors"
"fmt"
)
// Push notification error definitions
// This file contains all error types and messages used by the push notification system
// Error reason constants
const (
// HandlerReasons
ReasonHandlerNil = "handler cannot be nil"
ReasonHandlerExists = "cannot overwrite existing handler"
ReasonHandlerProtected = "handler is protected"
// ProcessorReasons
ReasonPushNotificationsDisabled = "push notifications are disabled"
)
// ProcessorType represents the type of processor involved in the error
// defined as a custom type for better readability and easier maintenance
type ProcessorType string
const (
// ProcessorTypes
ProcessorTypeProcessor = ProcessorType("processor")
ProcessorTypeVoidProcessor = ProcessorType("void_processor")
ProcessorTypeCustom = ProcessorType("custom")
)
// ProcessorOperation represents the operation being performed by the processor
// defined as a custom type for better readability and easier maintenance
type ProcessorOperation string
const (
// ProcessorOperations
ProcessorOperationProcess = ProcessorOperation("process")
ProcessorOperationRegister = ProcessorOperation("register")
ProcessorOperationUnregister = ProcessorOperation("unregister")
ProcessorOperationUnknown = ProcessorOperation("unknown")
)
// Common error variables for reuse
var (
// ErrHandlerNil is returned when attempting to register a nil handler
ErrHandlerNil = errors.New(ReasonHandlerNil)
)
// Registry errors
// ErrHandlerExists creates an error for when attempting to overwrite an existing handler
func ErrHandlerExists(pushNotificationName string) error {
return NewHandlerError(ProcessorOperationRegister, pushNotificationName, ReasonHandlerExists, nil)
}
// ErrProtectedHandler creates an error for when attempting to unregister a protected handler
func ErrProtectedHandler(pushNotificationName string) error {
return NewHandlerError(ProcessorOperationUnregister, pushNotificationName, ReasonHandlerProtected, nil)
}
// VoidProcessor errors
// ErrVoidProcessorRegister creates an error for when attempting to register a handler on void processor
func ErrVoidProcessorRegister(pushNotificationName string) error {
return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationRegister, pushNotificationName, ReasonPushNotificationsDisabled, nil)
}
// ErrVoidProcessorUnregister creates an error for when attempting to unregister a handler on void processor
func ErrVoidProcessorUnregister(pushNotificationName string) error {
return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationUnregister, pushNotificationName, ReasonPushNotificationsDisabled, nil)
}
// Error type definitions for advanced error handling
// HandlerError represents errors related to handler operations
type HandlerError struct {
Operation ProcessorOperation
PushNotificationName string
Reason string
Err error
}
func (e *HandlerError) Error() string {
if e.Err != nil {
return fmt.Sprintf("handler %s failed for '%s': %s (%v)", e.Operation, e.PushNotificationName, e.Reason, e.Err)
}
return fmt.Sprintf("handler %s failed for '%s': %s", e.Operation, e.PushNotificationName, e.Reason)
}
func (e *HandlerError) Unwrap() error {
return e.Err
}
// NewHandlerError creates a new HandlerError
func NewHandlerError(operation ProcessorOperation, pushNotificationName, reason string, err error) *HandlerError {
return &HandlerError{
Operation: operation,
PushNotificationName: pushNotificationName,
Reason: reason,
Err: err,
}
}
// ProcessorError represents errors related to processor operations
type ProcessorError struct {
ProcessorType ProcessorType // "processor", "void_processor"
Operation ProcessorOperation // "process", "register", "unregister"
PushNotificationName string // Name of the push notification involved
Reason string
Err error
}
func (e *ProcessorError) Error() string {
notifInfo := ""
if e.PushNotificationName != "" {
notifInfo = fmt.Sprintf(" for '%s'", e.PushNotificationName)
}
if e.Err != nil {
return fmt.Sprintf("%s %s failed%s: %s (%v)", e.ProcessorType, e.Operation, notifInfo, e.Reason, e.Err)
}
return fmt.Sprintf("%s %s failed%s: %s", e.ProcessorType, e.Operation, notifInfo, e.Reason)
}
func (e *ProcessorError) Unwrap() error {
return e.Err
}
// NewProcessorError creates a new ProcessorError
func NewProcessorError(processorType ProcessorType, operation ProcessorOperation, pushNotificationName, reason string, err error) *ProcessorError {
return &ProcessorError{
ProcessorType: processorType,
Operation: operation,
PushNotificationName: pushNotificationName,
Reason: reason,
Err: err,
}
}
// Helper functions for common error scenarios
// IsHandlerNilError checks if an error is due to a nil handler
func IsHandlerNilError(err error) bool {
return errors.Is(err, ErrHandlerNil)
}
// IsHandlerExistsError checks if an error is due to attempting to overwrite an existing handler
func IsHandlerExistsError(err error) bool {
if handlerErr, ok := err.(*HandlerError); ok {
return handlerErr.Operation == ProcessorOperationRegister && handlerErr.Reason == ReasonHandlerExists
}
return false
}
// IsProtectedHandlerError checks if an error is due to attempting to unregister a protected handler
func IsProtectedHandlerError(err error) bool {
if handlerErr, ok := err.(*HandlerError); ok {
return handlerErr.Operation == ProcessorOperationUnregister && handlerErr.Reason == ReasonHandlerProtected
}
return false
}
// IsVoidProcessorError checks if an error is due to void processor operations
func IsVoidProcessorError(err error) bool {
if procErr, ok := err.(*ProcessorError); ok {
return procErr.ProcessorType == ProcessorTypeVoidProcessor && procErr.Reason == ReasonPushNotificationsDisabled
}
return false
}

14
push/handler.go Normal file
View File

@@ -0,0 +1,14 @@
package push
import (
"context"
)
// NotificationHandler defines the interface for push notification handlers.
type NotificationHandler interface {
// HandlePushNotification processes a push notification with context information.
// The handlerCtx provides information about the client, connection pool, and connection
// on which the notification was received, allowing handlers to make informed decisions.
// Returns an error if the notification could not be handled.
HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error
}

44
push/handler_context.go Normal file
View File

@@ -0,0 +1,44 @@
package push
// No imports needed for this file
// NotificationHandlerContext provides context information about where a push notification was received.
// This struct allows handlers to make informed decisions based on the source of the notification
// with strongly typed access to different client types using concrete types.
type NotificationHandlerContext struct {
// Client is the Redis client instance that received the notification.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *redis.baseClient
// - *redis.Client
// - *redis.ClusterClient
// - *redis.Conn
Client interface{}
// ConnPool is the connection pool from which the connection was obtained.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.ConnPool
// - *pool.SingleConnPool
// - *pool.StickyConnPool
ConnPool interface{}
// PubSub is the PubSub instance that received the notification.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *redis.PubSub
PubSub interface{}
// Conn is the specific connection on which the notification was received.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.Conn
Conn interface{}
// IsBlocking indicates if the notification was received on a blocking connection.
IsBlocking bool
}

203
push/processor.go Normal file
View File

@@ -0,0 +1,203 @@
package push
import (
"context"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/proto"
)
// NotificationProcessor defines the interface for push notification processors.
type NotificationProcessor interface {
// GetHandler returns the handler for a specific push notification name.
GetHandler(pushNotificationName string) NotificationHandler
// ProcessPendingNotifications checks for and processes any pending push notifications.
// To be used when it is known that there are notifications on the socket.
// It will try to read from the socket and if it is empty - it may block.
ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error
// RegisterHandler registers a handler for a specific push notification name.
RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error
// UnregisterHandler removes a handler for a specific push notification name.
UnregisterHandler(pushNotificationName string) error
}
// Processor handles push notifications with a registry of handlers
type Processor struct {
registry *Registry
}
// NewProcessor creates a new push notification processor
func NewProcessor() *Processor {
return &Processor{
registry: NewRegistry(),
}
}
// GetHandler returns the handler for a specific push notification name
func (p *Processor) GetHandler(pushNotificationName string) NotificationHandler {
return p.registry.GetHandler(pushNotificationName)
}
// RegisterHandler registers a handler for a specific push notification name
func (p *Processor) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error {
return p.registry.RegisterHandler(pushNotificationName, handler, protected)
}
// UnregisterHandler removes a handler for a specific push notification name
func (p *Processor) UnregisterHandler(pushNotificationName string) error {
return p.registry.UnregisterHandler(pushNotificationName)
}
// ProcessPendingNotifications checks for and processes any pending push notifications
// This method should be called by the client in WithReader before reading the reply
// It will try to read from the socket and if it is empty - it may block.
func (p *Processor) ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error {
if rd == nil {
return nil
}
for {
// Check if there's data available to read
replyType, err := rd.PeekReplyType()
if err != nil {
// No more data available or error reading
// if timeout, it will be handled by the caller
break
}
// Only process push notifications (arrays starting with >)
if replyType != proto.RespPush {
break
}
// see if we should skip this notification
notificationName, err := rd.PeekPushNotificationName()
if err != nil {
break
}
if willHandleNotificationInClient(notificationName) {
break
}
// Read the push notification
reply, err := rd.ReadReply()
if err != nil {
internal.Logger.Printf(ctx, "push: error reading push notification: %v", err)
break
}
// Convert to slice of interfaces
notification, ok := reply.([]interface{})
if !ok {
break
}
// Handle the notification directly
if len(notification) > 0 {
// Extract the notification type (first element)
if notificationType, ok := notification[0].(string); ok {
// Get the handler for this notification type
if handler := p.registry.GetHandler(notificationType); handler != nil {
// Handle the notification
err := handler.HandlePushNotification(ctx, handlerCtx, notification)
if err != nil {
internal.Logger.Printf(ctx, "push: error handling push notification: %v", err)
}
}
}
}
}
return nil
}
// VoidProcessor discards all push notifications without processing them
type VoidProcessor struct{}
// NewVoidProcessor creates a new void push notification processor
func NewVoidProcessor() *VoidProcessor {
return &VoidProcessor{}
}
// GetHandler returns nil for void processor since it doesn't maintain handlers
func (v *VoidProcessor) GetHandler(_ string) NotificationHandler {
return nil
}
// RegisterHandler returns an error for void processor since it doesn't maintain handlers
func (v *VoidProcessor) RegisterHandler(pushNotificationName string, _ NotificationHandler, _ bool) error {
return ErrVoidProcessorRegister(pushNotificationName)
}
// UnregisterHandler returns an error for void processor since it doesn't maintain handlers
func (v *VoidProcessor) UnregisterHandler(pushNotificationName string) error {
return ErrVoidProcessorUnregister(pushNotificationName)
}
// ProcessPendingNotifications for VoidProcessor does nothing since push notifications
// are only available in RESP3 and this processor is used for RESP2 connections.
// This avoids unnecessary buffer scanning overhead.
// It does however read and discard all push notifications from the buffer to avoid
// them being interpreted as a reply.
// This method should be called by the client in WithReader before reading the reply
// to be sure there are no buffered push notifications.
// It will try to read from the socket and if it is empty - it may block.
func (v *VoidProcessor) ProcessPendingNotifications(_ context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error {
// read and discard all push notifications
if rd == nil {
return nil
}
for {
// Check if there's data available to read
replyType, err := rd.PeekReplyType()
if err != nil {
// No more data available or error reading
// if timeout, it will be handled by the caller
break
}
// Only process push notifications (arrays starting with >)
if replyType != proto.RespPush {
break
}
// see if we should skip this notification
notificationName, err := rd.PeekPushNotificationName()
if err != nil {
break
}
if willHandleNotificationInClient(notificationName) {
break
}
// Read the push notification
_, err = rd.ReadReply()
if err != nil {
internal.Logger.Printf(context.Background(), "push: error reading push notification: %v", err)
return nil
}
}
return nil
}
// willHandleNotificationInClient checks if a notification type should be ignored by the push notification
// processor and handled by other specialized systems instead (pub/sub, streams, keyspace, etc.).
func willHandleNotificationInClient(notificationType string) bool {
switch notificationType {
// Pub/Sub notifications - handled by pub/sub system
case "message", // Regular pub/sub message
"pmessage", // Pattern pub/sub message
"subscribe", // Subscription confirmation
"unsubscribe", // Unsubscription confirmation
"psubscribe", // Pattern subscription confirmation
"punsubscribe", // Pattern unsubscription confirmation
"smessage", // Sharded pub/sub message (Redis 7.0+)
"ssubscribe", // Sharded subscription confirmation
"sunsubscribe": // Sharded unsubscription confirmation
return true
default:
return false
}
}

315
push/processor_unit_test.go Normal file
View File

@@ -0,0 +1,315 @@
package push
import (
"context"
"testing"
)
// TestProcessorCreation tests processor creation and initialization
func TestProcessorCreation(t *testing.T) {
t.Run("NewProcessor", func(t *testing.T) {
processor := NewProcessor()
if processor == nil {
t.Fatal("NewProcessor should not return nil")
}
if processor.registry == nil {
t.Error("Processor should have a registry")
}
})
t.Run("NewVoidProcessor", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
if voidProcessor == nil {
t.Fatal("NewVoidProcessor should not return nil")
}
})
}
// TestProcessorHandlerManagement tests handler registration and retrieval
func TestProcessorHandlerManagement(t *testing.T) {
processor := NewProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("RegisterHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", handler, false)
if err != nil {
t.Errorf("RegisterHandler should not error: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != handler {
t.Error("GetHandler should return the registered handler")
}
})
t.Run("RegisterProtectedHandler", func(t *testing.T) {
protectedHandler := &UnitTestHandler{name: "protected-handler"}
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
if err != nil {
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler != protectedHandler {
t.Error("GetHandler should return the protected handler")
}
})
t.Run("GetNonExistentHandler", func(t *testing.T) {
handler := processor.GetHandler("NONEXISTENT")
if handler != nil {
t.Error("GetHandler should return nil for non-existent handler")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := processor.UnregisterHandler("TEST")
if err != nil {
t.Errorf("UnregisterHandler should not error: %v", err)
}
// Verify handler is removed
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != nil {
t.Error("GetHandler should return nil after unregistering")
}
})
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
err := processor.UnregisterHandler("PROTECTED")
if err == nil {
t.Error("UnregisterHandler should error for protected handler")
}
// Verify handler is still there
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler == nil {
t.Error("Protected handler should not be removed")
}
})
}
// TestVoidProcessorBehavior tests void processor behavior
func TestVoidProcessorBehavior(t *testing.T) {
voidProcessor := NewVoidProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("GetHandler", func(t *testing.T) {
retrievedHandler := voidProcessor.GetHandler("ANY")
if retrievedHandler != nil {
t.Error("VoidProcessor GetHandler should always return nil")
}
})
t.Run("RegisterHandler", func(t *testing.T) {
err := voidProcessor.RegisterHandler("TEST", handler, false)
if err == nil {
t.Error("VoidProcessor RegisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := voidProcessor.UnregisterHandler("TEST")
if err == nil {
t.Error("VoidProcessor UnregisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
}
// TestProcessPendingNotificationsNilReader tests handling of nil reader
func TestProcessPendingNotificationsNilReader(t *testing.T) {
t.Run("ProcessorWithNilReader", func(t *testing.T) {
processor := NewProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
}
// TestWillHandleNotificationInClient tests the notification filtering logic
func TestWillHandleNotificationInClient(t *testing.T) {
testCases := []struct {
name string
notificationType string
shouldHandle bool
}{
// Pub/Sub notifications (should be handled in client)
{"message", "message", true},
{"pmessage", "pmessage", true},
{"subscribe", "subscribe", true},
{"unsubscribe", "unsubscribe", true},
{"psubscribe", "psubscribe", true},
{"punsubscribe", "punsubscribe", true},
{"smessage", "smessage", true},
{"ssubscribe", "ssubscribe", true},
{"sunsubscribe", "sunsubscribe", true},
// Push notifications (should be handled by processor)
{"MOVING", "MOVING", false},
{"MIGRATING", "MIGRATING", false},
{"MIGRATED", "MIGRATED", false},
{"FAILING_OVER", "FAILING_OVER", false},
{"FAILED_OVER", "FAILED_OVER", false},
{"custom", "custom", false},
{"unknown", "unknown", false},
{"empty", "", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := willHandleNotificationInClient(tc.notificationType)
if result != tc.shouldHandle {
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
}
})
}
}
// TestProcessorErrorHandlingUnit tests error handling scenarios
func TestProcessorErrorHandlingUnit(t *testing.T) {
processor := NewProcessor()
t.Run("RegisterNilHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", nil, false)
if err == nil {
t.Error("RegisterHandler should error with nil handler")
}
// Check error type
if !IsHandlerNilError(err) {
t.Error("Error should be a HandlerNilError")
}
})
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
handler1 := &UnitTestHandler{name: "handler1"}
handler2 := &UnitTestHandler{name: "handler2"}
// Register first handler
err := processor.RegisterHandler("DUPLICATE", handler1, false)
if err != nil {
t.Errorf("First RegisterHandler should not error: %v", err)
}
// Try to register second handler with same name
err = processor.RegisterHandler("DUPLICATE", handler2, false)
if err == nil {
t.Error("RegisterHandler should error when registering duplicate handler")
}
// Verify original handler is still there
retrievedHandler := processor.GetHandler("DUPLICATE")
if retrievedHandler != handler1 {
t.Error("Original handler should remain after failed duplicate registration")
}
})
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
err := processor.UnregisterHandler("NONEXISTENT")
if err != nil {
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
}
})
}
// TestProcessorConcurrentAccess tests concurrent access to processor
func TestProcessorConcurrentAccess(t *testing.T) {
processor := NewProcessor()
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
done := make(chan bool, 2)
// Goroutine 1: Register handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
handler := &UnitTestHandler{name: "concurrent-handler"}
processor.RegisterHandler("CONCURRENT", handler, false)
processor.UnregisterHandler("CONCURRENT")
}
}()
// Goroutine 2: Get handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
processor.GetHandler("CONCURRENT")
}
}()
// Wait for both goroutines to complete
<-done
<-done
})
}
// TestProcessorInterfaceCompliance tests interface compliance
func TestProcessorInterfaceCompliance(t *testing.T) {
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*Processor)(nil)
})
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*VoidProcessor)(nil)
})
}
// UnitTestHandler is a test implementation of NotificationHandler
type UnitTestHandler struct {
name string
lastNotification []interface{}
errorToReturn error
callCount int
}
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
h.callCount++
h.lastNotification = notification
return h.errorToReturn
}
// Helper methods for UnitTestHandler
func (h *UnitTestHandler) GetCallCount() int {
return h.callCount
}
func (h *UnitTestHandler) GetLastNotification() []interface{} {
return h.lastNotification
}
func (h *UnitTestHandler) SetErrorToReturn(err error) {
h.errorToReturn = err
}
func (h *UnitTestHandler) Reset() {
h.callCount = 0
h.lastNotification = nil
h.errorToReturn = nil
}

7
push/push.go Normal file
View File

@@ -0,0 +1,7 @@
// Package push provides push notifications for Redis.
// This is an EXPERIMENTAL API for handling push notifications from Redis.
// It is not yet stable and may change in the future.
// Although this is in a public package, in its current form public use is not advised.
// Pending push notifications should be processed before executing any readReply from the connection
// as per RESP3 specification push notifications can be sent at any time.
package push

1713
push/push_test.go Normal file

File diff suppressed because it is too large Load Diff

61
push/registry.go Normal file
View File

@@ -0,0 +1,61 @@
package push
import (
"sync"
)
// Registry manages push notification handlers
type Registry struct {
mu sync.RWMutex
handlers map[string]NotificationHandler
protected map[string]bool
}
// NewRegistry creates a new push notification registry
func NewRegistry() *Registry {
return &Registry{
handlers: make(map[string]NotificationHandler),
protected: make(map[string]bool),
}
}
// RegisterHandler registers a handler for a specific push notification name
func (r *Registry) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error {
if handler == nil {
return ErrHandlerNil
}
r.mu.Lock()
defer r.mu.Unlock()
// Check if handler already exists
if _, exists := r.protected[pushNotificationName]; exists {
return ErrHandlerExists(pushNotificationName)
}
r.handlers[pushNotificationName] = handler
r.protected[pushNotificationName] = protected
return nil
}
// GetHandler returns the handler for a specific push notification name
func (r *Registry) GetHandler(pushNotificationName string) NotificationHandler {
r.mu.RLock()
defer r.mu.RUnlock()
return r.handlers[pushNotificationName]
}
// UnregisterHandler removes a handler for a specific push notification name
func (r *Registry) UnregisterHandler(pushNotificationName string) error {
r.mu.Lock()
defer r.mu.Unlock()
// Check if handler is protected
if protected, exists := r.protected[pushNotificationName]; exists && protected {
return ErrProtectedHandler(pushNotificationName)
}
delete(r.handlers, pushNotificationName)
delete(r.protected, pushNotificationName)
return nil
}

21
push_notifications.go Normal file
View File

@@ -0,0 +1,21 @@
package redis
import (
"github.com/redis/go-redis/v9/push"
)
// NewPushNotificationProcessor creates a new push notification processor
// This processor maintains a registry of handlers and processes push notifications
// It is used for RESP3 connections where push notifications are available
func NewPushNotificationProcessor() push.NotificationProcessor {
return push.NewProcessor()
}
// NewVoidPushNotificationProcessor creates a new void push notification processor
// This processor does not maintain any handlers and always returns nil for all operations
// It is used for RESP2 connections where push notifications are not available
// It can also be used to disable push notifications for RESP3 connections, where
// it will discard all push notifications without processing them
func NewVoidPushNotificationProcessor() push.NotificationProcessor {
return push.NewVoidProcessor()
}

401
redis.go
View File

@@ -10,10 +10,12 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/push"
)
// Scanner internal/hscan.Scanner exposed interface.
@@ -23,6 +25,7 @@ type Scanner = hscan.Scanner
const Nil = proto.Nil
// SetLogger set custom log
// Use with VoidLogger to disable logging.
func SetLogger(logger internal.Logging) {
internal.Logger = logger
}
@@ -203,15 +206,34 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
type baseClient struct {
opt *Options
optLock sync.RWMutex
connPool pool.Pooler
pubSubPool *pool.PubSubPool
hooksMixin
onClose func() error // hook called when client is closed
// Push notification processing
pushProcessor push.NotificationProcessor
// Hitless upgrade manager
hitlessManager *hitless.HitlessManager
hitlessManagerLock sync.RWMutex
}
func (c *baseClient) clone() *baseClient {
clone := *c
return &clone
c.hitlessManagerLock.RLock()
hitlessManager := c.hitlessManager
c.hitlessManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
hitlessManager: hitlessManager,
}
return clone
}
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
@@ -229,21 +251,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.NewConn(ctx)
if err != nil {
return nil, err
}
err = c.initConn(ctx, cn)
if err != nil {
_ = c.connPool.CloseConn(cn)
return nil, err
}
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
@@ -269,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
if cn.Inited {
if cn.IsInited() {
return cn, nil
}
@@ -351,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
if !cn.Inited.CompareAndSwap(false, true) {
return nil
}
var err error
cn.Inited = true
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin)
@@ -425,6 +430,51 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("failed to initialize connection options: %w", err)
}
// Enable maintenance notifications if hitless upgrades are configured
c.optLock.RLock()
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
protocol := c.opt.Protocol
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
c.optLock.RUnlock()
var hitlessHandshakeErr error
if hitlessEnabled && protocol == 3 {
hitlessHandshakeErr = conn.ClientMaintNotifications(
ctx,
true,
endpointType.String(),
).Err()
if hitlessHandshakeErr != nil {
if !isRedisError(hitlessHandshakeErr) {
// if not redis error, fail the connection
return hitlessHandshakeErr
}
c.optLock.Lock()
// handshake failed - check and modify config atomically
switch c.opt.HitlessUpgradeConfig.Mode {
case hitless.MaintNotificationsEnabled:
// enabled mode, fail the connection
c.optLock.Unlock()
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
default: // will handle auto and any other
internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake error: %v", hitlessHandshakeErr)
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
c.optLock.Unlock()
// auto mode, disable hitless upgrades and continue
if err := c.disableHitlessUpgrades(); err != nil {
// Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
}
}
} else {
// handshake was executed successfully
// to make sure that the handshake will be executed on other connections as well if it was successfully
// executed on this connection, we will force the handshake to be executed on all connections
c.optLock.Lock()
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
c.optLock.Unlock()
}
}
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
libName := ""
libVer := Version()
@@ -441,6 +491,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
}
}
cn.SetUsable(true)
cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections
cn.SetInitConnFunc(c.createInitConnFunc())
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
@@ -456,6 +512,10 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
if isBadConn(err, false, c.opt.Addr) {
c.connPool.Remove(ctx, cn, err)
} else {
// process any pending push notifications before returning the connection to the pool
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
}
c.connPool.Put(ctx, cn)
}
}
@@ -497,16 +557,16 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return lastErr
}
func (c *baseClient) assertUnstableCommand(cmd Cmder) bool {
func (c *baseClient) assertUnstableCommand(cmd Cmder) (bool, error) {
switch cmd.(type) {
case *AggregateCmd, *FTInfoCmd, *FTSpellCheckCmd, *FTSearchCmd, *FTSynDumpCmd:
if c.opt.UnstableResp3 {
return true
return true, nil
} else {
panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.")
return false, fmt.Errorf("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3. See the README and the release notes for guidance")
}
default:
return false
return false, nil
}
}
@@ -519,6 +579,11 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
retryTimeout := uint32(0)
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
// Process any pending push notifications before executing the command
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
}); err != nil {
@@ -527,10 +592,22 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
}
readReplyFunc := cmd.readReply
// Apply unstable RESP3 search module.
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
if c.opt.Protocol != 2 {
useRawReply, err := c.assertUnstableCommand(cmd)
if err != nil {
return err
}
if useRawReply {
readReplyFunc = cmd.readRawReply
}
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), readReplyFunc); err != nil {
}
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
return readReplyFunc(rd)
}); err != nil {
if cmd.readTimeout() == nil {
atomic.StoreUint32(&retryTimeout, 1)
} else {
@@ -573,20 +650,77 @@ func (c *baseClient) context(ctx context.Context) context.Context {
return context.Background()
}
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
return func(ctx context.Context, cn *pool.Conn) error {
return c.initConn(ctx, cn)
}
}
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
// This function is called during client initialization.
// will register push notification handlers for all hitless upgrade events.
// will start background workers for handoff processing in the pool hook.
func (c *baseClient) enableHitlessUpgrades() error {
// Create client adapter
clientAdapterInstance := newClientAdapter(c)
// Create hitless manager directly
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
if err != nil {
return err
}
// Set the manager reference and initialize pool hook
c.hitlessManagerLock.Lock()
c.hitlessManager = manager
c.hitlessManagerLock.Unlock()
// Initialize pool hook (safe to call without lock since manager is now set)
manager.InitPoolHook(c.dialHook)
return nil
}
func (c *baseClient) disableHitlessUpgrades() error {
c.hitlessManagerLock.Lock()
defer c.hitlessManagerLock.Unlock()
// Close the hitless manager
if c.hitlessManager != nil {
// Closing the manager will also shutdown the pool hook
// and remove it from the pool
c.hitlessManager.Close()
c.hitlessManager = nil
}
return nil
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
// Close hitless manager first
if err := c.disableHitlessUpgrades(); err != nil {
firstErr = err
}
if c.onClose != nil {
if err := c.onClose(); err != nil {
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.connPool != nil {
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.pubSubPool != nil {
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
@@ -625,6 +759,10 @@ func (c *baseClient) generalProcessPipeline(
// Enable retries by default to retry dial errors returned by withConn.
canRetry := true
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
// Process any pending push notifications before executing the pipeline
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err)
}
var err error
canRetry, err = p(ctx, cn, cmds)
return err
@@ -640,6 +778,11 @@ func (c *baseClient) generalProcessPipeline(
func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
// Process any pending push notifications before executing the pipeline
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
@@ -648,7 +791,8 @@ func (c *baseClient) pipelineProcessCmds(
}
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds)
// read all replies
return c.pipelineReadCmds(ctx, cn, rd, cmds)
}); err != nil {
return true, err
}
@@ -656,8 +800,12 @@ func (c *baseClient) pipelineProcessCmds(
return false, nil
}
func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *proto.Reader, cmds []Cmder) error {
for i, cmd := range cmds {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
err := cmd.readReply(rd)
cmd.SetErr(err)
if err != nil && !isRedisError(err) {
@@ -672,6 +820,11 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
// Process any pending push notifications before executing the transaction pipeline
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
@@ -684,12 +837,13 @@ func (c *baseClient) txPipelineProcessCmds(
// Trim multi and exec.
trimmedCmds := cmds[1 : len(cmds)-1]
if err := txPipelineReadQueued(rd, statusCmd, trimmedCmds); err != nil {
if err := c.txPipelineReadQueued(ctx, cn, rd, statusCmd, trimmedCmds); err != nil {
setCmdsErr(cmds, err)
return err
}
return pipelineReadCmds(rd, trimmedCmds)
// Read replies.
return c.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
}); err != nil {
return false, err
}
@@ -697,7 +851,13 @@ func (c *baseClient) txPipelineProcessCmds(
return false, nil
}
func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
// txPipelineReadQueued reads queued replies from the Redis server.
// It returns an error if the server returns an error or if the number of replies does not match the number of commands.
func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse +OK.
if err := statusCmd.readReply(rd); err != nil {
return err
@@ -705,6 +865,10 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
// Parse +QUEUED.
for _, cmd := range cmds {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
if err := statusCmd.readReply(rd); err != nil {
cmd.SetErr(err)
if !isRedisError(err) {
@@ -713,6 +877,10 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
}
}
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse number of replies.
line, err := rd.ReadLine()
if err != nil {
@@ -746,15 +914,56 @@ func NewClient(opt *Options) *Client {
if opt == nil {
panic("redis: NewClient nil options")
}
// clone to not share options with the caller
opt = opt.clone()
opt.init()
// Push notifications are always enabled for RESP3 (cannot be disabled)
c := Client{
baseClient: &baseClient{
opt: opt,
},
}
c.init()
c.connPool = newConnPool(opt, c.dialHook)
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
// set opt push processor for child clients
c.opt.PushNotificationProcessor = c.pushProcessor
// Create connection pools
var err error
c.connPool, err = newConnPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
// Initialize hitless upgrades first if enabled and protocol is RESP3
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
err := c.enableHitlessUpgrades()
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
/*
Design decision: panic here to fail fast if hitless upgrades cannot be enabled when explicitly requested.
We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect
an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced
immediately, rather than allowing the client to continue in a partially initialized or inconsistent state.
Clients relying on hitless upgrades should be aware that initialization errors will cause a panic, and should
handle this accordingly (e.g., via recover or by validating configuration before calling NewClient).
This approach is only used when HitlessUpgradeConfig.Mode is MaintNotificationsEnabled, indicating that hitless
upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic.
*/
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
}
}
}
return &c
}
@@ -791,11 +1000,51 @@ func (c *Client) Options() *Options {
return c.opt
}
// GetHitlessManager returns the hitless manager instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
c.hitlessManagerLock.RLock()
defer c.hitlessManagerLock.RUnlock()
return c.hitlessManager
}
// initializePushProcessor initializes the push notification processor for any client type.
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
func initializePushProcessor(opt *Options) push.NotificationProcessor {
// Always use custom processor if provided
if opt.PushNotificationProcessor != nil {
return opt.PushNotificationProcessor
}
// Push notifications are always enabled for RESP3, disabled for RESP2
if opt.Protocol == 3 {
// Create default processor for RESP3 connections
return NewPushNotificationProcessor()
}
// Create void processor for RESP2 connections (push notifications not available)
return NewVoidPushNotificationProcessor()
}
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
// Returns an error if a handler is already registered for this push notification name.
// If protected is true, the handler cannot be unregistered.
func (c *Client) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
}
// GetPushNotificationHandler returns the handler for a specific push notification name.
// Returns nil if no handler is registered for the given name.
func (c *Client) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler {
return c.pushProcessor.GetHandler(pushNotificationName)
}
type PoolStats pool.Stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
stats := c.connPool.Stats()
stats.PubSubStats = *(c.pubSubPool.Stats())
return (*PoolStats)(stats)
}
@@ -830,13 +1079,31 @@ func (c *Client) TxPipeline() Pipeliner {
func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}
@@ -920,6 +1187,10 @@ func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn
c.hooksMixin = parentHooks.clone()
}
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
c.cmdable = c.Process
c.statefulCmdable = c.Process
c.initHooks(hooks{
@@ -938,6 +1209,13 @@ func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
return err
}
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
// Returns an error if a handler is already registered for this push notification name.
// If protected is true, the handler cannot be unregistered.
func (c *Conn) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
}
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
@@ -965,3 +1243,50 @@ func (c *Conn) TxPipeline() Pipeliner {
pipe.init()
return &pipe
}
// processPushNotifications processes all pending push notifications on a connection
// This ensures that cluster topology changes are handled immediately before the connection is used
// This method should be called by the client before using WithReader for command execution
func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn) error {
// Only process push notifications for RESP3 connections with a processor
// Also check if there is any data to read before processing
// Which is an optimization on UNIX systems where MaybeHasData is a syscall
// On Windows, MaybeHasData always returns true, so this check is a no-op
if c.opt.Protocol != 3 || c.pushProcessor == nil || !cn.MaybeHasData() {
return nil
}
// Use WithReader to access the reader and process push notifications
// This is critical for hitless upgrades to work properly
// NOTE: almost no timeouts are set for this read, so it should not block
// longer than necessary, 10us should be plenty of time to read if there are any push notifications
// on the socket.
return cn.WithReader(ctx, 10*time.Microsecond, func(rd *proto.Reader) error {
// Create handler context with client, connection pool, and connection information
handlerCtx := c.pushNotificationHandlerContext(cn)
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
})
}
// processPendingPushNotificationWithReader processes all pending push notifications on a connection
// This method should be called by the client in WithReader before reading the reply
func (c *baseClient) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
// if we have the reader, we don't need to check for data on the socket, we are waiting
// for either a reply or a push notification, so we can block until we get a reply or reach the timeout
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}
// Create handler context with client, connection pool, and connection information
handlerCtx := c.pushNotificationHandlerContext(cn)
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
}
// pushNotificationHandlerContext creates a handler context for push notification processing
func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext {
return push.NotificationHandlerContext{
Client: c,
ConnPool: c.connPool,
Conn: cn, // Wrap in adapter for easier interface access
}
}

View File

@@ -12,7 +12,6 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)

View File

@@ -3407,14 +3407,16 @@ var _ = Describe("RediSearch commands Resp 3", Label("search"), func() {
Expect(rawValResults[0]).To(Or(BeEquivalentTo(results[0]), BeEquivalentTo(results[1])))
Expect(rawValResults[1]).To(Or(BeEquivalentTo(results[0]), BeEquivalentTo(results[1])))
// Test with UnstableResp3 false
Expect(func() {
// Test with UnstableResp3 false - should return error instead of panic
options = &redis.FTAggregateOptions{Apply: []redis.FTAggregateApply{{Field: "@CreatedDateTimeUTC * 10", As: "CreatedDateTimeUTC"}}}
rawRes, _ := client2.FTAggregateWithArgs(ctx, "idx1", "*", options).RawResult()
rawVal = client2.FTAggregateWithArgs(ctx, "idx1", "*", options).RawVal()
rawRes, err := client2.FTAggregateWithArgs(ctx, "idx1", "*", options).RawResult()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("RESP3 responses for this command are disabled"))
Expect(rawRes).To(BeNil())
rawVal = client2.FTAggregateWithArgs(ctx, "idx1", "*", options).RawVal()
Expect(client2.FTAggregateWithArgs(ctx, "idx1", "*", options).Err()).To(HaveOccurred())
Expect(rawVal).To(BeNil())
}).Should(Panic())
})
@@ -3435,13 +3437,15 @@ var _ = Describe("RediSearch commands Resp 3", Label("search"), func() {
flags = attributes[0].(map[interface{}]interface{})["flags"].([]interface{})
Expect(flags).To(ConsistOf("SORTABLE", "NOSTEM"))
// Test with UnstableResp3 false
Expect(func() {
rawResInfo, _ := client2.FTInfo(ctx, "idx1").RawResult()
rawValInfo := client2.FTInfo(ctx, "idx1").RawVal()
// Test with UnstableResp3 false - should return error instead of panic
rawResInfo, err := client2.FTInfo(ctx, "idx1").RawResult()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("RESP3 responses for this command are disabled"))
Expect(rawResInfo).To(BeNil())
rawValInfo := client2.FTInfo(ctx, "idx1").RawVal()
Expect(client2.FTInfo(ctx, "idx1").Err()).To(HaveOccurred())
Expect(rawValInfo).To(BeNil())
}).Should(Panic())
})
It("should handle FTSpellCheck with Unstable RESP3 Search Module and without stability", Label("search", "ftcreate", "ftspellcheck"), func() {
@@ -3462,13 +3466,15 @@ var _ = Describe("RediSearch commands Resp 3", Label("search"), func() {
results := resSpellCheck.(map[interface{}]interface{})["results"].(map[interface{}]interface{})
Expect(results["impornant"].([]interface{})[0].(map[interface{}]interface{})["important"]).To(BeEquivalentTo(0.5))
// Test with UnstableResp3 false
Expect(func() {
rawResSpellCheck, _ := client2.FTSpellCheck(ctx, "idx1", "impornant").RawResult()
rawValSpellCheck := client2.FTSpellCheck(ctx, "idx1", "impornant").RawVal()
// Test with UnstableResp3 false - should return error instead of panic
rawResSpellCheck, err := client2.FTSpellCheck(ctx, "idx1", "impornant").RawResult()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("RESP3 responses for this command are disabled"))
Expect(rawResSpellCheck).To(BeNil())
rawValSpellCheck := client2.FTSpellCheck(ctx, "idx1", "impornant").RawVal()
Expect(client2.FTSpellCheck(ctx, "idx1", "impornant").Err()).To(HaveOccurred())
Expect(rawValSpellCheck).To(BeNil())
}).Should(Panic())
})
It("should handle FTSearch with Unstable RESP3 Search Module and without stability", Label("search", "ftcreate", "ftsearch"), func() {
@@ -3489,13 +3495,15 @@ var _ = Describe("RediSearch commands Resp 3", Label("search"), func() {
totalResults2 := res2.(map[interface{}]interface{})["total_results"]
Expect(totalResults2).To(BeEquivalentTo(int64(1)))
// Test with UnstableResp3 false
Expect(func() {
rawRes2, _ := client2.FTSearchWithArgs(ctx, "txt", "foo bar hello world", &redis.FTSearchOptions{NoContent: true}).RawResult()
rawVal2 := client2.FTSearchWithArgs(ctx, "txt", "foo bar hello world", &redis.FTSearchOptions{NoContent: true}).RawVal()
// Test with UnstableResp3 false - should return error instead of panic
rawRes2, err := client2.FTSearchWithArgs(ctx, "txt", "foo bar hello world", &redis.FTSearchOptions{NoContent: true}).RawResult()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("RESP3 responses for this command are disabled"))
Expect(rawRes2).To(BeNil())
rawVal2 := client2.FTSearchWithArgs(ctx, "txt", "foo bar hello world", &redis.FTSearchOptions{NoContent: true}).RawVal()
Expect(client2.FTSearchWithArgs(ctx, "txt", "foo bar hello world", &redis.FTSearchOptions{NoContent: true}).Err()).To(HaveOccurred())
Expect(rawVal2).To(BeNil())
}).Should(Panic())
})
It("should handle FTSynDump with Unstable RESP3 Search Module and without stability", Label("search", "ftsyndump"), func() {
text1 := &redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText}
@@ -3523,13 +3531,15 @@ var _ = Describe("RediSearch commands Resp 3", Label("search"), func() {
Expect(valSynDump).To(BeEquivalentTo(resSynDump))
Expect(resSynDump.(map[interface{}]interface{})["baby"]).To(BeEquivalentTo([]interface{}{"id1"}))
// Test with UnstableResp3 false
Expect(func() {
rawResSynDump, _ := client2.FTSynDump(ctx, "idx1").RawResult()
rawValSynDump := client2.FTSynDump(ctx, "idx1").RawVal()
// Test with UnstableResp3 false - should return error instead of panic
rawResSynDump, err := client2.FTSynDump(ctx, "idx1").RawResult()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("RESP3 responses for this command are disabled"))
Expect(rawResSynDump).To(BeNil())
rawValSynDump := client2.FTSynDump(ctx, "idx1").RawVal()
Expect(client2.FTSynDump(ctx, "idx1").Err()).To(HaveOccurred())
Expect(rawValSynDump).To(BeNil())
}).Should(Panic())
})
It("should test not affected Resp 3 Search method - FTExplain", Label("search", "ftexplain"), func() {

View File

@@ -17,6 +17,7 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/push"
)
//------------------------------------------------------------------------------
@@ -62,6 +63,8 @@ type FailoverOptions struct {
Protocol int
Username string
Password string
// Push notifications are always enabled for RESP3 connections
// CredentialsProvider allows the username and password to be updated
// before reconnecting. It should return the current username and password.
CredentialsProvider func() (username string, password string)
@@ -136,6 +139,14 @@ type FailoverOptions struct {
FailingTimeoutSeconds int
UnstableResp3 bool
// Hitless is not supported for FailoverClients at the moment
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// upgrade notifications gracefully and manage connection/pool state transitions
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are disabled.
//HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *FailoverOptions) clientOptions() *Options {
@@ -454,8 +465,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt.Dialer = masterReplicaDialer(failover)
opt.init()
var connPool *pool.ConnPool
rdb := &Client{
baseClient: &baseClient{
opt: opt,
@@ -463,16 +472,30 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
}
rdb.init()
connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool
// Initialize push notification processor using shared helper
// Use void processor by default for RESP2 connections
rdb.pushProcessor = initializePushProcessor(opt)
var err error
rdb.connPool, err = newConnPool(opt, rdb.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
}
}
failover.mu.Unlock()
return rdb
@@ -529,15 +552,40 @@ func NewSentinelClient(opt *Options) *SentinelClient {
},
}
// Initialize push notification processor using shared helper
// Use void processor for Sentinel clients
c.pushProcessor = NewVoidPushNotificationProcessor()
c.initHooks(hooks{
dial: c.baseClient.dial,
process: c.baseClient.process,
})
c.connPool = newConnPool(opt, c.dialHook)
var err error
c.connPool, err = newConnPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
return c
}
// GetPushNotificationHandler returns the handler for a specific push notification name.
// Returns nil if no handler is registered for the given name.
func (c *SentinelClient) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler {
return c.pushProcessor.GetHandler(pushNotificationName)
}
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
// Returns an error if a handler is already registered for this push notification name.
// If protected is true, the handler cannot be unregistered.
func (c *SentinelClient) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
}
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
err := c.processHook(ctx, cmd)
cmd.SetErr(err)
@@ -547,13 +595,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}

3
tx.go
View File

@@ -24,9 +24,10 @@ type Tx struct {
func (c *Client) newTx() *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client
connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
pushProcessor: c.pushProcessor, // Copy push processor from parent client
},
}
tx.init()

View File

@@ -122,6 +122,9 @@ type UniversalOptions struct {
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
IsClusterMode bool
// HitlessUpgradeConfig provides configuration for hitless upgrades.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// Cluster returns cluster options created from the universal options.
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
// Note: HitlessUpgradeConfig not supported for FailoverOptions
}
}
@@ -288,6 +293,7 @@ func (o *UniversalOptions) Simple() *Options {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}