mirror of
https://github.com/redis/go-redis.git
synced 2025-09-10 07:11:50 +03:00
[CAE-1072] Hitless Upgrades (#3447)
* feat(hitless): Introduce handlers for hitless upgrades This commit includes all the work on hitless upgrades with the addition of: - Pubsub Pool - Examples - Refactor of push - Refactor of pool (using atomics for most things) - Introducing of hooks in pool --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,3 +9,6 @@ coverage.txt
|
||||
**/coverage.txt
|
||||
.vscode
|
||||
tmp/*
|
||||
|
||||
# Hitless upgrade documentation (temporary)
|
||||
hitless/docs/
|
||||
|
111
adapters.go
Normal file
111
adapters.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
|
||||
var ErrInvalidCommand = errors.New("invalid command type")
|
||||
|
||||
// ErrInvalidPool is returned when the pool type is not supported.
|
||||
var ErrInvalidPool = errors.New("invalid pool type")
|
||||
|
||||
// newClientAdapter creates a new client adapter for regular Redis clients.
|
||||
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
|
||||
return &clientAdapter{client: client}
|
||||
}
|
||||
|
||||
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
|
||||
type clientAdapter struct {
|
||||
client *baseClient
|
||||
}
|
||||
|
||||
// GetOptions returns the client options.
|
||||
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
|
||||
return &optionsAdapter{options: ca.client.opt}
|
||||
}
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
|
||||
}
|
||||
|
||||
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
|
||||
type optionsAdapter struct {
|
||||
options *Options
|
||||
}
|
||||
|
||||
// GetReadTimeout returns the read timeout.
|
||||
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
|
||||
return oa.options.ReadTimeout
|
||||
}
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
|
||||
return oa.options.WriteTimeout
|
||||
}
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
func (oa *optionsAdapter) GetNetwork() string {
|
||||
return oa.options.Network
|
||||
}
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
func (oa *optionsAdapter) GetAddr() string {
|
||||
return oa.options.Addr
|
||||
}
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
func (oa *optionsAdapter) IsTLSEnabled() bool {
|
||||
return oa.options.TLSConfig != nil
|
||||
}
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
func (oa *optionsAdapter) GetProtocol() int {
|
||||
return oa.options.Protocol
|
||||
}
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
func (oa *optionsAdapter) GetPoolSize() int {
|
||||
return oa.options.PoolSize
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
baseDialer := oa.options.NewDialer()
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Extract network and address from the options
|
||||
network := oa.options.Network
|
||||
addr := oa.options.Addr
|
||||
return baseDialer(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
|
||||
type pushProcessorAdapter struct {
|
||||
processor push.NotificationProcessor
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
|
||||
if pushHandler, ok := handler.(push.NotificationHandler); ok {
|
||||
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
|
||||
}
|
||||
return errors.New("handler must implement push.NotificationHandler")
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
|
||||
return ppa.processor.UnregisterHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
|
||||
return ppa.processor.GetHandler(pushNotificationName)
|
||||
}
|
353
async_handoff_integration_test.go
Normal file
353
async_handoff_integration_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
// mockNetConn implements net.Conn for testing
|
||||
type mockNetConn struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (m *mockNetConn) Close() error { return nil }
|
||||
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
|
||||
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
|
||||
func TestEventDrivenHandoffIntegration(t *testing.T) {
|
||||
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
|
||||
// Create a base dialer for testing
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
// Create processor with event-driven handoff support
|
||||
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a test pool with hooks
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
PoolSize: int32(5),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
defer testPool.Close()
|
||||
|
||||
// Set the pool reference in the processor for connection removal on handoff failure
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection and mark it for handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
// Set initialization function with a small delay to ensure handoff is pending
|
||||
initConnCalled := false
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||
initConnCalled = true
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
// Mark connection for handoff
|
||||
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Return connection to pool - this should queue handoff
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Give the on-demand worker a moment to start processing
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify handoff was queued
|
||||
if !processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should be queued in pending map")
|
||||
}
|
||||
|
||||
// Try to get the same connection - should be skipped due to pending handoff
|
||||
conn2, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get second connection: %v", err)
|
||||
}
|
||||
|
||||
// Should get a different connection (the pending one should be skipped)
|
||||
if conn == conn2 {
|
||||
t.Error("Should have gotten a different connection while handoff is pending")
|
||||
}
|
||||
|
||||
// Return the second connection
|
||||
testPool.Put(ctx, conn2)
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify handoff completed (removed from pending map)
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should have completed and been removed from pending map")
|
||||
}
|
||||
|
||||
if !initConnCalled {
|
||||
t.Error("InitConn should have been called during handoff")
|
||||
}
|
||||
|
||||
// Now the original connection should be available again
|
||||
conn3, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get third connection: %v", err)
|
||||
}
|
||||
|
||||
// Could be the original connection (now handed off) or a new one
|
||||
testPool.Put(ctx, conn3)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentHandoffs", func(t *testing.T) {
|
||||
// Create a base dialer that simulates slow handoffs
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate network delay
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple concurrent handoffs
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Get connection
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get conn[%d]: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set initialization function
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
// Mark for handoff
|
||||
conn.MarkForHandoff("new-endpoint:6379", int64(id))
|
||||
|
||||
// Return to pool (starts async handoff)
|
||||
testPool.Put(ctx, conn)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Wait for all handoffs to complete
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Verify pool is still functional
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
|
||||
}
|
||||
testPool.Put(ctx, conn)
|
||||
})
|
||||
|
||||
t.Run("HandoffFailureRecovery", func(t *testing.T) {
|
||||
// Create a failing base dialer
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
|
||||
}
|
||||
|
||||
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(3),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get connection and mark for handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
|
||||
|
||||
// Return to pool (starts async handoff that will fail)
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Wait for handoff to fail
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Connection should be removed from pending map after failed handoff
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Connection should be removed from pending map after failed handoff")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
conn2, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Pool should still be functional: %v", err)
|
||||
}
|
||||
|
||||
// In event-driven approach, the original connection remains in pool
|
||||
// even after failed handoff (it's still a valid connection)
|
||||
// We might get the same connection or a different one
|
||||
testPool.Put(ctx, conn2)
|
||||
})
|
||||
|
||||
t.Run("GracefulShutdown", func(t *testing.T) {
|
||||
// Create a slow base dialer
|
||||
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(2),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function with delay to ensure handoff is pending
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||
return nil
|
||||
})
|
||||
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Give the on-demand worker a moment to start and begin processing
|
||||
// The handoff should be pending because the slowDialer takes 100ms
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify handoff was queued and is being processed
|
||||
if !processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should be queued in pending map")
|
||||
}
|
||||
|
||||
// Give the handoff a moment to start processing
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown processor gracefully
|
||||
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = processor.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
t.Errorf("Graceful shutdown should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Handoff should have completed (removed from pending map)
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should have completed and been removed from pending map after shutdown")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
logging.Disable()
|
||||
}
|
18
commands.go
18
commands.go
@@ -193,6 +193,7 @@ type Cmdable interface {
|
||||
ClientID(ctx context.Context) *IntCmd
|
||||
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
||||
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
||||
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
|
||||
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
||||
ConfigResetStat(ctx context.Context) *StatusCmd
|
||||
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
||||
@@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
|
||||
// When enabled, the client will receive push notifications about Redis maintenance events.
|
||||
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
|
||||
args := []interface{}{"client", "maint_notifications"}
|
||||
if enabled {
|
||||
if endpointType == "" {
|
||||
endpointType = "none"
|
||||
}
|
||||
args = append(args, "on", "moving-endpoint-type", endpointType)
|
||||
} else {
|
||||
args = append(args, "off")
|
||||
}
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
||||
|
12
example/pubsub/go.mod
Normal file
12
example/pubsub/go.mod
Normal file
@@ -0,0 +1,12 @@
|
||||
module github.com/redis/go-redis/example/pubsub
|
||||
|
||||
go 1.18
|
||||
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require github.com/redis/go-redis/v9 v9.11.0
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
)
|
6
example/pubsub/go.sum
Normal file
6
example/pubsub/go.sum
Normal file
@@ -0,0 +1,6 @@
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
171
example/pubsub/main.go
Normal file
171
example/pubsub/main.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
var cntErrors atomic.Int64
|
||||
var cntSuccess atomic.Int64
|
||||
var startTime = time.Now()
|
||||
|
||||
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
|
||||
// It was used to find regressions in pool management in hitless mode.
|
||||
// Please don't use it as a reference for how to use pubsub.
|
||||
func main() {
|
||||
startTime = time.Now()
|
||||
wg := &sync.WaitGroup{}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ":6379",
|
||||
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
|
||||
Mode: hitless.MaintNotificationsEnabled,
|
||||
},
|
||||
})
|
||||
_ = rdb.FlushDB(ctx).Err()
|
||||
hitlessManager := rdb.GetHitlessManager()
|
||||
if hitlessManager == nil {
|
||||
panic("hitless manager is nil")
|
||||
}
|
||||
loggingHook := hitless.NewLoggingHook(logging.LogLevelDebug)
|
||||
hitlessManager.AddNotificationHook(loggingHook)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
|
||||
}
|
||||
}()
|
||||
err := rdb.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("published", rdb.Get(ctx, "published").Val())
|
||||
fmt.Println("received", rdb.Get(ctx, "received").Val())
|
||||
subCtx, cancelSubCtx := context.WithCancel(ctx)
|
||||
pubCtx, cancelPublishers := context.WithCancel(ctx)
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go subscribe(subCtx, rdb, "test", i, wg)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
cancelSubCtx()
|
||||
time.Sleep(time.Second)
|
||||
subCtx, cancelSubCtx = context.WithCancel(ctx)
|
||||
for i := 0; i < 10; i++ {
|
||||
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
|
||||
fmt.Println("incr error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
wg.Add(1)
|
||||
go floodThePool(pubCtx, rdb, wg)
|
||||
}
|
||||
|
||||
for i := 0; i < 500; i++ {
|
||||
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
|
||||
fmt.Println("incr error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go subscribe(subCtx, rdb, "test2", i, wg)
|
||||
}
|
||||
time.Sleep(120 * time.Second)
|
||||
fmt.Println("canceling publishers")
|
||||
cancelPublishers()
|
||||
time.Sleep(10 * time.Second)
|
||||
fmt.Println("canceling subscribers")
|
||||
cancelSubCtx()
|
||||
wg.Wait()
|
||||
published, err := rdb.Get(ctx, "published").Result()
|
||||
received, err := rdb.Get(ctx, "received").Result()
|
||||
publishers, err := rdb.Get(ctx, "publishers").Result()
|
||||
subscribers, err := rdb.Get(ctx, "subscribers").Result()
|
||||
fmt.Printf("publishers: %s\n", publishers)
|
||||
fmt.Printf("published: %s\n", published)
|
||||
fmt.Printf("subscribers: %s\n", subscribers)
|
||||
fmt.Printf("received: %s\n", received)
|
||||
publishedInt, err := rdb.Get(ctx, "published").Int()
|
||||
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
|
||||
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
fmt.Println("errors:", cntErrors.Load())
|
||||
fmt.Println("success:", cntSuccess.Load())
|
||||
fmt.Println("time:", time.Since(startTime))
|
||||
}
|
||||
|
||||
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := rdb.Publish(ctx, "test2", "hello").Err()
|
||||
if err != nil {
|
||||
if err.Error() != "context canceled" {
|
||||
log.Println("publish error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
err = rdb.Incr(ctx, "published").Err()
|
||||
if err != nil {
|
||||
if err.Error() != "context canceled" {
|
||||
log.Println("incr error:", err)
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
}
|
||||
time.Sleep(10 * time.Nanosecond)
|
||||
}
|
||||
}
|
||||
|
||||
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
rec := rdb.Subscribe(ctx, topic)
|
||||
recChan := rec.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rec.Close()
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rec.Close()
|
||||
return
|
||||
case msg := <-recChan:
|
||||
err := rdb.Incr(ctx, "received").Err()
|
||||
if err != nil {
|
||||
if err.Error() != "context canceled" {
|
||||
log.Printf("%s\n", err.Error())
|
||||
cntErrors.Add(1)
|
||||
}
|
||||
}
|
||||
_ = msg // Use the message to avoid unused variable warning
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -57,6 +57,8 @@ func Example_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[ping]>
|
||||
}
|
||||
|
||||
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// pipeline finished processing: [[ping] [ping]]
|
||||
}
|
||||
|
||||
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
|
||||
// finished processing: <[watch foo]>
|
||||
// starting processing: <[ping]>
|
||||
// finished processing: <[ping]>
|
||||
|
98
hitless/README.md
Normal file
98
hitless/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# Hitless Upgrades
|
||||
|
||||
Seamless Redis connection handoffs during cluster changes without dropping connections.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
HitlessUpgrades: &hitless.Config{
|
||||
Mode: hitless.MaintNotificationsEnabled,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- **`MaintNotificationsDisabled`** - Hitless upgrades disabled
|
||||
- **`MaintNotificationsEnabled`** - Forcefully enabled (fails if server doesn't support)
|
||||
- **`MaintNotificationsAuto`** - Auto-detect server support (default)
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
&hitless.Config{
|
||||
Mode: hitless.MaintNotificationsAuto,
|
||||
EndpointType: hitless.EndpointTypeAuto,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxHandoffRetries: 3,
|
||||
MaxWorkers: 0, // Auto-calculated
|
||||
HandoffQueueSize: 0, // Auto-calculated
|
||||
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
|
||||
LogLevel: logging.LogLevelError,
|
||||
}
|
||||
```
|
||||
|
||||
### Endpoint Types
|
||||
|
||||
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
|
||||
- **`EndpointTypeInternalIP`** - Internal IP address
|
||||
- **`EndpointTypeInternalFQDN`** - Internal FQDN
|
||||
- **`EndpointTypeExternalIP`** - External IP address
|
||||
- **`EndpointTypeExternalFQDN`** - External FQDN
|
||||
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
|
||||
|
||||
### Auto-Scaling
|
||||
|
||||
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
|
||||
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
|
||||
|
||||
**Examples:**
|
||||
- Pool 100: 33 workers, 660 queue (capped at 500)
|
||||
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Redis sends push notifications about cluster changes
|
||||
2. Client creates new connections to updated endpoints
|
||||
3. Active operations transfer to new connections
|
||||
4. Old connections close gracefully
|
||||
|
||||
## Supported Notifications
|
||||
|
||||
- `MOVING` - Slot moving to new node
|
||||
- `MIGRATING` - Slot in migration state
|
||||
- `MIGRATED` - Migration completed
|
||||
- `FAILING_OVER` - Node failing over
|
||||
- `FAILED_OVER` - Failover completed
|
||||
|
||||
## Hooks (Optional)
|
||||
|
||||
Monitor and customize hitless operations:
|
||||
|
||||
```go
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx, notificationCtx, notificationType, notification) ([]interface{}, bool)
|
||||
PostHook(ctx, notificationCtx, notificationType, notification, result)
|
||||
}
|
||||
|
||||
// Add custom hook
|
||||
manager.AddNotificationHook(&MyHook{})
|
||||
```
|
||||
|
||||
### Metrics Hook Example
|
||||
|
||||
```go
|
||||
// Create metrics hook
|
||||
metricsHook := hitless.NewMetricsHook()
|
||||
manager.AddNotificationHook(metricsHook)
|
||||
|
||||
// Access collected metrics
|
||||
metrics := metricsHook.GetMetrics()
|
||||
fmt.Printf("Notification counts: %v\n", metrics["notification_counts"])
|
||||
fmt.Printf("Processing times: %v\n", metrics["processing_times"])
|
||||
fmt.Printf("Error counts: %v\n", metrics["error_counts"])
|
||||
```
|
360
hitless/circuit_breaker.go
Normal file
360
hitless/circuit_breaker.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the state of a circuit breaker
|
||||
type CircuitBreakerState int32
|
||||
|
||||
const (
|
||||
// CircuitBreakerClosed - normal operation, requests allowed
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen - failing fast, requests rejected
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen - testing if service recovered
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
|
||||
type CircuitBreaker struct {
|
||||
// Configuration
|
||||
failureThreshold int // Number of failures before opening
|
||||
resetTimeout time.Duration // How long to stay open before testing
|
||||
maxRequests int // Max requests allowed in half-open state
|
||||
|
||||
// State tracking (atomic for lock-free access)
|
||||
state atomic.Int32 // CircuitBreakerState
|
||||
failures atomic.Int64 // Current failure count
|
||||
successes atomic.Int64 // Success count in half-open state
|
||||
requests atomic.Int64 // Request count in half-open state
|
||||
lastFailureTime atomic.Int64 // Unix timestamp of last failure
|
||||
lastSuccessTime atomic.Int64 // Unix timestamp of last success
|
||||
|
||||
// Endpoint identification
|
||||
endpoint string
|
||||
config *Config
|
||||
}
|
||||
|
||||
// newCircuitBreaker creates a new circuit breaker for an endpoint
|
||||
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
|
||||
// Use configuration values with sensible defaults
|
||||
failureThreshold := 5
|
||||
resetTimeout := 60 * time.Second
|
||||
maxRequests := 3
|
||||
|
||||
if config != nil {
|
||||
failureThreshold = config.CircuitBreakerFailureThreshold
|
||||
resetTimeout = config.CircuitBreakerResetTimeout
|
||||
maxRequests = config.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
failureThreshold: failureThreshold,
|
||||
resetTimeout: resetTimeout,
|
||||
maxRequests: maxRequests,
|
||||
endpoint: endpoint,
|
||||
config: config,
|
||||
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
|
||||
}
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is open (rejecting requests)
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
return state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// shouldAttemptReset checks if enough time has passed to attempt reset
|
||||
func (cb *CircuitBreaker) shouldAttemptReset() bool {
|
||||
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
|
||||
return time.Since(lastFailure) >= cb.resetTimeout
|
||||
}
|
||||
|
||||
// Execute runs the given function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
// Single atomic state load for consistency
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerOpen:
|
||||
if cb.shouldAttemptReset() {
|
||||
// Attempt transition to half-open
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
|
||||
cb.requests.Store(0)
|
||||
cb.successes.Store(0)
|
||||
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: circuit breaker for %s transitioning to half-open", cb.endpoint)
|
||||
}
|
||||
// Fall through to half-open logic
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
fallthrough
|
||||
case CircuitBreakerHalfOpen:
|
||||
requests := cb.requests.Add(1)
|
||||
if requests > int64(cb.maxRequests) {
|
||||
cb.requests.Add(-1) // Revert the increment
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the function with consistent state
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.lastFailureTime.Store(time.Now().Unix())
|
||||
failures := cb.failures.Add(1)
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
if failures >= int64(cb.failureThreshold) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
|
||||
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: circuit breaker opened for endpoint %s after %d failures",
|
||||
cb.endpoint, failures)
|
||||
}
|
||||
}
|
||||
}
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Any failure in half-open state immediately opens the circuit
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
|
||||
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: circuit breaker reopened for endpoint %s due to failure in half-open state",
|
||||
cb.endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a success and potentially closes the circuit
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.lastSuccessTime.Store(time.Now().Unix())
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
// Reset failure count on success in closed state
|
||||
cb.failures.Store(0)
|
||||
case CircuitBreakerHalfOpen:
|
||||
successes := cb.successes.Add(1)
|
||||
|
||||
// If we've had enough successful requests, close the circuit
|
||||
if successes >= int64(cb.maxRequests) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
|
||||
cb.failures.Store(0)
|
||||
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: circuit breaker closed for endpoint %s after %d successful requests",
|
||||
cb.endpoint, successes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
return CircuitBreakerState(cb.state.Load())
|
||||
}
|
||||
|
||||
// GetStats returns current statistics for monitoring
|
||||
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
|
||||
return CircuitBreakerStats{
|
||||
Endpoint: cb.endpoint,
|
||||
State: cb.GetState(),
|
||||
Failures: cb.failures.Load(),
|
||||
Successes: cb.successes.Load(),
|
||||
Requests: cb.requests.Load(),
|
||||
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
|
||||
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats provides statistics about a circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
Endpoint string
|
||||
State CircuitBreakerState
|
||||
Failures int64
|
||||
Successes int64
|
||||
Requests int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerEntry wraps a circuit breaker with access tracking
|
||||
type CircuitBreakerEntry struct {
|
||||
breaker *CircuitBreaker
|
||||
lastAccess atomic.Int64 // Unix timestamp
|
||||
created time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerManager manages circuit breakers for multiple endpoints
|
||||
type CircuitBreakerManager struct {
|
||||
breakers sync.Map // map[string]*CircuitBreakerEntry
|
||||
config *Config
|
||||
cleanupStop chan struct{}
|
||||
cleanupMu sync.Mutex
|
||||
lastCleanup atomic.Int64 // Unix timestamp
|
||||
}
|
||||
|
||||
// newCircuitBreakerManager creates a new circuit breaker manager
|
||||
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
|
||||
cbm := &CircuitBreakerManager{
|
||||
config: config,
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
cbm.lastCleanup.Store(time.Now().Unix())
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go cbm.cleanupLoop()
|
||||
|
||||
return cbm
|
||||
}
|
||||
|
||||
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
|
||||
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
|
||||
now := time.Now().Unix()
|
||||
|
||||
if entry, ok := cbm.breakers.Load(endpoint); ok {
|
||||
cbEntry := entry.(*CircuitBreakerEntry)
|
||||
cbEntry.lastAccess.Store(now)
|
||||
return cbEntry.breaker
|
||||
}
|
||||
|
||||
// Create new circuit breaker with metadata
|
||||
newBreaker := newCircuitBreaker(endpoint, cbm.config)
|
||||
newEntry := &CircuitBreakerEntry{
|
||||
breaker: newBreaker,
|
||||
created: time.Now(),
|
||||
}
|
||||
newEntry.lastAccess.Store(now)
|
||||
|
||||
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
|
||||
return actual.(*CircuitBreakerEntry).breaker
|
||||
}
|
||||
|
||||
// GetAllStats returns statistics for all circuit breakers
|
||||
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
|
||||
var stats []CircuitBreakerStats
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
stats = append(stats, entry.breaker.GetStats())
|
||||
return true
|
||||
})
|
||||
return stats
|
||||
}
|
||||
|
||||
// cleanupLoop runs background cleanup of unused circuit breakers
|
||||
func (cbm *CircuitBreakerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cbm.cleanup()
|
||||
case <-cbm.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes circuit breakers that haven't been accessed recently
|
||||
func (cbm *CircuitBreakerManager) cleanup() {
|
||||
// Prevent concurrent cleanups
|
||||
if !cbm.cleanupMu.TryLock() {
|
||||
return
|
||||
}
|
||||
defer cbm.cleanupMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
|
||||
|
||||
var toDelete []string
|
||||
count := 0
|
||||
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
endpoint := key.(string)
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
|
||||
count++
|
||||
|
||||
// Remove if not accessed recently
|
||||
if entry.lastAccess.Load() < cutoff {
|
||||
toDelete = append(toDelete, endpoint)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete expired entries
|
||||
for _, endpoint := range toDelete {
|
||||
cbm.breakers.Delete(endpoint)
|
||||
}
|
||||
|
||||
// Log cleanup results
|
||||
if len(toDelete) > 0 && cbm.config != nil && cbm.config.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: circuit breaker cleanup removed %d/%d entries", len(toDelete), count)
|
||||
}
|
||||
|
||||
cbm.lastCleanup.Store(now.Unix())
|
||||
}
|
||||
|
||||
// Shutdown stops the cleanup goroutine
|
||||
func (cbm *CircuitBreakerManager) Shutdown() {
|
||||
close(cbm.cleanupStop)
|
||||
}
|
||||
|
||||
// Reset resets all circuit breakers (useful for testing)
|
||||
func (cbm *CircuitBreakerManager) Reset() {
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
breaker := entry.breaker
|
||||
breaker.state.Store(int32(CircuitBreakerClosed))
|
||||
breaker.failures.Store(0)
|
||||
breaker.successes.Store(0)
|
||||
breaker.requests.Store(0)
|
||||
breaker.lastFailureTime.Store(0)
|
||||
breaker.lastSuccessTime.Store(0)
|
||||
return true
|
||||
})
|
||||
}
|
356
hitless/circuit_breaker_test.go
Normal file
356
hitless/circuit_breaker_test.go
Normal file
@@ -0,0 +1,356 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
config := &Config{
|
||||
LogLevel: logging.LogLevelError, // Reduce noise in tests
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
|
||||
t.Run("InitialState", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
|
||||
if cb.IsOpen() {
|
||||
t.Error("Circuit breaker should start in closed state")
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SuccessfulExecution", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
|
||||
err := cb.Execute(func() error {
|
||||
return nil // Success
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FailureThreshold", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Fail 4 times (below threshold of 5)
|
||||
for i := 0; i < 4; i++ {
|
||||
err := cb.Execute(func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Circuit should still be closed after %d failures", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 5th failure should open the circuit
|
||||
err := cb.Execute(func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenCircuitFailsFast", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
// Now it should fail fast
|
||||
err := cb.Execute(func() error {
|
||||
t.Error("Function should not be called when circuit is open")
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != ErrCircuitBreakerOpen {
|
||||
t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpenTransition", func(t *testing.T) {
|
||||
testConfig := &Config{
|
||||
LogLevel: logging.LogLevelError,
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open")
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Next call should transition to half-open
|
||||
executed := false
|
||||
err := cb.Execute(func() error {
|
||||
executed = true
|
||||
return nil // Success
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !executed {
|
||||
t.Error("Function should have been executed in half-open state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
|
||||
testConfig := &Config{
|
||||
LogLevel: logging.LogLevelError,
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 50 * time.Millisecond,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Execute successful requests in half-open state
|
||||
for i := 0; i < 3; i++ {
|
||||
err := cb.Execute(func() error {
|
||||
return nil // Success
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error on attempt %d, got %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Circuit should now be closed
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
|
||||
testConfig := &Config{
|
||||
LogLevel: logging.LogLevelError,
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 50 * time.Millisecond,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
// Wait for reset timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// First request in half-open state fails
|
||||
err := cb.Execute(func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
// Circuit should be open again
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
// Execute some operations
|
||||
cb.Execute(func() error { return testError }) // Failure
|
||||
cb.Execute(func() error { return testError }) // Failure
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
if stats.Endpoint != "test-endpoint:6379" {
|
||||
t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint)
|
||||
}
|
||||
|
||||
if stats.Failures != 2 {
|
||||
t.Errorf("Expected 2 failures, got %d", stats.Failures)
|
||||
}
|
||||
|
||||
if stats.State != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State)
|
||||
}
|
||||
|
||||
// Test that success resets failure count
|
||||
cb.Execute(func() error { return nil }) // Success
|
||||
stats = cb.GetStats()
|
||||
|
||||
if stats.Failures != 0 {
|
||||
t.Errorf("Expected 0 failures after success, got %d", stats.Failures)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCircuitBreakerManager(t *testing.T) {
|
||||
config := &Config{
|
||||
LogLevel: logging.LogLevelError,
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
}
|
||||
|
||||
t.Run("GetCircuitBreaker", func(t *testing.T) {
|
||||
manager := newCircuitBreakerManager(config)
|
||||
|
||||
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
|
||||
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
|
||||
cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1
|
||||
|
||||
if cb1 == cb2 {
|
||||
t.Error("Different endpoints should have different circuit breakers")
|
||||
}
|
||||
|
||||
if cb1 != cb3 {
|
||||
t.Error("Same endpoint should return the same circuit breaker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetAllStats", func(t *testing.T) {
|
||||
manager := newCircuitBreakerManager(config)
|
||||
|
||||
// Create circuit breakers for different endpoints
|
||||
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
|
||||
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
|
||||
|
||||
// Execute some operations
|
||||
cb1.Execute(func() error { return nil })
|
||||
cb2.Execute(func() error { return errors.New("test error") })
|
||||
|
||||
stats := manager.GetAllStats()
|
||||
|
||||
if len(stats) != 2 {
|
||||
t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats))
|
||||
}
|
||||
|
||||
// Check that we have stats for both endpoints
|
||||
endpoints := make(map[string]bool)
|
||||
for _, stat := range stats {
|
||||
endpoints[stat.Endpoint] = true
|
||||
}
|
||||
|
||||
if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] {
|
||||
t.Error("Missing stats for expected endpoints")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Reset", func(t *testing.T) {
|
||||
manager := newCircuitBreakerManager(config)
|
||||
testError := errors.New("test error")
|
||||
|
||||
cb := manager.GetCircuitBreaker("test-endpoint:6379")
|
||||
|
||||
// Force circuit to open
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Execute(func() error { return testError })
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open")
|
||||
}
|
||||
|
||||
// Reset all circuit breakers
|
||||
manager.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Circuit should be closed after reset")
|
||||
}
|
||||
|
||||
if cb.failures.Load() != 0 {
|
||||
t.Error("Failure count should be reset to 0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigurableParameters", func(t *testing.T) {
|
||||
config := &Config{
|
||||
LogLevel: logging.LogLevelError,
|
||||
CircuitBreakerFailureThreshold: 10,
|
||||
CircuitBreakerResetTimeout: 30 * time.Second,
|
||||
CircuitBreakerMaxRequests: 5,
|
||||
}
|
||||
|
||||
cb := newCircuitBreaker("test-endpoint:6379", config)
|
||||
|
||||
// Test that configuration values are used
|
||||
if cb.failureThreshold != 10 {
|
||||
t.Errorf("Expected failureThreshold=10, got %d", cb.failureThreshold)
|
||||
}
|
||||
if cb.resetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected resetTimeout=30s, got %v", cb.resetTimeout)
|
||||
}
|
||||
if cb.maxRequests != 5 {
|
||||
t.Errorf("Expected maxRequests=5, got %d", cb.maxRequests)
|
||||
}
|
||||
|
||||
// Test that circuit opens after configured threshold
|
||||
testError := errors.New("test error")
|
||||
for i := 0; i < 9; i++ {
|
||||
err := cb.Execute(func() error { return testError })
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Circuit should still be closed after %d failures", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 10th failure should open the circuit
|
||||
err := cb.Execute(func() error { return testError })
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
|
||||
}
|
||||
})
|
||||
}
|
472
hitless/config.go
Normal file
472
hitless/config.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
// MaintNotificationsMode represents the maintenance notifications mode
|
||||
type MaintNotificationsMode string
|
||||
|
||||
// Constants for maintenance push notifications modes
|
||||
const (
|
||||
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
|
||||
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
|
||||
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
|
||||
)
|
||||
|
||||
// IsValid returns true if the maintenance notifications mode is valid
|
||||
func (m MaintNotificationsMode) IsValid() bool {
|
||||
switch m {
|
||||
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the mode
|
||||
func (m MaintNotificationsMode) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
// EndpointType represents the type of endpoint to request in MOVING notifications
|
||||
type EndpointType string
|
||||
|
||||
// Constants for endpoint types
|
||||
const (
|
||||
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
|
||||
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
|
||||
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
|
||||
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
|
||||
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
|
||||
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
|
||||
)
|
||||
|
||||
// IsValid returns true if the endpoint type is valid
|
||||
func (e EndpointType) IsValid() bool {
|
||||
switch e {
|
||||
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the endpoint type
|
||||
func (e EndpointType) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Config provides configuration options for hitless upgrades.
|
||||
type Config struct {
|
||||
// Mode controls how client maintenance notifications are handled.
|
||||
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
|
||||
// Default: MaintNotificationsAuto
|
||||
Mode MaintNotificationsMode
|
||||
|
||||
// EndpointType specifies the type of endpoint to request in MOVING notifications.
|
||||
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
|
||||
// Default: EndpointTypeAuto
|
||||
EndpointType EndpointType
|
||||
|
||||
// RelaxedTimeout is the concrete timeout value to use during
|
||||
// MIGRATING/FAILING_OVER states to accommodate increased latency.
|
||||
// This applies to both read and write timeouts.
|
||||
// Default: 10 seconds
|
||||
RelaxedTimeout time.Duration
|
||||
|
||||
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
|
||||
// If handoff takes longer than this, the old connection will be forcibly closed.
|
||||
// Default: 15 seconds (matches server-side eviction timeout)
|
||||
HandoffTimeout time.Duration
|
||||
|
||||
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
|
||||
// Workers are created on-demand and automatically cleaned up when idle.
|
||||
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
|
||||
// If explicitly set, enforces minimum of PoolSize/2
|
||||
//
|
||||
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
|
||||
MaxWorkers int
|
||||
|
||||
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
|
||||
// If the queue is full, new handoff requests will be rejected.
|
||||
// Scales with both worker count and pool size for better burst handling.
|
||||
//
|
||||
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
HandoffQueueSize int
|
||||
|
||||
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
|
||||
// after a handoff completes. This provides additional resilience during cluster transitions.
|
||||
// Default: 2 * RelaxedTimeout
|
||||
PostHandoffRelaxedDuration time.Duration
|
||||
|
||||
// LogLevel controls the verbosity of hitless upgrade logging.
|
||||
// LogLevelError (0) = errors only, LogLevelWarn (1) = warnings, LogLevelInfo (2) = info, LogLevelDebug (3) = debug
|
||||
// Default: logging.LogLevelError(0)
|
||||
LogLevel logging.LogLevel
|
||||
|
||||
// Circuit breaker configuration for endpoint failure handling
|
||||
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
|
||||
// Default: 5
|
||||
CircuitBreakerFailureThreshold int
|
||||
|
||||
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
|
||||
// Default: 60 seconds
|
||||
CircuitBreakerResetTimeout time.Duration
|
||||
|
||||
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
|
||||
// Default: 3
|
||||
CircuitBreakerMaxRequests int
|
||||
|
||||
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
|
||||
// After this many retries, the connection will be removed from the pool.
|
||||
// Default: 3
|
||||
MaxHandoffRetries int
|
||||
}
|
||||
|
||||
func (c *Config) IsEnabled() bool {
|
||||
return c != nil && c.Mode != MaintNotificationsDisabled
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
|
||||
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: 0, // Auto-calculated based on pool size
|
||||
HandoffQueueSize: 0, // Auto-calculated based on max workers
|
||||
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
|
||||
LogLevel: logging.LogLevelError,
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
|
||||
// Connection Handoff Configuration
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RelaxedTimeout <= 0 {
|
||||
return ErrInvalidRelaxedTimeout
|
||||
}
|
||||
if c.HandoffTimeout <= 0 {
|
||||
return ErrInvalidHandoffTimeout
|
||||
}
|
||||
// Validate worker configuration
|
||||
// Allow 0 for auto-calculation, but negative values are invalid
|
||||
if c.MaxWorkers < 0 {
|
||||
return ErrInvalidHandoffWorkers
|
||||
}
|
||||
// HandoffQueueSize validation - allow 0 for auto-calculation
|
||||
if c.HandoffQueueSize < 0 {
|
||||
return ErrInvalidHandoffQueueSize
|
||||
}
|
||||
if c.PostHandoffRelaxedDuration < 0 {
|
||||
return ErrInvalidPostHandoffRelaxedDuration
|
||||
}
|
||||
if !c.LogLevel.IsValid() {
|
||||
return ErrInvalidLogLevel
|
||||
}
|
||||
|
||||
// Circuit breaker validation
|
||||
if c.CircuitBreakerFailureThreshold < 1 {
|
||||
return ErrInvalidCircuitBreakerFailureThreshold
|
||||
}
|
||||
if c.CircuitBreakerResetTimeout < 0 {
|
||||
return ErrInvalidCircuitBreakerResetTimeout
|
||||
}
|
||||
if c.CircuitBreakerMaxRequests < 1 {
|
||||
return ErrInvalidCircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
// Validate Mode (maintenance notifications mode)
|
||||
if !c.Mode.IsValid() {
|
||||
return ErrInvalidMaintNotifications
|
||||
}
|
||||
|
||||
// Validate EndpointType
|
||||
if !c.EndpointType.IsValid() {
|
||||
return ErrInvalidEndpointType
|
||||
}
|
||||
|
||||
// Validate configuration fields
|
||||
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
|
||||
return ErrInvalidHandoffRetries
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values to any zero-value fields in the configuration.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaults() *Config {
|
||||
return c.ApplyDefaultsWithPoolSize(0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size to calculate worker defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
|
||||
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size and max active connections to calculate worker and queue defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
defaults := DefaultConfig()
|
||||
result := &Config{}
|
||||
|
||||
// Apply defaults for enum fields (empty/zero means not set)
|
||||
result.Mode = defaults.Mode
|
||||
if c.Mode != "" {
|
||||
result.Mode = c.Mode
|
||||
}
|
||||
|
||||
result.EndpointType = defaults.EndpointType
|
||||
if c.EndpointType != "" {
|
||||
result.EndpointType = c.EndpointType
|
||||
}
|
||||
|
||||
// Apply defaults for duration fields (zero means not set)
|
||||
result.RelaxedTimeout = defaults.RelaxedTimeout
|
||||
if c.RelaxedTimeout > 0 {
|
||||
result.RelaxedTimeout = c.RelaxedTimeout
|
||||
}
|
||||
|
||||
result.HandoffTimeout = defaults.HandoffTimeout
|
||||
if c.HandoffTimeout > 0 {
|
||||
result.HandoffTimeout = c.HandoffTimeout
|
||||
}
|
||||
|
||||
// Copy worker configuration
|
||||
result.MaxWorkers = c.MaxWorkers
|
||||
|
||||
// Apply worker defaults based on pool size
|
||||
result.applyWorkerDefaults(poolSize)
|
||||
|
||||
// Apply queue size defaults with new scaling approach
|
||||
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolBasedSize := poolSize
|
||||
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
|
||||
if c.HandoffQueueSize > 0 {
|
||||
// When explicitly set: enforce minimum of 200
|
||||
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
|
||||
var queueCap int
|
||||
if maxActiveConns > 0 {
|
||||
queueCap = maxActiveConns + 1
|
||||
// Ensure queue cap is at least 2 for very small maxActiveConns
|
||||
if queueCap < 2 {
|
||||
queueCap = 2
|
||||
}
|
||||
} else {
|
||||
queueCap = poolSize * 5
|
||||
}
|
||||
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
|
||||
|
||||
// Ensure minimum queue size of 2 (fallback for very small pools)
|
||||
if result.HandoffQueueSize < 2 {
|
||||
result.HandoffQueueSize = 2
|
||||
}
|
||||
|
||||
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
|
||||
if c.PostHandoffRelaxedDuration > 0 {
|
||||
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
|
||||
// We'll use the provided value as-is, since 0 is valid
|
||||
result.LogLevel = c.LogLevel
|
||||
|
||||
// Apply defaults for configuration fields
|
||||
result.MaxHandoffRetries = defaults.MaxHandoffRetries
|
||||
if c.MaxHandoffRetries > 0 {
|
||||
result.MaxHandoffRetries = c.MaxHandoffRetries
|
||||
}
|
||||
|
||||
// Circuit breaker configuration
|
||||
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
|
||||
if c.CircuitBreakerFailureThreshold > 0 {
|
||||
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
|
||||
}
|
||||
|
||||
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
|
||||
if c.CircuitBreakerResetTimeout > 0 {
|
||||
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
|
||||
}
|
||||
|
||||
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
|
||||
if c.CircuitBreakerMaxRequests > 0 {
|
||||
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
if result.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), "hitless: debug logging enabled")
|
||||
internal.Logger.Printf(context.Background(), "hitless: config: %+v", result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration.
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Mode: c.Mode,
|
||||
EndpointType: c.EndpointType,
|
||||
RelaxedTimeout: c.RelaxedTimeout,
|
||||
HandoffTimeout: c.HandoffTimeout,
|
||||
MaxWorkers: c.MaxWorkers,
|
||||
HandoffQueueSize: c.HandoffQueueSize,
|
||||
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
|
||||
LogLevel: c.LogLevel,
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
|
||||
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
|
||||
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
|
||||
|
||||
// Configuration fields
|
||||
MaxHandoffRetries: c.MaxHandoffRetries,
|
||||
}
|
||||
}
|
||||
|
||||
// applyWorkerDefaults calculates and applies worker defaults based on pool size
|
||||
func (c *Config) applyWorkerDefaults(poolSize int) {
|
||||
// Calculate defaults based on pool size
|
||||
if poolSize <= 0 {
|
||||
poolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
|
||||
originalMaxWorkers := c.MaxWorkers
|
||||
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
|
||||
if originalMaxWorkers != 0 {
|
||||
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
|
||||
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
|
||||
}
|
||||
|
||||
// Ensure minimum of 1 worker (fallback for very small pools)
|
||||
if c.MaxWorkers < 1 {
|
||||
c.MaxWorkers = 1
|
||||
}
|
||||
}
|
||||
|
||||
// DetectEndpointType automatically detects the appropriate endpoint type
|
||||
// based on the connection address and TLS configuration.
|
||||
//
|
||||
// For IP addresses:
|
||||
// - If TLS is enabled: requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// For hostnames:
|
||||
// - If TLS is enabled: always requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// Internal vs External detection:
|
||||
// - For IPs: uses private IP range detection
|
||||
// - For hostnames: uses heuristics based on common internal naming patterns
|
||||
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
|
||||
// Extract host from "host:port" format
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr // Assume no port
|
||||
}
|
||||
|
||||
// Check if the host is an IP address or hostname
|
||||
ip := net.ParseIP(host)
|
||||
isIPAddress := ip != nil
|
||||
var endpointType EndpointType
|
||||
|
||||
if isIPAddress {
|
||||
// Address is an IP - determine if it's private or public
|
||||
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||
|
||||
if tlsEnabled {
|
||||
// TLS with IP addresses - still prefer FQDN for certificate validation
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
} else {
|
||||
// No TLS - can use IP addresses directly
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalIP
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalIP
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Address is a hostname
|
||||
isInternalHostname := isInternalHostname(host)
|
||||
if isInternalHostname {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
}
|
||||
|
||||
return endpointType
|
||||
}
|
||||
|
||||
// isInternalHostname determines if a hostname appears to be internal/private.
|
||||
// This is a heuristic based on common naming patterns.
|
||||
func isInternalHostname(hostname string) bool {
|
||||
// Convert to lowercase for comparison
|
||||
hostname = strings.ToLower(hostname)
|
||||
|
||||
// Common internal hostname patterns
|
||||
internalPatterns := []string{
|
||||
"localhost",
|
||||
".local",
|
||||
".internal",
|
||||
".corp",
|
||||
".lan",
|
||||
".intranet",
|
||||
".private",
|
||||
}
|
||||
|
||||
// Check for exact match or suffix match
|
||||
for _, pattern := range internalPatterns {
|
||||
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
|
||||
// If hostname doesn't contain dots, it's likely internal
|
||||
if !strings.Contains(hostname, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Default to external for fully qualified domain names
|
||||
return false
|
||||
}
|
490
hitless/config_test.go
Normal file
490
hitless/config_test.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("DefaultConfig", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
|
||||
// MaxWorkers should be 0 in default config (auto-calculated)
|
||||
if config.MaxWorkers != 0 {
|
||||
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be 0 in default config (auto-calculated)
|
||||
if config.HandoffQueueSize != 0 {
|
||||
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if config.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
|
||||
}
|
||||
|
||||
// Test configuration fields have proper defaults
|
||||
if config.MaxHandoffRetries != 3 {
|
||||
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
|
||||
}
|
||||
|
||||
// Circuit breaker defaults
|
||||
if config.CircuitBreakerFailureThreshold != 5 {
|
||||
t.Errorf("Expected CircuitBreakerFailureThreshold=5, got %d", config.CircuitBreakerFailureThreshold)
|
||||
}
|
||||
if config.CircuitBreakerResetTimeout != 60*time.Second {
|
||||
t.Errorf("Expected CircuitBreakerResetTimeout=60s, got %v", config.CircuitBreakerResetTimeout)
|
||||
}
|
||||
if config.CircuitBreakerMaxRequests != 3 {
|
||||
t.Errorf("Expected CircuitBreakerMaxRequests=3, got %d", config.CircuitBreakerMaxRequests)
|
||||
}
|
||||
|
||||
if config.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
|
||||
}
|
||||
|
||||
if config.PostHandoffRelaxedDuration != 0 {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
|
||||
}
|
||||
|
||||
// Test that defaults are applied correctly
|
||||
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
|
||||
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigValidation", func(t *testing.T) {
|
||||
// Valid config with applied defaults
|
||||
config := DefaultConfig().ApplyDefaults()
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Default config with applied defaults should be valid: %v", err)
|
||||
}
|
||||
|
||||
// Invalid worker configuration (negative MaxWorkers)
|
||||
config = &Config{
|
||||
RelaxedTimeout: 30 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: -1, // This should be invalid
|
||||
HandoffQueueSize: 100,
|
||||
PostHandoffRelaxedDuration: 10 * time.Second,
|
||||
LogLevel: 1,
|
||||
MaxHandoffRetries: 3, // Add required field
|
||||
}
|
||||
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
|
||||
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid HandoffQueueSize
|
||||
config = DefaultConfig().ApplyDefaults()
|
||||
config.HandoffQueueSize = -1
|
||||
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
|
||||
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid PostHandoffRelaxedDuration
|
||||
config = DefaultConfig().ApplyDefaults()
|
||||
config.PostHandoffRelaxedDuration = -1 * time.Second
|
||||
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
|
||||
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigClone", func(t *testing.T) {
|
||||
original := DefaultConfig()
|
||||
original.MaxWorkers = 20
|
||||
original.HandoffQueueSize = 200
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
if cloned.MaxWorkers != 20 {
|
||||
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
|
||||
}
|
||||
|
||||
if cloned.HandoffQueueSize != 200 {
|
||||
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Modify original to ensure clone is independent
|
||||
original.MaxWorkers = 2
|
||||
if cloned.MaxWorkers != 20 {
|
||||
t.Error("Clone should be independent of original")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyDefaults(t *testing.T) {
|
||||
t.Run("NilConfig", func(t *testing.T) {
|
||||
var config *Config
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// With nil config, should get default config with auto-calculated workers
|
||||
if result.MaxWorkers <= 0 {
|
||||
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be auto-calculated with hybrid scaling
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialConfig", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 60, // Set this field explicitly (> poolSize/2 = 50)
|
||||
// Leave other fields as zero values
|
||||
}
|
||||
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Should keep the explicitly set values when > poolSize/2
|
||||
if result.MaxWorkers != 60 {
|
||||
t.Errorf("Expected MaxWorkers to be 60 (explicitly set), got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// Should apply default for unset fields (auto-calculated queue size with hybrid scaling)
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test explicit queue size capping by 5x pool size
|
||||
configWithLargeQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 1000, // Much larger than 5x pool size
|
||||
}
|
||||
|
||||
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
|
||||
expectedCap := 20 * 5 // 5x pool size = 100
|
||||
if resultCapped.HandoffQueueSize != expectedCap {
|
||||
t.Errorf("Expected HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedCap, resultCapped.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test explicit queue size minimum enforcement
|
||||
configWithSmallQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 10, // Below minimum of 200
|
||||
}
|
||||
|
||||
resultMinimum := configWithSmallQueue.ApplyDefaultsWithPoolSize(100) // Large pool size
|
||||
if resultMinimum.HandoffQueueSize != 200 {
|
||||
t.Errorf("Expected HandoffQueueSize to be enforced minimum (200), got %d", resultMinimum.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test that large explicit values are capped by 5x pool size
|
||||
configWithVeryLargeQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 1000, // Much larger than 5x pool size
|
||||
}
|
||||
|
||||
resultVeryLarge := configWithVeryLargeQueue.ApplyDefaultsWithPoolSize(100) // Pool size 100
|
||||
expectedVeryLargeCap := 100 * 5 // 5x pool size = 500
|
||||
if resultVeryLarge.HandoffQueueSize != expectedVeryLargeCap {
|
||||
t.Errorf("Expected very large HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedVeryLargeCap, resultVeryLarge.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if result.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||
}
|
||||
|
||||
if result.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ZeroValues", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 0, // Zero value should get auto-calculated defaults
|
||||
HandoffQueueSize: 0, // Zero value should get default
|
||||
RelaxedTimeout: 0, // Zero value should get default
|
||||
LogLevel: 0, // Zero is valid for LogLevel (errors only)
|
||||
}
|
||||
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Zero values should get auto-calculated defaults
|
||||
if result.MaxWorkers <= 0 {
|
||||
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be auto-calculated with hybrid scaling
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if result.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||
}
|
||||
|
||||
// LogLevel 0 should be preserved (it's a valid value)
|
||||
if result.LogLevel != 0 {
|
||||
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessorWithConfig(t *testing.T) {
|
||||
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 50,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// The processor should be created successfully with custom config
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with custom config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 7, // Only set worker field
|
||||
// Other fields will get defaults
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Should work with partial config (defaults applied)
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with partial config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Should use default config when nil is passed
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with nil config (using defaults)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationWithApplyDefaults(t *testing.T) {
|
||||
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
|
||||
// Create a partial config with only some fields set
|
||||
partialConfig := &Config{
|
||||
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
|
||||
LogLevel: logging.LogLevelInfo, // Custom value
|
||||
// Other fields left as zero values - should get defaults
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
// Create processor - should apply defaults to missing fields
|
||||
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Processor should be created successfully
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with partial config")
|
||||
}
|
||||
|
||||
// Test that the ApplyDefaults method worked correctly by creating the same config
|
||||
// and applying defaults manually
|
||||
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Should preserve custom values (when >= poolSize/2)
|
||||
if expectedConfig.MaxWorkers != 50 { // max(poolSize/2, 15) = max(50, 15) = 50
|
||||
t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers)
|
||||
}
|
||||
|
||||
if expectedConfig.LogLevel != 2 {
|
||||
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
|
||||
}
|
||||
|
||||
// Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling)
|
||||
workerBasedSize := expectedConfig.MaxWorkers * 20
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
poolBasedSize := poolSize
|
||||
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
|
||||
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
|
||||
if expectedConfig.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, expectedConfig.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test that queue size is always capped by 5x pool size
|
||||
if expectedConfig.HandoffQueueSize > poolSize*5 {
|
||||
t.Errorf("HandoffQueueSize (%d) should never exceed 5x pool size (%d)",
|
||||
expectedConfig.HandoffQueueSize, poolSize*2)
|
||||
}
|
||||
|
||||
if expectedConfig.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
|
||||
}
|
||||
|
||||
if expectedConfig.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
|
||||
}
|
||||
|
||||
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnhancedConfigValidation(t *testing.T) {
|
||||
t.Run("ValidateFields", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
|
||||
|
||||
// Should pass validation with default values
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Default config should be valid, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid MaxHandoffRetries
|
||||
config.MaxHandoffRetries = 0
|
||||
if err := config.Validate(); err == nil {
|
||||
t.Error("Expected validation error for MaxHandoffRetries = 0")
|
||||
}
|
||||
config.MaxHandoffRetries = 11
|
||||
if err := config.Validate(); err == nil {
|
||||
t.Error("Expected validation error for MaxHandoffRetries = 11")
|
||||
}
|
||||
config.MaxHandoffRetries = 3 // Reset to valid value
|
||||
|
||||
// Test circuit breaker validation
|
||||
config.CircuitBreakerFailureThreshold = 0
|
||||
if err := config.Validate(); err != ErrInvalidCircuitBreakerFailureThreshold {
|
||||
t.Errorf("Expected ErrInvalidCircuitBreakerFailureThreshold, got %v", err)
|
||||
}
|
||||
config.CircuitBreakerFailureThreshold = 5 // Reset to valid value
|
||||
|
||||
config.CircuitBreakerResetTimeout = -1 * time.Second
|
||||
if err := config.Validate(); err != ErrInvalidCircuitBreakerResetTimeout {
|
||||
t.Errorf("Expected ErrInvalidCircuitBreakerResetTimeout, got %v", err)
|
||||
}
|
||||
config.CircuitBreakerResetTimeout = 60 * time.Second // Reset to valid value
|
||||
|
||||
config.CircuitBreakerMaxRequests = 0
|
||||
if err := config.Validate(); err != ErrInvalidCircuitBreakerMaxRequests {
|
||||
t.Errorf("Expected ErrInvalidCircuitBreakerMaxRequests, got %v", err)
|
||||
}
|
||||
config.CircuitBreakerMaxRequests = 3 // Reset to valid value
|
||||
|
||||
// Should pass validation again
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Config should be valid after reset, got error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigClone(t *testing.T) {
|
||||
original := DefaultConfig()
|
||||
original.MaxHandoffRetries = 7
|
||||
original.HandoffTimeout = 8 * time.Second
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
// Test that values are copied
|
||||
if cloned.MaxHandoffRetries != 7 {
|
||||
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
|
||||
}
|
||||
if cloned.HandoffTimeout != 8*time.Second {
|
||||
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
|
||||
}
|
||||
|
||||
// Test that modifying clone doesn't affect original
|
||||
cloned.MaxHandoffRetries = 10
|
||||
if original.MaxHandoffRetries != 7 {
|
||||
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxWorkersLogic(t *testing.T) {
|
||||
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
poolSize int
|
||||
expectedWorkers int
|
||||
description string
|
||||
}{
|
||||
{6, 3, "Small pool: min(6/2, max(10, 6/3)) = min(3, max(10, 2)) = min(3, 10) = 3"},
|
||||
{15, 7, "Medium pool: min(15/2, max(10, 15/3)) = min(7, max(10, 5)) = min(7, 10) = 7"},
|
||||
{30, 10, "Large pool: min(30/2, max(10, 30/3)) = min(15, max(10, 10)) = min(15, 10) = 10"},
|
||||
{60, 20, "Very large pool: min(60/2, max(10, 60/3)) = min(30, max(10, 20)) = min(30, 20) = 20"},
|
||||
{120, 40, "Huge pool: min(120/2, max(10, 120/3)) = min(60, max(10, 40)) = min(60, 40) = 40"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
config := &Config{} // MaxWorkers = 0 (not set)
|
||||
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
|
||||
|
||||
if result.MaxWorkers != tc.expectedWorkers {
|
||||
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
|
||||
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
setValue int
|
||||
expectedWorkers int
|
||||
description string
|
||||
}{
|
||||
{1, 50, "Set 1: max(poolSize/2, 1) = max(50, 1) = 50 (enforced minimum)"},
|
||||
{5, 50, "Set 5: max(poolSize/2, 5) = max(50, 5) = 50 (enforced minimum)"},
|
||||
{8, 50, "Set 8: max(poolSize/2, 8) = max(50, 8) = 50 (enforced minimum)"},
|
||||
{10, 50, "Set 10: max(poolSize/2, 10) = max(50, 10) = 50 (enforced minimum)"},
|
||||
{15, 50, "Set 15: max(poolSize/2, 15) = max(50, 15) = 50 (enforced minimum)"},
|
||||
{60, 60, "Set 60: max(poolSize/2, 60) = max(50, 60) = 60 (respects user choice)"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
config := &Config{
|
||||
MaxWorkers: tc.setValue, // Explicitly set
|
||||
}
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
|
||||
|
||||
if result.MaxWorkers != tc.expectedWorkers {
|
||||
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
|
||||
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
105
hitless/errors.go
Normal file
105
hitless/errors.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Configuration errors
|
||||
var (
|
||||
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
|
||||
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
|
||||
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
|
||||
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
|
||||
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
|
||||
ErrInvalidLogLevel = errors.New("hitless: log level must be LogLevelError (0), LogLevelWarn (1), LogLevelInfo (2), or LogLevelDebug (3)")
|
||||
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
|
||||
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
|
||||
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
|
||||
|
||||
// Configuration validation errors
|
||||
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
|
||||
)
|
||||
|
||||
// Integration errors
|
||||
var (
|
||||
ErrInvalidClient = errors.New("hitless: invalid client type")
|
||||
)
|
||||
|
||||
// Handoff errors
|
||||
var (
|
||||
ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration")
|
||||
)
|
||||
|
||||
// Notification errors
|
||||
var (
|
||||
ErrInvalidNotification = errors.New("hitless: invalid notification format")
|
||||
)
|
||||
|
||||
// connection handoff errors
|
||||
var (
|
||||
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoff = errors.New("hitless: connection marked for handoff")
|
||||
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
|
||||
ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff")
|
||||
)
|
||||
|
||||
// general errors
|
||||
var (
|
||||
ErrShutdown = errors.New("hitless: shutdown")
|
||||
)
|
||||
|
||||
// circuit breaker errors
|
||||
var (
|
||||
ErrCircuitBreakerOpen = errors.New("hitless: circuit breaker is open, failing fast")
|
||||
)
|
||||
|
||||
// CircuitBreakerError provides detailed context for circuit breaker failures
|
||||
type CircuitBreakerError struct {
|
||||
Endpoint string
|
||||
State string
|
||||
Failures int64
|
||||
LastFailure time.Time
|
||||
NextAttempt time.Time
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *CircuitBreakerError) Error() string {
|
||||
if e.NextAttempt.IsZero() {
|
||||
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v): %s",
|
||||
e.State, e.Endpoint, e.Failures, e.LastFailure, e.Message)
|
||||
}
|
||||
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v, next attempt: %v): %s",
|
||||
e.State, e.Endpoint, e.Failures, e.LastFailure, e.NextAttempt, e.Message)
|
||||
}
|
||||
|
||||
// HandoffError provides detailed context for connection handoff failures
|
||||
type HandoffError struct {
|
||||
ConnectionID uint64
|
||||
SourceEndpoint string
|
||||
TargetEndpoint string
|
||||
Attempt int
|
||||
MaxAttempts int
|
||||
Duration time.Duration
|
||||
FinalError error
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *HandoffError) Error() string {
|
||||
return fmt.Sprintf("hitless: handoff failed for conn[%d] %s→%s (attempt %d/%d, duration: %v): %s",
|
||||
e.ConnectionID, e.SourceEndpoint, e.TargetEndpoint,
|
||||
e.Attempt, e.MaxAttempts, e.Duration, e.Message)
|
||||
}
|
||||
|
||||
func (e *HandoffError) Unwrap() error {
|
||||
return e.FinalError
|
||||
}
|
||||
|
||||
// circuit breaker configuration errors
|
||||
var (
|
||||
ErrInvalidCircuitBreakerFailureThreshold = errors.New("hitless: circuit breaker failure threshold must be >= 1")
|
||||
ErrInvalidCircuitBreakerResetTimeout = errors.New("hitless: circuit breaker reset timeout must be >= 0")
|
||||
ErrInvalidCircuitBreakerMaxRequests = errors.New("hitless: circuit breaker max requests must be >= 1")
|
||||
)
|
100
hitless/example_hooks.go
Normal file
100
hitless/example_hooks.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
startTimeKey contextKey = "notif_hitless_start_time"
|
||||
)
|
||||
|
||||
// MetricsHook collects metrics about notification processing.
|
||||
type MetricsHook struct {
|
||||
NotificationCounts map[string]int64
|
||||
ProcessingTimes map[string]time.Duration
|
||||
ErrorCounts map[string]int64
|
||||
HandoffCounts int64 // Total handoffs initiated
|
||||
HandoffSuccesses int64 // Successful handoffs
|
||||
HandoffFailures int64 // Failed handoffs
|
||||
}
|
||||
|
||||
// NewMetricsHook creates a new metrics collection hook.
|
||||
func NewMetricsHook() *MetricsHook {
|
||||
return &MetricsHook{
|
||||
NotificationCounts: make(map[string]int64),
|
||||
ProcessingTimes: make(map[string]time.Duration),
|
||||
ErrorCounts: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// PreHook records the start time for processing metrics.
|
||||
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
mh.NotificationCounts[notificationType]++
|
||||
|
||||
// Log connection information if available
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on conn[%d]", notificationType, conn.GetID())
|
||||
}
|
||||
|
||||
// Store start time in context for duration calculation
|
||||
startTime := time.Now()
|
||||
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
|
||||
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// PostHook records processing completion and any errors.
|
||||
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
// Calculate processing duration
|
||||
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
|
||||
duration := time.Since(startTime)
|
||||
mh.ProcessingTimes[notificationType] = duration
|
||||
}
|
||||
|
||||
// Record errors
|
||||
if result != nil {
|
||||
mh.ErrorCounts[notificationType]++
|
||||
|
||||
// Log error details with connection information
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on conn[%d]: %v", notificationType, conn.GetID(), result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns a summary of collected metrics.
|
||||
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"notification_counts": mh.NotificationCounts,
|
||||
"processing_times": mh.ProcessingTimes,
|
||||
"error_counts": mh.ErrorCounts,
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
|
||||
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
|
||||
// Get circuit breaker statistics
|
||||
stats := poolHook.GetCircuitBreakerStats()
|
||||
|
||||
for _, stat := range stats {
|
||||
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
|
||||
fmt.Printf(" State: %s\n", stat.State)
|
||||
fmt.Printf(" Failures: %d\n", stat.Failures)
|
||||
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
|
||||
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
|
||||
|
||||
// Alert if circuit breaker is open
|
||||
if stat.State.String() == "open" {
|
||||
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
455
hitless/handoff_worker.go
Normal file
455
hitless/handoff_worker.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// handoffWorkerManager manages background workers and queue for connection handoffs
|
||||
type handoffWorkerManager struct {
|
||||
// Event-driven handoff support
|
||||
handoffQueue chan HandoffRequest // Queue for handoff requests
|
||||
shutdown chan struct{} // Shutdown signal
|
||||
shutdownOnce sync.Once // Ensure clean shutdown
|
||||
workerWg sync.WaitGroup // Track worker goroutines
|
||||
|
||||
// On-demand worker management
|
||||
maxWorkers int
|
||||
activeWorkers atomic.Int32
|
||||
workerTimeout time.Duration // How long workers wait for work before exiting
|
||||
workersScaling atomic.Bool
|
||||
|
||||
// Simple state tracking
|
||||
pending sync.Map // map[uint64]int64 (connID -> seqID)
|
||||
|
||||
// Configuration for the hitless upgrade
|
||||
config *Config
|
||||
|
||||
// Pool hook reference for handoff processing
|
||||
poolHook *PoolHook
|
||||
|
||||
// Circuit breaker manager for endpoint failure handling
|
||||
circuitBreakerManager *CircuitBreakerManager
|
||||
}
|
||||
|
||||
// newHandoffWorkerManager creates a new handoff worker manager
|
||||
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
|
||||
return &handoffWorkerManager{
|
||||
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
|
||||
shutdown: make(chan struct{}),
|
||||
maxWorkers: config.MaxWorkers,
|
||||
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
|
||||
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
|
||||
config: config,
|
||||
poolHook: poolHook,
|
||||
circuitBreakerManager: newCircuitBreakerManager(config),
|
||||
}
|
||||
}
|
||||
|
||||
// getCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
|
||||
return int(hwm.activeWorkers.Load())
|
||||
}
|
||||
|
||||
// getPendingMap returns the pending map for testing purposes
|
||||
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
|
||||
return &hwm.pending
|
||||
}
|
||||
|
||||
// getMaxWorkers returns the max workers for testing purposes
|
||||
func (hwm *handoffWorkerManager) getMaxWorkers() int {
|
||||
return hwm.maxWorkers
|
||||
}
|
||||
|
||||
// getHandoffQueue returns the handoff queue for testing purposes
|
||||
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
|
||||
return hwm.handoffQueue
|
||||
}
|
||||
|
||||
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return hwm.circuitBreakerManager.GetAllStats()
|
||||
}
|
||||
|
||||
// resetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
|
||||
hwm.circuitBreakerManager.Reset()
|
||||
}
|
||||
|
||||
// isHandoffPending returns true if the given connection has a pending handoff
|
||||
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
|
||||
_, pending := hwm.pending.Load(conn.GetID())
|
||||
return pending
|
||||
}
|
||||
|
||||
// ensureWorkerAvailable ensures at least one worker is available to process requests
|
||||
// Creates a new worker if needed and under the max limit
|
||||
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return
|
||||
default:
|
||||
if hwm.workersScaling.CompareAndSwap(false, true) {
|
||||
defer hwm.workersScaling.Store(false)
|
||||
// Check if we need a new worker
|
||||
currentWorkers := hwm.activeWorkers.Load()
|
||||
workersWas := currentWorkers
|
||||
for currentWorkers < int32(hwm.maxWorkers) {
|
||||
hwm.workerWg.Add(1)
|
||||
go hwm.onDemandWorker()
|
||||
currentWorkers++
|
||||
}
|
||||
// workersWas is always <= currentWorkers
|
||||
// currentWorkers will be maxWorkers, but if we have a worker that was closed
|
||||
// while we were creating new workers, just add the difference between
|
||||
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
|
||||
hwm.activeWorkers.Add(currentWorkers - workersWas)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onDemandWorker processes handoff requests and exits when idle
|
||||
func (hwm *handoffWorkerManager) onDemandWorker() {
|
||||
defer func() {
|
||||
// Decrement active worker count when exiting
|
||||
hwm.activeWorkers.Add(-1)
|
||||
hwm.workerWg.Done()
|
||||
}()
|
||||
|
||||
// Create reusable timer to prevent timer leaks
|
||||
timer := time.NewTimer(hwm.workerTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
// Reset timer for next iteration
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(hwm.workerTimeout)
|
||||
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return
|
||||
case <-timer.C:
|
||||
// Worker has been idle for too long, exit to save resources
|
||||
if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout)
|
||||
}
|
||||
return
|
||||
case request := <-hwm.handoffQueue:
|
||||
// Check for shutdown before processing
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
// Clean up the request before exiting
|
||||
hwm.pending.Delete(request.ConnID)
|
||||
return
|
||||
default:
|
||||
// Process the request
|
||||
hwm.processHandoffRequest(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processHandoffRequest processes a single handoff request
|
||||
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
// Remove from pending map
|
||||
defer hwm.pending.Delete(request.Conn.GetID())
|
||||
internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID())
|
||||
|
||||
// Create a context with handoff timeout from config
|
||||
handoffTimeout := 15 * time.Second // Default timeout
|
||||
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
|
||||
handoffTimeout = hwm.config.HandoffTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a context that also respects the shutdown signal
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Monitor shutdown signal in a separate goroutine
|
||||
go func() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
shutdownCancel()
|
||||
case <-shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform the handoff with cancellable context
|
||||
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
|
||||
minRetryBackoff := 500 * time.Millisecond
|
||||
if err != nil {
|
||||
if shouldRetry {
|
||||
now := time.Now()
|
||||
deadline, ok := shutdownCtx.Deadline()
|
||||
thirdOfTimeout := handoffTimeout / 3
|
||||
if !ok || deadline.Before(now) {
|
||||
// wait half the timeout before retrying if no deadline or deadline has passed
|
||||
deadline = now.Add(thirdOfTimeout)
|
||||
}
|
||||
afterTime := deadline.Sub(now)
|
||||
if afterTime < minRetryBackoff {
|
||||
afterTime = minRetryBackoff
|
||||
}
|
||||
|
||||
internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err)
|
||||
time.AfterFunc(afterTime, func() {
|
||||
if err := hwm.queueHandoff(request.Conn); err != nil {
|
||||
internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err)
|
||||
hwm.closeConnFromRequest(context.Background(), request, err)
|
||||
}
|
||||
})
|
||||
return
|
||||
} else {
|
||||
go hwm.closeConnFromRequest(ctx, request, err)
|
||||
}
|
||||
|
||||
// Clear handoff state if not returned for retry
|
||||
seqID := request.Conn.GetMovingSeqID()
|
||||
connID := request.Conn.GetID()
|
||||
if hwm.poolHook.hitlessManager != nil {
|
||||
hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// queueHandoff queues a handoff request for processing
|
||||
// if err is returned, connection will be removed from pool
|
||||
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
|
||||
// Create handoff request
|
||||
request := HandoffRequest{
|
||||
Conn: conn,
|
||||
ConnID: conn.GetID(),
|
||||
Endpoint: conn.GetHandoffEndpoint(),
|
||||
SeqID: conn.GetMovingSeqID(),
|
||||
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
|
||||
}
|
||||
|
||||
select {
|
||||
// priority to shutdown
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond): // give workers a chance to process
|
||||
// Queue is full - log and attempt scaling
|
||||
queueLen := len(hwm.handoffQueue)
|
||||
queueCap := cap(hwm.handoffQueue)
|
||||
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration",
|
||||
queueLen, queueCap)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have workers available to handle the load
|
||||
hwm.ensureWorkerAvailable()
|
||||
return ErrHandoffQueueFull
|
||||
}
|
||||
|
||||
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
|
||||
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
|
||||
hwm.shutdownOnce.Do(func() {
|
||||
close(hwm.shutdown)
|
||||
// workers will exit when they finish their current request
|
||||
|
||||
// Shutdown circuit breaker manager cleanup goroutine
|
||||
if hwm.circuitBreakerManager != nil {
|
||||
hwm.circuitBreakerManager.Shutdown()
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for workers to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hwm.workerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// performConnectionHandoff performs the actual connection handoff
|
||||
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
|
||||
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
|
||||
// Clear handoff state after successful handoff
|
||||
connID := conn.GetID()
|
||||
|
||||
newEndpoint := conn.GetHandoffEndpoint()
|
||||
if newEndpoint == "" {
|
||||
return false, ErrConnectionInvalidHandoffState
|
||||
}
|
||||
|
||||
// Use circuit breaker to protect against failing endpoints
|
||||
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
|
||||
|
||||
// Check if circuit breaker is open before attempting handoff
|
||||
if circuitBreaker.IsOpen() {
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] handoff to %s failed fast due to circuit breaker", connID, newEndpoint)
|
||||
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
|
||||
}
|
||||
|
||||
// Perform the handoff
|
||||
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
|
||||
|
||||
// Update circuit breaker based on result
|
||||
if err != nil {
|
||||
// Only track dial/network errors in circuit breaker, not initialization errors
|
||||
if shouldRetry {
|
||||
circuitBreaker.recordFailure()
|
||||
}
|
||||
return shouldRetry, err
|
||||
}
|
||||
|
||||
// Success - record in circuit breaker
|
||||
circuitBreaker.recordSuccess()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
|
||||
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
|
||||
|
||||
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", connID, retries, newEndpoint, conn.RemoteAddr().String())
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
|
||||
if retries > maxRetries {
|
||||
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
|
||||
maxRetries, connID, newEndpoint)
|
||||
}
|
||||
// won't retry on ErrMaxHandoffRetriesReached
|
||||
return false, ErrMaxHandoffRetriesReached
|
||||
}
|
||||
|
||||
// Create endpoint-specific dialer
|
||||
endpointDialer := hwm.createEndpointDialer(newEndpoint)
|
||||
|
||||
// Create new connection to the new endpoint
|
||||
newNetConn, err := endpointDialer(ctx)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", connID, newEndpoint, err)
|
||||
// hitless: will retry
|
||||
// Maybe a network error - retry after a delay
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Get the old connection
|
||||
oldConn := conn.GetNetConn()
|
||||
|
||||
// Apply relaxed timeout to the new connection for the configured post-handoff duration
|
||||
// This gives the new connection more time to handle operations during cluster transition
|
||||
// Setting this here (before initing the connection) ensures that the connection is going
|
||||
// to use the relaxed timeout for the first operation (auth/ACL select)
|
||||
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
|
||||
relaxedTimeout := hwm.config.RelaxedTimeout
|
||||
// Set relaxed timeout with deadline - no background goroutine needed
|
||||
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
|
||||
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
|
||||
|
||||
if hwm.config.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v",
|
||||
connID, relaxedTimeout, deadline.Format("15:04:05.000"))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the connection and execute initialization
|
||||
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
|
||||
if err != nil {
|
||||
// hitless: won't retry
|
||||
// Initialization failed - remove the connection
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if oldConn != nil {
|
||||
oldConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn.ClearHandoffState()
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", connID, newEndpoint)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// createEndpointDialer creates a dialer function that connects to a specific endpoint
|
||||
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Parse endpoint to extract host and port
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
// If no port specified, assume default Redis port
|
||||
host = endpoint
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
}
|
||||
|
||||
// Use the base dialer to connect to the new endpoint
|
||||
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnFromRequest closes the connection and logs the reason
|
||||
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
|
||||
pooler := request.Pool
|
||||
conn := request.Conn
|
||||
if pooler != nil {
|
||||
pooler.Remove(ctx, conn, err)
|
||||
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: removed conn[%d] from pool due: %v",
|
||||
conn.GetID(), err)
|
||||
}
|
||||
} else {
|
||||
conn.Close()
|
||||
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: no pool provided for conn[%d], cannot remove due to: %v",
|
||||
conn.GetID(), err)
|
||||
}
|
||||
}
|
||||
}
|
318
hitless/hitless_manager.go
Normal file
318
hitless/hitless_manager.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Push notification type constants for hitless upgrades
|
||||
const (
|
||||
NotificationMoving = "MOVING"
|
||||
NotificationMigrating = "MIGRATING"
|
||||
NotificationMigrated = "MIGRATED"
|
||||
NotificationFailingOver = "FAILING_OVER"
|
||||
NotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
|
||||
var hitlessNotificationTypes = []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
// NotificationHook is called before and after notification processing
|
||||
// PreHook can modify the notification and return false to skip processing
|
||||
// PostHook is called after successful processing
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
|
||||
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
|
||||
}
|
||||
|
||||
// MovingOperationKey provides a unique key for tracking MOVING operations
|
||||
// that combines sequence ID with connection identifier to handle duplicate
|
||||
// sequence IDs across multiple connections to the same node.
|
||||
type MovingOperationKey struct {
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
ConnID uint64 // Unique connection identifier
|
||||
}
|
||||
|
||||
// String returns a string representation of the key for debugging
|
||||
func (k MovingOperationKey) String() string {
|
||||
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
|
||||
}
|
||||
|
||||
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
|
||||
type HitlessManager struct {
|
||||
client interfaces.ClientInterface
|
||||
config *Config
|
||||
options interfaces.OptionsInterface
|
||||
pool pool.Pooler
|
||||
|
||||
// MOVING operation tracking - using sync.Map for better concurrent performance
|
||||
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
|
||||
|
||||
// Atomic state tracking - no locks needed for state queries
|
||||
activeOperationCount atomic.Int64 // Number of active operations
|
||||
closed atomic.Bool // Manager closed state
|
||||
|
||||
// Notification hooks for extensibility
|
||||
hooks []NotificationHook
|
||||
hooksMu sync.RWMutex // Protects hooks slice
|
||||
poolHooksRef *PoolHook
|
||||
}
|
||||
|
||||
// MovingOperation tracks an active MOVING operation.
|
||||
type MovingOperation struct {
|
||||
SeqID int64
|
||||
NewEndpoint string
|
||||
StartTime time.Time
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
// NewHitlessManager creates a new simplified hitless manager.
|
||||
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
|
||||
if client == nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
hm := &HitlessManager{
|
||||
client: client,
|
||||
pool: pool,
|
||||
options: client.GetOptions(),
|
||||
config: config.Clone(),
|
||||
hooks: make([]NotificationHook, 0),
|
||||
}
|
||||
|
||||
// Set up push notification handling
|
||||
if err := hm.setupPushNotifications(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// GetPoolHook creates a pool hook with a custom dialer.
|
||||
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
|
||||
poolHook := hm.createPoolHook(baseDialer)
|
||||
hm.pool.AddPoolHook(poolHook)
|
||||
}
|
||||
|
||||
// setupPushNotifications sets up push notification handling by registering with the client's processor.
|
||||
func (hm *HitlessManager) setupPushNotifications() error {
|
||||
processor := hm.client.GetPushProcessor()
|
||||
if processor == nil {
|
||||
return ErrInvalidClient // Client doesn't support push notifications
|
||||
}
|
||||
|
||||
// Create our notification handler
|
||||
handler := &NotificationHandler{manager: hm}
|
||||
|
||||
// Register handlers for all hitless upgrade notifications with the client's processor
|
||||
for _, notificationType := range hitlessNotificationTypes {
|
||||
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
|
||||
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
|
||||
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Create MOVING operation record
|
||||
movingOp := &MovingOperation{
|
||||
SeqID: seqID,
|
||||
NewEndpoint: newEndpoint,
|
||||
StartTime: time.Now(),
|
||||
Deadline: deadline,
|
||||
}
|
||||
|
||||
// Use LoadOrStore for atomic check-and-set operation
|
||||
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
|
||||
// Duplicate MOVING notification, ignore
|
||||
if hm.config.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Duplicate MOVING operation ignored: %s", connID, seqID, key.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if hm.config.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Tracking MOVING operation: %s", connID, seqID, key.String())
|
||||
}
|
||||
|
||||
// Increment active operation count atomically
|
||||
hm.activeOperationCount.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
|
||||
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Remove from active operations atomically
|
||||
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
|
||||
if hm.config.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Untracking MOVING operation: %s", connID, seqID, key.String())
|
||||
}
|
||||
// Decrement active operation count only if operation existed
|
||||
hm.activeOperationCount.Add(-1)
|
||||
} else {
|
||||
if hm.config.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Operation not found for untracking: %s", connID, seqID, key.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveMovingOperations returns active operations with composite keys.
|
||||
// WARNING: This method creates a new map and copies all operations on every call.
|
||||
// Use sparingly, especially in hot paths or high-frequency logging.
|
||||
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
|
||||
result := make(map[MovingOperationKey]*MovingOperation)
|
||||
|
||||
// Iterate over sync.Map to build result
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
k := key.(MovingOperationKey)
|
||||
op := value.(*MovingOperation)
|
||||
|
||||
// Create a copy to avoid sharing references
|
||||
result[k] = &MovingOperation{
|
||||
SeqID: op.SeqID,
|
||||
NewEndpoint: op.NewEndpoint,
|
||||
StartTime: op.StartTime,
|
||||
Deadline: op.Deadline,
|
||||
}
|
||||
return true // Continue iteration
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHandoffInProgress returns true if any handoff is in progress.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *HitlessManager) IsHandoffInProgress() bool {
|
||||
return hm.activeOperationCount.Load() > 0
|
||||
}
|
||||
|
||||
// GetActiveOperationCount returns the number of active operations.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *HitlessManager) GetActiveOperationCount() int64 {
|
||||
return hm.activeOperationCount.Load()
|
||||
}
|
||||
|
||||
// Close closes the hitless manager.
|
||||
func (hm *HitlessManager) Close() error {
|
||||
// Use atomic operation for thread-safe close check
|
||||
if !hm.closed.CompareAndSwap(false, true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
// Shutdown the pool hook if it exists
|
||||
if hm.poolHooksRef != nil {
|
||||
// Use a timeout to prevent hanging indefinitely
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := hm.poolHooksRef.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
// was not able to close pool hook, keep closed state false
|
||||
hm.closed.Store(false)
|
||||
return err
|
||||
}
|
||||
// Remove the pool hook from the pool
|
||||
if hm.pool != nil {
|
||||
hm.pool.RemovePoolHook(hm.poolHooksRef)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all active operations
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
hm.activeMovingOps.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Reset counter
|
||||
hm.activeOperationCount.Store(0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetState returns current state using atomic counter for lock-free operation.
|
||||
func (hm *HitlessManager) GetState() State {
|
||||
if hm.activeOperationCount.Load() > 0 {
|
||||
return StateMoving
|
||||
}
|
||||
return StateIdle
|
||||
}
|
||||
|
||||
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
|
||||
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
currentNotification := notification
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
|
||||
if !shouldContinue {
|
||||
return modifiedNotification, false
|
||||
}
|
||||
currentNotification = modifiedNotification
|
||||
}
|
||||
|
||||
return currentNotification, true
|
||||
}
|
||||
|
||||
// processPostHooks calls all post-hooks with the processing result.
|
||||
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
|
||||
}
|
||||
}
|
||||
|
||||
// createPoolHook creates a pool hook with this manager already set.
|
||||
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
|
||||
if hm.poolHooksRef != nil {
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
// Get pool size from client options for better worker defaults
|
||||
poolSize := 0
|
||||
if hm.options != nil {
|
||||
poolSize = hm.options.GetPoolSize()
|
||||
}
|
||||
|
||||
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
|
||||
hm.poolHooksRef.SetPool(hm.pool)
|
||||
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
|
||||
func (hm *HitlessManager) AddNotificationHook(notificationHook NotificationHook) {
|
||||
hm.hooksMu.Lock()
|
||||
defer hm.hooksMu.Unlock()
|
||||
hm.hooks = append(hm.hooks, notificationHook)
|
||||
}
|
260
hitless/hitless_manager_test.go
Normal file
260
hitless/hitless_manager_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
)
|
||||
|
||||
// MockClient implements interfaces.ClientInterface for testing
|
||||
type MockClient struct {
|
||||
options interfaces.OptionsInterface
|
||||
}
|
||||
|
||||
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
|
||||
return mc.options
|
||||
}
|
||||
|
||||
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &MockPushProcessor{}
|
||||
}
|
||||
|
||||
// MockPushProcessor implements interfaces.NotificationProcessor for testing
|
||||
type MockPushProcessor struct{}
|
||||
|
||||
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockOptions implements interfaces.OptionsInterface for testing
|
||||
type MockOptions struct{}
|
||||
|
||||
func (mo *MockOptions) GetReadTimeout() time.Duration {
|
||||
return 5 * time.Second
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetWriteTimeout() time.Duration {
|
||||
return 5 * time.Second
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetAddr() string {
|
||||
return "localhost:6379"
|
||||
}
|
||||
|
||||
func (mo *MockOptions) IsTLSEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetProtocol() int {
|
||||
return 3 // RESP3
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetPoolSize() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetNetwork() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestHitlessManagerRefactoring(t *testing.T) {
|
||||
t.Run("AtomicStateTracking", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewHitlessManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// Test initial state
|
||||
if manager.IsHandoffInProgress() {
|
||||
t.Error("Expected no handoff in progress initially")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateIdle {
|
||||
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||
}
|
||||
|
||||
// Add an operation
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation: %v", err)
|
||||
}
|
||||
|
||||
// Test state after adding operation
|
||||
if !manager.IsHandoffInProgress() {
|
||||
t.Error("Expected handoff in progress after adding operation")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 1 {
|
||||
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateMoving {
|
||||
t.Errorf("Expected StateMoving, got %v", manager.GetState())
|
||||
}
|
||||
|
||||
// Remove the operation
|
||||
manager.UntrackOperationWithConnID(12345, 1)
|
||||
|
||||
// Test state after removing operation
|
||||
if manager.IsHandoffInProgress() {
|
||||
t.Error("Expected no handoff in progress after removing operation")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateIdle {
|
||||
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SyncMapPerformance", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewHitlessManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
|
||||
// Test concurrent operations
|
||||
const numOps = 100
|
||||
for i := 0; i < numOps; i++ {
|
||||
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != numOps {
|
||||
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
// Test GetActiveMovingOperations
|
||||
operations := manager.GetActiveMovingOperations()
|
||||
if len(operations) != numOps {
|
||||
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
|
||||
}
|
||||
|
||||
// Remove all operations
|
||||
for i := 0; i < numOps; i++ {
|
||||
manager.UntrackOperationWithConnID(int64(i), uint64(i))
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DuplicateOperationHandling", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewHitlessManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
|
||||
// Add operation
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation: %v", err)
|
||||
}
|
||||
|
||||
// Try to add duplicate operation
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Duplicate operation should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Should still have only 1 operation
|
||||
if manager.GetActiveOperationCount() != 1 {
|
||||
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotificationTypeConstants", func(t *testing.T) {
|
||||
// Test that constants are properly defined
|
||||
expectedTypes := []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
if len(hitlessNotificationTypes) != len(expectedTypes) {
|
||||
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
|
||||
}
|
||||
|
||||
// Test that all expected types are present
|
||||
typeMap := make(map[string]bool)
|
||||
for _, t := range hitlessNotificationTypes {
|
||||
typeMap[t] = true
|
||||
}
|
||||
|
||||
for _, expected := range expectedTypes {
|
||||
if !typeMap[expected] {
|
||||
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that hitlessNotificationTypes contains all expected constants
|
||||
expectedConstants := []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
for _, expected := range expectedConstants {
|
||||
found := false
|
||||
for _, actual := range hitlessNotificationTypes {
|
||||
if actual == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
47
hitless/hooks.go
Normal file
47
hitless/hooks.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// LoggingHook is an example hook implementation that logs all notifications.
|
||||
type LoggingHook struct {
|
||||
LogLevel logging.LogLevel
|
||||
}
|
||||
|
||||
// PreHook logs the notification before processing and allows modification.
|
||||
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
if lh.LogLevel.InfoOrAbove() { // Info level
|
||||
// Log the notification type and content
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] processing %s notification: %v", connID, notificationType, notification)
|
||||
}
|
||||
return notification, true // Continue processing with unmodified notification
|
||||
}
|
||||
|
||||
// PostHook logs the result after processing.
|
||||
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
if result != nil && lh.LogLevel.WarnOrAbove() { // Warning level
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processing failed: %v - %v", connID, notificationType, result, notification)
|
||||
} else if lh.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processed successfully", connID, notificationType)
|
||||
}
|
||||
}
|
||||
|
||||
// NewLoggingHook creates a new logging hook with the specified log level.
|
||||
// Log levels: LogLevelError=errors, LogLevelWarn=warnings, LogLevelInfo=info, LogLevelDebug=debug
|
||||
func NewLoggingHook(logLevel logging.LogLevel) *LoggingHook {
|
||||
return &LoggingHook{LogLevel: logLevel}
|
||||
}
|
179
hitless/pool_hook.go
Normal file
179
hitless/pool_hook.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// HitlessManagerInterface defines the interface for completing handoff operations
|
||||
type HitlessManagerInterface interface {
|
||||
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
|
||||
UntrackOperationWithConnID(seqID int64, connID uint64)
|
||||
}
|
||||
|
||||
// HandoffRequest represents a request to handoff a connection to a new endpoint
|
||||
type HandoffRequest struct {
|
||||
Conn *pool.Conn
|
||||
ConnID uint64 // Unique connection identifier
|
||||
Endpoint string
|
||||
SeqID int64
|
||||
Pool pool.Pooler // Pool to remove connection from on failure
|
||||
}
|
||||
|
||||
// PoolHook implements pool.PoolHook for Redis-specific connection handling
|
||||
// with hitless upgrade support.
|
||||
type PoolHook struct {
|
||||
// Base dialer for creating connections to new endpoints during handoffs
|
||||
// args are network and address
|
||||
baseDialer func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// Network type (e.g., "tcp", "unix")
|
||||
network string
|
||||
|
||||
// Worker manager for background handoff processing
|
||||
workerManager *handoffWorkerManager
|
||||
|
||||
// Configuration for the hitless upgrade
|
||||
config *Config
|
||||
|
||||
// Hitless manager for operation completion tracking
|
||||
hitlessManager HitlessManagerInterface
|
||||
|
||||
// Pool interface for removing connections on handoff failure
|
||||
pool pool.Pooler
|
||||
}
|
||||
|
||||
// NewPoolHook creates a new pool hook
|
||||
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
|
||||
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
|
||||
}
|
||||
|
||||
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
|
||||
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
|
||||
// Apply defaults if config is nil or has zero values
|
||||
if config == nil {
|
||||
config = config.ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
ph := &PoolHook{
|
||||
// baseDialer is used to create connections to new endpoints during handoffs
|
||||
baseDialer: baseDialer,
|
||||
network: network,
|
||||
config: config,
|
||||
// Hitless manager for operation completion tracking
|
||||
hitlessManager: hitlessManager,
|
||||
}
|
||||
|
||||
// Create worker manager
|
||||
ph.workerManager = newHandoffWorkerManager(config, ph)
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
// SetPool sets the pool interface for removing connections on handoff failure
|
||||
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
|
||||
ph.pool = pooler
|
||||
}
|
||||
|
||||
// GetCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (ph *PoolHook) GetCurrentWorkers() int {
|
||||
return ph.workerManager.getCurrentWorkers()
|
||||
}
|
||||
|
||||
// IsHandoffPending returns true if the given connection has a pending handoff
|
||||
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
|
||||
return ph.workerManager.isHandoffPending(conn)
|
||||
}
|
||||
|
||||
// GetPendingMap returns the pending map for testing purposes
|
||||
func (ph *PoolHook) GetPendingMap() *sync.Map {
|
||||
return ph.workerManager.getPendingMap()
|
||||
}
|
||||
|
||||
// GetMaxWorkers returns the max workers for testing purposes
|
||||
func (ph *PoolHook) GetMaxWorkers() int {
|
||||
return ph.workerManager.getMaxWorkers()
|
||||
}
|
||||
|
||||
// GetHandoffQueue returns the handoff queue for testing purposes
|
||||
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
|
||||
return ph.workerManager.getHandoffQueue()
|
||||
}
|
||||
|
||||
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return ph.workerManager.getCircuitBreakerStats()
|
||||
}
|
||||
|
||||
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
ph.workerManager.resetCircuitBreakers()
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
|
||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||
// in a handoff state at the moment.
|
||||
|
||||
// Check if connection is usable (not in a handoff state)
|
||||
// Should not happen since the pool will not return a connection that is not usable.
|
||||
if !conn.IsUsable() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||
if conn.ShouldHandoff() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// first check if we should handoff for faster rejection
|
||||
if !conn.ShouldHandoff() {
|
||||
// Default behavior (no handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// check pending handoff to not queue the same connection twice
|
||||
if ph.workerManager.isHandoffPending(conn) {
|
||||
// Default behavior (pending handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := ph.workerManager.queueHandoff(conn); err != nil {
|
||||
// Failed to queue handoff, remove the connection
|
||||
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
|
||||
// Don't pool, remove connection, no error to caller
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
// Check if handoff was already processed by a worker before we can mark it as queued
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was already processed - this is normal and the connection should be pooled
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := conn.MarkQueuedForHandoff(); err != nil {
|
||||
// If marking fails, check if handoff was processed in the meantime
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was processed - this is normal, pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
// Other error - remove the connection
|
||||
return false, true, nil
|
||||
}
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||
return ph.workerManager.shutdownWorkers(ctx)
|
||||
}
|
964
hitless/pool_hook_test.go
Normal file
964
hitless/pool_hook_test.go
Normal file
@@ -0,0 +1,964 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// mockNetConn implements net.Conn for testing
|
||||
type mockNetConn struct {
|
||||
addr string
|
||||
shouldFailInit bool
|
||||
}
|
||||
|
||||
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (m *mockNetConn) Close() error { return nil }
|
||||
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
|
||||
// createMockPoolConnection creates a mock pool connection for testing
|
||||
func createMockPoolConnection() *pool.Conn {
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
conn.SetUsable(true) // Make connection usable for testing
|
||||
return conn
|
||||
}
|
||||
|
||||
// mockPool implements pool.Pooler for testing
|
||||
type mockPool struct {
|
||||
removedConnections map[uint64]bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
|
||||
// Not implemented for testing
|
||||
}
|
||||
|
||||
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Use pool.Conn directly - no adapter needed
|
||||
mp.removedConnections[conn.GetID()] = true
|
||||
}
|
||||
|
||||
// WasRemoved safely checks if a connection was removed from the pool
|
||||
func (mp *mockPool) WasRemoved(connID uint64) bool {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
return mp.removedConnections[connID]
|
||||
}
|
||||
|
||||
func (mp *mockPool) Len() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) IdleLen() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) Stats() *pool.Stats {
|
||||
return &pool.Stats{}
|
||||
}
|
||||
|
||||
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
|
||||
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
|
||||
func (mp *mockPool) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnectionHook tests the Redis connection processor functionality
|
||||
func TestConnectionHook(t *testing.T) {
|
||||
// Create a base dialer for testing
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Mode: MaintNotificationsAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
|
||||
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
|
||||
MaxHandoffRetries: 3,
|
||||
LogLevel: 2,
|
||||
}
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Verify connection is marked for handoff
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Fatal("Connection should be marked for handoff")
|
||||
}
|
||||
// Set a mock initialization function with synchronization
|
||||
initConnCalled := make(chan bool, 1)
|
||||
proceedWithInit := make(chan bool, 1)
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
select {
|
||||
case initConnCalled <- true:
|
||||
default:
|
||||
}
|
||||
// Wait for test to proceed
|
||||
<-proceedWithInit
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection immediately (handoff queued)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled immediately with event-driven handoff")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when queuing handoff")
|
||||
}
|
||||
|
||||
// Wait for initialization to be called (indicates handoff started)
|
||||
select {
|
||||
case <-initConnCalled:
|
||||
// Good, initialization was called
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Timeout waiting for initialization function to be called")
|
||||
}
|
||||
|
||||
// Connection should be in pending map while initialization is blocked
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending handoffs map")
|
||||
}
|
||||
|
||||
// Allow initialization to proceed
|
||||
proceedWithInit <- true
|
||||
|
||||
// Wait for handoff to complete with proper timeout and polling
|
||||
timeout := time.After(2 * time.Second)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify handoff completed (removed from pending map)
|
||||
if _, pending := processor.GetPendingMap().Load(conn); pending {
|
||||
t.Error("Connection should be removed from pending map after handoff")
|
||||
}
|
||||
|
||||
// Verify connection is usable again
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after successful handoff")
|
||||
}
|
||||
|
||||
// Verify handoff state is cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after completion")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandoffNotNeeded", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
// Don't mark for handoff
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error when handoff not needed: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection normally
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled when no handoff needed")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when no handoff needed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyEndpoint", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error with empty endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection (empty endpoint clears state)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled after clearing empty endpoint")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed after clearing empty endpoint")
|
||||
}
|
||||
|
||||
// State should be cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
|
||||
// Create a failing base dialer
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Mode: MaintNotificationsAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 2, // Reduced retries for faster test
|
||||
HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test
|
||||
LogLevel: 2,
|
||||
}
|
||||
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not return error to caller: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection initially (handoff queued)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled initially with event-driven handoff")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when queuing handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete and fail with proper timeout and polling
|
||||
timeout := time.After(3 * time.Second)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// wait for handoff to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for failed handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection should be removed from pending map after failed handoff
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending {
|
||||
t.Error("Connection should be removed from pending map after failed handoff")
|
||||
}
|
||||
|
||||
// Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up)
|
||||
// Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete
|
||||
time.Sleep(800 * time.Millisecond)
|
||||
|
||||
// After max retries are reached, the connection should be removed from pool
|
||||
// and handoff state should be cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after max retries reached")
|
||||
}
|
||||
|
||||
t.Logf("EventDrivenHandoffDialerError test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("BufferedDataRESP2", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// For this test, we'll just verify the logic works for connections without buffered data
|
||||
// The actual buffered data detection is handled by the pool's connection health check
|
||||
// which is outside the scope of the Redis connection processor
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection normally (no buffered data in mock)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled when no buffered data")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when no buffered data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGet", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Mode: MaintNotificationsAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
|
||||
LogLevel: 2,
|
||||
}
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// Simulate a pending handoff by marking for handoff and queuing
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
})
|
||||
|
||||
t.Run("EventDrivenStateManagement", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// Test initial state - no pending handoffs
|
||||
if _, pending := processor.GetPendingMap().Load(conn); pending {
|
||||
t.Error("New connection should not have pending handoffs")
|
||||
}
|
||||
|
||||
// Test adding to pending map
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending map")
|
||||
}
|
||||
|
||||
// Test OnGet with pending handoff
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||
}
|
||||
|
||||
// Test removing from pending map and clearing handoff state
|
||||
processor.GetPendingMap().Delete(conn)
|
||||
if _, pending := processor.GetPendingMap().Load(conn); pending {
|
||||
t.Error("Connection should be removed from pending map")
|
||||
}
|
||||
|
||||
// Clear handoff state to simulate completed handoff
|
||||
conn.ClearHandoffState()
|
||||
conn.SetUsable(true) // Make connection usable again
|
||||
|
||||
// Test OnGet without pending handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||
// Create processor with small queue to test optimization features
|
||||
config := &Config{
|
||||
MaxWorkers: 3,
|
||||
HandoffQueueSize: 2,
|
||||
MaxHandoffRetries: 3, // Small queue to trigger optimizations
|
||||
LogLevel: 3, // Debug level to see optimization logs
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Add small delay to simulate network latency
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create multiple connections that need handoff to fill the queue
|
||||
connections := make([]*pool.Conn, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
connections[i] = createMockPoolConnection()
|
||||
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
|
||||
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
|
||||
}
|
||||
// Set a mock initialization function
|
||||
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
successCount := 0
|
||||
|
||||
// Process connections - should trigger scaling and timeout logic
|
||||
for _, conn := range connections {
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Logf("OnPut returned error (expected with timeout): %v", err)
|
||||
}
|
||||
|
||||
if shouldPool && !shouldRemove {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
// With timeout and scaling, most handoffs should eventually succeed
|
||||
if successCount == 0 {
|
||||
t.Error("Should have queued some handoffs with timeout and scaling")
|
||||
}
|
||||
|
||||
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
|
||||
|
||||
// Give time for workers to process and scaling to occur
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("WorkerScalingBehavior", func(t *testing.T) {
|
||||
// Create processor with small queue to test scaling behavior
|
||||
config := &Config{
|
||||
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
|
||||
HandoffQueueSize: 1,
|
||||
MaxHandoffRetries: 3, // Very small queue to force scaling
|
||||
LogLevel: 2, // Info level to see scaling logs
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Verify initial worker count (should be 0 with on-demand workers)
|
||||
if processor.GetCurrentWorkers() != 0 {
|
||||
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
|
||||
}
|
||||
if processor.GetMaxWorkers() != 15 {
|
||||
t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers())
|
||||
}
|
||||
|
||||
// The on-demand worker behavior creates workers only when needed
|
||||
// This test just verifies the basic configuration is correct
|
||||
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
|
||||
processor.GetMaxWorkers(), processor.GetCurrentWorkers())
|
||||
})
|
||||
|
||||
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
|
||||
// Create processor with fast post-handoff duration for testing
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Allow retries for successful handoff
|
||||
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
|
||||
RelaxedTimeout: 5 * time.Second,
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a connection and trigger handoff
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("Handoff should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete with proper timeout and polling
|
||||
timeout := time.After(1 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify relaxed timeout is set with deadline
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should have relaxed timeout after handoff")
|
||||
}
|
||||
|
||||
// Test that timeout is still active before deadline
|
||||
// We'll use HasRelaxedTimeout which internally checks the deadline
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should still have active relaxed timeout before deadline")
|
||||
}
|
||||
|
||||
// Wait for deadline to pass
|
||||
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
|
||||
|
||||
// Test that timeout is automatically restored after deadline
|
||||
// HasRelaxedTimeout should return false after deadline passes
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should not have active relaxed timeout after deadline")
|
||||
}
|
||||
|
||||
// Additional verification: calling HasRelaxedTimeout again should still return false
|
||||
// and should have cleared the internal timeout values
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should not have relaxed timeout after deadline (second check)")
|
||||
}
|
||||
|
||||
t.Logf("Passive timeout restoration test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("UsableFlagBehavior", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new connection without setting it usable
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
|
||||
// Initially, connection should not be usable (not initialized)
|
||||
if conn.IsUsable() {
|
||||
t.Error("New connection should not be usable before initialization")
|
||||
}
|
||||
|
||||
// Simulate initialization by setting usable to true
|
||||
conn.SetUsable(true)
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after initialization")
|
||||
}
|
||||
|
||||
// OnGet should succeed for usable connection
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||
}
|
||||
|
||||
// Mark connection for handoff
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Connection should still be usable until queued, but marked for handoff
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should still be usable after being marked for handoff (until queued)")
|
||||
}
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Error("Connection should be marked for handoff")
|
||||
}
|
||||
|
||||
// OnGet should fail for connection marked for handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("OnGet should fail for connection marked for handoff")
|
||||
}
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// After handoff completion, connection should be usable again
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after handoff completion")
|
||||
}
|
||||
|
||||
// OnGet should succeed again
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Usable flag behavior test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("StaticQueueBehavior", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 3,
|
||||
HandoffQueueSize: 50,
|
||||
MaxHandoffRetries: 3, // Explicit static queue size
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Verify queue capacity matches configured size
|
||||
queueCapacity := cap(processor.GetHandoffQueue())
|
||||
if queueCapacity != 50 {
|
||||
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
|
||||
}
|
||||
|
||||
// Test that queue size is static regardless of pool size
|
||||
// (No dynamic resizing should occur)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill part of the queue
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
|
||||
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
|
||||
}
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to queue handoff %d: %v", i, err)
|
||||
}
|
||||
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
|
||||
i, shouldPool, shouldRemove)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify queue capacity remains static (the main purpose of this test)
|
||||
finalCapacity := cap(processor.GetHandoffQueue())
|
||||
|
||||
if finalCapacity != 50 {
|
||||
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
|
||||
}
|
||||
|
||||
// Note: We don't check queue size here because workers process items quickly
|
||||
// The important thing is that the capacity remains static regardless of pool size
|
||||
})
|
||||
|
||||
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
|
||||
// Create a failing dialer that will cause handoff initialization to fail
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Return a connection that will fail during initialization
|
||||
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a mock pool that tracks removals
|
||||
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
|
||||
processor.SetPool(mockPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a connection and mark it for handoff
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a failing initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return fmt.Errorf("initialization failed")
|
||||
})
|
||||
|
||||
// Process the connection - handoff should fail and connection should be removed
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after failed handoff attempt")
|
||||
}
|
||||
|
||||
// Wait for handoff to be attempted and fail
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify that the connection was removed from the pool
|
||||
if !mockPool.WasRemoved(conn.GetID()) {
|
||||
t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID())
|
||||
}
|
||||
|
||||
t.Logf("Connection removal on handoff failure test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
|
||||
// Create config with short post-handoff duration for testing
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Allow retries for successful handoff
|
||||
RelaxedTimeout: 5 * time.Second,
|
||||
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("OnPut failed: %v", err)
|
||||
}
|
||||
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled after successful handoff")
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed after successful handoff")
|
||||
}
|
||||
|
||||
// Wait for the handoff to complete (it happens asynchronously)
|
||||
timeout := time.After(1 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.GetPendingMap().Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that relaxed timeout was applied to the new connection
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("New connection should have relaxed timeout applied after handoff")
|
||||
}
|
||||
|
||||
// Wait for the post-handoff duration to expire
|
||||
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
|
||||
|
||||
// Verify that relaxed timeout was automatically cleared
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// First mark should succeed
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("First MarkForHandoff should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Second mark should fail
|
||||
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
|
||||
t.Fatal("Second MarkForHandoff should return error")
|
||||
} else if err.Error() != "connection is already marked for handoff" {
|
||||
t.Fatalf("Expected specific error message, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify original handoff data is preserved
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Fatal("Connection should still be marked for handoff")
|
||||
}
|
||||
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
|
||||
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
|
||||
}
|
||||
if conn.GetMovingSeqID() != 1 {
|
||||
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
|
||||
// Test that HandoffTimeout from config is actually used
|
||||
customTimeout := 2 * time.Second
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
HandoffTimeout: customTimeout, // Custom timeout
|
||||
MaxHandoffRetries: 1, // Single retry to speed up test
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a connection that will test the timeout
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a dialer that will check the context timeout
|
||||
var timeoutVerified int32 // Use atomic for thread safety
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
// Check that the context has the expected timeout
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("Context should have a deadline")
|
||||
return errors.New("no deadline")
|
||||
}
|
||||
|
||||
// The deadline should be approximately customTimeout from now
|
||||
expectedDeadline := time.Now().Add(customTimeout)
|
||||
timeDiff := deadline.Sub(expectedDeadline)
|
||||
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
|
||||
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
|
||||
expectedDeadline, deadline, timeDiff)
|
||||
} else {
|
||||
atomic.StoreInt32(&timeoutVerified, 1)
|
||||
}
|
||||
|
||||
return nil // Successful handoff
|
||||
})
|
||||
|
||||
// Trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Connection should be queued for handoff
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Errorf("Connection should be pooled for handoff processing")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if atomic.LoadInt32(&timeoutVerified) == 0 {
|
||||
t.Error("HandoffTimeout was not properly applied to context")
|
||||
}
|
||||
|
||||
t.Logf("HandoffTimeout configuration test completed successfully")
|
||||
})
|
||||
}
|
276
hitless/push_notification_handler.go
Normal file
276
hitless/push_notification_handler.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// NotificationHandler handles push notifications for the simplified manager.
|
||||
type NotificationHandler struct {
|
||||
manager *HitlessManager
|
||||
}
|
||||
|
||||
// HandlePushNotification processes push notifications with hook support.
|
||||
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid notification format: %v", notification)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid notification type format: %v", notification[0])
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Process pre-hooks - they can modify the notification or skip processing
|
||||
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
|
||||
if !shouldContinue {
|
||||
return nil // Hooks decided to skip processing
|
||||
}
|
||||
|
||||
var err error
|
||||
switch notificationType {
|
||||
case NotificationMoving:
|
||||
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrating:
|
||||
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrated:
|
||||
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailingOver:
|
||||
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailedOver:
|
||||
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
|
||||
default:
|
||||
// Ignore other notification types (e.g., pub/sub messages)
|
||||
err = nil
|
||||
}
|
||||
|
||||
// Process post-hooks with the result
|
||||
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMoving processes MOVING notifications.
|
||||
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
|
||||
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid MOVING notification: %v", notification)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
seqID, ok := notification[1].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid seqID in MOVING notification: %v", notification[1])
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Extract timeS
|
||||
timeS, ok := notification[2].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid timeS in MOVING notification: %v", notification[2])
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
newEndpoint := ""
|
||||
if len(notification) > 3 {
|
||||
// Extract new endpoint
|
||||
newEndpoint, ok = notification[3].(string)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid newEndpoint in MOVING notification: %v", notification[3])
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
}
|
||||
|
||||
// Get the connection that received this notification
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MOVING notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to get the underlying pool connection
|
||||
var poolConn *pool.Conn
|
||||
if pc, ok := conn.(*pool.Conn); ok {
|
||||
poolConn = pc
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// If the connection is closed or not pooled, we can ignore the notification
|
||||
// this connection won't be remembered by the pool and will be garbage collected
|
||||
// Keep pubsub connections around since they are not pooled but are long-lived
|
||||
// and should be allowed to handoff (the pubsub instance will reconnect and change
|
||||
// the underlying *pool.Conn)
|
||||
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
|
||||
return nil
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
|
||||
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
|
||||
if newEndpoint == "" || newEndpoint == internal.RedisNull {
|
||||
if snh.manager.config.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] scheduling handoff to current endpoint in %v seconds",
|
||||
poolConn.GetID(), timeS/2)
|
||||
}
|
||||
// same as current endpoint
|
||||
newEndpoint = snh.manager.options.GetAddr()
|
||||
// delay the handoff for timeS/2 seconds to the same endpoint
|
||||
// do this in a goroutine to avoid blocking the notification handler
|
||||
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
|
||||
// and there should be no possibility of a race condition or double handoff.
|
||||
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
|
||||
if poolConn == nil || poolConn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
|
||||
// Log error but don't fail the goroutine - use background context since original may be cancelled
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
|
||||
}
|
||||
|
||||
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
|
||||
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
|
||||
// Connection is already marked for handoff, which is acceptable
|
||||
// This can happen if multiple MOVING notifications are received for the same connection
|
||||
return nil
|
||||
}
|
||||
// Optionally track in hitless manager for monitoring/debugging
|
||||
if snh.manager != nil {
|
||||
connID := conn.GetID()
|
||||
// Track the operation (ignore errors since this is optional)
|
||||
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
|
||||
} else {
|
||||
return fmt.Errorf("hitless: manager not initialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrating processes MIGRATING notifications.
|
||||
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATING notifications indicate that a connection is about to be migrated
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid MIGRATING notification: %v", notification)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for MIGRATING notification",
|
||||
conn.GetID(),
|
||||
snh.manager.config.RelaxedTimeout)
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrated processes MIGRATED notifications.
|
||||
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATED notifications indicate that a connection migration has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid MIGRATED notification: %v", notification)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for MIGRATED notification", connID)
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOver processes FAILING_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILING_OVER notifications indicate that a connection is about to failover
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid FAILING_OVER notification: %v", notification)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for FAILING_OVER notification", connID, snh.manager.config.RelaxedTimeout)
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOver processes FAILED_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILED_OVER notifications indicate that a connection failover has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid FAILED_OVER notification: %v", notification)
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification")
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for FAILED_OVER notification", connID)
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
24
hitless/state.go
Normal file
24
hitless/state.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package hitless
|
||||
|
||||
// State represents the current state of a hitless upgrade operation.
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateIdle indicates no upgrade is in progress
|
||||
StateIdle State = iota
|
||||
|
||||
// StateHandoff indicates a connection handoff is in progress
|
||||
StateMoving
|
||||
)
|
||||
|
||||
// String returns a string representation of the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateIdle:
|
||||
return "idle"
|
||||
case StateMoving:
|
||||
return "moving"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
54
internal/interfaces/interfaces.go
Normal file
54
internal/interfaces/interfaces.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Package interfaces provides shared interfaces used by both the main redis package
|
||||
// and the hitless upgrade package to avoid circular dependencies.
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NotificationProcessor is (most probably) a push.NotificationProcessor
|
||||
// forward declaration to avoid circular imports
|
||||
type NotificationProcessor interface {
|
||||
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
|
||||
UnregisterHandler(pushNotificationName string) error
|
||||
GetHandler(pushNotificationName string) interface{}
|
||||
}
|
||||
|
||||
// ClientInterface defines the interface that clients must implement for hitless upgrades.
|
||||
type ClientInterface interface {
|
||||
// GetOptions returns the client options.
|
||||
GetOptions() OptionsInterface
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
GetPushProcessor() NotificationProcessor
|
||||
}
|
||||
|
||||
// OptionsInterface defines the interface for client options.
|
||||
// Uses an adapter pattern to avoid circular dependencies.
|
||||
type OptionsInterface interface {
|
||||
// GetReadTimeout returns the read timeout.
|
||||
GetReadTimeout() time.Duration
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
GetWriteTimeout() time.Duration
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
GetNetwork() string
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
GetAddr() string
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
IsTLSEnabled() bool
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
GetProtocol() int
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
GetPoolSize() int
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
NewDialer() func(context.Context) (net.Conn, error)
|
||||
}
|
@@ -14,26 +14,20 @@ type Logging interface {
|
||||
Printf(ctx context.Context, format string, v ...interface{})
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
type DefaultLogger struct {
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
_ = l.log.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func NewDefaultLogger() Logging {
|
||||
return &DefaultLogger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
}
|
||||
}
|
||||
|
||||
// Logger calls Output to print to the stderr.
|
||||
// Arguments are handled in the manner of fmt.Print.
|
||||
var Logger Logging = &logger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
}
|
||||
|
||||
// VoidLogger is a logger that does nothing.
|
||||
// Used to disable logging and thus speed up the library.
|
||||
type VoidLogger struct{}
|
||||
|
||||
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
var _ Logging = (*VoidLogger)(nil)
|
||||
var Logger Logging = NewDefaultLogger()
|
||||
|
@@ -2,6 +2,7 @@ package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
|
||||
b.Run(bm.String(), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: bm.poolSize,
|
||||
PoolSize: int32(bm.poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
b.Run(bm.String(), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: bm.poolSize,
|
||||
PoolSize: int32(bm.poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
It("should use default buffer sizes when not specified", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
})
|
||||
|
||||
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
ReadBufferSize: customReadSize,
|
||||
WriteBufferSize: customWriteSize,
|
||||
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
It("should handle zero buffer sizes by using defaults", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
ReadBufferSize: 0, // Should use default
|
||||
WriteBufferSize: 0, // Should use default
|
||||
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
// without setting ReadBufferSize and WriteBufferSize
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
// ReadBufferSize and WriteBufferSize are not set (will be 0)
|
||||
})
|
||||
|
@@ -3,7 +3,10 @@ package pool
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -12,17 +15,65 @@ import (
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
|
||||
type atomicNetConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// generateConnID generates a fast unique identifier for a connection with zero allocations
|
||||
func generateConnID() uint64 {
|
||||
return atomic.AddUint64(&connIDCounter, 1)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
usedAt int64 // atomic
|
||||
netConn net.Conn
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||
netConnAtomic atomic.Value // stores *atomicNetConn
|
||||
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
|
||||
Inited bool
|
||||
// Lightweight mutex to protect reader operations during handoff
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
Inited atomic.Bool
|
||||
pooled bool
|
||||
pubsub bool
|
||||
closed atomic.Bool
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
|
||||
// Hitless upgrade support: relaxed timeouts during migrations/failovers
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
|
||||
|
||||
// Counter to track multiple relaxed timeout setters if we have nested calls
|
||||
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
|
||||
// if counter reaches 0, we clear the relaxed timeouts
|
||||
relaxedCounter atomic.Int32
|
||||
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
// Connection identifier for unique tracking across handoffs
|
||||
id uint64 // Unique numeric identifier for this connection
|
||||
|
||||
// Handoff state - using atomic operations for lock-free access
|
||||
usableAtomic atomic.Bool // Connection usability state
|
||||
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
|
||||
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
|
||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||
// newEndpointAtomic needs special handling as it's a string
|
||||
newEndpointAtomic atomic.Value // stores string
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
@@ -33,8 +84,8 @@ func NewConn(netConn net.Conn) *Conn {
|
||||
|
||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||
cn := &Conn{
|
||||
netConn: netConn,
|
||||
createdAt: time.Now(),
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
}
|
||||
|
||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||
@@ -50,6 +101,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
||||
}
|
||||
|
||||
// Store netConn atomically for lock-free access using wrapper
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
// Initialize atomic handoff state
|
||||
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
||||
cn.shouldHandoffAtomic.Store(false) // false initially
|
||||
cn.movingSeqIDAtomic.Store(0) // 0 initially
|
||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||
cn.newEndpointAtomic.Store("") // empty string initially
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(time.Now())
|
||||
return cn
|
||||
@@ -64,23 +125,366 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
// This is the fast path for accessing netConn without mutex overhead.
|
||||
func (cn *Conn) getNetConn() net.Conn {
|
||||
if v := cn.netConnAtomic.Load(); v != nil {
|
||||
if wrapper, ok := v.(*atomicNetConn); ok {
|
||||
return wrapper.conn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setNetConn stores the network connection atomically (lock-free).
|
||||
// This is used for the fast path of connection replacement.
|
||||
func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// Lock-free helper methods for handoff state management
|
||||
|
||||
// isUsable returns true if the connection is safe to use (lock-free).
|
||||
func (cn *Conn) isUsable() bool {
|
||||
return cn.usableAtomic.Load()
|
||||
}
|
||||
|
||||
// setUsable sets the usable flag atomically (lock-free).
|
||||
func (cn *Conn) setUsable(usable bool) {
|
||||
cn.usableAtomic.Store(usable)
|
||||
}
|
||||
|
||||
// shouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) shouldHandoff() bool {
|
||||
return cn.shouldHandoffAtomic.Load()
|
||||
}
|
||||
|
||||
// setShouldHandoff sets the handoff flag atomically (lock-free).
|
||||
func (cn *Conn) setShouldHandoff(should bool) {
|
||||
cn.shouldHandoffAtomic.Store(should)
|
||||
}
|
||||
|
||||
// getMovingSeqID returns the sequence ID atomically (lock-free).
|
||||
func (cn *Conn) getMovingSeqID() int64 {
|
||||
return cn.movingSeqIDAtomic.Load()
|
||||
}
|
||||
|
||||
// setMovingSeqID sets the sequence ID atomically (lock-free).
|
||||
func (cn *Conn) setMovingSeqID(seqID int64) {
|
||||
cn.movingSeqIDAtomic.Store(seqID)
|
||||
}
|
||||
|
||||
// getNewEndpoint returns the new endpoint atomically (lock-free).
|
||||
func (cn *Conn) getNewEndpoint() string {
|
||||
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
|
||||
return endpoint.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// setNewEndpoint sets the new endpoint atomically (lock-free).
|
||||
func (cn *Conn) setNewEndpoint(endpoint string) {
|
||||
cn.newEndpointAtomic.Store(endpoint)
|
||||
}
|
||||
|
||||
// setHandoffRetries sets the retry count atomically (lock-free).
|
||||
func (cn *Conn) setHandoffRetries(retries int) {
|
||||
cn.handoffRetriesAtomic.Store(uint32(retries))
|
||||
}
|
||||
|
||||
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
|
||||
func (cn *Conn) incrementHandoffRetries(delta int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
return cn.isUsable()
|
||||
}
|
||||
|
||||
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
|
||||
func (cn *Conn) IsPooled() bool {
|
||||
return cn.pooled
|
||||
}
|
||||
|
||||
// IsPubSub returns true if the connection is used for PubSub.
|
||||
func (cn *Conn) IsPubSub() bool {
|
||||
return cn.pubsub
|
||||
}
|
||||
|
||||
func (cn *Conn) IsInited() bool {
|
||||
return cn.Inited.Load()
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
cn.setUsable(usable)
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||
cn.relaxedCounter.Add(1)
|
||||
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||
}
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
|
||||
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
|
||||
}
|
||||
|
||||
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) ClearRelaxedTimeout() {
|
||||
// Atomically decrement counter and check if we should clear
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
// Use atomic load to get current value for CAS to avoid stale value race
|
||||
current := cn.relaxedCounter.Load()
|
||||
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) clearRelaxedTimeout() {
|
||||
cn.relaxedReadTimeoutNs.Store(0)
|
||||
cn.relaxedWriteTimeoutNs.Store(0)
|
||||
cn.relaxedDeadlineNs.Store(0)
|
||||
cn.relaxedCounter.Store(0)
|
||||
}
|
||||
|
||||
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
|
||||
// This checks both the timeout values and the deadline (if set).
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) HasRelaxedTimeout() bool {
|
||||
// Fast path: no relaxed timeouts are set
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// If no relaxed timeouts are set, return false
|
||||
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, relaxed timeouts are active
|
||||
if deadlineNs == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If deadline is set, check if it's still in the future
|
||||
return time.Now().UnixNano() < deadlineNs
|
||||
}
|
||||
|
||||
// getEffectiveReadTimeout returns the timeout to use for read operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if readTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(readTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
cn.relaxedCounter.Add(-1)
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// getEffectiveWriteTimeout returns the timeout to use for write operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if writeTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(writeTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
cn.relaxedCounter.Add(-1)
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) SetOnClose(fn func() error) {
|
||||
cn.onClose = fn
|
||||
}
|
||||
|
||||
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
|
||||
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
|
||||
cn.initConnFunc = fn
|
||||
}
|
||||
|
||||
// ExecuteInitConn runs the stored connection initialization function if available.
|
||||
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
|
||||
if cn.initConnFunc != nil {
|
||||
return cn.initConnFunc(ctx, cn)
|
||||
}
|
||||
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||
cn.netConn = netConn
|
||||
// Store the new connection atomically first (lock-free)
|
||||
cn.setNetConn(netConn)
|
||||
// Protect reader reset operations to avoid data races
|
||||
// Use write lock since we're modifying the reader state
|
||||
cn.readerMu.Lock()
|
||||
cn.rd.Reset(netConn)
|
||||
cn.readerMu.Unlock()
|
||||
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
|
||||
// GetNetConn safely returns the current network connection using atomic load (lock-free).
|
||||
// This method is used by the pool for health checks and provides better performance.
|
||||
func (cn *Conn) GetNetConn() net.Conn {
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||
// New connection is not initialized yet
|
||||
cn.Inited.Store(false)
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
return cn.ExecuteInitConn(ctx)
|
||||
}
|
||||
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
|
||||
// Returns an error if the connection is already marked for handoff.
|
||||
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||
// Use single atomic CAS operation for state transition
|
||||
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
|
||||
return errors.New("connection is already marked for handoff")
|
||||
}
|
||||
|
||||
cn.setNewEndpoint(newEndpoint)
|
||||
cn.setMovingSeqID(seqID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
// Use single atomic CAS operation for state transition
|
||||
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
|
||||
return errors.New("connection was not marked for handoff")
|
||||
}
|
||||
cn.setUsable(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
return cn.shouldHandoff()
|
||||
}
|
||||
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
return cn.getNewEndpoint()
|
||||
}
|
||||
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
return cn.getMovingSeqID()
|
||||
}
|
||||
|
||||
// GetID returns the unique identifier for this connection.
|
||||
func (cn *Conn) GetID() uint64 {
|
||||
return cn.id
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
// clear handoff state
|
||||
cn.setShouldHandoff(false)
|
||||
cn.setNewEndpoint("")
|
||||
cn.setMovingSeqID(0)
|
||||
cn.setHandoffRetries(0)
|
||||
cn.setUsable(true) // Connection is safe to use again after handoff completes
|
||||
}
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return cn.incrementHandoffRetries(n)
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) HasBufferedData() bool {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
return cn.rd.Buffered() > 0
|
||||
}
|
||||
|
||||
// PeekReplyTypeSafe safely peeks at the reply type.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
|
||||
if cn.rd.Buffered() <= 0 {
|
||||
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
|
||||
}
|
||||
return cn.rd.PeekReplyType()
|
||||
}
|
||||
|
||||
func (cn *Conn) Write(b []byte) (int, error) {
|
||||
return cn.netConn.Write(b)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Write(b)
|
||||
}
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) RemoteAddr() net.Addr {
|
||||
if cn.netConn != nil {
|
||||
return cn.netConn.RemoteAddr()
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -89,7 +493,16 @@ func (cn *Conn) WithReader(
|
||||
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
|
||||
|
||||
// Get the connection directly from atomic storage
|
||||
netConn := cn.getNetConn()
|
||||
if netConn == nil {
|
||||
return fmt.Errorf("redis: connection not available")
|
||||
}
|
||||
|
||||
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -100,13 +513,26 @@ func (cn *Conn) WithWriter(
|
||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||
|
||||
// Always set write deadline, even if getNetConn() returns nil
|
||||
// This prevents write operations from hanging indefinitely
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// If getNetConn() returns nil, we still need to respect the timeout
|
||||
// Return an error to prevent indefinite blocking
|
||||
return fmt.Errorf("redis: connection not available for write operation")
|
||||
}
|
||||
}
|
||||
|
||||
if cn.bw.Buffered() > 0 {
|
||||
cn.bw.Reset(cn.netConn)
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fn(cn.wr); err != nil {
|
||||
@@ -116,19 +542,33 @@ func (cn *Conn) WithWriter(
|
||||
return cn.bw.Flush()
|
||||
}
|
||||
|
||||
func (cn *Conn) IsClosed() bool {
|
||||
return cn.closed.Load()
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
cn.closed.Store(true)
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
}
|
||||
return cn.netConn.Close()
|
||||
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
||||
// This is used to check if there are push notifications available
|
||||
// Important: This will work on Linux, but not on Windows
|
||||
func (cn *Conn) MaybeHasData() bool {
|
||||
return maybeHasData(cn.netConn)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return maybeHasData(netConn)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
|
92
internal/pool/conn_relaxed_timeout_test.go
Normal file
92
internal/pool/conn_relaxed_timeout_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout
|
||||
func TestConcurrentRelaxedTimeoutClearing(t *testing.T) {
|
||||
// Create a dummy connection for testing
|
||||
netConn := &net.TCPConn{}
|
||||
cn := NewConn(netConn)
|
||||
defer cn.Close()
|
||||
|
||||
// Set relaxed timeout multiple times to increase counter
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
|
||||
// Verify counter is 3
|
||||
if count := cn.relaxedCounter.Load(); count != 3 {
|
||||
t.Errorf("Expected relaxed counter to be 3, got %d", count)
|
||||
}
|
||||
|
||||
// Clear timeouts concurrently to test race condition fix
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cn.ClearRelaxedTimeout()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify counter is 0 and timeouts are cleared
|
||||
if count := cn.relaxedCounter.Load(); count != 0 {
|
||||
t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count)
|
||||
}
|
||||
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout)
|
||||
}
|
||||
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario
|
||||
func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) {
|
||||
netConn := &net.TCPConn{}
|
||||
cn := NewConn(netConn)
|
||||
defer cn.Close()
|
||||
|
||||
// Set relaxed timeout once
|
||||
cn.SetRelaxedTimeout(time.Second, time.Second)
|
||||
|
||||
// Verify counter is 1
|
||||
if count := cn.relaxedCounter.Load(); count != 1 {
|
||||
t.Errorf("Expected relaxed counter to be 1, got %d", count)
|
||||
}
|
||||
|
||||
// Test concurrent clearing with race condition scenario
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Multiple goroutines try to clear simultaneously
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cn.ClearRelaxedTimeout()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify final state is consistent
|
||||
if count := cn.relaxedCounter.Load(); count != 0 {
|
||||
t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count)
|
||||
}
|
||||
|
||||
// Verify timeouts are actually cleared
|
||||
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout)
|
||||
}
|
||||
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
|
||||
t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout)
|
||||
}
|
||||
if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 {
|
||||
t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline)
|
||||
}
|
||||
}
|
@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
|
||||
}
|
||||
|
||||
func (cn *Conn) NetConn() net.Conn {
|
||||
return cn.netConn
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
func (p *ConnPool) CheckMinIdleConns() {
|
||||
|
114
internal/pool/hooks.go
Normal file
114
internal/pool/hooks.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolHook defines the interface for connection lifecycle hooks.
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||
}
|
||||
|
||||
// PoolHookManager manages multiple pool hooks.
|
||||
type PoolHookManager struct {
|
||||
hooks []PoolHook
|
||||
hooksMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolHookManager creates a new pool hook manager.
|
||||
func NewPoolHookManager() *PoolHookManager {
|
||||
return &PoolHookManager{
|
||||
hooks: make([]PoolHook, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddHook adds a pool hook to the manager.
|
||||
// Hooks are called in the order they were added.
|
||||
func (phm *PoolHookManager) AddHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
phm.hooks = append(phm.hooks, hook)
|
||||
}
|
||||
|
||||
// RemoveHook removes a pool hook from the manager.
|
||||
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
|
||||
for i, h := range phm.hooks {
|
||||
if h == hook {
|
||||
// Remove hook by swapping with last element and truncating
|
||||
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
|
||||
phm.hooks = phm.hooks[:len(phm.hooks)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
shouldPool = true // Default to pooling the connection
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||
|
||||
if hookErr != nil {
|
||||
return false, true, hookErr
|
||||
}
|
||||
|
||||
// If any hook says to remove or not pool, respect that decision
|
||||
if hookShouldRemove {
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
if !hookShouldPool {
|
||||
shouldPool = false
|
||||
}
|
||||
}
|
||||
|
||||
return shouldPool, false, nil
|
||||
}
|
||||
|
||||
// GetHookCount returns the number of registered hooks (for testing).
|
||||
func (phm *PoolHookManager) GetHookCount() int {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
return len(phm.hooks)
|
||||
}
|
||||
|
||||
// GetHooks returns a copy of all registered hooks.
|
||||
func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
hooks := make([]PoolHook, len(phm.hooks))
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
213
internal/pool/hooks_test.go
Normal file
213
internal/pool/hooks_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestHook for testing hook functionality
|
||||
type TestHook struct {
|
||||
OnGetCalled int
|
||||
OnPutCalled int
|
||||
GetError error
|
||||
PutError error
|
||||
ShouldPool bool
|
||||
ShouldRemove bool
|
||||
}
|
||||
|
||||
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
th.OnGetCalled++
|
||||
return th.GetError
|
||||
}
|
||||
|
||||
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
th.OnPutCalled++
|
||||
return th.ShouldPool, th.ShouldRemove, th.PutError
|
||||
}
|
||||
|
||||
func TestPoolHookManager(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Test initial state
|
||||
if manager.GetHookCount() != 0 {
|
||||
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Add hooks
|
||||
hook1 := &TestHook{ShouldPool: true}
|
||||
hook2 := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(hook1)
|
||||
manager.AddHook(hook2)
|
||||
|
||||
if manager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test ProcessOnGet
|
||||
ctx := context.Background()
|
||||
conn := &Conn{} // Mock connection
|
||||
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||
}
|
||||
|
||||
if hook1.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||
}
|
||||
|
||||
if hook2.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
|
||||
}
|
||||
|
||||
// Test ProcessOnPut
|
||||
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
if !shouldPool {
|
||||
t.Error("Expected shouldPool to be true")
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
t.Error("Expected shouldRemove to be false")
|
||||
}
|
||||
|
||||
if hook1.OnPutCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
|
||||
}
|
||||
|
||||
if hook2.OnPutCalled != 1 {
|
||||
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
|
||||
}
|
||||
|
||||
// Remove a hook
|
||||
manager.RemoveHook(hook1)
|
||||
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookErrorHandling(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Hook that returns error on Get
|
||||
errorHook := &TestHook{
|
||||
GetError: errors.New("test error"),
|
||||
ShouldPool: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(errorHook)
|
||||
manager.AddHook(normalHook)
|
||||
|
||||
ctx := context.Background()
|
||||
conn := &Conn{}
|
||||
|
||||
// Test that error stops processing
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ProcessOnGet")
|
||||
}
|
||||
|
||||
if errorHook.OnGetCalled != 1 {
|
||||
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||
}
|
||||
|
||||
// normalHook should not be called due to error
|
||||
if normalHook.OnGetCalled != 0 {
|
||||
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookShouldRemove(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Hook that says to remove connection
|
||||
removeHook := &TestHook{
|
||||
ShouldPool: false,
|
||||
ShouldRemove: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(removeHook)
|
||||
manager.AddHook(normalHook)
|
||||
|
||||
ctx := context.Background()
|
||||
conn := &Conn{}
|
||||
|
||||
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
if shouldPool {
|
||||
t.Error("Expected shouldPool to be false")
|
||||
}
|
||||
|
||||
if !shouldRemove {
|
||||
t.Error("Expected shouldRemove to be true")
|
||||
}
|
||||
|
||||
if removeHook.OnPutCalled != 1 {
|
||||
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
|
||||
}
|
||||
|
||||
// normalHook should not be called due to early return
|
||||
if normalHook.OnPutCalled != 0 {
|
||||
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithHooks(t *testing.T) {
|
||||
// Create a pool with hooks
|
||||
hookManager := NewPoolHookManager()
|
||||
testHook := &TestHook{ShouldPool: true}
|
||||
hookManager.AddHook(testHook)
|
||||
|
||||
opt := &Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil // Mock connection
|
||||
},
|
||||
PoolSize: 1,
|
||||
DialTimeout: time.Second,
|
||||
}
|
||||
|
||||
pool := NewConnPool(opt)
|
||||
defer pool.Close()
|
||||
|
||||
// Add hook to pool after creation
|
||||
pool.AddPoolHook(testHook)
|
||||
|
||||
// Verify hooks are initialized
|
||||
if pool.hookManager == nil {
|
||||
t.Error("Expected hookManager to be initialized")
|
||||
}
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test adding hook to pool
|
||||
additionalHook := &TestHook{ShouldPool: true}
|
||||
pool.AddPoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test removing hook from pool
|
||||
pool.RemovePoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
}
|
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -22,6 +23,23 @@ var (
|
||||
|
||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||
|
||||
// popAttempts is the maximum number of attempts to find a usable connection
|
||||
// when popping from the idle connection pool. This handles cases where connections
|
||||
// are temporarily marked as unusable (e.g., during hitless upgrades or network issues).
|
||||
// Value of 50 provides sufficient resilience without excessive overhead.
|
||||
// This is capped by the idle connection count, so we won't loop excessively.
|
||||
popAttempts = 50
|
||||
|
||||
// getAttempts is the maximum number of attempts to get a connection that passes
|
||||
// hook validation (e.g., hitless upgrade hooks). This protects against race conditions
|
||||
// where hooks might temporarily reject connections during cluster transitions.
|
||||
// Value of 3 balances resilience with performance - most hook rejections resolve quickly.
|
||||
getAttempts = 3
|
||||
|
||||
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
|
||||
maxTime = minTime.Add(1<<63 - 1)
|
||||
noExpiration = maxTime
|
||||
)
|
||||
|
||||
var timers = sync.Pool{
|
||||
@@ -38,11 +56,14 @@ type Stats struct {
|
||||
Misses uint32 // number of times free connection was NOT found in the pool
|
||||
Timeouts uint32 // number of times a wait timeout occurred
|
||||
WaitCount uint32 // number of times a connection was waited
|
||||
Unusable uint32 // number of times a connection was found to be unusable
|
||||
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
|
||||
|
||||
TotalConns uint32 // number of total connections in the pool
|
||||
IdleConns uint32 // number of idle connections in the pool
|
||||
StaleConns uint32 // number of stale connections removed from the pool
|
||||
|
||||
PubSubStats PubSubStats
|
||||
}
|
||||
|
||||
type Pooler interface {
|
||||
@@ -57,29 +78,35 @@ type Pooler interface {
|
||||
IdleLen() int
|
||||
Stats() *Stats
|
||||
|
||||
AddPoolHook(hook PoolHook)
|
||||
RemovePoolHook(hook PoolHook)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Dialer func(context.Context) (net.Conn, error)
|
||||
|
||||
PoolFIFO bool
|
||||
PoolSize int
|
||||
DialTimeout time.Duration
|
||||
PoolTimeout time.Duration
|
||||
MinIdleConns int
|
||||
MaxIdleConns int
|
||||
MaxActiveConns int
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnMaxLifetime time.Duration
|
||||
|
||||
|
||||
// Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without)
|
||||
Protocol int
|
||||
|
||||
ReadBufferSize int
|
||||
WriteBufferSize int
|
||||
|
||||
PoolFIFO bool
|
||||
PoolSize int32
|
||||
DialTimeout time.Duration
|
||||
PoolTimeout time.Duration
|
||||
MinIdleConns int32
|
||||
MaxIdleConns int32
|
||||
MaxActiveConns int32
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnMaxLifetime time.Duration
|
||||
PushNotificationsEnabled bool
|
||||
|
||||
// DialerRetries is the maximum number of retry attempts when dialing fails.
|
||||
// Default: 5
|
||||
DialerRetries int
|
||||
|
||||
// DialerRetryTimeout is the backoff duration between retry attempts.
|
||||
// Default: 100ms
|
||||
DialerRetryTimeout time.Duration
|
||||
}
|
||||
|
||||
type lastDialErrorWrap struct {
|
||||
@@ -95,16 +122,21 @@ type ConnPool struct {
|
||||
queue chan struct{}
|
||||
|
||||
connsMu sync.Mutex
|
||||
conns []*Conn
|
||||
conns map[uint64]*Conn
|
||||
idleConns []*Conn
|
||||
|
||||
poolSize int
|
||||
idleConnsLen int
|
||||
poolSize atomic.Int32
|
||||
idleConnsLen atomic.Int32
|
||||
idleCheckInProgress atomic.Bool
|
||||
|
||||
stats Stats
|
||||
waitDurationNs atomic.Int64
|
||||
|
||||
_closed uint32 // atomic
|
||||
|
||||
// Pool hooks manager for flexible connection processing
|
||||
hookManagerMu sync.RWMutex
|
||||
hookManager *PoolHookManager
|
||||
}
|
||||
|
||||
var _ Pooler = (*ConnPool)(nil)
|
||||
@@ -114,34 +146,69 @@ func NewConnPool(opt *Options) *ConnPool {
|
||||
cfg: opt,
|
||||
|
||||
queue: make(chan struct{}, opt.PoolSize),
|
||||
conns: make([]*Conn, 0, opt.PoolSize),
|
||||
conns: make(map[uint64]*Conn),
|
||||
idleConns: make([]*Conn, 0, opt.PoolSize),
|
||||
}
|
||||
|
||||
// Only create MinIdleConns if explicitly requested (> 0)
|
||||
// This avoids creating connections during pool initialization for tests
|
||||
if opt.MinIdleConns > 0 {
|
||||
p.connsMu.Lock()
|
||||
p.checkMinIdleConns()
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// initializeHooks sets up the pool hooks system.
|
||||
func (p *ConnPool) initializeHooks() {
|
||||
p.hookManager = NewPoolHookManager()
|
||||
}
|
||||
|
||||
// AddPoolHook adds a pool hook to the pool.
|
||||
func (p *ConnPool) AddPoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager == nil {
|
||||
p.initializeHooks()
|
||||
}
|
||||
p.hookManager.AddHook(hook)
|
||||
}
|
||||
|
||||
// RemovePoolHook removes a pool hook from the pool.
|
||||
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager != nil {
|
||||
p.hookManager.RemoveHook(hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) checkMinIdleConns() {
|
||||
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer p.idleCheckInProgress.Store(false)
|
||||
|
||||
if p.cfg.MinIdleConns == 0 {
|
||||
return
|
||||
}
|
||||
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
|
||||
|
||||
// Only create idle connections if we haven't reached the total pool size limit
|
||||
// MinIdleConns should be a subset of PoolSize, not additional connections
|
||||
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
p.poolSize++
|
||||
p.idleConnsLen++
|
||||
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.connsMu.Lock()
|
||||
p.poolSize--
|
||||
p.idleConnsLen--
|
||||
p.connsMu.Unlock()
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
@@ -150,12 +217,9 @@ func (p *ConnPool) checkMinIdleConns() {
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.connsMu.Lock()
|
||||
p.poolSize--
|
||||
p.idleConnsLen--
|
||||
p.connsMu.Unlock()
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
|
||||
p.freeTurn()
|
||||
}()
|
||||
default:
|
||||
@@ -172,6 +236,9 @@ func (p *ConnPool) addIdleConn() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
@@ -182,11 +249,15 @@ func (p *ConnPool) addIdleConn() error {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
p.conns = append(p.conns, cn)
|
||||
p.conns[cn.GetID()] = cn
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewConn creates a new connection and returns it to the user.
|
||||
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
|
||||
//
|
||||
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
|
||||
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
|
||||
return p.newConn(ctx, false)
|
||||
}
|
||||
@@ -196,33 +267,44 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
||||
p.connsMu.Unlock()
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
p.connsMu.Unlock()
|
||||
|
||||
cn, err := p.dialConn(ctx, pooled)
|
||||
dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout)
|
||||
defer cancel()
|
||||
cn, err := p.dialConn(dialCtx, pooled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||
_ = cn.Close()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
p.conns = append(p.conns, cn)
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
if p.closed() {
|
||||
_ = cn.Close()
|
||||
return nil, ErrClosed
|
||||
}
|
||||
// Check if pool was closed while we were waiting for the lock
|
||||
if p.conns == nil {
|
||||
p.conns = make(map[uint64]*Conn)
|
||||
}
|
||||
p.conns[cn.GetID()] = cn
|
||||
|
||||
if pooled {
|
||||
// If pool is full remove the cn on next Put.
|
||||
if p.poolSize >= p.cfg.PoolSize {
|
||||
currentPoolSize := p.poolSize.Load()
|
||||
if currentPoolSize >= p.cfg.PoolSize {
|
||||
cn.pooled = false
|
||||
} else {
|
||||
p.poolSize++
|
||||
p.poolSize.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,18 +320,57 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, p.getLastDialError()
|
||||
}
|
||||
|
||||
// Retry dialing with backoff
|
||||
// the context timeout is already handled by the context passed in
|
||||
// so we may never reach the max retries, higher values don't hurt
|
||||
maxRetries := p.cfg.DialerRetries
|
||||
if maxRetries <= 0 {
|
||||
maxRetries = 5 // Default value
|
||||
}
|
||||
backoffDuration := p.cfg.DialerRetryTimeout
|
||||
if backoffDuration <= 0 {
|
||||
backoffDuration = 100 * time.Millisecond // Default value
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
shouldLoop := true
|
||||
// when the timeout is reached, we should stop retrying
|
||||
// but keep the lastErr to return to the caller
|
||||
// instead of a generic context deadline exceeded error
|
||||
for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ {
|
||||
netConn, err := p.cfg.Dialer(ctx)
|
||||
if err != nil {
|
||||
p.setLastDialError(err)
|
||||
lastErr = err
|
||||
// Add backoff delay for retry attempts
|
||||
// (not for the first attempt, do at least one)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shouldLoop = false
|
||||
case <-time.After(backoffDuration):
|
||||
// Continue with retry
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Success - create connection
|
||||
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
||||
cn.pooled = pooled
|
||||
if p.cfg.ConnMaxLifetime > 0 {
|
||||
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
|
||||
} else {
|
||||
cn.expiresAt = noExpiration
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr)
|
||||
// All retries failed - handle error tracking
|
||||
p.setLastDialError(lastErr)
|
||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
|
||||
go p.tryDial()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
||||
cn.pooled = pooled
|
||||
return cn, nil
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (p *ConnPool) tryDial() {
|
||||
@@ -289,6 +410,14 @@ func (p *ConnPool) getLastDialError() error {
|
||||
|
||||
// Get returns existed connection from the pool or creates a new one.
|
||||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return p.getConn(ctx)
|
||||
}
|
||||
|
||||
// getConn returns a connection from the pool.
|
||||
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
var cn *Conn
|
||||
var err error
|
||||
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
@@ -297,9 +426,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
attempts := 0
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
|
||||
break
|
||||
}
|
||||
attempts++
|
||||
|
||||
p.connsMu.Lock()
|
||||
cn, err := p.popIdle()
|
||||
cn, err = p.popIdle()
|
||||
p.connsMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -311,11 +448,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
break
|
||||
}
|
||||
|
||||
if !p.isHealthyConn(cn) {
|
||||
if !p.isHealthyConn(cn, now) {
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
// Failed to process connection, discard it
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint32(&p.stats.Hits, 1)
|
||||
return cn, nil
|
||||
}
|
||||
@@ -328,6 +479,19 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
||||
// Failed to process connection, discard it
|
||||
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
|
||||
_ = p.CloseConn(newcn)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return newcn, nil
|
||||
}
|
||||
|
||||
@@ -356,7 +520,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
|
||||
}
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
|
||||
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
|
||||
atomic.AddUint32(&p.stats.WaitCount, 1)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
@@ -376,68 +540,130 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
defer p.checkMinIdleConns()
|
||||
|
||||
n := len(p.idleConns)
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var cn *Conn
|
||||
attempts := 0
|
||||
|
||||
maxAttempts := util.Min(popAttempts, n)
|
||||
for attempts < maxAttempts {
|
||||
if len(p.idleConns) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.cfg.PoolFIFO {
|
||||
cn = p.idleConns[0]
|
||||
copy(p.idleConns, p.idleConns[1:])
|
||||
p.idleConns = p.idleConns[:n-1]
|
||||
p.idleConns = p.idleConns[:len(p.idleConns)-1]
|
||||
} else {
|
||||
idx := n - 1
|
||||
idx := len(p.idleConns) - 1
|
||||
cn = p.idleConns[idx]
|
||||
p.idleConns = p.idleConns[:idx]
|
||||
}
|
||||
p.idleConnsLen--
|
||||
p.checkMinIdleConns()
|
||||
attempts++
|
||||
|
||||
if cn.IsUsable() {
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
|
||||
// Connection is not usable, put it back in the pool
|
||||
if p.cfg.PoolFIFO {
|
||||
// FIFO: put at end (will be picked up last since we pop from front)
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
} else {
|
||||
// LIFO: put at beginning (will be picked up last since we pop from end)
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
}
|
||||
cn = nil
|
||||
}
|
||||
|
||||
// If we exhausted all attempts without finding a usable connection, return nil
|
||||
if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() {
|
||||
internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
// Process connection using the hooks system
|
||||
shouldPool := true
|
||||
shouldRemove := false
|
||||
if cn.rd.Buffered() > 0 {
|
||||
// Check if this might be push notification data
|
||||
if p.cfg.Protocol == 3 {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block and check if it's a push notification
|
||||
if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush {
|
||||
shouldRemove = true
|
||||
var err error
|
||||
|
||||
if cn.HasBufferedData() {
|
||||
// Peek at the reply type to check if it's a push notification
|
||||
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
|
||||
// Not a push notification or error peeking, remove connection
|
||||
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
|
||||
p.Remove(ctx, cn, err)
|
||||
}
|
||||
} else {
|
||||
// not a push notification since protocol 2 doesn't support them
|
||||
shouldRemove = true
|
||||
// It's a push notification, allow pooling (client will handle it)
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
// For non-RESP3 or data that is not a push notification, buffered data is unexpected
|
||||
internal.Logger.Printf(ctx, "Conn has unread data, closing it")
|
||||
p.Remove(ctx, cn, BadConnError{})
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
|
||||
p.Remove(ctx, cn, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If hooks say to remove the connection, do so
|
||||
if shouldRemove {
|
||||
p.Remove(ctx, cn, errors.New("hook requested removal"))
|
||||
return
|
||||
}
|
||||
|
||||
// If processor says not to pool the connection, remove it
|
||||
if !shouldPool {
|
||||
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.pooled {
|
||||
p.Remove(ctx, cn, nil)
|
||||
p.Remove(ctx, cn, errors.New("connection not pooled"))
|
||||
return
|
||||
}
|
||||
|
||||
var shouldCloseConn bool
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||
// put them at the opposite end of the queue
|
||||
if !cn.IsUsable() {
|
||||
if p.cfg.PoolFIFO {
|
||||
p.connsMu.Lock()
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.idleConnsLen++
|
||||
p.connsMu.Unlock()
|
||||
} else {
|
||||
p.removeConn(cn)
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
} else {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
p.idleConnsLen.Add(1)
|
||||
} else {
|
||||
p.removeConnWithLock(cn)
|
||||
shouldCloseConn = true
|
||||
}
|
||||
|
||||
p.connsMu.Unlock()
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
if shouldCloseConn {
|
||||
@@ -447,8 +673,13 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
|
||||
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
p.removeConnWithLock(cn)
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
_ = p.closeConn(cn)
|
||||
|
||||
// Check if we need to create new idle connections to maintain MinIdleConns
|
||||
p.checkMinIdleConns()
|
||||
}
|
||||
|
||||
func (p *ConnPool) CloseConn(cn *Conn) error {
|
||||
@@ -463,17 +694,23 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
|
||||
}
|
||||
|
||||
func (p *ConnPool) removeConn(cn *Conn) {
|
||||
for i, c := range p.conns {
|
||||
if c == cn {
|
||||
p.conns = append(p.conns[:i], p.conns[i+1:]...)
|
||||
cid := cn.GetID()
|
||||
delete(p.conns, cid)
|
||||
atomic.AddUint32(&p.stats.StaleConns, 1)
|
||||
|
||||
// Decrement pool size counter when removing a connection
|
||||
if cn.pooled {
|
||||
p.poolSize--
|
||||
p.checkMinIdleConns()
|
||||
}
|
||||
p.poolSize.Add(-1)
|
||||
// this can be idle conn
|
||||
for idx, ic := range p.idleConns {
|
||||
if ic.GetID() == cid {
|
||||
internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid)
|
||||
p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...)
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
}
|
||||
atomic.AddUint32(&p.stats.StaleConns, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) closeConn(cn *Conn) error {
|
||||
@@ -491,9 +728,9 @@ func (p *ConnPool) Len() int {
|
||||
// IdleLen returns number of idle connections.
|
||||
func (p *ConnPool) IdleLen() int {
|
||||
p.connsMu.Lock()
|
||||
n := p.idleConnsLen
|
||||
n := p.idleConnsLen.Load()
|
||||
p.connsMu.Unlock()
|
||||
return n
|
||||
return int(n)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Stats() *Stats {
|
||||
@@ -502,6 +739,7 @@ func (p *ConnPool) Stats() *Stats {
|
||||
Misses: atomic.LoadUint32(&p.stats.Misses),
|
||||
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
|
||||
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
|
||||
Unusable: atomic.LoadUint32(&p.stats.Unusable),
|
||||
WaitDurationNs: p.waitDurationNs.Load(),
|
||||
|
||||
TotalConns: uint32(p.Len()),
|
||||
@@ -542,30 +780,33 @@ func (p *ConnPool) Close() error {
|
||||
}
|
||||
}
|
||||
p.conns = nil
|
||||
p.poolSize = 0
|
||||
p.poolSize.Store(0)
|
||||
p.idleConns = nil
|
||||
p.idleConnsLen = 0
|
||||
p.idleConnsLen.Store(0)
|
||||
p.connsMu.Unlock()
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
||||
now := time.Now()
|
||||
|
||||
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
|
||||
// slight optimization, check expiresAt first.
|
||||
if cn.expiresAt.Before(now) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if connection has exceeded idle timeout
|
||||
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check connection health, but be aware of push notifications
|
||||
if err := connCheck(cn.netConn); err != nil {
|
||||
cn.SetUsedAt(now)
|
||||
// Check basic connection health
|
||||
// Use GetNetConn() to safely access netConn and avoid data races
|
||||
if err := connCheck(cn.getNetConn()); err != nil {
|
||||
// If there's unexpected data, it might be push notifications (RESP3)
|
||||
// However, push notification processing is now handled by the client
|
||||
// before WithReader to ensure proper context is available to handlers
|
||||
if err == errUnexpectedRead && p.cfg.Protocol == 3 {
|
||||
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block
|
||||
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
||||
@@ -579,7 +820,5 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsedAt(now)
|
||||
return true
|
||||
}
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package pool
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type SingleConnPool struct {
|
||||
pool Pooler
|
||||
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
|
||||
func (p *StickyConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
@@ -2,15 +2,17 @@ package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
var _ = Describe("ConnPool", func() {
|
||||
@@ -20,7 +22,7 @@ var _ = Describe("ConnPool", func() {
|
||||
BeforeEach(func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Hour,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
@@ -45,11 +47,11 @@ var _ = Describe("ConnPool", func() {
|
||||
<-closedChan
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Hour,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
MinIdleConns: minIdleConns,
|
||||
MinIdleConns: int32(minIdleConns),
|
||||
})
|
||||
wg.Wait()
|
||||
Expect(connPool.Close()).NotTo(HaveOccurred())
|
||||
@@ -105,7 +107,7 @@ var _ = Describe("ConnPool", func() {
|
||||
// ok
|
||||
}
|
||||
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
|
||||
// Check that Get is unblocked.
|
||||
select {
|
||||
@@ -130,8 +132,8 @@ var _ = Describe("MinIdleConns", func() {
|
||||
newConnPool := func() *pool.ConnPool {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
PoolSize: int32(poolSize),
|
||||
MinIdleConns: int32(minIdleConns),
|
||||
PoolTimeout: 100 * time.Millisecond,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: -1,
|
||||
@@ -168,7 +170,7 @@ var _ = Describe("MinIdleConns", func() {
|
||||
|
||||
Context("after Remove", func() {
|
||||
BeforeEach(func() {
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
})
|
||||
|
||||
It("has idle connections", func() {
|
||||
@@ -245,7 +247,7 @@ var _ = Describe("MinIdleConns", func() {
|
||||
BeforeEach(func() {
|
||||
perform(len(cns), func(i int) {
|
||||
mu.RLock()
|
||||
connPool.Remove(ctx, cns[i], nil)
|
||||
connPool.Remove(ctx, cns[i], errors.New("test"))
|
||||
mu.RUnlock()
|
||||
})
|
||||
|
||||
@@ -309,7 +311,7 @@ var _ = Describe("race", func() {
|
||||
It("does not happen on Get, Put, and Remove", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Minute,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
@@ -328,7 +330,7 @@ var _ = Describe("race", func() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -339,15 +341,15 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1000,
|
||||
MinIdleConns: 50,
|
||||
PoolSize: int32(1000),
|
||||
MinIdleConns: int32(50),
|
||||
PoolTimeout: 3 * time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < opt.PoolSize; i++ {
|
||||
for i := int32(0); i < opt.PoolSize; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
@@ -366,8 +368,8 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
panic("test panic")
|
||||
},
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 30,
|
||||
PoolSize: int32(100),
|
||||
MinIdleConns: int32(30),
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
||||
@@ -384,7 +386,7 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 3 * time.Second,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
@@ -415,7 +417,7 @@ var _ = Describe("race", func() {
|
||||
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: testPoolTimeout,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
@@ -435,3 +437,73 @@ var _ = Describe("race", func() {
|
||||
Expect(stats.Timeouts).To(Equal(uint32(1)))
|
||||
})
|
||||
})
|
||||
|
||||
// TestDialerRetryConfiguration tests the new DialerRetries and DialerRetryTimeout options
|
||||
func TestDialerRetryConfiguration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CustomDialerRetries", func(t *testing.T) {
|
||||
var attempts int64
|
||||
failingDialer := func(ctx context.Context) (net.Conn, error) {
|
||||
atomic.AddInt64(&attempts, 1)
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: failingDialer,
|
||||
PoolSize: 1,
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
DialerRetries: 3, // Custom retry count
|
||||
DialerRetryTimeout: 10 * time.Millisecond, // Fast retries for testing
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
_, err := connPool.Get(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error from failing dialer")
|
||||
}
|
||||
|
||||
// Should have attempted at least 3 times (DialerRetries = 3)
|
||||
// There might be additional attempts due to pool logic
|
||||
finalAttempts := atomic.LoadInt64(&attempts)
|
||||
if finalAttempts < 3 {
|
||||
t.Errorf("Expected at least 3 dial attempts, got %d", finalAttempts)
|
||||
}
|
||||
if finalAttempts > 6 {
|
||||
t.Errorf("Expected around 3 dial attempts, got %d (too many)", finalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DefaultDialerRetries", func(t *testing.T) {
|
||||
var attempts int64
|
||||
failingDialer := func(ctx context.Context) (net.Conn, error) {
|
||||
atomic.AddInt64(&attempts, 1)
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: failingDialer,
|
||||
PoolSize: 1,
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
// DialerRetries and DialerRetryTimeout not set - should use defaults
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
_, err := connPool.Get(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error from failing dialer")
|
||||
}
|
||||
|
||||
// Should have attempted 5 times (default DialerRetries = 5)
|
||||
finalAttempts := atomic.LoadInt64(&attempts)
|
||||
if finalAttempts != 5 {
|
||||
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
logging.Disable()
|
||||
}
|
||||
|
78
internal/pool/pubsub.go
Normal file
78
internal/pool/pubsub.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type PubSubStats struct {
|
||||
Created uint32
|
||||
Untracked uint32
|
||||
Active uint32
|
||||
}
|
||||
|
||||
// PubSubPool manages a pool of PubSub connections.
|
||||
type PubSubPool struct {
|
||||
opt *Options
|
||||
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Map to track active PubSub connections
|
||||
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
|
||||
closed atomic.Bool
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
opt: opt,
|
||||
netDialer: netDialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
netConn, err := p.netDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
|
||||
cn.pubsub = true
|
||||
atomic.AddUint32(&p.stats.Created, 1)
|
||||
return cn, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *PubSubPool) TrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, 1)
|
||||
p.activeConns.Store(cn.GetID(), cn)
|
||||
}
|
||||
|
||||
func (p *PubSubPool) UntrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, ^uint32(0))
|
||||
atomic.AddUint32(&p.stats.Untracked, 1)
|
||||
p.activeConns.Delete(cn.GetID())
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Close() error {
|
||||
p.closed.Store(true)
|
||||
p.activeConns.Range(func(key, value interface{}) bool {
|
||||
cn := value.(*Conn)
|
||||
_ = cn.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Stats() *PubSubStats {
|
||||
// load stats atomically
|
||||
return &PubSubStats{
|
||||
Created: atomic.LoadUint32(&p.stats.Created),
|
||||
Untracked: atomic.LoadUint32(&p.stats.Untracked),
|
||||
Active: atomic.LoadUint32(&p.stats.Active),
|
||||
}
|
||||
}
|
3
internal/redis.go
Normal file
3
internal/redis.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package internal
|
||||
|
||||
const RedisNull = "null"
|
@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
|
||||
func SafeIntToInt32(value int, fieldName string) (int32, error) {
|
||||
if value > math.MaxInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
|
||||
}
|
||||
if value < math.MinInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
|
||||
}
|
||||
return int32(value), nil
|
||||
}
|
||||
|
17
internal/util/math.go
Normal file
17
internal/util/math.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
// Max returns the maximum of two integers
|
||||
func Max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Min returns the minimum of two integers
|
||||
func Min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
121
logging/logging.go
Normal file
121
logging/logging.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Package logging provides logging level constants and utilities for the go-redis library.
|
||||
// This package centralizes logging configuration to ensure consistency across all components.
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
// LogLevel represents the logging level
|
||||
type LogLevel int
|
||||
|
||||
// Log level constants for the entire go-redis library
|
||||
const (
|
||||
LogLevelError LogLevel = iota // 0 - errors only
|
||||
LogLevelWarn // 1 - warnings and errors
|
||||
LogLevelInfo // 2 - info, warnings, and errors
|
||||
LogLevelDebug // 3 - debug, info, warnings, and errors
|
||||
)
|
||||
|
||||
// String returns the string representation of the log level
|
||||
func (l LogLevel) String() string {
|
||||
switch l {
|
||||
case LogLevelError:
|
||||
return "ERROR"
|
||||
case LogLevelWarn:
|
||||
return "WARN"
|
||||
case LogLevelInfo:
|
||||
return "INFO"
|
||||
case LogLevelDebug:
|
||||
return "DEBUG"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the log level is valid
|
||||
func (l LogLevel) IsValid() bool {
|
||||
return l >= LogLevelError && l <= LogLevelDebug
|
||||
}
|
||||
|
||||
func (l LogLevel) WarnOrAbove() bool {
|
||||
return l >= LogLevelWarn
|
||||
}
|
||||
|
||||
func (l LogLevel) InfoOrAbove() bool {
|
||||
return l >= LogLevelInfo
|
||||
}
|
||||
|
||||
func (l LogLevel) DebugOrAbove() bool {
|
||||
return l >= LogLevelDebug
|
||||
}
|
||||
|
||||
// VoidLogger is a logger that does nothing.
|
||||
// Used to disable logging and thus speed up the library.
|
||||
type VoidLogger struct{}
|
||||
|
||||
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// Disable disables logging by setting the internal logger to a void logger.
|
||||
// This can be used to speed up the library if logging is not needed.
|
||||
// It will override any custom logger that was set before and set the VoidLogger.
|
||||
func Disable() {
|
||||
internal.Logger = &VoidLogger{}
|
||||
}
|
||||
|
||||
// Enable enables logging by setting the internal logger to the default logger.
|
||||
// This is the default behavior.
|
||||
// You can use redis.SetLogger to set a custom logger.
|
||||
//
|
||||
// NOTE: This function is not thread-safe.
|
||||
// It will override any custom logger that was set before and set the DefaultLogger.
|
||||
func Enable() {
|
||||
internal.Logger = internal.NewDefaultLogger()
|
||||
}
|
||||
|
||||
// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings.
|
||||
// This can be used to filter out messages containing sensitive information.
|
||||
func NewBlacklistLogger(substr []string) internal.Logging {
|
||||
l := internal.NewDefaultLogger()
|
||||
return &filterLogger{logger: l, substr: substr, blacklist: true}
|
||||
}
|
||||
|
||||
// NewWhitelistLogger returns a new logger that only logs messages containing any of the substrings.
|
||||
// This can be used to only log messages related to specific commands or patterns.
|
||||
func NewWhitelistLogger(substr []string) internal.Logging {
|
||||
l := internal.NewDefaultLogger()
|
||||
return &filterLogger{logger: l, substr: substr, blacklist: false}
|
||||
}
|
||||
|
||||
type filterLogger struct {
|
||||
logger internal.Logging
|
||||
blacklist bool
|
||||
substr []string
|
||||
}
|
||||
|
||||
func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
found := false
|
||||
for _, substr := range l.substr {
|
||||
if strings.Contains(msg, substr) {
|
||||
found = true
|
||||
if l.blacklist {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// whitelist, only log if one of the substrings is present
|
||||
if !l.blacklist && !found {
|
||||
return
|
||||
}
|
||||
if l.logger != nil {
|
||||
l.logger.Printf(ctx, format, v...)
|
||||
return
|
||||
}
|
||||
}
|
59
logging/logging_test.go
Normal file
59
logging/logging_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package logging
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLogLevel_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
level LogLevel
|
||||
expected string
|
||||
}{
|
||||
{LogLevelError, "ERROR"},
|
||||
{LogLevelWarn, "WARN"},
|
||||
{LogLevelInfo, "INFO"},
|
||||
{LogLevelDebug, "DEBUG"},
|
||||
{LogLevel(99), "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if got := test.level.String(); got != test.expected {
|
||||
t.Errorf("LogLevel(%d).String() = %q, want %q", test.level, got, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevel_IsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
level LogLevel
|
||||
expected bool
|
||||
}{
|
||||
{LogLevelError, true},
|
||||
{LogLevelWarn, true},
|
||||
{LogLevelInfo, true},
|
||||
{LogLevelDebug, true},
|
||||
{LogLevel(-1), false},
|
||||
{LogLevel(4), false},
|
||||
{LogLevel(99), false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if got := test.level.IsValid(); got != test.expected {
|
||||
t.Errorf("LogLevel(%d).IsValid() = %v, want %v", test.level, got, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevelConstants(t *testing.T) {
|
||||
// Test that constants have expected values
|
||||
if LogLevelError != 0 {
|
||||
t.Errorf("LogLevelError = %d, want 0", LogLevelError)
|
||||
}
|
||||
if LogLevelWarn != 1 {
|
||||
t.Errorf("LogLevelWarn = %d, want 1", LogLevelWarn)
|
||||
}
|
||||
if LogLevelInfo != 2 {
|
||||
t.Errorf("LogLevelInfo = %d, want 2", LogLevelInfo)
|
||||
}
|
||||
if LogLevelDebug != 3 {
|
||||
t.Errorf("LogLevelDebug = %d, want 3", LogLevelDebug)
|
||||
}
|
||||
}
|
@@ -13,6 +13,7 @@ import (
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -102,6 +103,7 @@ var _ = BeforeSuite(func() {
|
||||
fmt.Printf("RCEDocker: %v\n", RCEDocker)
|
||||
fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion)
|
||||
fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE"))
|
||||
logging.Disable()
|
||||
|
||||
if RedisVersion < 7.0 || RedisVersion > 9 {
|
||||
panic("incorrect or not supported redis version")
|
||||
|
136
options.go
136
options.go
@@ -14,9 +14,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Limiter is the interface of a rate limiter or a circuit breaker.
|
||||
@@ -107,9 +109,19 @@ type Options struct {
|
||||
|
||||
// DialTimeout for establishing new connections.
|
||||
//
|
||||
// default: 5 seconds
|
||||
// default: 10 seconds
|
||||
DialTimeout time.Duration
|
||||
|
||||
// DialerRetries is the maximum number of retry attempts when dialing fails.
|
||||
//
|
||||
// default: 5
|
||||
DialerRetries int
|
||||
|
||||
// DialerRetryTimeout is the backoff duration between retry attempts.
|
||||
//
|
||||
// default: 100 milliseconds
|
||||
DialerRetryTimeout time.Duration
|
||||
|
||||
// ReadTimeout for socket reads. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
//
|
||||
@@ -153,6 +165,7 @@ type Options struct {
|
||||
//
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
// default: false
|
||||
PoolFIFO bool
|
||||
|
||||
// PoolSize is the base number of socket connections.
|
||||
@@ -244,8 +257,19 @@ type Options struct {
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
|
||||
// This is an alias to hitless.Config for convenience.
|
||||
type HitlessUpgradeConfig = hitless.Config
|
||||
|
||||
func (opt *Options) init() {
|
||||
if opt.Addr == "" {
|
||||
opt.Addr = "localhost:6379"
|
||||
@@ -261,7 +285,13 @@ func (opt *Options) init() {
|
||||
opt.Protocol = 3
|
||||
}
|
||||
if opt.DialTimeout == 0 {
|
||||
opt.DialTimeout = 5 * time.Second
|
||||
opt.DialTimeout = 10 * time.Second
|
||||
}
|
||||
if opt.DialerRetries == 0 {
|
||||
opt.DialerRetries = 5
|
||||
}
|
||||
if opt.DialerRetryTimeout == 0 {
|
||||
opt.DialerRetryTimeout = 100 * time.Millisecond
|
||||
}
|
||||
if opt.Dialer == nil {
|
||||
opt.Dialer = NewDialer(opt)
|
||||
@@ -320,13 +350,36 @@ func (opt *Options) init() {
|
||||
case 0:
|
||||
opt.MaxRetryBackoff = 512 * time.Millisecond
|
||||
}
|
||||
|
||||
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
|
||||
|
||||
// auto-detect endpoint type if not specified
|
||||
endpointType := opt.HitlessUpgradeConfig.EndpointType
|
||||
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
|
||||
// Auto-detect endpoint type if not specified
|
||||
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
|
||||
}
|
||||
opt.HitlessUpgradeConfig.EndpointType = endpointType
|
||||
}
|
||||
|
||||
func (opt *Options) clone() *Options {
|
||||
clone := *opt
|
||||
|
||||
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
|
||||
if opt.HitlessUpgradeConfig != nil {
|
||||
configClone := *opt.HitlessUpgradeConfig
|
||||
clone.HitlessUpgradeConfig = &configClone
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
|
||||
return NewDialer(opt)
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
||||
@@ -612,23 +665,84 @@ func getUserPassword(u *url.URL) (string, string) {
|
||||
func newConnPool(
|
||||
opt *Options,
|
||||
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) *pool.ConnPool {
|
||||
) (*pool.ConnPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer(ctx, opt.Network, opt.Addr)
|
||||
},
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: opt.PoolSize,
|
||||
PoolSize: poolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: opt.MinIdleConns,
|
||||
MaxIdleConns: opt.MaxIdleConns,
|
||||
MaxActiveConns: opt.MaxActiveConns,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
// Pass protocol version for push notification optimization
|
||||
Protocol: opt.Protocol,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
})
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) (*pool.PubSubPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewPubSubPool(&pool.Options{
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}, dialer), nil
|
||||
}
|
||||
|
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +39,7 @@ type ClusterOptions struct {
|
||||
ClientName string
|
||||
|
||||
// NewClient creates a cluster node client with provided name and options.
|
||||
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
|
||||
NewClient func(opt *Options) *Client
|
||||
|
||||
// The maximum number of retries before giving up. Command is retried
|
||||
@@ -125,10 +127,22 @@ type ClusterOptions struct {
|
||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||
UnstableResp3 bool
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) init() {
|
||||
@@ -319,6 +333,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
|
||||
var hitlessConfig *HitlessUpgradeConfig
|
||||
if opt.HitlessUpgradeConfig != nil {
|
||||
configClone := *opt.HitlessUpgradeConfig
|
||||
hitlessConfig = &configClone
|
||||
}
|
||||
|
||||
return &Options{
|
||||
ClientName: opt.ClientName,
|
||||
Dialer: opt.Dialer,
|
||||
@@ -362,6 +383,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// situations in the options below will prevent that from happening.
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
HitlessUpgradeConfig: hitlessConfig,
|
||||
PushNotificationProcessor: opt.PushNotificationProcessor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1830,12 +1853,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
|
||||
return err
|
||||
}
|
||||
|
||||
// hitless won't work here for now
|
||||
func (c *ClusterClient) pubSub() *PubSub {
|
||||
var node *clusterNode
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt.clientOptions(),
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
if node != nil {
|
||||
panic("node != nil")
|
||||
}
|
||||
@@ -1850,18 +1873,25 @@ func (c *ClusterClient) pubSub() *PubSub {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn, err := node.Client.newConn(context.TODO())
|
||||
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
|
||||
if err != nil {
|
||||
node = nil
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// will return nil if already initialized
|
||||
err = node.Client.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
node = nil
|
||||
return nil, err
|
||||
}
|
||||
node.Client.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
err := node.Client.connPool.CloseConn(cn)
|
||||
// Untrack connection from PubSubPool
|
||||
node.Client.pubSubPool.UntrackConn(cn)
|
||||
err := cn.Close()
|
||||
node = nil
|
||||
return err
|
||||
},
|
||||
|
375
pool_pubsub_bench_test.go
Normal file
375
pool_pubsub_bench_test.go
Normal file
@@ -0,0 +1,375 @@
|
||||
// Pool and PubSub Benchmark Suite
|
||||
//
|
||||
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
|
||||
// It's designed to be run against different branches to compare performance.
|
||||
//
|
||||
// Usage Examples:
|
||||
// # Run all benchmarks
|
||||
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run only pool benchmarks
|
||||
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run only PubSub benchmarks
|
||||
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Compare between branches
|
||||
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
|
||||
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
|
||||
// benchcmp branch1.txt branch2.txt
|
||||
//
|
||||
// # Run with memory profiling
|
||||
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run with CPU profiling
|
||||
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// dummyDialer creates a mock connection for benchmarking
|
||||
func dummyDialer(ctx context.Context) (net.Conn, error) {
|
||||
return &dummyConn{}, nil
|
||||
}
|
||||
|
||||
// dummyConn implements net.Conn for benchmarking
|
||||
type dummyConn struct{}
|
||||
|
||||
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Close() error { return nil }
|
||||
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
|
||||
func (c *dummyConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
|
||||
}
|
||||
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// =============================================================================
|
||||
// POOL BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
|
||||
func BenchmarkPoolGetPut(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
MinIdleConns: int32(0), // Start with no idle connections
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
|
||||
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
configs := []struct {
|
||||
poolSize int
|
||||
minIdleConns int
|
||||
}{
|
||||
{8, 2},
|
||||
{16, 4},
|
||||
{32, 8},
|
||||
{64, 16},
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(config.poolSize),
|
||||
MinIdleConns: int32(config.minIdleConns),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
|
||||
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(32),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
MinIdleConns: int32(0),
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Test with different levels of concurrency
|
||||
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||
b.SetParallelism(concurrency)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// PUBSUB BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// benchmarkClient creates a Redis client for benchmarking with mock dialer
|
||||
func benchmarkClient(poolSize int) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379", // Mock address
|
||||
DialTimeout: time.Second,
|
||||
ReadTimeout: time.Second,
|
||||
WriteTimeout: time.Second,
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
|
||||
func BenchmarkPubSubCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 4, 8, 16, 32}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
client := benchmarkClient(poolSize)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
|
||||
func BenchmarkPubSubPatternCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 4, 8, 16, 32}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
client := benchmarkClient(poolSize)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.PSubscribe(ctx, "test-*")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
|
||||
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
concurrencyLevels := []int{1, 2, 4, 8, 16}
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, concurrency)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
|
||||
func BenchmarkPubSubMultipleChannels(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
channelCounts := []int{1, 5, 10, 25, 50, 100}
|
||||
|
||||
for _, channelCount := range channelCounts {
|
||||
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
|
||||
// Prepare channel names
|
||||
channels := make([]string, channelCount)
|
||||
for i := 0; i < channelCount; i++ {
|
||||
channels[i] = fmt.Sprintf("channel-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.Subscribe(ctx, channels...)
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
|
||||
func BenchmarkPubSubReuse(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Benchmark just the creation and closing of PubSub connections
|
||||
// This simulates reuse patterns without requiring actual Redis operations
|
||||
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
|
||||
pubsub.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// COMBINED BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
|
||||
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// Mix of pool stats collection and PubSub creation
|
||||
if pb.Next() {
|
||||
// Pool stats operation
|
||||
stats := client.PoolStats()
|
||||
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
|
||||
}
|
||||
|
||||
if pb.Next() {
|
||||
// PubSub operation
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
|
||||
func BenchmarkPoolStatsCollection(b *testing.B) {
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stats := client.PoolStats()
|
||||
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolHighContention tests pool performance under high contention
|
||||
func BenchmarkPoolHighContention(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// High contention Get/Put operations
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
54
pubsub.go
54
pubsub.go
@@ -22,7 +22,7 @@ import (
|
||||
type PubSub struct {
|
||||
opt *Options
|
||||
|
||||
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
|
||||
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
|
||||
closeConn func(*pool.Conn) error
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -42,6 +42,9 @@ type PubSub struct {
|
||||
|
||||
// Push notification processor for handling generic push notifications
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// Cleanup callback for hitless upgrade tracking
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func (c *PubSub) init() {
|
||||
@@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
|
||||
return c.cn, nil
|
||||
}
|
||||
|
||||
if c.opt.Addr == "" {
|
||||
// TODO(hitless):
|
||||
// this is probably cluster client
|
||||
// c.newConn will ignore the addr argument
|
||||
// will be changed when we have hitless upgrades for cluster clients
|
||||
c.opt.Addr = internal.RedisNull
|
||||
}
|
||||
|
||||
channels := mapKeys(c.channels)
|
||||
channels = append(channels, newChannels...)
|
||||
|
||||
cn, err := c.newConn(ctx, channels)
|
||||
cn, err := c.newConn(ctx, c.opt.Addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -157,12 +168,31 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
|
||||
if c.cn != cn {
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.IsUsable() || cn.ShouldHandoff() {
|
||||
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
|
||||
}
|
||||
|
||||
if isBadConn(err, allowTimeout, c.opt.Addr) {
|
||||
c.reconnect(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PubSub) reconnect(ctx context.Context, reason error) {
|
||||
if c.cn != nil && c.cn.ShouldHandoff() {
|
||||
newEndpoint := c.cn.GetHandoffEndpoint()
|
||||
// If new endpoint is NULL, use the original address
|
||||
if newEndpoint == internal.RedisNull {
|
||||
newEndpoint = c.opt.Addr
|
||||
}
|
||||
|
||||
if newEndpoint != "" {
|
||||
// Update the address in the options
|
||||
oldAddr := c.cn.RemoteAddr().String()
|
||||
c.opt.Addr = newEndpoint
|
||||
internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr)
|
||||
}
|
||||
}
|
||||
_ = c.closeTheCn(reason)
|
||||
_, _ = c.conn(ctx, nil)
|
||||
}
|
||||
@@ -171,9 +201,6 @@ func (c *PubSub) closeTheCn(reason error) error {
|
||||
if c.cn == nil {
|
||||
return nil
|
||||
}
|
||||
if !c.closed {
|
||||
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
|
||||
}
|
||||
err := c.closeConn(c.cn)
|
||||
c.cn = nil
|
||||
return err
|
||||
@@ -189,6 +216,11 @@ func (c *PubSub) Close() error {
|
||||
c.closed = true
|
||||
close(c.exit)
|
||||
|
||||
// Call cleanup callback if set
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
|
||||
return c.closeTheCn(pool.ErrClosed)
|
||||
}
|
||||
|
||||
@@ -444,11 +476,10 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err)
|
||||
}
|
||||
return c.cmd.readReply(rd)
|
||||
})
|
||||
|
||||
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
|
||||
|
||||
if err != nil {
|
||||
@@ -461,6 +492,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
||||
// Receive returns a message as a Subscription, Message, Pong or error.
|
||||
// See PubSub example for details. This is low-level API and in most cases
|
||||
// Channel should be used instead.
|
||||
// Receive returns a message as a Subscription, Message, Pong, or an error.
|
||||
// See PubSub example for details. This is a low-level API and in most cases
|
||||
// Channel should be used instead.
|
||||
// This method blocks until a message is received or an error occurs.
|
||||
// It may return early with an error if the context is canceled, the connection fails,
|
||||
// or other internal errors occur.
|
||||
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
|
||||
return c.ReceiveTimeout(ctx, 0)
|
||||
}
|
||||
@@ -543,7 +580,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
|
||||
}
|
||||
|
||||
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
|
||||
if c.pushProcessor == nil {
|
||||
// Only process push notifications for RESP3 connections with a processor
|
||||
if c.opt.Protocol != 3 || c.pushProcessor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -113,6 +113,9 @@ var _ = Describe("PubSub", func() {
|
||||
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
|
||||
defer pubsub.Close()
|
||||
|
||||
// sleep a bit to make sure redis knows about the subscriptions
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))
|
||||
|
@@ -1,8 +1,6 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
// No imports needed for this file
|
||||
|
||||
// NotificationHandlerContext provides context information about where a push notification was received.
|
||||
// This struct allows handlers to make informed decisions based on the source of the notification
|
||||
@@ -35,7 +33,11 @@ type NotificationHandlerContext struct {
|
||||
PubSub interface{}
|
||||
|
||||
// Conn is the specific connection on which the notification was received.
|
||||
Conn *pool.Conn
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *pool.Conn
|
||||
Conn interface{}
|
||||
|
||||
// IsBlocking indicates if the notification was received on a blocking connection.
|
||||
IsBlocking bool
|
||||
|
315
push/processor_unit_test.go
Normal file
315
push/processor_unit_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProcessorCreation tests processor creation and initialization
|
||||
func TestProcessorCreation(t *testing.T) {
|
||||
t.Run("NewProcessor", func(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
if processor == nil {
|
||||
t.Fatal("NewProcessor should not return nil")
|
||||
}
|
||||
if processor.registry == nil {
|
||||
t.Error("Processor should have a registry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewVoidProcessor", func(t *testing.T) {
|
||||
voidProcessor := NewVoidProcessor()
|
||||
if voidProcessor == nil {
|
||||
t.Fatal("NewVoidProcessor should not return nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessorHandlerManagement tests handler registration and retrieval
|
||||
func TestProcessorHandlerManagement(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
handler := &UnitTestHandler{name: "test-handler"}
|
||||
|
||||
t.Run("RegisterHandler", func(t *testing.T) {
|
||||
err := processor.RegisterHandler("TEST", handler, false)
|
||||
if err != nil {
|
||||
t.Errorf("RegisterHandler should not error: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler is registered
|
||||
retrievedHandler := processor.GetHandler("TEST")
|
||||
if retrievedHandler != handler {
|
||||
t.Error("GetHandler should return the registered handler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterProtectedHandler", func(t *testing.T) {
|
||||
protectedHandler := &UnitTestHandler{name: "protected-handler"}
|
||||
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
|
||||
if err != nil {
|
||||
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler is registered
|
||||
retrievedHandler := processor.GetHandler("PROTECTED")
|
||||
if retrievedHandler != protectedHandler {
|
||||
t.Error("GetHandler should return the protected handler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetNonExistentHandler", func(t *testing.T) {
|
||||
handler := processor.GetHandler("NONEXISTENT")
|
||||
if handler != nil {
|
||||
t.Error("GetHandler should return nil for non-existent handler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterHandler", func(t *testing.T) {
|
||||
err := processor.UnregisterHandler("TEST")
|
||||
if err != nil {
|
||||
t.Errorf("UnregisterHandler should not error: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler is removed
|
||||
retrievedHandler := processor.GetHandler("TEST")
|
||||
if retrievedHandler != nil {
|
||||
t.Error("GetHandler should return nil after unregistering")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
|
||||
err := processor.UnregisterHandler("PROTECTED")
|
||||
if err == nil {
|
||||
t.Error("UnregisterHandler should error for protected handler")
|
||||
}
|
||||
|
||||
// Verify handler is still there
|
||||
retrievedHandler := processor.GetHandler("PROTECTED")
|
||||
if retrievedHandler == nil {
|
||||
t.Error("Protected handler should not be removed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestVoidProcessorBehavior tests void processor behavior
|
||||
func TestVoidProcessorBehavior(t *testing.T) {
|
||||
voidProcessor := NewVoidProcessor()
|
||||
handler := &UnitTestHandler{name: "test-handler"}
|
||||
|
||||
t.Run("GetHandler", func(t *testing.T) {
|
||||
retrievedHandler := voidProcessor.GetHandler("ANY")
|
||||
if retrievedHandler != nil {
|
||||
t.Error("VoidProcessor GetHandler should always return nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterHandler", func(t *testing.T) {
|
||||
err := voidProcessor.RegisterHandler("TEST", handler, false)
|
||||
if err == nil {
|
||||
t.Error("VoidProcessor RegisterHandler should return error")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if !IsVoidProcessorError(err) {
|
||||
t.Error("Error should be a VoidProcessorError")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterHandler", func(t *testing.T) {
|
||||
err := voidProcessor.UnregisterHandler("TEST")
|
||||
if err == nil {
|
||||
t.Error("VoidProcessor UnregisterHandler should return error")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if !IsVoidProcessorError(err) {
|
||||
t.Error("Error should be a VoidProcessorError")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessPendingNotificationsNilReader tests handling of nil reader
|
||||
func TestProcessPendingNotificationsNilReader(t *testing.T) {
|
||||
t.Run("ProcessorWithNilReader", func(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
ctx := context.Background()
|
||||
handlerCtx := NotificationHandlerContext{}
|
||||
|
||||
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
|
||||
voidProcessor := NewVoidProcessor()
|
||||
ctx := context.Background()
|
||||
handlerCtx := NotificationHandlerContext{}
|
||||
|
||||
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
|
||||
if err != nil {
|
||||
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWillHandleNotificationInClient tests the notification filtering logic
|
||||
func TestWillHandleNotificationInClient(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
notificationType string
|
||||
shouldHandle bool
|
||||
}{
|
||||
// Pub/Sub notifications (should be handled in client)
|
||||
{"message", "message", true},
|
||||
{"pmessage", "pmessage", true},
|
||||
{"subscribe", "subscribe", true},
|
||||
{"unsubscribe", "unsubscribe", true},
|
||||
{"psubscribe", "psubscribe", true},
|
||||
{"punsubscribe", "punsubscribe", true},
|
||||
{"smessage", "smessage", true},
|
||||
{"ssubscribe", "ssubscribe", true},
|
||||
{"sunsubscribe", "sunsubscribe", true},
|
||||
|
||||
// Push notifications (should be handled by processor)
|
||||
{"MOVING", "MOVING", false},
|
||||
{"MIGRATING", "MIGRATING", false},
|
||||
{"MIGRATED", "MIGRATED", false},
|
||||
{"FAILING_OVER", "FAILING_OVER", false},
|
||||
{"FAILED_OVER", "FAILED_OVER", false},
|
||||
{"custom", "custom", false},
|
||||
{"unknown", "unknown", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := willHandleNotificationInClient(tc.notificationType)
|
||||
if result != tc.shouldHandle {
|
||||
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessorErrorHandlingUnit tests error handling scenarios
|
||||
func TestProcessorErrorHandlingUnit(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
|
||||
t.Run("RegisterNilHandler", func(t *testing.T) {
|
||||
err := processor.RegisterHandler("TEST", nil, false)
|
||||
if err == nil {
|
||||
t.Error("RegisterHandler should error with nil handler")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if !IsHandlerNilError(err) {
|
||||
t.Error("Error should be a HandlerNilError")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
|
||||
handler1 := &UnitTestHandler{name: "handler1"}
|
||||
handler2 := &UnitTestHandler{name: "handler2"}
|
||||
|
||||
// Register first handler
|
||||
err := processor.RegisterHandler("DUPLICATE", handler1, false)
|
||||
if err != nil {
|
||||
t.Errorf("First RegisterHandler should not error: %v", err)
|
||||
}
|
||||
|
||||
// Try to register second handler with same name
|
||||
err = processor.RegisterHandler("DUPLICATE", handler2, false)
|
||||
if err == nil {
|
||||
t.Error("RegisterHandler should error when registering duplicate handler")
|
||||
}
|
||||
|
||||
// Verify original handler is still there
|
||||
retrievedHandler := processor.GetHandler("DUPLICATE")
|
||||
if retrievedHandler != handler1 {
|
||||
t.Error("Original handler should remain after failed duplicate registration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
|
||||
err := processor.UnregisterHandler("NONEXISTENT")
|
||||
if err != nil {
|
||||
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessorConcurrentAccess tests concurrent access to processor
|
||||
func TestProcessorConcurrentAccess(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
|
||||
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Goroutine 1: Register handlers
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for i := 0; i < 100; i++ {
|
||||
handler := &UnitTestHandler{name: "concurrent-handler"}
|
||||
processor.RegisterHandler("CONCURRENT", handler, false)
|
||||
processor.UnregisterHandler("CONCURRENT")
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine 2: Get handlers
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for i := 0; i < 100; i++ {
|
||||
processor.GetHandler("CONCURRENT")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for both goroutines to complete
|
||||
<-done
|
||||
<-done
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessorInterfaceCompliance tests interface compliance
|
||||
func TestProcessorInterfaceCompliance(t *testing.T) {
|
||||
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
|
||||
var _ NotificationProcessor = (*Processor)(nil)
|
||||
})
|
||||
|
||||
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
|
||||
var _ NotificationProcessor = (*VoidProcessor)(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// UnitTestHandler is a test implementation of NotificationHandler
|
||||
type UnitTestHandler struct {
|
||||
name string
|
||||
lastNotification []interface{}
|
||||
errorToReturn error
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
|
||||
h.callCount++
|
||||
h.lastNotification = notification
|
||||
return h.errorToReturn
|
||||
}
|
||||
|
||||
// Helper methods for UnitTestHandler
|
||||
func (h *UnitTestHandler) GetCallCount() int {
|
||||
return h.callCount
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) GetLastNotification() []interface{} {
|
||||
return h.lastNotification
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) SetErrorToReturn(err error) {
|
||||
h.errorToReturn = err
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) Reset() {
|
||||
h.callCount = 0
|
||||
h.lastNotification = nil
|
||||
h.errorToReturn = nil
|
||||
}
|
@@ -4,24 +4,6 @@ import (
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Push notification constants for cluster operations
|
||||
const (
|
||||
// MOVING indicates a slot is being moved to a different node
|
||||
PushNotificationMoving = "MOVING"
|
||||
|
||||
// MIGRATING indicates a slot is being migrated from this node
|
||||
PushNotificationMigrating = "MIGRATING"
|
||||
|
||||
// MIGRATED indicates a slot has been migrated to this node
|
||||
PushNotificationMigrated = "MIGRATED"
|
||||
|
||||
// FAILING_OVER indicates a failover is starting
|
||||
PushNotificationFailingOver = "FAILING_OVER"
|
||||
|
||||
// FAILED_OVER indicates a failover has completed
|
||||
PushNotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// NewPushNotificationProcessor creates a new push notification processor
|
||||
// This processor maintains a registry of handlers and processes push notifications
|
||||
// It is used for RESP3 connections where push notifications are available
|
||||
|
226
redis.go
226
redis.go
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/hscan"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
@@ -205,18 +206,34 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
||||
|
||||
type baseClient struct {
|
||||
opt *Options
|
||||
optLock sync.RWMutex
|
||||
connPool pool.Pooler
|
||||
pubSubPool *pool.PubSubPool
|
||||
hooksMixin
|
||||
|
||||
onClose func() error // hook called when client is closed
|
||||
|
||||
// Push notification processing
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// Hitless upgrade manager
|
||||
hitlessManager *hitless.HitlessManager
|
||||
hitlessManagerLock sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *baseClient) clone() *baseClient {
|
||||
clone := *c
|
||||
return &clone
|
||||
c.hitlessManagerLock.RLock()
|
||||
hitlessManager := c.hitlessManager
|
||||
c.hitlessManagerLock.RUnlock()
|
||||
|
||||
clone := &baseClient{
|
||||
opt: c.opt,
|
||||
connPool: c.connPool,
|
||||
onClose: c.onClose,
|
||||
pushProcessor: c.pushProcessor,
|
||||
hitlessManager: hitlessManager,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
|
||||
@@ -234,21 +251,6 @@ func (c *baseClient) String() string {
|
||||
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
|
||||
}
|
||||
|
||||
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
|
||||
cn, err := c.connPool.NewConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = c.connPool.CloseConn(cn)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
if c.opt.Limiter != nil {
|
||||
err := c.opt.Limiter.Allow()
|
||||
@@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cn.Inited {
|
||||
if cn.IsInited() {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
@@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if cn.Inited {
|
||||
if !cn.Inited.CompareAndSwap(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
cn.Inited = true
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
@@ -430,6 +430,51 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
}
|
||||
|
||||
// Enable maintenance notifications if hitless upgrades are configured
|
||||
c.optLock.RLock()
|
||||
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
|
||||
protocol := c.opt.Protocol
|
||||
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
|
||||
c.optLock.RUnlock()
|
||||
var hitlessHandshakeErr error
|
||||
if hitlessEnabled && protocol == 3 {
|
||||
hitlessHandshakeErr = conn.ClientMaintNotifications(
|
||||
ctx,
|
||||
true,
|
||||
endpointType.String(),
|
||||
).Err()
|
||||
if hitlessHandshakeErr != nil {
|
||||
if !isRedisError(hitlessHandshakeErr) {
|
||||
// if not redis error, fail the connection
|
||||
return hitlessHandshakeErr
|
||||
}
|
||||
c.optLock.Lock()
|
||||
// handshake failed - check and modify config atomically
|
||||
switch c.opt.HitlessUpgradeConfig.Mode {
|
||||
case hitless.MaintNotificationsEnabled:
|
||||
// enabled mode, fail the connection
|
||||
c.optLock.Unlock()
|
||||
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
|
||||
default: // will handle auto and any other
|
||||
internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake error: %v", hitlessHandshakeErr)
|
||||
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
|
||||
c.optLock.Unlock()
|
||||
// auto mode, disable hitless upgrades and continue
|
||||
if err := c.disableHitlessUpgrades(); err != nil {
|
||||
// Log error but continue - auto mode should be resilient
|
||||
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// handshake was executed successfully
|
||||
// to make sure that the handshake will be executed on other connections as well if it was successfully
|
||||
// executed on this connection, we will force the handshake to be executed on all connections
|
||||
c.optLock.Lock()
|
||||
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
|
||||
c.optLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||
libName := ""
|
||||
libVer := Version()
|
||||
@@ -446,6 +491,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsable(true)
|
||||
cn.Inited.Store(true)
|
||||
|
||||
// Set the connection initialization function for potential reconnections
|
||||
cn.SetInitConnFunc(c.createInitConnFunc())
|
||||
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
}
|
||||
@@ -512,6 +563,8 @@ func (c *baseClient) assertUnstableCommand(cmd Cmder) bool {
|
||||
if c.opt.UnstableResp3 {
|
||||
return true
|
||||
} else {
|
||||
// TODO: find the best way to remove the panic and return error here
|
||||
// The client should not panic when executing a command, only when initializing.
|
||||
panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.")
|
||||
}
|
||||
default:
|
||||
@@ -593,20 +646,77 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
|
||||
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
|
||||
return func(ctx context.Context, cn *pool.Conn) error {
|
||||
return c.initConn(ctx, cn)
|
||||
}
|
||||
}
|
||||
|
||||
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
|
||||
// This function is called during client initialization.
|
||||
// will register push notification handlers for all hitless upgrade events.
|
||||
// will start background workers for handoff processing in the pool hook.
|
||||
func (c *baseClient) enableHitlessUpgrades() error {
|
||||
// Create client adapter
|
||||
clientAdapterInstance := newClientAdapter(c)
|
||||
|
||||
// Create hitless manager directly
|
||||
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Set the manager reference and initialize pool hook
|
||||
c.hitlessManagerLock.Lock()
|
||||
c.hitlessManager = manager
|
||||
c.hitlessManagerLock.Unlock()
|
||||
|
||||
// Initialize pool hook (safe to call without lock since manager is now set)
|
||||
manager.InitPoolHook(c.dialHook)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseClient) disableHitlessUpgrades() error {
|
||||
c.hitlessManagerLock.Lock()
|
||||
defer c.hitlessManagerLock.Unlock()
|
||||
|
||||
// Close the hitless manager
|
||||
if c.hitlessManager != nil {
|
||||
// Closing the manager will also shutdown the pool hook
|
||||
// and remove it from the pool
|
||||
c.hitlessManager.Close()
|
||||
c.hitlessManager = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the client, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Client, as the Client is meant to be
|
||||
// long-lived and shared between many goroutines.
|
||||
func (c *baseClient) Close() error {
|
||||
var firstErr error
|
||||
|
||||
// Close hitless manager first
|
||||
if err := c.disableHitlessUpgrades(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
|
||||
if c.onClose != nil {
|
||||
if err := c.onClose(); err != nil {
|
||||
if err := c.onClose(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if c.connPool != nil {
|
||||
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if c.pubSubPool != nil {
|
||||
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
@@ -796,6 +906,8 @@ func NewClient(opt *Options) *Client {
|
||||
if opt == nil {
|
||||
panic("redis: NewClient nil options")
|
||||
}
|
||||
// clone to not share options with the caller
|
||||
opt = opt.clone()
|
||||
opt.init()
|
||||
|
||||
// Push notifications are always enabled for RESP3 (cannot be disabled)
|
||||
@@ -810,11 +922,40 @@ func NewClient(opt *Options) *Client {
|
||||
// Initialize push notification processor using shared helper
|
||||
// Use void processor for RESP2 connections (push notifications not available)
|
||||
c.pushProcessor = initializePushProcessor(opt)
|
||||
// set opt push processor for child clients
|
||||
c.opt.PushNotificationProcessor = c.pushProcessor
|
||||
|
||||
// Update options with the initialized push processor for connection pool
|
||||
opt.PushNotificationProcessor = c.pushProcessor
|
||||
// Create connection pools
|
||||
var err error
|
||||
c.connPool, err = newConnPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
|
||||
}
|
||||
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
// Initialize hitless upgrades first if enabled and protocol is RESP3
|
||||
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
|
||||
err := c.enableHitlessUpgrades()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
|
||||
/*
|
||||
Design decision: panic here to fail fast if hitless upgrades cannot be enabled when explicitly requested.
|
||||
We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect
|
||||
an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced
|
||||
immediately, rather than allowing the client to continue in a partially initialized or inconsistent state.
|
||||
Clients relying on hitless upgrades should be aware that initialization errors will cause a panic, and should
|
||||
handle this accordingly (e.g., via recover or by validating configuration before calling NewClient).
|
||||
This approach is only used when HitlessUpgradeConfig.Mode is MaintNotificationsEnabled, indicating that hitless
|
||||
upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic.
|
||||
*/
|
||||
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
@@ -851,6 +992,14 @@ func (c *Client) Options() *Options {
|
||||
return c.opt
|
||||
}
|
||||
|
||||
// GetHitlessManager returns the hitless manager instance for monitoring and control.
|
||||
// Returns nil if hitless upgrades are not enabled.
|
||||
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
|
||||
c.hitlessManagerLock.RLock()
|
||||
defer c.hitlessManagerLock.RUnlock()
|
||||
return c.hitlessManager
|
||||
}
|
||||
|
||||
// initializePushProcessor initializes the push notification processor for any client type.
|
||||
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
|
||||
func initializePushProcessor(opt *Options) push.NotificationProcessor {
|
||||
@@ -887,6 +1036,7 @@ type PoolStats pool.Stats
|
||||
// PoolStats returns connection pool stats.
|
||||
func (c *Client) PoolStats() *PoolStats {
|
||||
stats := c.connPool.Stats()
|
||||
stats.PubSubStats = *(c.pubSubPool.Stats())
|
||||
return (*PoolStats)(stats)
|
||||
}
|
||||
|
||||
@@ -921,11 +1071,27 @@ func (c *Client) TxPipeline() Pipeliner {
|
||||
func (c *Client) pubSub() *PubSub {
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt,
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
return c.newConn(ctx)
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// will return nil if already initialized
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Track connection in PubSubPool
|
||||
c.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
// Untrack connection from PubSubPool
|
||||
c.pubSubPool.UntrackConn(cn)
|
||||
_ = cn.Close()
|
||||
return nil
|
||||
},
|
||||
closeConn: c.connPool.CloseConn,
|
||||
pushProcessor: c.pushProcessor,
|
||||
}
|
||||
pubsub.init()
|
||||
@@ -1113,6 +1279,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica
|
||||
return push.NotificationHandlerContext{
|
||||
Client: c,
|
||||
ConnPool: c.connPool,
|
||||
Conn: cn,
|
||||
Conn: cn, // Wrap in adapter for easier interface access
|
||||
}
|
||||
}
|
||||
|
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
62
sentinel.go
62
sentinel.go
@@ -16,8 +16,8 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -139,6 +139,14 @@ type FailoverOptions struct {
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
UnstableResp3 bool
|
||||
|
||||
// Hitless is not supported for FailoverClients at the moment
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||
// upgrade notifications gracefully and manage connection/pool state transitions
|
||||
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, hitless upgrades are disabled.
|
||||
//HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
func (opt *FailoverOptions) clientOptions() *Options {
|
||||
@@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
opt.Dialer = masterReplicaDialer(failover)
|
||||
opt.init()
|
||||
|
||||
var connPool *pool.ConnPool
|
||||
|
||||
rdb := &Client{
|
||||
baseClient: &baseClient{
|
||||
opt: opt,
|
||||
@@ -469,16 +475,26 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
// Use void processor by default for RESP2 connections
|
||||
rdb.pushProcessor = initializePushProcessor(opt)
|
||||
|
||||
connPool = newConnPool(opt, rdb.dialHook)
|
||||
rdb.connPool = connPool
|
||||
var err error
|
||||
rdb.connPool, err = newConnPool(opt, rdb.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
|
||||
}
|
||||
rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||
|
||||
failover.mu.Lock()
|
||||
failover.onFailover = func(ctx context.Context, addr string) {
|
||||
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
|
||||
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
||||
return cn.RemoteAddr().String() != addr
|
||||
})
|
||||
}
|
||||
}
|
||||
failover.mu.Unlock()
|
||||
|
||||
return rdb
|
||||
@@ -543,7 +559,15 @@ func NewSentinelClient(opt *Options) *SentinelClient {
|
||||
dial: c.baseClient.dial,
|
||||
process: c.baseClient.process,
|
||||
})
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
var err error
|
||||
c.connPool, err = newConnPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
|
||||
}
|
||||
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
@@ -570,13 +594,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
|
||||
func (c *SentinelClient) pubSub() *PubSub {
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt,
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
return c.newConn(ctx)
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// will return nil if already initialized
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Track connection in PubSubPool
|
||||
c.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: c.connPool.CloseConn,
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
// Untrack connection from PubSubPool
|
||||
c.pubSubPool.UntrackConn(cn)
|
||||
_ = cn.Close()
|
||||
return nil
|
||||
},
|
||||
pushProcessor: c.pushProcessor,
|
||||
}
|
||||
pubsub.init()
|
||||
|
||||
return pubsub
|
||||
}
|
||||
|
||||
|
2
tx.go
2
tx.go
@@ -24,7 +24,7 @@ type Tx struct {
|
||||
func (c *Client) newTx() *Tx {
|
||||
tx := Tx{
|
||||
baseClient: baseClient{
|
||||
opt: c.opt,
|
||||
opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client
|
||||
connPool: pool.NewStickyConnPool(c.connPool),
|
||||
hooksMixin: c.hooksMixin.clone(),
|
||||
pushProcessor: c.pushProcessor, // Copy push processor from parent client
|
||||
|
@@ -122,6 +122,9 @@ type UniversalOptions struct {
|
||||
|
||||
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
|
||||
IsClusterMode bool
|
||||
|
||||
// HitlessUpgradeConfig provides configuration for hitless upgrades.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
// Cluster returns cluster options created from the universal options.
|
||||
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
// Note: HitlessUpgradeConfig not supported for FailoverOptions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +293,7 @@ func (o *UniversalOptions) Simple() *Options {
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user