mirror of
https://github.com/redis/go-redis.git
synced 2025-09-02 22:01:16 +03:00
feat(hitless): Introduce handlers for hitless upgrades
This commit includes all the work on hitless upgrades with the addition of: - Pubsub Pool - Examples - Refactor of push - Refactor of pool (using atomics for most things) - Introducing of hooks in pool
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,3 +9,6 @@ coverage.txt
|
||||
**/coverage.txt
|
||||
.vscode
|
||||
tmp/*
|
||||
|
||||
# Hitless upgrade documentation (temporary)
|
||||
hitless/docs/
|
||||
|
149
adapters.go
Normal file
149
adapters.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
|
||||
var ErrInvalidCommand = errors.New("invalid command type")
|
||||
|
||||
// ErrInvalidPool is returned when the pool type is not supported.
|
||||
var ErrInvalidPool = errors.New("invalid pool type")
|
||||
|
||||
// newClientAdapter creates a new client adapter for regular Redis clients.
|
||||
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
|
||||
return &clientAdapter{client: client}
|
||||
}
|
||||
|
||||
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
|
||||
type clientAdapter struct {
|
||||
client *baseClient
|
||||
}
|
||||
|
||||
// GetOptions returns the client options.
|
||||
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
|
||||
return &optionsAdapter{options: ca.client.opt}
|
||||
}
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
|
||||
}
|
||||
|
||||
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
|
||||
type optionsAdapter struct {
|
||||
options *Options
|
||||
}
|
||||
|
||||
// GetReadTimeout returns the read timeout.
|
||||
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
|
||||
return oa.options.ReadTimeout
|
||||
}
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
|
||||
return oa.options.WriteTimeout
|
||||
}
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
func (oa *optionsAdapter) GetNetwork() string {
|
||||
return oa.options.Network
|
||||
}
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
func (oa *optionsAdapter) GetAddr() string {
|
||||
return oa.options.Addr
|
||||
}
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
func (oa *optionsAdapter) IsTLSEnabled() bool {
|
||||
return oa.options.TLSConfig != nil
|
||||
}
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
func (oa *optionsAdapter) GetProtocol() int {
|
||||
return oa.options.Protocol
|
||||
}
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
func (oa *optionsAdapter) GetPoolSize() int {
|
||||
return oa.options.PoolSize
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
baseDialer := oa.options.NewDialer()
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Extract network and address from the options
|
||||
network := oa.options.Network
|
||||
addr := oa.options.Addr
|
||||
return baseDialer(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout
|
||||
type connectionAdapter struct {
|
||||
conn *pool.Conn
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (ca *connectionAdapter) Close() error {
|
||||
return ca.conn.Close()
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands.
|
||||
func (ca *connectionAdapter) IsUsable() bool {
|
||||
return ca.conn.IsUsable()
|
||||
}
|
||||
|
||||
// GetPoolConnection returns the underlying pool connection.
|
||||
func (ca *connectionAdapter) GetPoolConnection() *pool.Conn {
|
||||
return ca.conn
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||
// These timeouts remain active until explicitly cleared.
|
||||
func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||
ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout)
|
||||
}
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||
ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline)
|
||||
}
|
||||
|
||||
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
|
||||
func (ca *connectionAdapter) ClearRelaxedTimeout() {
|
||||
ca.conn.ClearRelaxedTimeout()
|
||||
}
|
||||
|
||||
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
|
||||
type pushProcessorAdapter struct {
|
||||
processor push.NotificationProcessor
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
|
||||
if pushHandler, ok := handler.(push.NotificationHandler); ok {
|
||||
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
|
||||
}
|
||||
return errors.New("handler must implement push.NotificationHandler")
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
|
||||
return ppa.processor.UnregisterHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
|
||||
return ppa.processor.GetHandler(pushNotificationName)
|
||||
}
|
348
async_handoff_integration_test.go
Normal file
348
async_handoff_integration_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// mockNetConn implements net.Conn for testing
|
||||
type mockNetConn struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (m *mockNetConn) Close() error { return nil }
|
||||
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
|
||||
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
|
||||
func TestEventDrivenHandoffIntegration(t *testing.T) {
|
||||
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
|
||||
// Create a base dialer for testing
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
// Create processor with event-driven handoff support
|
||||
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a test pool with hooks
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
PoolSize: int32(5),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
defer testPool.Close()
|
||||
|
||||
// Set the pool reference in the processor for connection removal on handoff failure
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection and mark it for handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
// Set initialization function with a small delay to ensure handoff is pending
|
||||
initConnCalled := false
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||
initConnCalled = true
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
// Mark connection for handoff
|
||||
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Return connection to pool - this should queue handoff
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Give the on-demand worker a moment to start processing
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify handoff was queued
|
||||
if !processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should be queued in pending map")
|
||||
}
|
||||
|
||||
// Try to get the same connection - should be skipped due to pending handoff
|
||||
conn2, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get second connection: %v", err)
|
||||
}
|
||||
|
||||
// Should get a different connection (the pending one should be skipped)
|
||||
if conn == conn2 {
|
||||
t.Error("Should have gotten a different connection while handoff is pending")
|
||||
}
|
||||
|
||||
// Return the second connection
|
||||
testPool.Put(ctx, conn2)
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify handoff completed (removed from pending map)
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should have completed and been removed from pending map")
|
||||
}
|
||||
|
||||
if !initConnCalled {
|
||||
t.Error("InitConn should have been called during handoff")
|
||||
}
|
||||
|
||||
// Now the original connection should be available again
|
||||
conn3, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get third connection: %v", err)
|
||||
}
|
||||
|
||||
// Could be the original connection (now handed off) or a new one
|
||||
testPool.Put(ctx, conn3)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentHandoffs", func(t *testing.T) {
|
||||
// Create a base dialer that simulates slow handoffs
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate network delay
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start multiple concurrent handoffs
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Get connection
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get connection %d: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set initialization function
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
// Mark for handoff
|
||||
conn.MarkForHandoff("new-endpoint:6379", int64(id))
|
||||
|
||||
// Return to pool (starts async handoff)
|
||||
testPool.Put(ctx, conn)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Wait for all handoffs to complete
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Verify pool is still functional
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
|
||||
}
|
||||
testPool.Put(ctx, conn)
|
||||
})
|
||||
|
||||
t.Run("HandoffFailureRecovery", func(t *testing.T) {
|
||||
// Create a failing base dialer
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
|
||||
}
|
||||
|
||||
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(3),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get connection and mark for handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
|
||||
|
||||
// Return to pool (starts async handoff that will fail)
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Wait for handoff to fail
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Connection should be removed from pending map after failed handoff
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Connection should be removed from pending map after failed handoff")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
conn2, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Pool should still be functional: %v", err)
|
||||
}
|
||||
|
||||
// In event-driven approach, the original connection remains in pool
|
||||
// even after failed handoff (it's still a valid connection)
|
||||
// We might get the same connection or a different one
|
||||
testPool.Put(ctx, conn2)
|
||||
})
|
||||
|
||||
t.Run("GracefulShutdown", func(t *testing.T) {
|
||||
// Create a slow base dialer
|
||||
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create hooks manager and add processor as hook
|
||||
hookManager := pool.NewPoolHookManager()
|
||||
hookManager.AddHook(processor)
|
||||
|
||||
testPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &mockNetConn{addr: "original:6379"}, nil
|
||||
},
|
||||
|
||||
PoolSize: int32(2),
|
||||
PoolTimeout: time.Second,
|
||||
})
|
||||
defer testPool.Close()
|
||||
|
||||
// Add the hook to the pool after creation
|
||||
testPool.AddPoolHook(processor)
|
||||
|
||||
// Set the pool reference in the processor
|
||||
processor.SetPool(testPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a handoff
|
||||
conn, err := testPool.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function with delay to ensure handoff is pending
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||
return nil
|
||||
})
|
||||
|
||||
testPool.Put(ctx, conn)
|
||||
|
||||
// Give the on-demand worker a moment to start and begin processing
|
||||
// The handoff should be pending because the slowDialer takes 100ms
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify handoff was queued and is being processed
|
||||
if !processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should be queued in pending map")
|
||||
}
|
||||
|
||||
// Give the handoff a moment to start processing
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown processor gracefully
|
||||
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = processor.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
t.Errorf("Graceful shutdown should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Handoff should have completed (removed from pending map)
|
||||
if processor.IsHandoffPending(conn) {
|
||||
t.Error("Handoff should have completed and been removed from pending map after shutdown")
|
||||
}
|
||||
})
|
||||
}
|
18
commands.go
18
commands.go
@@ -193,6 +193,7 @@ type Cmdable interface {
|
||||
ClientID(ctx context.Context) *IntCmd
|
||||
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
||||
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
||||
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
|
||||
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
||||
ConfigResetStat(ctx context.Context) *StatusCmd
|
||||
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
||||
@@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
|
||||
// When enabled, the client will receive push notifications about Redis maintenance events.
|
||||
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
|
||||
args := []interface{}{"client", "maint_notifications"}
|
||||
if enabled {
|
||||
if endpointType == "" {
|
||||
endpointType = "none"
|
||||
}
|
||||
args = append(args, "on", "moving-endpoint-type", endpointType)
|
||||
} else {
|
||||
args = append(args, "off")
|
||||
}
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
||||
|
12
example/pubsub/go.mod
Normal file
12
example/pubsub/go.mod
Normal file
@@ -0,0 +1,12 @@
|
||||
module github.com/redis/go-redis/example/pubsub
|
||||
|
||||
go 1.18
|
||||
|
||||
replace github.com/redis/go-redis/v9 => ../..
|
||||
|
||||
require github.com/redis/go-redis/v9 v9.11.0
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
)
|
6
example/pubsub/go.sum
Normal file
6
example/pubsub/go.sum
Normal file
@@ -0,0 +1,6 @@
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
146
example/pubsub/main.go
Normal file
146
example/pubsub/main.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
|
||||
// It was used to find regressions in pool management in hitless mode.
|
||||
// Please don't use it as a reference for how to use pubsub.
|
||||
func main() {
|
||||
wg := &sync.WaitGroup{}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: ":6379",
|
||||
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
|
||||
Mode: hitless.MaintNotificationsEnabled,
|
||||
},
|
||||
})
|
||||
_ = rdb.FlushDB(ctx).Err()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
|
||||
}
|
||||
}()
|
||||
err := rdb.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("published", rdb.Get(ctx, "published").Val())
|
||||
fmt.Println("received", rdb.Get(ctx, "received").Val())
|
||||
subCtx, cancelSubCtx := context.WithCancel(ctx)
|
||||
pubCtx, cancelPublishers := context.WithCancel(ctx)
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go subscribe(subCtx, rdb, "test", i, wg)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
cancelSubCtx()
|
||||
time.Sleep(time.Second)
|
||||
subCtx, cancelSubCtx = context.WithCancel(ctx)
|
||||
for i := 0; i < 10; i++ {
|
||||
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
wg.Add(1)
|
||||
go floodThePool(pubCtx, rdb, wg)
|
||||
}
|
||||
|
||||
for i := 0; i < 500; i++ {
|
||||
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
wg.Add(1)
|
||||
go subscribe(subCtx, rdb, "test2", i, wg)
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
fmt.Println("canceling publishers")
|
||||
cancelPublishers()
|
||||
time.Sleep(10 * time.Second)
|
||||
fmt.Println("canceling subscribers")
|
||||
cancelSubCtx()
|
||||
wg.Wait()
|
||||
published, err := rdb.Get(ctx, "published").Result()
|
||||
received, err := rdb.Get(ctx, "received").Result()
|
||||
publishers, err := rdb.Get(ctx, "publishers").Result()
|
||||
subscribers, err := rdb.Get(ctx, "subscribers").Result()
|
||||
fmt.Printf("publishers: %s\n", publishers)
|
||||
fmt.Printf("published: %s\n", published)
|
||||
fmt.Printf("subscribers: %s\n", subscribers)
|
||||
fmt.Printf("received: %s\n", received)
|
||||
publishedInt, err := rdb.Get(ctx, "published").Int()
|
||||
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
|
||||
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := rdb.Publish(ctx, "test2", "hello").Err()
|
||||
if err != nil {
|
||||
// noop
|
||||
//log.Println("publish error:", err)
|
||||
}
|
||||
|
||||
err = rdb.Incr(ctx, "published").Err()
|
||||
if err != nil {
|
||||
// noop
|
||||
//log.Println("incr error:", err)
|
||||
}
|
||||
time.Sleep(10 * time.Nanosecond)
|
||||
}
|
||||
}
|
||||
|
||||
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
rec := rdb.Subscribe(ctx, topic)
|
||||
recChan := rec.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rec.Close()
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rec.Close()
|
||||
return
|
||||
case msg := <-recChan:
|
||||
err := rdb.Incr(ctx, "received").Err()
|
||||
if err != nil {
|
||||
log.Println("incr error:", err)
|
||||
}
|
||||
_ = msg // Use the message to avoid unused variable warning
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -57,6 +57,8 @@ func Example_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||
// finished processing: <[ping]>
|
||||
}
|
||||
|
||||
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||
// pipeline finished processing: [[ping] [ping]]
|
||||
}
|
||||
|
||||
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
|
||||
// finished dialing tcp :6379
|
||||
// starting processing: <[hello 3]>
|
||||
// finished processing: <[hello 3]>
|
||||
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||
// finished processing: <[watch foo]>
|
||||
// starting processing: <[ping]>
|
||||
// finished processing: <[ping]>
|
||||
|
72
hitless/README.md
Normal file
72
hitless/README.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# Hitless Upgrades
|
||||
|
||||
Seamless Redis connection handoffs during topology changes without interrupting operations.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import "github.com/redis/go-redis/v9/hitless"
|
||||
|
||||
opt := &redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
HitlessUpgrades: &redis.HitlessUpgradeConfig{
|
||||
Mode: hitless.MaintNotificationsEnabled, // or MaintNotificationsAuto
|
||||
},
|
||||
}
|
||||
client := redis.NewClient(opt)
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- **`MaintNotificationsDisabled`**: Hitless upgrades are completely disabled
|
||||
- **`MaintNotificationsEnabled`**: Hitless upgrades are forcefully enabled (fails if server doesn't support it)
|
||||
- **`MaintNotificationsAuto`**: Hitless upgrades are enabled if server supports it (default)
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
import "github.com/redis/go-redis/v9/hitless"
|
||||
|
||||
Config: &hitless.Config{
|
||||
Mode: hitless.MaintNotificationsAuto, // Notification mode
|
||||
MaxHandoffRetries: 3, // Retry failed handoffs
|
||||
HandoffTimeout: 15 * time.Second, // Handoff operation timeout
|
||||
RelaxedTimeout: 10 * time.Second, // Extended timeout during migrations
|
||||
PostHandoffRelaxedDuration: 20 * time.Second, // Keep relaxed timeout after handoff
|
||||
LogLevel: 1, // 0=errors, 1=warnings, 2=info, 3=debug
|
||||
MaxWorkers: 15, // Concurrent handoff workers
|
||||
HandoffQueueSize: 50, // Handoff request queue size
|
||||
}
|
||||
```
|
||||
|
||||
### Worker Scaling
|
||||
- **Auto-calculated**: `min(10, PoolSize/3)` - scales with pool size, capped at 10
|
||||
- **Explicit values**: `max(10, set_value)` - enforces minimum 10 workers
|
||||
- **On-demand**: Workers created when needed, cleaned up when idle
|
||||
|
||||
### Queue Sizing
|
||||
- **Auto-calculated**: `10 × MaxWorkers`, capped by pool size
|
||||
- **Always capped**: Queue size never exceeds pool size
|
||||
|
||||
## Metrics Hook Example
|
||||
|
||||
A metrics collection hook is available in `example_hooks.go` that demonstrates how to monitor hitless upgrade operations:
|
||||
|
||||
```go
|
||||
import "github.com/redis/go-redis/v9/hitless"
|
||||
|
||||
metricsHook := hitless.NewMetricsHook()
|
||||
// Use with your monitoring system
|
||||
```
|
||||
|
||||
The metrics hook tracks:
|
||||
- Handoff success/failure rates
|
||||
- Handoff duration
|
||||
- Queue depth
|
||||
- Worker utilization
|
||||
- Connection lifecycle events
|
||||
|
||||
## Requirements
|
||||
|
||||
- **RESP3 Protocol**: Required for push notifications
|
377
hitless/config.go
Normal file
377
hitless/config.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
// MaintNotificationsMode represents the maintenance notifications mode
|
||||
type MaintNotificationsMode string
|
||||
|
||||
// Constants for maintenance push notifications modes
|
||||
const (
|
||||
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
|
||||
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
|
||||
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
|
||||
)
|
||||
|
||||
// IsValid returns true if the maintenance notifications mode is valid
|
||||
func (m MaintNotificationsMode) IsValid() bool {
|
||||
switch m {
|
||||
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the mode
|
||||
func (m MaintNotificationsMode) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
// EndpointType represents the type of endpoint to request in MOVING notifications
|
||||
type EndpointType string
|
||||
|
||||
// Constants for endpoint types
|
||||
const (
|
||||
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
|
||||
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
|
||||
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
|
||||
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
|
||||
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
|
||||
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
|
||||
)
|
||||
|
||||
// IsValid returns true if the endpoint type is valid
|
||||
func (e EndpointType) IsValid() bool {
|
||||
switch e {
|
||||
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the endpoint type
|
||||
func (e EndpointType) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Config provides configuration options for hitless upgrades.
|
||||
type Config struct {
|
||||
// Mode controls how client maintenance notifications are handled.
|
||||
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
|
||||
// Default: MaintNotificationsAuto
|
||||
Mode MaintNotificationsMode
|
||||
|
||||
// EndpointType specifies the type of endpoint to request in MOVING notifications.
|
||||
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
|
||||
// Default: EndpointTypeAuto
|
||||
EndpointType EndpointType
|
||||
|
||||
// RelaxedTimeout is the concrete timeout value to use during
|
||||
// MIGRATING/FAILING_OVER states to accommodate increased latency.
|
||||
// This applies to both read and write timeouts.
|
||||
// Default: 10 seconds
|
||||
RelaxedTimeout time.Duration
|
||||
|
||||
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
|
||||
// If handoff takes longer than this, the old connection will be forcibly closed.
|
||||
// Default: 15 seconds (matches server-side eviction timeout)
|
||||
HandoffTimeout time.Duration
|
||||
|
||||
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
|
||||
// Workers are created on-demand and automatically cleaned up when idle.
|
||||
// If zero, defaults to min(10, PoolSize/3) to handle bursts effectively.
|
||||
// If explicitly set, enforces minimum of 10 workers.
|
||||
//
|
||||
// Default: min(10, PoolSize/3), Minimum when set: 10
|
||||
MaxWorkers int
|
||||
|
||||
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
|
||||
// If the queue is full, new handoff requests will be rejected.
|
||||
// Always capped by pool size since you can't handoff more connections than exist.
|
||||
//
|
||||
// Default: 10x max workers, capped by pool size, min 2
|
||||
HandoffQueueSize int
|
||||
|
||||
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
|
||||
// after a handoff completes. This provides additional resilience during cluster transitions.
|
||||
// Default: 2 * RelaxedTimeout
|
||||
PostHandoffRelaxedDuration time.Duration
|
||||
|
||||
// ScaleDownDelay is the delay before checking if workers should be scaled down.
|
||||
// This prevents expensive checks on every handoff completion and avoids rapid scaling cycles.
|
||||
// Default: 2 seconds
|
||||
ScaleDownDelay time.Duration
|
||||
|
||||
// LogLevel controls the verbosity of hitless upgrade logging.
|
||||
// 0 = errors only, 1 = warnings, 2 = info, 3 = debug
|
||||
// Default: 1 (warnings)
|
||||
LogLevel int
|
||||
|
||||
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
|
||||
// After this many retries, the connection will be removed from the pool.
|
||||
// Default: 3
|
||||
MaxHandoffRetries int
|
||||
}
|
||||
|
||||
func (c *Config) IsEnabled() bool {
|
||||
return c != nil && c.Mode != MaintNotificationsDisabled
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
|
||||
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: 0, // Auto-calculated based on pool size
|
||||
HandoffQueueSize: 0, // Auto-calculated based on max workers
|
||||
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
|
||||
ScaleDownDelay: 2 * time.Second,
|
||||
LogLevel: 1,
|
||||
|
||||
// Connection Handoff Configuration
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RelaxedTimeout <= 0 {
|
||||
return ErrInvalidRelaxedTimeout
|
||||
}
|
||||
if c.HandoffTimeout <= 0 {
|
||||
return ErrInvalidHandoffTimeout
|
||||
}
|
||||
// Validate worker configuration
|
||||
// Allow 0 for auto-calculation, but negative values are invalid
|
||||
if c.MaxWorkers < 0 {
|
||||
return ErrInvalidHandoffWorkers
|
||||
}
|
||||
// HandoffQueueSize validation - allow 0 for auto-calculation
|
||||
if c.HandoffQueueSize < 0 {
|
||||
return ErrInvalidHandoffQueueSize
|
||||
}
|
||||
if c.PostHandoffRelaxedDuration < 0 {
|
||||
return ErrInvalidPostHandoffRelaxedDuration
|
||||
}
|
||||
if c.LogLevel < 0 || c.LogLevel > 3 {
|
||||
return ErrInvalidLogLevel
|
||||
}
|
||||
|
||||
// Validate Mode (maintenance notifications mode)
|
||||
if !c.Mode.IsValid() {
|
||||
return ErrInvalidMaintNotifications
|
||||
}
|
||||
|
||||
// Validate EndpointType
|
||||
if !c.EndpointType.IsValid() {
|
||||
return ErrInvalidEndpointType
|
||||
}
|
||||
|
||||
// Validate configuration fields
|
||||
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
|
||||
return ErrInvalidHandoffRetries
|
||||
}
|
||||
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values to any zero-value fields in the configuration.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaults() *Config {
|
||||
return c.ApplyDefaultsWithPoolSize(0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size to calculate worker defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
defaults := DefaultConfig()
|
||||
result := &Config{}
|
||||
|
||||
// Apply defaults for enum fields (empty/zero means not set)
|
||||
if c.Mode == "" {
|
||||
result.Mode = defaults.Mode
|
||||
} else {
|
||||
result.Mode = c.Mode
|
||||
}
|
||||
|
||||
if c.EndpointType == "" {
|
||||
result.EndpointType = defaults.EndpointType
|
||||
} else {
|
||||
result.EndpointType = c.EndpointType
|
||||
}
|
||||
|
||||
// Apply defaults for duration fields (zero means not set)
|
||||
if c.RelaxedTimeout <= 0 {
|
||||
result.RelaxedTimeout = defaults.RelaxedTimeout
|
||||
} else {
|
||||
result.RelaxedTimeout = c.RelaxedTimeout
|
||||
}
|
||||
|
||||
if c.HandoffTimeout <= 0 {
|
||||
result.HandoffTimeout = defaults.HandoffTimeout
|
||||
} else {
|
||||
result.HandoffTimeout = c.HandoffTimeout
|
||||
}
|
||||
|
||||
// Apply defaults for integer fields (zero means not set)
|
||||
if c.HandoffQueueSize <= 0 {
|
||||
result.HandoffQueueSize = defaults.HandoffQueueSize
|
||||
} else {
|
||||
result.HandoffQueueSize = c.HandoffQueueSize
|
||||
}
|
||||
|
||||
// Copy worker configuration
|
||||
result.MaxWorkers = c.MaxWorkers
|
||||
|
||||
// Apply worker defaults based on pool size
|
||||
result.applyWorkerDefaults(poolSize)
|
||||
|
||||
// Apply queue size defaults based on max workers, capped by pool size
|
||||
if c.HandoffQueueSize <= 0 {
|
||||
// Queue size: 10x max workers, but never more than pool size
|
||||
workerBasedSize := result.MaxWorkers * 10
|
||||
result.HandoffQueueSize = util.Min(workerBasedSize, poolSize)
|
||||
} else {
|
||||
result.HandoffQueueSize = c.HandoffQueueSize
|
||||
}
|
||||
|
||||
// Always cap queue size by pool size - no point having more queue slots than connections
|
||||
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, poolSize)
|
||||
|
||||
// Ensure minimum queue size of 2
|
||||
if result.HandoffQueueSize < 2 {
|
||||
result.HandoffQueueSize = 2
|
||||
}
|
||||
|
||||
if c.PostHandoffRelaxedDuration <= 0 {
|
||||
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
|
||||
} else {
|
||||
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
if c.ScaleDownDelay <= 0 {
|
||||
result.ScaleDownDelay = defaults.ScaleDownDelay
|
||||
} else {
|
||||
result.ScaleDownDelay = c.ScaleDownDelay
|
||||
}
|
||||
|
||||
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
|
||||
// We'll use the provided value as-is, since 0 is valid
|
||||
result.LogLevel = c.LogLevel
|
||||
|
||||
// Apply defaults for configuration fields
|
||||
if c.MaxHandoffRetries <= 0 {
|
||||
result.MaxHandoffRetries = defaults.MaxHandoffRetries
|
||||
} else {
|
||||
result.MaxHandoffRetries = c.MaxHandoffRetries
|
||||
}
|
||||
|
||||
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration.
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Mode: c.Mode,
|
||||
EndpointType: c.EndpointType,
|
||||
RelaxedTimeout: c.RelaxedTimeout,
|
||||
HandoffTimeout: c.HandoffTimeout,
|
||||
MaxWorkers: c.MaxWorkers,
|
||||
HandoffQueueSize: c.HandoffQueueSize,
|
||||
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
|
||||
ScaleDownDelay: c.ScaleDownDelay,
|
||||
LogLevel: c.LogLevel,
|
||||
|
||||
// Configuration fields
|
||||
MaxHandoffRetries: c.MaxHandoffRetries,
|
||||
}
|
||||
}
|
||||
|
||||
// applyWorkerDefaults calculates and applies worker defaults based on pool size
|
||||
func (c *Config) applyWorkerDefaults(poolSize int) {
|
||||
// Calculate defaults based on pool size
|
||||
if poolSize <= 0 {
|
||||
poolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
if c.MaxWorkers == 0 {
|
||||
// When not set: min(10, poolSize/3) - don't exceed 10 workers for small pools
|
||||
c.MaxWorkers = util.Min(10, poolSize/3)
|
||||
} else {
|
||||
// When explicitly set: max(10, set_value) - ensure at least 10 workers
|
||||
c.MaxWorkers = util.Max(10, c.MaxWorkers)
|
||||
}
|
||||
|
||||
// Ensure minimum of 1 worker (fallback for very small pools)
|
||||
if c.MaxWorkers < 1 {
|
||||
c.MaxWorkers = 1
|
||||
}
|
||||
}
|
||||
|
||||
// DetectEndpointType automatically detects the appropriate endpoint type
|
||||
// based on the connection address and TLS configuration.
|
||||
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
|
||||
// Parse the address to determine if it's an IP or hostname
|
||||
isPrivate := isPrivateIP(addr)
|
||||
|
||||
var endpointType EndpointType
|
||||
|
||||
if tlsEnabled {
|
||||
// TLS requires FQDN for certificate validation
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
} else {
|
||||
// No TLS, can use IP addresses
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalIP
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalIP
|
||||
}
|
||||
}
|
||||
|
||||
return endpointType
|
||||
}
|
||||
|
||||
// isPrivateIP checks if the given address is in a private IP range.
|
||||
func isPrivateIP(addr string) bool {
|
||||
// Extract host from "host:port" format
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr // Assume no port
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false // Not an IP address (likely hostname)
|
||||
}
|
||||
|
||||
// Check for private/loopback ranges
|
||||
return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||
}
|
427
hitless/config_test.go
Normal file
427
hitless/config_test.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("DefaultConfig", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
|
||||
// MaxWorkers should be 0 in default config (auto-calculated)
|
||||
if config.MaxWorkers != 0 {
|
||||
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be 0 in default config (auto-calculated)
|
||||
if config.HandoffQueueSize != 0 {
|
||||
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if config.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
|
||||
}
|
||||
|
||||
// Test configuration fields have proper defaults
|
||||
if config.MaxHandoffRetries != 3 {
|
||||
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
|
||||
}
|
||||
|
||||
if config.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
|
||||
}
|
||||
|
||||
if config.PostHandoffRelaxedDuration != 0 {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
|
||||
}
|
||||
|
||||
// Test that defaults are applied correctly
|
||||
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
|
||||
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigValidation", func(t *testing.T) {
|
||||
// Valid config with applied defaults
|
||||
config := DefaultConfig().ApplyDefaults()
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Default config with applied defaults should be valid: %v", err)
|
||||
}
|
||||
|
||||
// Invalid worker configuration (negative MaxWorkers)
|
||||
config = &Config{
|
||||
RelaxedTimeout: 30 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: -1, // This should be invalid
|
||||
HandoffQueueSize: 100,
|
||||
PostHandoffRelaxedDuration: 10 * time.Second,
|
||||
LogLevel: 1,
|
||||
MaxHandoffRetries: 3, // Add required field
|
||||
}
|
||||
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
|
||||
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid HandoffQueueSize
|
||||
config = DefaultConfig().ApplyDefaults()
|
||||
config.HandoffQueueSize = -1
|
||||
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
|
||||
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
|
||||
}
|
||||
|
||||
// Invalid PostHandoffRelaxedDuration
|
||||
config = DefaultConfig().ApplyDefaults()
|
||||
config.PostHandoffRelaxedDuration = -1 * time.Second
|
||||
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
|
||||
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConfigClone", func(t *testing.T) {
|
||||
original := DefaultConfig()
|
||||
original.MaxWorkers = 20
|
||||
original.HandoffQueueSize = 200
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
if cloned.MaxWorkers != 20 {
|
||||
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
|
||||
}
|
||||
|
||||
if cloned.HandoffQueueSize != 200 {
|
||||
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Modify original to ensure clone is independent
|
||||
original.MaxWorkers = 2
|
||||
if cloned.MaxWorkers != 20 {
|
||||
t.Error("Clone should be independent of original")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyDefaults(t *testing.T) {
|
||||
t.Run("NilConfig", func(t *testing.T) {
|
||||
var config *Config
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// With nil config, should get default config with auto-calculated workers
|
||||
if result.MaxWorkers <= 0 {
|
||||
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
|
||||
workerBasedSize := result.MaxWorkers * 10
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartialConfig", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 12, // Set this field explicitly
|
||||
// Leave other fields as zero values
|
||||
}
|
||||
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Should keep the explicitly set values
|
||||
if result.MaxWorkers != 12 {
|
||||
t.Errorf("Expected MaxWorkers to be 12 (explicitly set), got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// Should apply default for unset fields (auto-calculated queue size, capped by pool size)
|
||||
workerBasedSize := result.MaxWorkers * 10
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test explicit queue size capping by pool size
|
||||
configWithLargeQueue := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 1000, // Much larger than pool size
|
||||
}
|
||||
|
||||
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
|
||||
if resultCapped.HandoffQueueSize != 20 {
|
||||
t.Errorf("Expected HandoffQueueSize to be capped by pool size (20), got %d", resultCapped.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if result.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||
}
|
||||
|
||||
if result.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ZeroValues", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 0, // Zero value should get auto-calculated defaults
|
||||
HandoffQueueSize: 0, // Zero value should get default
|
||||
RelaxedTimeout: 0, // Zero value should get default
|
||||
LogLevel: 0, // Zero is valid for LogLevel (errors only)
|
||||
}
|
||||
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Zero values should get auto-calculated defaults
|
||||
if result.MaxWorkers <= 0 {
|
||||
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
|
||||
}
|
||||
|
||||
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
|
||||
workerBasedSize := result.MaxWorkers * 10
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||
if result.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
|
||||
}
|
||||
|
||||
if result.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||
}
|
||||
|
||||
// LogLevel 0 should be preserved (it's a valid value)
|
||||
if result.LogLevel != 0 {
|
||||
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessorWithConfig(t *testing.T) {
|
||||
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 5,
|
||||
HandoffQueueSize: 50,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// The processor should be created successfully with custom config
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with custom config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 7, // Only set worker field
|
||||
// Other fields will get defaults
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Should work with partial config (defaults applied)
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with partial config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Should use default config when nil is passed
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with nil config (using defaults)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationWithApplyDefaults(t *testing.T) {
|
||||
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
|
||||
// Create a partial config with only some fields set
|
||||
partialConfig := &Config{
|
||||
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
|
||||
LogLevel: 2, // Custom value
|
||||
// Other fields left as zero values - should get defaults
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
// Create processor - should apply defaults to missing fields
|
||||
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Processor should be created successfully
|
||||
if processor == nil {
|
||||
t.Error("Processor should be created with partial config")
|
||||
}
|
||||
|
||||
// Test that the ApplyDefaults method worked correctly by creating the same config
|
||||
// and applying defaults manually
|
||||
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||
|
||||
// Should preserve custom values (when >= 10)
|
||||
if expectedConfig.MaxWorkers != 15 {
|
||||
t.Errorf("Expected MaxWorkers to be 15, got %d", expectedConfig.MaxWorkers)
|
||||
}
|
||||
|
||||
if expectedConfig.LogLevel != 2 {
|
||||
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
|
||||
}
|
||||
|
||||
// Should apply defaults for missing fields (auto-calculated queue size, capped by pool size)
|
||||
workerBasedSize := expectedConfig.MaxWorkers * 10
|
||||
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||
if expectedConfig.HandoffQueueSize != expectedQueueSize {
|
||||
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||
expectedQueueSize, workerBasedSize, poolSize, expectedConfig.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Test that queue size is always capped by pool size
|
||||
if expectedConfig.HandoffQueueSize > poolSize {
|
||||
t.Errorf("HandoffQueueSize (%d) should never exceed pool size (%d)",
|
||||
expectedConfig.HandoffQueueSize, poolSize)
|
||||
}
|
||||
|
||||
if expectedConfig.RelaxedTimeout != 10*time.Second {
|
||||
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
|
||||
}
|
||||
|
||||
if expectedConfig.HandoffTimeout != 15*time.Second {
|
||||
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
|
||||
}
|
||||
|
||||
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
|
||||
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnhancedConfigValidation(t *testing.T) {
|
||||
t.Run("ValidateFields", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
|
||||
|
||||
// Should pass validation with default values
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Default config should be valid, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid MaxHandoffRetries
|
||||
config.MaxHandoffRetries = 0
|
||||
if err := config.Validate(); err == nil {
|
||||
t.Error("Expected validation error for MaxHandoffRetries = 0")
|
||||
}
|
||||
config.MaxHandoffRetries = 11
|
||||
if err := config.Validate(); err == nil {
|
||||
t.Error("Expected validation error for MaxHandoffRetries = 11")
|
||||
}
|
||||
config.MaxHandoffRetries = 3 // Reset to valid value
|
||||
|
||||
// Should pass validation again
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Config should be valid after reset, got error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigClone(t *testing.T) {
|
||||
original := DefaultConfig()
|
||||
original.MaxHandoffRetries = 7
|
||||
original.HandoffTimeout = 8 * time.Second
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
// Test that values are copied
|
||||
if cloned.MaxHandoffRetries != 7 {
|
||||
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
|
||||
}
|
||||
if cloned.HandoffTimeout != 8*time.Second {
|
||||
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
|
||||
}
|
||||
|
||||
// Test that modifying clone doesn't affect original
|
||||
cloned.MaxHandoffRetries = 10
|
||||
if original.MaxHandoffRetries != 7 {
|
||||
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxWorkersLogic(t *testing.T) {
|
||||
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
poolSize int
|
||||
expectedWorkers int
|
||||
description string
|
||||
}{
|
||||
{6, 2, "Small pool: min(10, 6/3) = min(10, 2) = 2"},
|
||||
{15, 5, "Medium pool: min(10, 15/3) = min(10, 5) = 5"},
|
||||
{30, 10, "Large pool: min(10, 30/3) = min(10, 10) = 10"},
|
||||
{60, 10, "Very large pool: min(10, 60/3) = min(10, 20) = 10"},
|
||||
{120, 10, "Huge pool: min(10, 120/3) = min(10, 40) = 10"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
config := &Config{} // MaxWorkers = 0 (not set)
|
||||
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
|
||||
|
||||
if result.MaxWorkers != tc.expectedWorkers {
|
||||
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
|
||||
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
setValue int
|
||||
expectedWorkers int
|
||||
description string
|
||||
}{
|
||||
{1, 10, "Set 1: max(10, 1) = 10 (enforced minimum)"},
|
||||
{5, 10, "Set 5: max(10, 5) = 10 (enforced minimum)"},
|
||||
{8, 10, "Set 8: max(10, 8) = 10 (enforced minimum)"},
|
||||
{10, 10, "Set 10: max(10, 10) = 10 (exact minimum)"},
|
||||
{15, 15, "Set 15: max(10, 15) = 15 (respects user choice)"},
|
||||
{20, 20, "Set 20: max(10, 20) = 20 (respects user choice)"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
config := &Config{
|
||||
MaxWorkers: tc.setValue, // Explicitly set
|
||||
}
|
||||
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
|
||||
|
||||
if result.MaxWorkers != tc.expectedWorkers {
|
||||
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
|
||||
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
76
hitless/errors.go
Normal file
76
hitless/errors.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Configuration errors
|
||||
var (
|
||||
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
|
||||
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
|
||||
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
|
||||
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
|
||||
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
|
||||
ErrInvalidLogLevel = errors.New("hitless: log level must be between 0 and 3")
|
||||
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
|
||||
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
|
||||
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
|
||||
|
||||
// Configuration validation errors
|
||||
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
|
||||
ErrInvalidConnectionValidationTimeout = errors.New("hitless: ConnectionValidationTimeout must be greater than 0 and less than 30 seconds")
|
||||
ErrInvalidConnectionHealthCheckInterval = errors.New("hitless: ConnectionHealthCheckInterval must be between 0 and 1 hour")
|
||||
ErrInvalidOperationCleanupInterval = errors.New("hitless: OperationCleanupInterval must be greater than 0 and less than 1 hour")
|
||||
ErrInvalidMaxActiveOperations = errors.New("hitless: MaxActiveOperations must be between 100 and 100000")
|
||||
ErrInvalidNotificationBufferSize = errors.New("hitless: NotificationBufferSize must be between 10 and 10000")
|
||||
ErrInvalidNotificationTimeout = errors.New("hitless: NotificationTimeout must be greater than 0 and less than 30 seconds")
|
||||
)
|
||||
|
||||
// Integration errors
|
||||
var (
|
||||
ErrInvalidClient = errors.New("hitless: invalid client type")
|
||||
)
|
||||
|
||||
// Handoff errors
|
||||
var (
|
||||
ErrHandoffInProgress = errors.New("hitless: handoff already in progress")
|
||||
ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress")
|
||||
ErrConnectionFailed = errors.New("hitless: failed to establish new connection")
|
||||
)
|
||||
|
||||
// Dead error variables removed - unused in simplified architecture
|
||||
|
||||
// Notification errors
|
||||
var (
|
||||
ErrInvalidNotification = errors.New("hitless: invalid notification format")
|
||||
)
|
||||
|
||||
// Dead error variables removed - unused in simplified architecture
|
||||
|
||||
// HandoffError represents an error that occurred during connection handoff.
|
||||
type HandoffError struct {
|
||||
Operation string
|
||||
Endpoint string
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *HandoffError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint)
|
||||
}
|
||||
|
||||
func (e *HandoffError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// NewHandoffError creates a new HandoffError.
|
||||
func NewHandoffError(operation, endpoint string, cause error) *HandoffError {
|
||||
return &HandoffError{
|
||||
Operation: operation,
|
||||
Endpoint: endpoint,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
63
hitless/example_hooks.go
Normal file
63
hitless/example_hooks.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
startTimeKey contextKey = "notif_hitless_start_time"
|
||||
)
|
||||
|
||||
// MetricsHook collects metrics about notification processing.
|
||||
type MetricsHook struct {
|
||||
NotificationCounts map[string]int64
|
||||
ProcessingTimes map[string]time.Duration
|
||||
ErrorCounts map[string]int64
|
||||
}
|
||||
|
||||
// NewMetricsHook creates a new metrics collection hook.
|
||||
func NewMetricsHook() *MetricsHook {
|
||||
return &MetricsHook{
|
||||
NotificationCounts: make(map[string]int64),
|
||||
ProcessingTimes: make(map[string]time.Duration),
|
||||
ErrorCounts: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// PreHook records the start time for processing metrics.
|
||||
func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
mh.NotificationCounts[notificationType]++
|
||||
|
||||
// Store start time in context for duration calculation
|
||||
startTime := time.Now()
|
||||
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
|
||||
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// PostHook records processing completion and any errors.
|
||||
func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||
// Calculate processing duration
|
||||
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
|
||||
duration := time.Since(startTime)
|
||||
mh.ProcessingTimes[notificationType] = duration
|
||||
}
|
||||
|
||||
// Record errors
|
||||
if result != nil {
|
||||
mh.ErrorCounts[notificationType]++
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns a summary of collected metrics.
|
||||
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"notification_counts": mh.NotificationCounts,
|
||||
"processing_times": mh.ProcessingTimes,
|
||||
"error_counts": mh.ErrorCounts,
|
||||
}
|
||||
}
|
299
hitless/hitless_manager.go
Normal file
299
hitless/hitless_manager.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
|
||||
|
||||
// Push notification type constants for hitless upgrades
|
||||
const (
|
||||
NotificationMoving = "MOVING"
|
||||
NotificationMigrating = "MIGRATING"
|
||||
NotificationMigrated = "MIGRATED"
|
||||
NotificationFailingOver = "FAILING_OVER"
|
||||
NotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
|
||||
var hitlessNotificationTypes = []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
// NotificationHook is called before and after notification processing
|
||||
// PreHook can modify the notification and return false to skip processing
|
||||
// PostHook is called after successful processing
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool)
|
||||
PostHook(ctx context.Context, notificationType string, notification []interface{}, result error)
|
||||
}
|
||||
|
||||
// MovingOperationKey provides a unique key for tracking MOVING operations
|
||||
// that combines sequence ID with connection identifier to handle duplicate
|
||||
// sequence IDs across multiple connections to the same node.
|
||||
type MovingOperationKey struct {
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
ConnID uint64 // Unique connection identifier
|
||||
}
|
||||
|
||||
// String returns a string representation of the key for debugging
|
||||
func (k MovingOperationKey) String() string {
|
||||
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
|
||||
}
|
||||
|
||||
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
|
||||
type HitlessManager struct {
|
||||
client interfaces.ClientInterface
|
||||
config *Config
|
||||
options interfaces.OptionsInterface
|
||||
pool pool.Pooler
|
||||
|
||||
// MOVING operation tracking - using sync.Map for better concurrent performance
|
||||
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
|
||||
|
||||
// Atomic state tracking - no locks needed for state queries
|
||||
activeOperationCount atomic.Int64 // Number of active operations
|
||||
closed atomic.Bool // Manager closed state
|
||||
|
||||
// Notification hooks for extensibility
|
||||
hooks []NotificationHook
|
||||
hooksMu sync.RWMutex // Protects hooks slice
|
||||
poolHooksRef *PoolHook
|
||||
}
|
||||
|
||||
// MovingOperation tracks an active MOVING operation.
|
||||
type MovingOperation struct {
|
||||
SeqID int64
|
||||
NewEndpoint string
|
||||
StartTime time.Time
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
// NewHitlessManager creates a new simplified hitless manager.
|
||||
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
|
||||
if client == nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
hm := &HitlessManager{
|
||||
client: client,
|
||||
pool: pool,
|
||||
options: client.GetOptions(),
|
||||
config: config.Clone(),
|
||||
hooks: make([]NotificationHook, 0),
|
||||
}
|
||||
|
||||
// Set up push notification handling
|
||||
if err := hm.setupPushNotifications(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// GetPoolHook creates a pool hook with a custom dialer.
|
||||
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
|
||||
poolHook := hm.createPoolHook(baseDialer)
|
||||
hm.pool.AddPoolHook(poolHook)
|
||||
}
|
||||
|
||||
// setupPushNotifications sets up push notification handling by registering with the client's processor.
|
||||
func (hm *HitlessManager) setupPushNotifications() error {
|
||||
processor := hm.client.GetPushProcessor()
|
||||
if processor == nil {
|
||||
return ErrInvalidClient // Client doesn't support push notifications
|
||||
}
|
||||
|
||||
// Create our notification handler
|
||||
handler := &NotificationHandler{manager: hm}
|
||||
|
||||
// Register handlers for all hitless upgrade notifications with the client's processor
|
||||
for _, notificationType := range hitlessNotificationTypes {
|
||||
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
|
||||
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
|
||||
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Create MOVING operation record
|
||||
movingOp := &MovingOperation{
|
||||
SeqID: seqID,
|
||||
NewEndpoint: newEndpoint,
|
||||
StartTime: time.Now(),
|
||||
Deadline: deadline,
|
||||
}
|
||||
|
||||
// Use LoadOrStore for atomic check-and-set operation
|
||||
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
|
||||
// Duplicate MOVING notification, ignore
|
||||
internal.Logger.Printf(ctx, "Duplicate MOVING operation ignored: %s", key.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Increment active operation count atomically
|
||||
hm.activeOperationCount.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
|
||||
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Remove from active operations atomically
|
||||
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
|
||||
// Decrement active operation count only if operation existed
|
||||
hm.activeOperationCount.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveMovingOperations returns active operations with composite keys.
|
||||
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
|
||||
result := make(map[MovingOperationKey]*MovingOperation)
|
||||
|
||||
// Iterate over sync.Map to build result
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
k := key.(MovingOperationKey)
|
||||
op := value.(*MovingOperation)
|
||||
|
||||
// Create a copy to avoid sharing references
|
||||
result[k] = &MovingOperation{
|
||||
SeqID: op.SeqID,
|
||||
NewEndpoint: op.NewEndpoint,
|
||||
StartTime: op.StartTime,
|
||||
Deadline: op.Deadline,
|
||||
}
|
||||
return true // Continue iteration
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHandoffInProgress returns true if any handoff is in progress.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *HitlessManager) IsHandoffInProgress() bool {
|
||||
return hm.activeOperationCount.Load() > 0
|
||||
}
|
||||
|
||||
// GetActiveOperationCount returns the number of active operations.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *HitlessManager) GetActiveOperationCount() int64 {
|
||||
return hm.activeOperationCount.Load()
|
||||
}
|
||||
|
||||
// Close closes the hitless manager.
|
||||
func (hm *HitlessManager) Close() error {
|
||||
// Use atomic operation for thread-safe close check
|
||||
if !hm.closed.CompareAndSwap(false, true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
// Shutdown the pool hook if it exists
|
||||
if hm.poolHooksRef != nil {
|
||||
// Use a timeout to prevent hanging indefinitely
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := hm.poolHooksRef.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
// was not able to close pool hook, keep closed state false
|
||||
hm.closed.Store(false)
|
||||
return err
|
||||
}
|
||||
// Remove the pool hook from the pool
|
||||
if hm.pool != nil {
|
||||
hm.pool.RemovePoolHook(hm.poolHooksRef)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all active operations
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
hm.activeMovingOps.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Reset counter
|
||||
hm.activeOperationCount.Store(0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetState returns current state using atomic counter for lock-free operation.
|
||||
func (hm *HitlessManager) GetState() State {
|
||||
if hm.activeOperationCount.Load() > 0 {
|
||||
return StateMoving
|
||||
}
|
||||
return StateIdle
|
||||
}
|
||||
|
||||
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
|
||||
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
currentNotification := notification
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationType, currentNotification)
|
||||
if !shouldContinue {
|
||||
return modifiedNotification, false
|
||||
}
|
||||
currentNotification = modifiedNotification
|
||||
}
|
||||
|
||||
return currentNotification, true
|
||||
}
|
||||
|
||||
// processPostHooks calls all post-hooks with the processing result.
|
||||
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
hook.PostHook(ctx, notificationType, notification, result)
|
||||
}
|
||||
}
|
||||
|
||||
// createPoolHook creates a pool hook with this manager already set.
|
||||
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
|
||||
if hm.poolHooksRef != nil {
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
// Get pool size from client options for better worker defaults
|
||||
poolSize := 0
|
||||
if hm.options != nil {
|
||||
poolSize = hm.options.GetPoolSize()
|
||||
}
|
||||
|
||||
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
|
||||
hm.poolHooksRef.SetPool(hm.pool)
|
||||
|
||||
return hm.poolHooksRef
|
||||
}
|
260
hitless/hitless_manager_test.go
Normal file
260
hitless/hitless_manager_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
)
|
||||
|
||||
// MockClient implements interfaces.ClientInterface for testing
|
||||
type MockClient struct {
|
||||
options interfaces.OptionsInterface
|
||||
}
|
||||
|
||||
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
|
||||
return mc.options
|
||||
}
|
||||
|
||||
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &MockPushProcessor{}
|
||||
}
|
||||
|
||||
// MockPushProcessor implements interfaces.NotificationProcessor for testing
|
||||
type MockPushProcessor struct{}
|
||||
|
||||
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockOptions implements interfaces.OptionsInterface for testing
|
||||
type MockOptions struct{}
|
||||
|
||||
func (mo *MockOptions) GetReadTimeout() time.Duration {
|
||||
return 5 * time.Second
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetWriteTimeout() time.Duration {
|
||||
return 5 * time.Second
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetAddr() string {
|
||||
return "localhost:6379"
|
||||
}
|
||||
|
||||
func (mo *MockOptions) IsTLSEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetProtocol() int {
|
||||
return 3 // RESP3
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetPoolSize() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
func (mo *MockOptions) GetNetwork() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestHitlessManagerRefactoring(t *testing.T) {
|
||||
t.Run("AtomicStateTracking", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewHitlessManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// Test initial state
|
||||
if manager.IsHandoffInProgress() {
|
||||
t.Error("Expected no handoff in progress initially")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateIdle {
|
||||
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||
}
|
||||
|
||||
// Add an operation
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation: %v", err)
|
||||
}
|
||||
|
||||
// Test state after adding operation
|
||||
if !manager.IsHandoffInProgress() {
|
||||
t.Error("Expected handoff in progress after adding operation")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 1 {
|
||||
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateMoving {
|
||||
t.Errorf("Expected StateMoving, got %v", manager.GetState())
|
||||
}
|
||||
|
||||
// Remove the operation
|
||||
manager.UntrackOperationWithConnID(12345, 1)
|
||||
|
||||
// Test state after removing operation
|
||||
if manager.IsHandoffInProgress() {
|
||||
t.Error("Expected no handoff in progress after removing operation")
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
if manager.GetState() != StateIdle {
|
||||
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SyncMapPerformance", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewHitlessManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
|
||||
// Test concurrent operations
|
||||
const numOps = 100
|
||||
for i := 0; i < numOps; i++ {
|
||||
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != numOps {
|
||||
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
|
||||
}
|
||||
|
||||
// Test GetActiveMovingOperations
|
||||
operations := manager.GetActiveMovingOperations()
|
||||
if len(operations) != numOps {
|
||||
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
|
||||
}
|
||||
|
||||
// Remove all operations
|
||||
for i := 0; i < numOps; i++ {
|
||||
manager.UntrackOperationWithConnID(int64(i), uint64(i))
|
||||
}
|
||||
|
||||
if manager.GetActiveOperationCount() != 0 {
|
||||
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DuplicateOperationHandling", func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
client := &MockClient{options: &MockOptions{}}
|
||||
|
||||
manager, err := NewHitlessManager(client, nil, config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
|
||||
// Add operation
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to track operation: %v", err)
|
||||
}
|
||||
|
||||
// Try to add duplicate operation
|
||||
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Duplicate operation should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Should still have only 1 operation
|
||||
if manager.GetActiveOperationCount() != 1 {
|
||||
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotificationTypeConstants", func(t *testing.T) {
|
||||
// Test that constants are properly defined
|
||||
expectedTypes := []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
if len(hitlessNotificationTypes) != len(expectedTypes) {
|
||||
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
|
||||
}
|
||||
|
||||
// Test that all expected types are present
|
||||
typeMap := make(map[string]bool)
|
||||
for _, t := range hitlessNotificationTypes {
|
||||
typeMap[t] = true
|
||||
}
|
||||
|
||||
for _, expected := range expectedTypes {
|
||||
if !typeMap[expected] {
|
||||
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that hitlessNotificationTypes contains all expected constants
|
||||
expectedConstants := []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
for _, expected := range expectedConstants {
|
||||
found := false
|
||||
for _, actual := range hitlessNotificationTypes {
|
||||
if actual == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
48
hitless/hooks.go
Normal file
48
hitless/hooks.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
// LoggingHook is an example hook implementation that logs all notifications.
|
||||
type LoggingHook struct {
|
||||
LogLevel int
|
||||
}
|
||||
|
||||
// PreHook logs the notification before processing and allows modification.
|
||||
func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
if lh.LogLevel >= 2 { // Info level
|
||||
internal.Logger.Printf(ctx, "hitless: processing %s notification: %v", notificationType, notification)
|
||||
}
|
||||
return notification, true // Continue processing with unmodified notification
|
||||
}
|
||||
|
||||
// PostHook logs the result after processing.
|
||||
func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||
if result != nil && lh.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v", notificationType, result)
|
||||
} else if lh.LogLevel >= 3 { // Debug level
|
||||
internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType)
|
||||
}
|
||||
}
|
||||
|
||||
// FilterHook is an example hook that can filter out certain notifications.
|
||||
type FilterHook struct {
|
||||
BlockedTypes map[string]bool
|
||||
}
|
||||
|
||||
// PreHook filters notifications based on type.
|
||||
func (fh *FilterHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
if fh.BlockedTypes[notificationType] {
|
||||
internal.Logger.Printf(ctx, "hitless: filtering out %s notification", notificationType)
|
||||
return notification, false // Skip processing
|
||||
}
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// PostHook does nothing for filter hook.
|
||||
func (fh *FilterHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||
// No post-processing needed for filter hook
|
||||
}
|
247
hitless/notification_handler.go
Normal file
247
hitless/notification_handler.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// NotificationHandler handles push notifications for the simplified manager.
|
||||
type NotificationHandler struct {
|
||||
manager *HitlessManager
|
||||
}
|
||||
|
||||
// HandlePushNotification processes push notifications with hook support.
|
||||
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Process pre-hooks - they can modify the notification or skip processing
|
||||
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, notificationType, notification)
|
||||
if !shouldContinue {
|
||||
return nil // Hooks decided to skip processing
|
||||
}
|
||||
|
||||
var err error
|
||||
switch notificationType {
|
||||
case NotificationMoving:
|
||||
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrating:
|
||||
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrated:
|
||||
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailingOver:
|
||||
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailedOver:
|
||||
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
|
||||
default:
|
||||
// Ignore other notification types (e.g., pub/sub messages)
|
||||
err = nil
|
||||
}
|
||||
|
||||
// Process post-hooks with the result
|
||||
snh.manager.processPostHooks(ctx, notificationType, modifiedNotification, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMoving processes MOVING notifications.
|
||||
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
|
||||
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
seqIDStr, ok := notification[1].(string)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
seqID, err := strconv.ParseInt(seqIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Extract timeS
|
||||
timeSStr, ok := notification[2].(string)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
timeS, err := strconv.ParseInt(timeSStr, 10, 64)
|
||||
if err != nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
newEndpoint := ""
|
||||
if len(notification) > 3 {
|
||||
// Extract new endpoint
|
||||
newEndpoint, ok = notification[3].(string)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
}
|
||||
|
||||
// Get the connection that received this notification
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to get the underlying pool connection
|
||||
var poolConn *pool.Conn
|
||||
if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok {
|
||||
poolConn = connAdapter.GetPoolConn()
|
||||
} else if pc, ok := conn.(*pool.Conn); ok {
|
||||
poolConn = pc
|
||||
} else {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
|
||||
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
|
||||
if newEndpoint == "" || newEndpoint == internal.RedisNull {
|
||||
// same as current endpoint
|
||||
newEndpoint = snh.manager.options.GetAddr()
|
||||
// delay the handoff for timeS/2 seconds to the same endpoint
|
||||
// do this in a goroutine to avoid blocking the notification handler
|
||||
go func() {
|
||||
time.Sleep(time.Duration(timeS/2) * time.Second)
|
||||
if poolConn == nil || poolConn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
|
||||
// Log error but don't fail the goroutine
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
|
||||
}
|
||||
|
||||
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
|
||||
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
|
||||
// Connection is already marked for handoff, which is acceptable
|
||||
// This can happen if multiple MOVING notifications are received for the same connection
|
||||
return nil
|
||||
}
|
||||
// Optionally track in hitless manager for monitoring/debugging
|
||||
if snh.manager != nil {
|
||||
connID := conn.GetID()
|
||||
|
||||
// Track the operation (ignore errors since this is optional)
|
||||
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
|
||||
} else {
|
||||
return fmt.Errorf("hitless: manager not initialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrating processes MIGRATING notifications.
|
||||
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATING notifications indicate that a connection is about to be migrated
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Get the connection from handler context and type assert to connectionAdapter
|
||||
if handlerCtx.Conn == nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrated processes MIGRATED notifications.
|
||||
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATED notifications indicate that a connection migration has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Get the connection from handler context and type assert to connectionAdapter
|
||||
if handlerCtx.Conn == nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
connAdapter.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOver processes FAILING_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILING_OVER notifications indicate that a connection is about to failover
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Get the connection from handler context and type assert to connectionAdapter
|
||||
if handlerCtx.Conn == nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOver processes FAILED_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILED_OVER notifications indicate that a connection failover has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Get the connection from handler context and type assert to connectionAdapter
|
||||
if handlerCtx.Conn == nil {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||
if !ok {
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
connAdapter.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
477
hitless/pool_hook.go
Normal file
477
hitless/pool_hook.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// HitlessManagerInterface defines the interface for completing handoff operations
|
||||
type HitlessManagerInterface interface {
|
||||
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
|
||||
UntrackOperationWithConnID(seqID int64, connID uint64)
|
||||
}
|
||||
|
||||
// HandoffRequest represents a request to handoff a connection to a new endpoint
|
||||
type HandoffRequest struct {
|
||||
Conn *pool.Conn
|
||||
ConnID uint64 // Unique connection identifier
|
||||
Endpoint string
|
||||
SeqID int64
|
||||
Pool pool.Pooler // Pool to remove connection from on failure
|
||||
}
|
||||
|
||||
// PoolHook implements pool.PoolHook for Redis-specific connection handling
|
||||
// with hitless upgrade support.
|
||||
type PoolHook struct {
|
||||
// Base dialer for creating connections to new endpoints during handoffs
|
||||
// args are network and address
|
||||
baseDialer func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// Network type (e.g., "tcp", "unix")
|
||||
network string
|
||||
|
||||
// Event-driven handoff support
|
||||
handoffQueue chan HandoffRequest // Queue for handoff requests
|
||||
shutdown chan struct{} // Shutdown signal
|
||||
shutdownOnce sync.Once // Ensure clean shutdown
|
||||
workerWg sync.WaitGroup // Track worker goroutines
|
||||
|
||||
// On-demand worker management
|
||||
maxWorkers int
|
||||
activeWorkers int32 // Atomic counter for active workers
|
||||
workerTimeout time.Duration // How long workers wait for work before exiting
|
||||
|
||||
// Simple state tracking
|
||||
pending sync.Map // map[uint64]int64 (connID -> seqID)
|
||||
|
||||
// Configuration for the hitless upgrade
|
||||
config *Config
|
||||
|
||||
// Hitless manager for operation completion tracking
|
||||
hitlessManager HitlessManagerInterface
|
||||
|
||||
// Pool interface for removing connections on handoff failure
|
||||
pool pool.Pooler
|
||||
}
|
||||
|
||||
// NewPoolHook creates a new pool hook
|
||||
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
|
||||
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
|
||||
}
|
||||
|
||||
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
|
||||
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
|
||||
// Apply defaults to any missing configuration fields, using pool size for worker calculations
|
||||
config = config.ApplyDefaultsWithPoolSize(poolSize)
|
||||
|
||||
ph := &PoolHook{
|
||||
// baseDialer is used to create connections to new endpoints during handoffs
|
||||
baseDialer: baseDialer,
|
||||
network: network,
|
||||
// handoffQueue is a buffered channel for queuing handoff requests
|
||||
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
|
||||
// shutdown is a channel for signaling shutdown
|
||||
shutdown: make(chan struct{}),
|
||||
maxWorkers: config.MaxWorkers,
|
||||
activeWorkers: 0, // Start with no workers - create on demand
|
||||
workerTimeout: 30 * time.Second, // Workers exit after 30s of inactivity
|
||||
config: config,
|
||||
// Hitless manager for operation completion tracking
|
||||
hitlessManager: hitlessManager,
|
||||
}
|
||||
|
||||
// No upfront worker creation - workers are created on demand
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
// SetPool sets the pool interface for removing connections on handoff failure
|
||||
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
|
||||
ph.pool = pooler
|
||||
}
|
||||
|
||||
// GetCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (ph *PoolHook) GetCurrentWorkers() int {
|
||||
return int(atomic.LoadInt32(&ph.activeWorkers))
|
||||
}
|
||||
|
||||
// GetScaleLevel returns 1 if workers are active, 0 if none (for testing compatibility)
|
||||
func (ph *PoolHook) GetScaleLevel() int {
|
||||
if atomic.LoadInt32(&ph.activeWorkers) > 0 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsHandoffPending returns true if the given connection has a pending handoff
|
||||
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
|
||||
_, pending := ph.pending.Load(conn.GetID())
|
||||
return pending
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) error {
|
||||
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||
// in a handoff state at the moment.
|
||||
|
||||
// Check if connection is usable (not in a handoff state)
|
||||
// Should not happen since the pool will not return a connection that is not usable.
|
||||
if !conn.IsUsable() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||
if conn.ShouldHandoff() {
|
||||
return ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// first check if we should handoff for faster rejection
|
||||
if conn.ShouldHandoff() {
|
||||
// check pending handoff to not queue the same connection twice
|
||||
_, hasPendingHandoff := ph.pending.Load(conn.GetID())
|
||||
if !hasPendingHandoff {
|
||||
// Check for empty endpoint first (synchronous check)
|
||||
if conn.GetHandoffEndpoint() == "" {
|
||||
conn.ClearHandoffState()
|
||||
} else {
|
||||
if err := ph.queueHandoff(conn); err != nil {
|
||||
// Failed to queue handoff, remove the connection
|
||||
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
|
||||
return false, true, nil // Don't pool, remove connection, no error to caller
|
||||
}
|
||||
|
||||
// Check if handoff was already processed by a worker before we can mark it as queued
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was already processed - this is normal and the connection should be pooled
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := conn.MarkQueuedForHandoff(); err != nil {
|
||||
// If marking fails, check if handoff was processed in the meantime
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was processed - this is normal, pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
// Other error - remove the connection
|
||||
return false, true, nil
|
||||
}
|
||||
return true, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// Default: pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// ensureWorkerAvailable ensures at least one worker is available to process requests
|
||||
// Creates a new worker if needed and under the max limit
|
||||
func (ph *PoolHook) ensureWorkerAvailable() {
|
||||
select {
|
||||
case <-ph.shutdown:
|
||||
return
|
||||
default:
|
||||
// Check if we need a new worker
|
||||
currentWorkers := atomic.LoadInt32(&ph.activeWorkers)
|
||||
if currentWorkers < int32(ph.maxWorkers) {
|
||||
// Try to create a new worker (atomic increment to prevent race)
|
||||
if atomic.CompareAndSwapInt32(&ph.activeWorkers, currentWorkers, currentWorkers+1) {
|
||||
ph.workerWg.Add(1)
|
||||
go ph.onDemandWorker()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onDemandWorker processes handoff requests and exits when idle
|
||||
func (ph *PoolHook) onDemandWorker() {
|
||||
defer func() {
|
||||
// Decrement active worker count when exiting
|
||||
atomic.AddInt32(&ph.activeWorkers, -1)
|
||||
ph.workerWg.Done()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case request := <-ph.handoffQueue:
|
||||
// Check for shutdown before processing
|
||||
select {
|
||||
case <-ph.shutdown:
|
||||
// Clean up the request before exiting
|
||||
ph.pending.Delete(request.ConnID)
|
||||
return
|
||||
default:
|
||||
// Process the request
|
||||
ph.processHandoffRequest(request)
|
||||
}
|
||||
|
||||
case <-time.After(ph.workerTimeout):
|
||||
// Worker has been idle for too long, exit to save resources
|
||||
if ph.config != nil && ph.config.LogLevel >= 3 { // Debug level
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout)
|
||||
}
|
||||
return
|
||||
|
||||
case <-ph.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processHandoffRequest processes a single handoff request
|
||||
func (ph *PoolHook) processHandoffRequest(request HandoffRequest) {
|
||||
// Remove from pending map
|
||||
defer ph.pending.Delete(request.Conn.GetID())
|
||||
|
||||
// Create a context with handoff timeout from config
|
||||
handoffTimeout := 30 * time.Second // Default fallback
|
||||
if ph.config != nil && ph.config.HandoffTimeout > 0 {
|
||||
handoffTimeout = ph.config.HandoffTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a context that also respects the shutdown signal
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Monitor shutdown signal in a separate goroutine
|
||||
go func() {
|
||||
select {
|
||||
case <-ph.shutdown:
|
||||
shutdownCancel()
|
||||
case <-shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform the handoff with cancellable context
|
||||
err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool)
|
||||
|
||||
// If handoff failed, restore the handoff state for potential retry
|
||||
if err != nil {
|
||||
request.Conn.RestoreHandoffState()
|
||||
internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err)
|
||||
}
|
||||
|
||||
// No need for scale down scheduling with on-demand workers
|
||||
// Workers automatically exit when idle
|
||||
}
|
||||
|
||||
// queueHandoff queues a handoff request for processing
|
||||
// if err is returned, connection will be removed from pool
|
||||
func (ph *PoolHook) queueHandoff(conn *pool.Conn) error {
|
||||
// Create handoff request
|
||||
request := HandoffRequest{
|
||||
Conn: conn,
|
||||
ConnID: conn.GetID(),
|
||||
Endpoint: conn.GetHandoffEndpoint(),
|
||||
SeqID: conn.GetMovingSeqID(),
|
||||
Pool: ph.pool, // Include pool for connection removal on failure
|
||||
}
|
||||
|
||||
select {
|
||||
// priority to shutdown
|
||||
case <-ph.shutdown:
|
||||
return errors.New("shutdown")
|
||||
default:
|
||||
select {
|
||||
case <-ph.shutdown:
|
||||
return errors.New("shutdown")
|
||||
case ph.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
ph.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
ph.ensureWorkerAvailable()
|
||||
return nil
|
||||
default:
|
||||
// Queue is full - log and attempt scaling
|
||||
queueLen := len(ph.handoffQueue)
|
||||
queueCap := cap(ph.handoffQueue)
|
||||
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers",
|
||||
queueLen, queueCap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have workers available to handle the load
|
||||
ph.ensureWorkerAvailable()
|
||||
return errors.New("queue full")
|
||||
}
|
||||
|
||||
// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure
|
||||
// if err is returned, connection will be removed from pool
|
||||
func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) error {
|
||||
// Clear handoff state after successful handoff
|
||||
seqID := conn.GetMovingSeqID()
|
||||
connID := conn.GetID()
|
||||
|
||||
// Notify hitless manager of completion if available
|
||||
if ph.hitlessManager != nil {
|
||||
defer ph.hitlessManager.UntrackOperationWithConnID(seqID, connID)
|
||||
}
|
||||
|
||||
newEndpoint := conn.GetHandoffEndpoint()
|
||||
if newEndpoint == "" {
|
||||
// TODO(hitless): Handle by performing the handoff to the current endpoint in N seconds,
|
||||
// Where N is the time in the moving notification...
|
||||
// For now, clear the handoff state and return
|
||||
conn.ClearHandoffState()
|
||||
return nil
|
||||
}
|
||||
|
||||
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||
maxRetries := 3 // Default fallback
|
||||
if ph.config != nil {
|
||||
maxRetries = ph.config.MaxHandoffRetries
|
||||
}
|
||||
|
||||
if retries > maxRetries {
|
||||
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: reached max retries (%d) for handoff of connection %d to %s",
|
||||
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
|
||||
}
|
||||
err := ErrMaxHandoffRetriesReached
|
||||
if pooler != nil {
|
||||
go pooler.Remove(ctx, conn, err)
|
||||
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: removed connection %d from pool due to max handoff retries reached",
|
||||
conn.GetID())
|
||||
}
|
||||
} else {
|
||||
go conn.Close()
|
||||
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
|
||||
conn.GetID(), err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Create endpoint-specific dialer
|
||||
endpointDialer := ph.createEndpointDialer(newEndpoint)
|
||||
|
||||
// Create new connection to the new endpoint
|
||||
newNetConn, err := endpointDialer(ctx)
|
||||
if err != nil {
|
||||
// TODO(hitless): retry
|
||||
// This is the only case where we should retry the handoff request
|
||||
// Should we do anything else other than return the error?
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the old connection
|
||||
oldConn := conn.GetNetConn()
|
||||
|
||||
// Replace the connection and execute initialization
|
||||
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
|
||||
if err != nil {
|
||||
// Remove the connection from the pool since it's in a bad state
|
||||
if pooler != nil {
|
||||
// Use pool.Pooler interface directly - no adapter needed
|
||||
go pooler.Remove(ctx, conn, err)
|
||||
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: removed connection %d from pool due to handoff initialization failure: %v",
|
||||
conn.GetID(), err)
|
||||
}
|
||||
} else {
|
||||
go conn.Close()
|
||||
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx,
|
||||
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
|
||||
conn.GetID(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the handoff state for retry
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if oldConn != nil {
|
||||
oldConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn.ClearHandoffState()
|
||||
|
||||
// Apply relaxed timeout to the new connection for the configured post-handoff duration
|
||||
// This gives the new connection more time to handle operations during cluster transition
|
||||
if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 {
|
||||
relaxedTimeout := ph.config.RelaxedTimeout
|
||||
postHandoffDuration := ph.config.PostHandoffRelaxedDuration
|
||||
|
||||
// Set relaxed timeout with deadline - no background goroutine needed
|
||||
deadline := time.Now().Add(postHandoffDuration)
|
||||
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
|
||||
|
||||
if ph.config.LogLevel >= 2 { // Info level
|
||||
internal.Logger.Printf(context.Background(),
|
||||
"hitless: applied post-handoff relaxed timeout (%v) until %v for connection %d",
|
||||
relaxedTimeout, deadline.Format("15:04:05.000"), connID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createEndpointDialer creates a dialer function that connects to a specific endpoint
|
||||
func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Parse endpoint to extract host and port
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
// If no port specified, assume default Redis port
|
||||
host = endpoint
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
}
|
||||
|
||||
// Use the base dialer to connect to the new endpoint
|
||||
return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||
ph.shutdownOnce.Do(func() {
|
||||
close(ph.shutdown)
|
||||
|
||||
// No timers to clean up with on-demand workers
|
||||
})
|
||||
|
||||
// Wait for workers to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ph.workerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff")
|
959
hitless/pool_hook_test.go
Normal file
959
hitless/pool_hook_test.go
Normal file
@@ -0,0 +1,959 @@
|
||||
package hitless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// mockNetConn implements net.Conn for testing
|
||||
type mockNetConn struct {
|
||||
addr string
|
||||
shouldFailInit bool
|
||||
}
|
||||
|
||||
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (m *mockNetConn) Close() error { return nil }
|
||||
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
|
||||
// createMockPoolConnection creates a mock pool connection for testing
|
||||
func createMockPoolConnection() *pool.Conn {
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
conn.SetUsable(true) // Make connection usable for testing
|
||||
return conn
|
||||
}
|
||||
|
||||
// mockPool implements pool.Pooler for testing
|
||||
type mockPool struct {
|
||||
removedConnections map[uint64]bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
|
||||
// Not implemented for testing
|
||||
}
|
||||
|
||||
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Use pool.Conn directly - no adapter needed
|
||||
mp.removedConnections[conn.GetID()] = true
|
||||
}
|
||||
|
||||
// WasRemoved safely checks if a connection was removed from the pool
|
||||
func (mp *mockPool) WasRemoved(connID uint64) bool {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
return mp.removedConnections[connID]
|
||||
}
|
||||
|
||||
func (mp *mockPool) Len() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) IdleLen() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (mp *mockPool) Stats() *pool.Stats {
|
||||
return &pool.Stats{}
|
||||
}
|
||||
|
||||
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
|
||||
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
|
||||
// Mock implementation - do nothing
|
||||
}
|
||||
|
||||
func (mp *mockPool) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnectionHook tests the Redis connection processor functionality
|
||||
func TestConnectionHook(t *testing.T) {
|
||||
// Create a base dialer for testing
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Mode: MaintNotificationsAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
|
||||
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
|
||||
MaxHandoffRetries: 3,
|
||||
LogLevel: 2,
|
||||
}
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Verify connection is marked for handoff
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Fatal("Connection should be marked for handoff")
|
||||
}
|
||||
// Set a mock initialization function with synchronization
|
||||
initConnCalled := make(chan bool, 1)
|
||||
proceedWithInit := make(chan bool, 1)
|
||||
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||
select {
|
||||
case initConnCalled <- true:
|
||||
default:
|
||||
}
|
||||
// Wait for test to proceed
|
||||
<-proceedWithInit
|
||||
return nil
|
||||
}
|
||||
conn.SetInitConnFunc(initConnFunc)
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection immediately (handoff queued)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled immediately with event-driven handoff")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when queuing handoff")
|
||||
}
|
||||
|
||||
// Wait for initialization to be called (indicates handoff started)
|
||||
select {
|
||||
case <-initConnCalled:
|
||||
// Good, initialization was called
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Timeout waiting for initialization function to be called")
|
||||
}
|
||||
|
||||
// Connection should be in pending map while initialization is blocked
|
||||
if _, pending := processor.pending.Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending handoffs map")
|
||||
}
|
||||
|
||||
// Allow initialization to proceed
|
||||
proceedWithInit <- true
|
||||
|
||||
// Wait for handoff to complete with proper timeout and polling
|
||||
timeout := time.After(2 * time.Second)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.pending.Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify handoff completed (removed from pending map)
|
||||
if _, pending := processor.pending.Load(conn); pending {
|
||||
t.Error("Connection should be removed from pending map after handoff")
|
||||
}
|
||||
|
||||
// Verify connection is usable again
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after successful handoff")
|
||||
}
|
||||
|
||||
// Verify handoff state is cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after completion")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandoffNotNeeded", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
// Don't mark for handoff
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error when handoff not needed: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection normally
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled when no handoff needed")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when no handoff needed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyEndpoint", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error with empty endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection (empty endpoint clears state)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled after clearing empty endpoint")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed after clearing empty endpoint")
|
||||
}
|
||||
|
||||
// State should be cleared
|
||||
if conn.ShouldHandoff() {
|
||||
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
|
||||
// Create a failing base dialer
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Mode: MaintNotificationsAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
HandoffTimeout: 1 * time.Second, // Shorter timeout for faster test
|
||||
LogLevel: 2,
|
||||
}
|
||||
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not return error to caller: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection initially (handoff queued)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled initially with event-driven handoff")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when queuing handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete and fail with proper timeout and polling
|
||||
// Use longer timeout to account for handoff timeout + processing time
|
||||
timeout := time.After(5 * time.Second)
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// wait for handoff to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for failed handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.pending.Load(conn.GetID()); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection should be removed from pending map after failed handoff
|
||||
if _, pending := processor.pending.Load(conn.GetID()); pending {
|
||||
t.Error("Connection should be removed from pending map after failed handoff")
|
||||
}
|
||||
|
||||
// Handoff state should still be set (since handoff failed)
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Error("Connection should still be marked for handoff after failed handoff")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BufferedDataRESP2", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// For this test, we'll just verify the logic works for connections without buffered data
|
||||
// The actual buffered data detection is handled by the pool's connection health check
|
||||
// which is outside the scope of the Redis connection processor
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
// Should pool the connection normally (no buffered data in mock)
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled when no buffered data")
|
||||
}
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed when no buffered data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGet", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Mode: MaintNotificationsAuto,
|
||||
EndpointType: EndpointTypeAuto,
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
|
||||
LogLevel: 2,
|
||||
}
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// Simulate a pending handoff by marking for handoff and queuing
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
processor.pending.Delete(conn)
|
||||
})
|
||||
|
||||
t.Run("EventDrivenStateManagement", func(t *testing.T) {
|
||||
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// Test initial state - no pending handoffs
|
||||
if _, pending := processor.pending.Load(conn); pending {
|
||||
t.Error("New connection should not have pending handoffs")
|
||||
}
|
||||
|
||||
// Test adding to pending map
|
||||
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||
|
||||
if _, pending := processor.pending.Load(conn.GetID()); !pending {
|
||||
t.Error("Connection should be in pending map")
|
||||
}
|
||||
|
||||
// Test OnGet with pending handoff
|
||||
ctx := context.Background()
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||
}
|
||||
|
||||
// Test removing from pending map and clearing handoff state
|
||||
processor.pending.Delete(conn)
|
||||
if _, pending := processor.pending.Load(conn); pending {
|
||||
t.Error("Connection should be removed from pending map")
|
||||
}
|
||||
|
||||
// Clear handoff state to simulate completed handoff
|
||||
conn.ClearHandoffState()
|
||||
conn.SetUsable(true) // Make connection usable again
|
||||
|
||||
// Test OnGet without pending handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||
// Create processor with small queue to test optimization features
|
||||
config := &Config{
|
||||
MaxWorkers: 3,
|
||||
HandoffQueueSize: 2,
|
||||
MaxHandoffRetries: 3, // Small queue to trigger optimizations
|
||||
LogLevel: 3, // Debug level to see optimization logs
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Add small delay to simulate network latency
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create multiple connections that need handoff to fill the queue
|
||||
connections := make([]*pool.Conn, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
connections[i] = createMockPoolConnection()
|
||||
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
|
||||
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
|
||||
}
|
||||
// Set a mock initialization function
|
||||
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
successCount := 0
|
||||
|
||||
// Process connections - should trigger scaling and timeout logic
|
||||
for _, conn := range connections {
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Logf("OnPut returned error (expected with timeout): %v", err)
|
||||
}
|
||||
|
||||
if shouldPool && !shouldRemove {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
// With timeout and scaling, most handoffs should eventually succeed
|
||||
if successCount == 0 {
|
||||
t.Error("Should have queued some handoffs with timeout and scaling")
|
||||
}
|
||||
|
||||
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
|
||||
|
||||
// Give time for workers to process and scaling to occur
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("WorkerScalingBehavior", func(t *testing.T) {
|
||||
// Create processor with small queue to test scaling behavior
|
||||
config := &Config{
|
||||
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
|
||||
HandoffQueueSize: 1,
|
||||
MaxHandoffRetries: 3, // Very small queue to force scaling
|
||||
LogLevel: 2, // Info level to see scaling logs
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Verify initial worker count (should be 0 with on-demand workers)
|
||||
if processor.GetCurrentWorkers() != 0 {
|
||||
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
|
||||
}
|
||||
if processor.GetScaleLevel() != 0 {
|
||||
t.Errorf("Processor should be at scale level 0 initially, got %d", processor.GetScaleLevel())
|
||||
}
|
||||
if processor.maxWorkers != 15 {
|
||||
t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers)
|
||||
}
|
||||
|
||||
// The on-demand worker behavior creates workers only when needed
|
||||
// This test just verifies the basic configuration is correct
|
||||
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
|
||||
processor.maxWorkers, processor.GetCurrentWorkers())
|
||||
})
|
||||
|
||||
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
|
||||
// Create processor with fast post-handoff duration for testing
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
|
||||
RelaxedTimeout: 5 * time.Second,
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a connection and trigger handoff
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("Handoff should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete with proper timeout and polling
|
||||
timeout := time.After(1 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.pending.Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify relaxed timeout is set with deadline
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should have relaxed timeout after handoff")
|
||||
}
|
||||
|
||||
// Test that timeout is still active before deadline
|
||||
// We'll use HasRelaxedTimeout which internally checks the deadline
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should still have active relaxed timeout before deadline")
|
||||
}
|
||||
|
||||
// Wait for deadline to pass
|
||||
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
|
||||
|
||||
// Test that timeout is automatically restored after deadline
|
||||
// HasRelaxedTimeout should return false after deadline passes
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should not have active relaxed timeout after deadline")
|
||||
}
|
||||
|
||||
// Additional verification: calling HasRelaxedTimeout again should still return false
|
||||
// and should have cleared the internal timeout values
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Connection should not have relaxed timeout after deadline (second check)")
|
||||
}
|
||||
|
||||
t.Logf("Passive timeout restoration test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("UsableFlagBehavior", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new connection without setting it usable
|
||||
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||
conn := pool.NewConn(mockNetConn)
|
||||
|
||||
// Initially, connection should not be usable (not initialized)
|
||||
if conn.IsUsable() {
|
||||
t.Error("New connection should not be usable before initialization")
|
||||
}
|
||||
|
||||
// Simulate initialization by setting usable to true
|
||||
conn.SetUsable(true)
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after initialization")
|
||||
}
|
||||
|
||||
// OnGet should succeed for usable connection
|
||||
err := processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||
}
|
||||
|
||||
// Mark connection for handoff
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Connection should still be usable until queued, but marked for handoff
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should still be usable after being marked for handoff (until queued)")
|
||||
}
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Error("Connection should be marked for handoff")
|
||||
}
|
||||
|
||||
// OnGet should fail for connection marked for handoff
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("OnGet should fail for connection marked for handoff")
|
||||
}
|
||||
if err != ErrConnectionMarkedForHandoff {
|
||||
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||
}
|
||||
|
||||
// Process the connection to trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should succeed: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after handoff")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// After handoff completion, connection should be usable again
|
||||
if !conn.IsUsable() {
|
||||
t.Error("Connection should be usable after handoff completion")
|
||||
}
|
||||
|
||||
// OnGet should succeed again
|
||||
err = processor.OnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Usable flag behavior test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("StaticQueueBehavior", func(t *testing.T) {
|
||||
config := &Config{
|
||||
MaxWorkers: 3,
|
||||
HandoffQueueSize: 50,
|
||||
MaxHandoffRetries: 3, // Explicit static queue size
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Verify queue capacity matches configured size
|
||||
queueCapacity := cap(processor.handoffQueue)
|
||||
if queueCapacity != 50 {
|
||||
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
|
||||
}
|
||||
|
||||
// Test that queue size is static regardless of pool size
|
||||
// (No dynamic resizing should occur)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill part of the queue
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
|
||||
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
|
||||
}
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to queue handoff %d: %v", i, err)
|
||||
}
|
||||
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Errorf("Connection %d should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
|
||||
i, shouldPool, shouldRemove)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify queue capacity remains static (the main purpose of this test)
|
||||
finalCapacity := cap(processor.handoffQueue)
|
||||
|
||||
if finalCapacity != 50 {
|
||||
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
|
||||
}
|
||||
|
||||
// Note: We don't check queue size here because workers process items quickly
|
||||
// The important thing is that the capacity remains static regardless of pool size
|
||||
})
|
||||
|
||||
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
|
||||
// Create a failing dialer that will cause handoff initialization to fail
|
||||
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Return a connection that will fail during initialization
|
||||
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
MaxHandoffRetries: 3,
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a mock pool that tracks removals
|
||||
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
|
||||
processor.SetPool(mockPool)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a connection and mark it for handoff
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a failing initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return fmt.Errorf("initialization failed")
|
||||
})
|
||||
|
||||
// Process the connection - handoff should fail and connection should be removed
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not error: %v", err)
|
||||
}
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Error("Connection should be pooled after failed handoff attempt")
|
||||
}
|
||||
|
||||
// Wait for handoff to be attempted and fail
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify that the connection was removed from the pool
|
||||
if !mockPool.WasRemoved(conn.GetID()) {
|
||||
t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID())
|
||||
}
|
||||
|
||||
t.Logf("Connection removal on handoff failure test completed successfully")
|
||||
})
|
||||
|
||||
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
|
||||
// Create config with short post-handoff duration for testing
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
RelaxedTimeout: 5 * time.Second,
|
||||
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
|
||||
}
|
||||
|
||||
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &mockNetConn{addr: addr}, nil
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a mock initialization function
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("OnPut failed: %v", err)
|
||||
}
|
||||
|
||||
if !shouldPool {
|
||||
t.Error("Connection should be pooled after successful handoff")
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
t.Error("Connection should not be removed after successful handoff")
|
||||
}
|
||||
|
||||
// Wait for the handoff to complete (it happens asynchronously)
|
||||
timeout := time.After(1 * time.Second)
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
handoffCompleted := false
|
||||
for !handoffCompleted {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Timeout waiting for handoff to complete")
|
||||
case <-ticker.C:
|
||||
if _, pending := processor.pending.Load(conn); !pending {
|
||||
handoffCompleted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that relaxed timeout was applied to the new connection
|
||||
if !conn.HasRelaxedTimeout() {
|
||||
t.Error("New connection should have relaxed timeout applied after handoff")
|
||||
}
|
||||
|
||||
// Wait for the post-handoff duration to expire
|
||||
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
|
||||
|
||||
// Verify that relaxed timeout was automatically cleared
|
||||
if conn.HasRelaxedTimeout() {
|
||||
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
|
||||
conn := createMockPoolConnection()
|
||||
|
||||
// First mark should succeed
|
||||
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||
t.Fatalf("First MarkForHandoff should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Second mark should fail
|
||||
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
|
||||
t.Fatal("Second MarkForHandoff should return error")
|
||||
} else if err.Error() != "connection is already marked for handoff" {
|
||||
t.Fatalf("Expected specific error message, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify original handoff data is preserved
|
||||
if !conn.ShouldHandoff() {
|
||||
t.Fatal("Connection should still be marked for handoff")
|
||||
}
|
||||
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
|
||||
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
|
||||
}
|
||||
if conn.GetMovingSeqID() != 1 {
|
||||
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
|
||||
// Test that HandoffTimeout from config is actually used
|
||||
customTimeout := 2 * time.Second
|
||||
config := &Config{
|
||||
MaxWorkers: 2,
|
||||
HandoffQueueSize: 10,
|
||||
HandoffTimeout: customTimeout, // Custom timeout
|
||||
MaxHandoffRetries: 1, // Single retry to speed up test
|
||||
LogLevel: 2,
|
||||
}
|
||||
|
||||
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||
defer processor.Shutdown(context.Background())
|
||||
|
||||
// Create a connection that will test the timeout
|
||||
conn := createMockPoolConnection()
|
||||
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
|
||||
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||
}
|
||||
|
||||
// Set a dialer that will check the context timeout
|
||||
var timeoutVerified int32 // Use atomic for thread safety
|
||||
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||
// Check that the context has the expected timeout
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("Context should have a deadline")
|
||||
return errors.New("no deadline")
|
||||
}
|
||||
|
||||
// The deadline should be approximately customTimeout from now
|
||||
expectedDeadline := time.Now().Add(customTimeout)
|
||||
timeDiff := deadline.Sub(expectedDeadline)
|
||||
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
|
||||
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
|
||||
expectedDeadline, deadline, timeDiff)
|
||||
} else {
|
||||
atomic.StoreInt32(&timeoutVerified, 1)
|
||||
}
|
||||
|
||||
return nil // Successful handoff
|
||||
})
|
||||
|
||||
// Trigger handoff
|
||||
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
|
||||
if err != nil {
|
||||
t.Errorf("OnPut should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Connection should be queued for handoff
|
||||
if !shouldPool || shouldRemove {
|
||||
t.Errorf("Connection should be pooled for handoff processing")
|
||||
}
|
||||
|
||||
// Wait for handoff to complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if atomic.LoadInt32(&timeoutVerified) == 0 {
|
||||
t.Error("HandoffTimeout was not properly applied to context")
|
||||
}
|
||||
|
||||
t.Logf("HandoffTimeout configuration test completed successfully")
|
||||
})
|
||||
}
|
24
hitless/state.go
Normal file
24
hitless/state.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package hitless
|
||||
|
||||
// State represents the current state of a hitless upgrade operation.
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateIdle indicates no upgrade is in progress
|
||||
StateIdle State = iota
|
||||
|
||||
// StateHandoff indicates a connection handoff is in progress
|
||||
StateMoving
|
||||
)
|
||||
|
||||
// String returns a string representation of the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateIdle:
|
||||
return "idle"
|
||||
case StateMoving:
|
||||
return "moving"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
67
internal/interfaces/interfaces.go
Normal file
67
internal/interfaces/interfaces.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Package interfaces provides shared interfaces used by both the main redis package
|
||||
// and the hitless upgrade package to avoid circular dependencies.
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Forward declaration to avoid circular imports
|
||||
type NotificationProcessor interface {
|
||||
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
|
||||
UnregisterHandler(pushNotificationName string) error
|
||||
GetHandler(pushNotificationName string) interface{}
|
||||
}
|
||||
|
||||
// ClientInterface defines the interface that clients must implement for hitless upgrades.
|
||||
type ClientInterface interface {
|
||||
// GetOptions returns the client options.
|
||||
GetOptions() OptionsInterface
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
GetPushProcessor() NotificationProcessor
|
||||
}
|
||||
|
||||
// OptionsInterface defines the interface for client options.
|
||||
type OptionsInterface interface {
|
||||
// GetReadTimeout returns the read timeout.
|
||||
GetReadTimeout() time.Duration
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
GetWriteTimeout() time.Duration
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
GetNetwork() string
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
GetAddr() string
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
IsTLSEnabled() bool
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
GetProtocol() int
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
GetPoolSize() int
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
NewDialer() func(context.Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
// ConnectionWithRelaxedTimeout defines the interface for connections that support relaxed timeout adjustment.
|
||||
// This is used by the hitless upgrade system for per-connection timeout management.
|
||||
type ConnectionWithRelaxedTimeout interface {
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||
// These timeouts remain active until explicitly cleared.
|
||||
SetRelaxedTimeout(readTimeout, writeTimeout time.Duration)
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time)
|
||||
|
||||
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
|
||||
ClearRelaxedTimeout()
|
||||
}
|
@@ -2,6 +2,7 @@ package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
|
||||
b.Run(bm.String(), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: bm.poolSize,
|
||||
PoolSize: int32(bm.poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
b.Run(bm.String(), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: bm.poolSize,
|
||||
PoolSize: int32(bm.poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
It("should use default buffer sizes when not specified", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
})
|
||||
|
||||
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
ReadBufferSize: customReadSize,
|
||||
WriteBufferSize: customWriteSize,
|
||||
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
It("should handle zero buffer sizes by using defaults", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
ReadBufferSize: 0, // Should use default
|
||||
WriteBufferSize: 0, // Should use default
|
||||
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
||||
// without setting ReadBufferSize and WriteBufferSize
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 1000,
|
||||
// ReadBufferSize and WriteBufferSize are not set (will be 0)
|
||||
})
|
||||
|
@@ -3,7 +3,10 @@ package pool
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -12,17 +15,64 @@ import (
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
|
||||
type atomicNetConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// generateConnID generates a fast unique identifier for a connection with zero allocations
|
||||
func generateConnID() uint64 {
|
||||
return atomic.AddUint64(&connIDCounter, 1)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
usedAt int64 // atomic
|
||||
netConn net.Conn
|
||||
usedAt int64 // atomic
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||
netConnAtomic atomic.Value // stores *atomicNetConn
|
||||
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
|
||||
Inited bool
|
||||
// Lightweight mutex to protect reader operations during handoff
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
Inited atomic.Bool
|
||||
pooled bool
|
||||
closed atomic.Bool
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
|
||||
// Hitless upgrade support: relaxed timeouts during migrations/failovers
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
|
||||
|
||||
// Counter to track multiple relaxed timeout setters if we have nested calls
|
||||
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
|
||||
// if counter reaches 0, we clear the relaxed timeouts
|
||||
relaxedCounter atomic.Int32
|
||||
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
// Connection identifier for unique tracking across handoffs
|
||||
id uint64 // Unique numeric identifier for this connection
|
||||
|
||||
// Handoff state - using atomic operations for lock-free access
|
||||
usableAtomic atomic.Bool // Connection usability state
|
||||
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
|
||||
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
|
||||
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||
// newEndpointAtomic needs special handling as it's a string
|
||||
newEndpointAtomic atomic.Value // stores string
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
@@ -33,8 +83,8 @@ func NewConn(netConn net.Conn) *Conn {
|
||||
|
||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||
cn := &Conn{
|
||||
netConn: netConn,
|
||||
createdAt: time.Now(),
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
}
|
||||
|
||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||
@@ -50,6 +100,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
||||
}
|
||||
|
||||
// Store netConn atomically for lock-free access using wrapper
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
// Initialize atomic handoff state
|
||||
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
||||
cn.shouldHandoffAtomic.Store(false) // false initially
|
||||
cn.movingSeqIDAtomic.Store(0) // 0 initially
|
||||
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||
cn.newEndpointAtomic.Store("") // empty string initially
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(time.Now())
|
||||
return cn
|
||||
@@ -64,23 +124,368 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
// This is the fast path for accessing netConn without mutex overhead.
|
||||
func (cn *Conn) getNetConn() net.Conn {
|
||||
if v := cn.netConnAtomic.Load(); v != nil {
|
||||
if wrapper, ok := v.(*atomicNetConn); ok {
|
||||
return wrapper.conn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setNetConn stores the network connection atomically (lock-free).
|
||||
// This is used for the fast path of connection replacement.
|
||||
func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// Lock-free helper methods for handoff state management
|
||||
|
||||
// isUsable returns true if the connection is safe to use (lock-free).
|
||||
func (cn *Conn) isUsable() bool {
|
||||
return cn.usableAtomic.Load()
|
||||
}
|
||||
|
||||
// setUsable sets the usable flag atomically (lock-free).
|
||||
func (cn *Conn) setUsable(usable bool) {
|
||||
cn.usableAtomic.Store(usable)
|
||||
}
|
||||
|
||||
// shouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) shouldHandoff() bool {
|
||||
return cn.shouldHandoffAtomic.Load()
|
||||
}
|
||||
|
||||
// setShouldHandoff sets the handoff flag atomically (lock-free).
|
||||
func (cn *Conn) setShouldHandoff(should bool) {
|
||||
cn.shouldHandoffAtomic.Store(should)
|
||||
}
|
||||
|
||||
// getMovingSeqID returns the sequence ID atomically (lock-free).
|
||||
func (cn *Conn) getMovingSeqID() int64 {
|
||||
return cn.movingSeqIDAtomic.Load()
|
||||
}
|
||||
|
||||
// setMovingSeqID sets the sequence ID atomically (lock-free).
|
||||
func (cn *Conn) setMovingSeqID(seqID int64) {
|
||||
cn.movingSeqIDAtomic.Store(seqID)
|
||||
}
|
||||
|
||||
// getNewEndpoint returns the new endpoint atomically (lock-free).
|
||||
func (cn *Conn) getNewEndpoint() string {
|
||||
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
|
||||
return endpoint.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// setNewEndpoint sets the new endpoint atomically (lock-free).
|
||||
func (cn *Conn) setNewEndpoint(endpoint string) {
|
||||
cn.newEndpointAtomic.Store(endpoint)
|
||||
}
|
||||
|
||||
// setHandoffRetries sets the retry count atomically (lock-free).
|
||||
func (cn *Conn) setHandoffRetries(retries int) {
|
||||
cn.handoffRetriesAtomic.Store(uint32(retries))
|
||||
}
|
||||
|
||||
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
|
||||
func (cn *Conn) incrementHandoffRetries(delta int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
return cn.isUsable()
|
||||
}
|
||||
|
||||
func (cn *Conn) IsInited() bool {
|
||||
return cn.Inited.Load()
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
cn.setUsable(usable)
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||
cn.relaxedCounter.Add(1)
|
||||
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||
}
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||
cn.relaxedCounter.Add(1)
|
||||
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
|
||||
}
|
||||
|
||||
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) ClearRelaxedTimeout() {
|
||||
// Atomically decrement counter and check if we should clear
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
// Use compare-and-swap to ensure only one goroutine clears
|
||||
if cn.relaxedCounter.CompareAndSwap(newCount, 0) {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) clearRelaxedTimeout() {
|
||||
cn.relaxedReadTimeoutNs.Store(0)
|
||||
cn.relaxedWriteTimeoutNs.Store(0)
|
||||
cn.relaxedDeadlineNs.Store(0)
|
||||
cn.relaxedCounter.Store(0)
|
||||
}
|
||||
|
||||
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
|
||||
// This checks both the timeout values and the deadline (if set).
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) HasRelaxedTimeout() bool {
|
||||
// Fast path: no relaxed timeouts are set
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// If no relaxed timeouts are set, return false
|
||||
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, relaxed timeouts are active
|
||||
if deadlineNs == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If deadline is set, check if it's still in the future
|
||||
return time.Now().UnixNano() < deadlineNs
|
||||
}
|
||||
|
||||
// getEffectiveReadTimeout returns the timeout to use for read operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if readTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(readTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
cn.relaxedCounter.Add(-1)
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// getEffectiveWriteTimeout returns the timeout to use for write operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if writeTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
nowNs := time.Now().UnixNano()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(writeTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
cn.relaxedCounter.Add(-1)
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) SetOnClose(fn func() error) {
|
||||
cn.onClose = fn
|
||||
}
|
||||
|
||||
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
|
||||
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
|
||||
cn.initConnFunc = fn
|
||||
}
|
||||
|
||||
// ExecuteInitConn runs the stored connection initialization function if available.
|
||||
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
|
||||
if cn.initConnFunc != nil {
|
||||
return cn.initConnFunc(ctx, cn)
|
||||
}
|
||||
return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||
cn.netConn = netConn
|
||||
// Store the new connection atomically first (lock-free)
|
||||
cn.setNetConn(netConn)
|
||||
// Clear relaxed timeouts when connection is replaced
|
||||
cn.clearRelaxedTimeout()
|
||||
|
||||
// Protect reader reset operations to avoid data races
|
||||
// Use write lock since we're modifying the reader state
|
||||
cn.readerMu.Lock()
|
||||
cn.rd.Reset(netConn)
|
||||
cn.readerMu.Unlock()
|
||||
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
|
||||
// GetNetConn safely returns the current network connection using atomic load (lock-free).
|
||||
// This method is used by the pool for health checks and provides better performance.
|
||||
func (cn *Conn) GetNetConn() net.Conn {
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||
// New connection is not initialized yet
|
||||
cn.Inited.Store(false)
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
return cn.ExecuteInitConn(ctx)
|
||||
}
|
||||
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
|
||||
// Returns an error if the connection is already marked for handoff.
|
||||
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||
// Use single atomic CAS operation for state transition
|
||||
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
|
||||
return errors.New("connection is already marked for handoff")
|
||||
}
|
||||
|
||||
cn.setNewEndpoint(newEndpoint)
|
||||
cn.setMovingSeqID(seqID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
// Use single atomic CAS operation for state transition
|
||||
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
|
||||
return errors.New("connection was not marked for handoff")
|
||||
}
|
||||
cn.setUsable(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreHandoffState restores the handoff state after a failed handoff (lock-free).
|
||||
func (cn *Conn) RestoreHandoffState() {
|
||||
// Restore shouldHandoff flag for retry
|
||||
cn.shouldHandoffAtomic.Store(true)
|
||||
// Keep usable=false to prevent the connection from being used until handoff succeeds
|
||||
cn.setUsable(false)
|
||||
}
|
||||
|
||||
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
return cn.shouldHandoff()
|
||||
}
|
||||
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
return cn.getNewEndpoint()
|
||||
}
|
||||
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
return cn.getMovingSeqID()
|
||||
}
|
||||
|
||||
// GetID returns the unique identifier for this connection.
|
||||
func (cn *Conn) GetID() uint64 {
|
||||
return cn.id
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
// clear handoff state
|
||||
cn.setShouldHandoff(false)
|
||||
cn.setNewEndpoint("")
|
||||
cn.setMovingSeqID(0)
|
||||
cn.setHandoffRetries(0)
|
||||
cn.setUsable(true) // Connection is safe to use again after handoff completes
|
||||
}
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return cn.incrementHandoffRetries(n)
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) HasBufferedData() bool {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
return cn.rd.Buffered() > 0
|
||||
}
|
||||
|
||||
// PeekReplyTypeSafe safely peeks at the reply type.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
|
||||
if cn.rd.Buffered() <= 0 {
|
||||
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
|
||||
}
|
||||
return cn.rd.PeekReplyType()
|
||||
}
|
||||
|
||||
func (cn *Conn) Write(b []byte) (int, error) {
|
||||
return cn.netConn.Write(b)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Write(b)
|
||||
}
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) RemoteAddr() net.Addr {
|
||||
if cn.netConn != nil {
|
||||
return cn.netConn.RemoteAddr()
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -89,7 +494,16 @@ func (cn *Conn) WithReader(
|
||||
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
|
||||
|
||||
// Get the connection directly from atomic storage
|
||||
netConn := cn.getNetConn()
|
||||
if netConn == nil {
|
||||
return fmt.Errorf("redis: connection not available")
|
||||
}
|
||||
|
||||
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -100,13 +514,26 @@ func (cn *Conn) WithWriter(
|
||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
return err
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||
|
||||
// Always set write deadline, even if getNetConn() returns nil
|
||||
// This prevents write operations from hanging indefinitely
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// If getNetConn() returns nil, we still need to respect the timeout
|
||||
// Return an error to prevent indefinite blocking
|
||||
return fmt.Errorf("redis: connection not available for write operation")
|
||||
}
|
||||
}
|
||||
|
||||
if cn.bw.Buffered() > 0 {
|
||||
cn.bw.Reset(cn.netConn)
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fn(cn.wr); err != nil {
|
||||
@@ -116,19 +543,33 @@ func (cn *Conn) WithWriter(
|
||||
return cn.bw.Flush()
|
||||
}
|
||||
|
||||
func (cn *Conn) IsClosed() bool {
|
||||
return cn.closed.Load()
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
cn.closed.Store(true)
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
}
|
||||
return cn.netConn.Close()
|
||||
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
||||
// This is used to check if there are push notifications available
|
||||
// Important: This will work on Linux, but not on Windows
|
||||
func (cn *Conn) MaybeHasData() bool {
|
||||
return maybeHasData(cn.netConn)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return maybeHasData(netConn)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
|
@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
|
||||
}
|
||||
|
||||
func (cn *Conn) NetConn() net.Conn {
|
||||
return cn.netConn
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
func (p *ConnPool) CheckMinIdleConns() {
|
||||
|
114
internal/pool/hooks.go
Normal file
114
internal/pool/hooks.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolHook defines the interface for connection lifecycle hooks.
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||
}
|
||||
|
||||
// PoolHookManager manages multiple pool hooks.
|
||||
type PoolHookManager struct {
|
||||
hooks []PoolHook
|
||||
hooksMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolHookManager creates a new pool hook manager.
|
||||
func NewPoolHookManager() *PoolHookManager {
|
||||
return &PoolHookManager{
|
||||
hooks: make([]PoolHook, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddHook adds a pool hook to the manager.
|
||||
// Hooks are called in the order they were added.
|
||||
func (phm *PoolHookManager) AddHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
phm.hooks = append(phm.hooks, hook)
|
||||
}
|
||||
|
||||
// RemoveHook removes a pool hook from the manager.
|
||||
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
|
||||
for i, h := range phm.hooks {
|
||||
if h == hook {
|
||||
// Remove hook by swapping with last element and truncating
|
||||
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
|
||||
phm.hooks = phm.hooks[:len(phm.hooks)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
shouldPool = true // Default to pooling the connection
|
||||
|
||||
for _, hook := range phm.hooks {
|
||||
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||
|
||||
if hookErr != nil {
|
||||
return false, true, hookErr
|
||||
}
|
||||
|
||||
// If any hook says to remove or not pool, respect that decision
|
||||
if hookShouldRemove {
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
if !hookShouldPool {
|
||||
shouldPool = false
|
||||
}
|
||||
}
|
||||
|
||||
return shouldPool, false, nil
|
||||
}
|
||||
|
||||
// GetHookCount returns the number of registered hooks (for testing).
|
||||
func (phm *PoolHookManager) GetHookCount() int {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
return len(phm.hooks)
|
||||
}
|
||||
|
||||
// GetHooks returns a copy of all registered hooks.
|
||||
func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
hooks := make([]PoolHook, len(phm.hooks))
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
213
internal/pool/hooks_test.go
Normal file
213
internal/pool/hooks_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestHook for testing hook functionality
|
||||
type TestHook struct {
|
||||
OnGetCalled int
|
||||
OnPutCalled int
|
||||
GetError error
|
||||
PutError error
|
||||
ShouldPool bool
|
||||
ShouldRemove bool
|
||||
}
|
||||
|
||||
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||
th.OnGetCalled++
|
||||
return th.GetError
|
||||
}
|
||||
|
||||
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
th.OnPutCalled++
|
||||
return th.ShouldPool, th.ShouldRemove, th.PutError
|
||||
}
|
||||
|
||||
func TestPoolHookManager(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Test initial state
|
||||
if manager.GetHookCount() != 0 {
|
||||
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Add hooks
|
||||
hook1 := &TestHook{ShouldPool: true}
|
||||
hook2 := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(hook1)
|
||||
manager.AddHook(hook2)
|
||||
|
||||
if manager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test ProcessOnGet
|
||||
ctx := context.Background()
|
||||
conn := &Conn{} // Mock connection
|
||||
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||
}
|
||||
|
||||
if hook1.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||
}
|
||||
|
||||
if hook2.OnGetCalled != 1 {
|
||||
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
|
||||
}
|
||||
|
||||
// Test ProcessOnPut
|
||||
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
if !shouldPool {
|
||||
t.Error("Expected shouldPool to be true")
|
||||
}
|
||||
|
||||
if shouldRemove {
|
||||
t.Error("Expected shouldRemove to be false")
|
||||
}
|
||||
|
||||
if hook1.OnPutCalled != 1 {
|
||||
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
|
||||
}
|
||||
|
||||
if hook2.OnPutCalled != 1 {
|
||||
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
|
||||
}
|
||||
|
||||
// Remove a hook
|
||||
manager.RemoveHook(hook1)
|
||||
|
||||
if manager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookErrorHandling(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Hook that returns error on Get
|
||||
errorHook := &TestHook{
|
||||
GetError: errors.New("test error"),
|
||||
ShouldPool: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(errorHook)
|
||||
manager.AddHook(normalHook)
|
||||
|
||||
ctx := context.Background()
|
||||
conn := &Conn{}
|
||||
|
||||
// Test that error stops processing
|
||||
err := manager.ProcessOnGet(ctx, conn, false)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ProcessOnGet")
|
||||
}
|
||||
|
||||
if errorHook.OnGetCalled != 1 {
|
||||
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||
}
|
||||
|
||||
// normalHook should not be called due to error
|
||||
if normalHook.OnGetCalled != 0 {
|
||||
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookShouldRemove(t *testing.T) {
|
||||
manager := NewPoolHookManager()
|
||||
|
||||
// Hook that says to remove connection
|
||||
removeHook := &TestHook{
|
||||
ShouldPool: false,
|
||||
ShouldRemove: true,
|
||||
}
|
||||
|
||||
normalHook := &TestHook{ShouldPool: true}
|
||||
|
||||
manager.AddHook(removeHook)
|
||||
manager.AddHook(normalHook)
|
||||
|
||||
ctx := context.Background()
|
||||
conn := &Conn{}
|
||||
|
||||
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||
}
|
||||
|
||||
if shouldPool {
|
||||
t.Error("Expected shouldPool to be false")
|
||||
}
|
||||
|
||||
if !shouldRemove {
|
||||
t.Error("Expected shouldRemove to be true")
|
||||
}
|
||||
|
||||
if removeHook.OnPutCalled != 1 {
|
||||
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
|
||||
}
|
||||
|
||||
// normalHook should not be called due to early return
|
||||
if normalHook.OnPutCalled != 0 {
|
||||
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithHooks(t *testing.T) {
|
||||
// Create a pool with hooks
|
||||
hookManager := NewPoolHookManager()
|
||||
testHook := &TestHook{ShouldPool: true}
|
||||
hookManager.AddHook(testHook)
|
||||
|
||||
opt := &Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil // Mock connection
|
||||
},
|
||||
PoolSize: 1,
|
||||
DialTimeout: time.Second,
|
||||
}
|
||||
|
||||
pool := NewConnPool(opt)
|
||||
defer pool.Close()
|
||||
|
||||
// Add hook to pool after creation
|
||||
pool.AddPoolHook(testHook)
|
||||
|
||||
// Verify hooks are initialized
|
||||
if pool.hookManager == nil {
|
||||
t.Error("Expected hookManager to be initialized")
|
||||
}
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test adding hook to pool
|
||||
additionalHook := &TestHook{ShouldPool: true}
|
||||
pool.AddPoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 2 {
|
||||
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
|
||||
// Test removing hook from pool
|
||||
pool.RemovePoolHook(additionalHook)
|
||||
|
||||
if pool.hookManager.GetHookCount() != 1 {
|
||||
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
|
||||
}
|
||||
}
|
@@ -3,6 +3,7 @@ package pool
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -22,6 +23,12 @@ var (
|
||||
|
||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||
|
||||
popAttempts = 10
|
||||
getAttempts = 3
|
||||
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
|
||||
maxTime = minTime.Add(1<<63 - 1)
|
||||
noExpiration = maxTime
|
||||
)
|
||||
|
||||
var timers = sync.Pool{
|
||||
@@ -38,11 +45,14 @@ type Stats struct {
|
||||
Misses uint32 // number of times free connection was NOT found in the pool
|
||||
Timeouts uint32 // number of times a wait timeout occurred
|
||||
WaitCount uint32 // number of times a connection was waited
|
||||
Unusable uint32 // number of times a connection was found to be unusable
|
||||
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
|
||||
|
||||
TotalConns uint32 // number of total connections in the pool
|
||||
IdleConns uint32 // number of idle connections in the pool
|
||||
StaleConns uint32 // number of stale connections removed from the pool
|
||||
|
||||
PubSubStats PubSubStats
|
||||
}
|
||||
|
||||
type Pooler interface {
|
||||
@@ -57,29 +67,27 @@ type Pooler interface {
|
||||
IdleLen() int
|
||||
Stats() *Stats
|
||||
|
||||
AddPoolHook(hook PoolHook)
|
||||
RemovePoolHook(hook PoolHook)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Dialer func(context.Context) (net.Conn, error)
|
||||
|
||||
PoolFIFO bool
|
||||
PoolSize int
|
||||
DialTimeout time.Duration
|
||||
PoolTimeout time.Duration
|
||||
MinIdleConns int
|
||||
MaxIdleConns int
|
||||
MaxActiveConns int
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnMaxLifetime time.Duration
|
||||
|
||||
|
||||
// Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without)
|
||||
Protocol int
|
||||
|
||||
Dialer func(context.Context) (net.Conn, error)
|
||||
ReadBufferSize int
|
||||
WriteBufferSize int
|
||||
|
||||
PoolFIFO bool
|
||||
PoolSize int32
|
||||
DialTimeout time.Duration
|
||||
PoolTimeout time.Duration
|
||||
MinIdleConns int32
|
||||
MaxIdleConns int32
|
||||
MaxActiveConns int32
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnMaxLifetime time.Duration
|
||||
PushNotificationsEnabled bool
|
||||
}
|
||||
|
||||
type lastDialErrorWrap struct {
|
||||
@@ -95,16 +103,21 @@ type ConnPool struct {
|
||||
queue chan struct{}
|
||||
|
||||
connsMu sync.Mutex
|
||||
conns []*Conn
|
||||
conns map[uint64]*Conn
|
||||
idleConns []*Conn
|
||||
|
||||
poolSize int
|
||||
idleConnsLen int
|
||||
poolSize atomic.Int32
|
||||
idleConnsLen atomic.Int32
|
||||
idleCheckInProgress atomic.Bool
|
||||
|
||||
stats Stats
|
||||
waitDurationNs atomic.Int64
|
||||
|
||||
_closed uint32 // atomic
|
||||
|
||||
// Pool hooks manager for flexible connection processing
|
||||
hookManagerMu sync.RWMutex
|
||||
hookManager *PoolHookManager
|
||||
}
|
||||
|
||||
var _ Pooler = (*ConnPool)(nil)
|
||||
@@ -114,34 +127,69 @@ func NewConnPool(opt *Options) *ConnPool {
|
||||
cfg: opt,
|
||||
|
||||
queue: make(chan struct{}, opt.PoolSize),
|
||||
conns: make([]*Conn, 0, opt.PoolSize),
|
||||
conns: make(map[uint64]*Conn),
|
||||
idleConns: make([]*Conn, 0, opt.PoolSize),
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
p.checkMinIdleConns()
|
||||
p.connsMu.Unlock()
|
||||
// Only create MinIdleConns if explicitly requested (> 0)
|
||||
// This avoids creating connections during pool initialization for tests
|
||||
if opt.MinIdleConns > 0 {
|
||||
p.connsMu.Lock()
|
||||
p.checkMinIdleConns()
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// initializeHooks sets up the pool hooks system.
|
||||
func (p *ConnPool) initializeHooks() {
|
||||
p.hookManager = NewPoolHookManager()
|
||||
}
|
||||
|
||||
// AddPoolHook adds a pool hook to the pool.
|
||||
func (p *ConnPool) AddPoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager == nil {
|
||||
p.initializeHooks()
|
||||
}
|
||||
p.hookManager.AddHook(hook)
|
||||
}
|
||||
|
||||
// RemovePoolHook removes a pool hook from the pool.
|
||||
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
|
||||
p.hookManagerMu.Lock()
|
||||
defer p.hookManagerMu.Unlock()
|
||||
|
||||
if p.hookManager != nil {
|
||||
p.hookManager.RemoveHook(hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) checkMinIdleConns() {
|
||||
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer p.idleCheckInProgress.Store(false)
|
||||
|
||||
if p.cfg.MinIdleConns == 0 {
|
||||
return
|
||||
}
|
||||
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
|
||||
|
||||
// Only create idle connections if we haven't reached the total pool size limit
|
||||
// MinIdleConns should be a subset of PoolSize, not additional connections
|
||||
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
|
||||
select {
|
||||
case p.queue <- struct{}{}:
|
||||
p.poolSize++
|
||||
p.idleConnsLen++
|
||||
|
||||
p.poolSize.Add(1)
|
||||
p.idleConnsLen.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
p.connsMu.Lock()
|
||||
p.poolSize--
|
||||
p.idleConnsLen--
|
||||
p.connsMu.Unlock()
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
|
||||
p.freeTurn()
|
||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||
@@ -150,12 +198,9 @@ func (p *ConnPool) checkMinIdleConns() {
|
||||
|
||||
err := p.addIdleConn()
|
||||
if err != nil && err != ErrClosed {
|
||||
p.connsMu.Lock()
|
||||
p.poolSize--
|
||||
p.idleConnsLen--
|
||||
p.connsMu.Unlock()
|
||||
p.poolSize.Add(-1)
|
||||
p.idleConnsLen.Add(-1)
|
||||
}
|
||||
|
||||
p.freeTurn()
|
||||
}()
|
||||
default:
|
||||
@@ -172,6 +217,9 @@ func (p *ConnPool) addIdleConn() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
@@ -182,11 +230,15 @@ func (p *ConnPool) addIdleConn() error {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
p.conns = append(p.conns, cn)
|
||||
p.conns[cn.GetID()] = cn
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewConn creates a new connection and returns it to the user.
|
||||
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
|
||||
//
|
||||
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
|
||||
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
|
||||
return p.newConn(ctx, false)
|
||||
}
|
||||
@@ -196,33 +248,42 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
||||
p.connsMu.Unlock()
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
p.connsMu.Unlock()
|
||||
|
||||
cn, err := p.dialConn(ctx, pooled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Mark connection as usable after successful creation
|
||||
// This is essential for normal pool operations
|
||||
cn.SetUsable(true)
|
||||
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
||||
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||
_ = cn.Close()
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
p.conns = append(p.conns, cn)
|
||||
p.connsMu.Lock()
|
||||
defer p.connsMu.Unlock()
|
||||
if p.closed() {
|
||||
_ = cn.Close()
|
||||
return nil, ErrClosed
|
||||
}
|
||||
// Check if pool was closed while we were waiting for the lock
|
||||
if p.conns == nil {
|
||||
p.conns = make(map[uint64]*Conn)
|
||||
}
|
||||
p.conns[cn.GetID()] = cn
|
||||
|
||||
if pooled {
|
||||
// If pool is full remove the cn on next Put.
|
||||
if p.poolSize >= p.cfg.PoolSize {
|
||||
currentPoolSize := p.poolSize.Load()
|
||||
if currentPoolSize >= int32(p.cfg.PoolSize) {
|
||||
cn.pooled = false
|
||||
} else {
|
||||
p.poolSize++
|
||||
p.poolSize.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,6 +310,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
|
||||
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
||||
cn.pooled = pooled
|
||||
if p.cfg.ConnMaxLifetime > 0 {
|
||||
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
|
||||
} else {
|
||||
cn.expiresAt = noExpiration
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
@@ -289,6 +356,14 @@ func (p *ConnPool) getLastDialError() error {
|
||||
|
||||
// Get returns existed connection from the pool or creates a new one.
|
||||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return p.getConn(ctx)
|
||||
}
|
||||
|
||||
// getConn returns a connection from the pool.
|
||||
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||
var cn *Conn
|
||||
var err error
|
||||
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
@@ -297,9 +372,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
attempts := 0
|
||||
for {
|
||||
if attempts >= getAttempts {
|
||||
log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts)
|
||||
break
|
||||
}
|
||||
attempts++
|
||||
|
||||
p.connsMu.Lock()
|
||||
cn, err := p.popIdle()
|
||||
cn, err = p.popIdle()
|
||||
p.connsMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -311,11 +394,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
break
|
||||
}
|
||||
|
||||
if !p.isHealthyConn(cn) {
|
||||
if !p.isHealthyConn(cn, now) {
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
||||
log.Printf("redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||
// Failed to process connection, discard it
|
||||
_ = p.CloseConn(cn)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint32(&p.stats.Hits, 1)
|
||||
return cn, nil
|
||||
}
|
||||
@@ -328,6 +425,20 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process connection using the hooks system
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
||||
// Failed to process connection, discard it
|
||||
log.Printf("redis: connection pool: failed to process new connection by hook: %v", err)
|
||||
_ = p.CloseConn(newcn)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return newcn, nil
|
||||
}
|
||||
|
||||
@@ -356,7 +467,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
|
||||
}
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
|
||||
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
|
||||
atomic.AddUint32(&p.stats.WaitCount, 1)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
@@ -376,68 +487,128 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
n := len(p.idleConns)
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var cn *Conn
|
||||
if p.cfg.PoolFIFO {
|
||||
cn = p.idleConns[0]
|
||||
copy(p.idleConns, p.idleConns[1:])
|
||||
p.idleConns = p.idleConns[:n-1]
|
||||
} else {
|
||||
idx := n - 1
|
||||
cn = p.idleConns[idx]
|
||||
p.idleConns = p.idleConns[:idx]
|
||||
attempts := 0
|
||||
|
||||
for attempts < popAttempts {
|
||||
if len(p.idleConns) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.cfg.PoolFIFO {
|
||||
cn = p.idleConns[0]
|
||||
copy(p.idleConns, p.idleConns[1:])
|
||||
p.idleConns = p.idleConns[:len(p.idleConns)-1]
|
||||
} else {
|
||||
idx := len(p.idleConns) - 1
|
||||
cn = p.idleConns[idx]
|
||||
p.idleConns = p.idleConns[:idx]
|
||||
}
|
||||
attempts++
|
||||
|
||||
if cn.IsUsable() {
|
||||
p.idleConnsLen.Add(-1)
|
||||
break
|
||||
}
|
||||
|
||||
// Connection is not usable, put it back in the pool
|
||||
if p.cfg.PoolFIFO {
|
||||
// FIFO: put at end (will be picked up last since we pop from front)
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
} else {
|
||||
// LIFO: put at beginning (will be picked up last since we pop from end)
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
}
|
||||
}
|
||||
p.idleConnsLen--
|
||||
|
||||
// If we exhausted all attempts without finding a usable connection, return nil
|
||||
if attempts >= popAttempts {
|
||||
log.Printf("redis: connection pool: failed to get an usable connection after %d attempts", popAttempts)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
p.checkMinIdleConns()
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||
// Process connection using the hooks system
|
||||
shouldPool := true
|
||||
shouldRemove := false
|
||||
if cn.rd.Buffered() > 0 {
|
||||
// Check if this might be push notification data
|
||||
if p.cfg.Protocol == 3 {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block and check if it's a push notification
|
||||
if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush {
|
||||
shouldRemove = true
|
||||
}
|
||||
} else {
|
||||
// not a push notification since protocol 2 doesn't support them
|
||||
shouldRemove = true
|
||||
}
|
||||
var err error
|
||||
|
||||
if shouldRemove {
|
||||
// For non-RESP3 or data that is not a push notification, buffered data is unexpected
|
||||
internal.Logger.Printf(ctx, "Conn has unread data, closing it")
|
||||
p.Remove(ctx, cn, BadConnError{})
|
||||
if cn.HasBufferedData() {
|
||||
// Peek at the reply type to check if it's a push notification
|
||||
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
|
||||
// Not a push notification or error peeking, remove connection
|
||||
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
|
||||
p.Remove(ctx, cn, err)
|
||||
}
|
||||
// It's a push notification, allow pooling (client will handle it)
|
||||
}
|
||||
|
||||
p.hookManagerMu.RLock()
|
||||
hookManager := p.hookManager
|
||||
p.hookManagerMu.RUnlock()
|
||||
|
||||
if hookManager != nil {
|
||||
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
|
||||
p.Remove(ctx, cn, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If hooks say to remove the connection, do so
|
||||
if shouldRemove {
|
||||
p.Remove(ctx, cn, errors.New("hook requested removal"))
|
||||
return
|
||||
}
|
||||
|
||||
// If processor says not to pool the connection, remove it
|
||||
if !shouldPool {
|
||||
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.pooled {
|
||||
p.Remove(ctx, cn, nil)
|
||||
p.Remove(ctx, cn, errors.New("connection not pooled"))
|
||||
return
|
||||
}
|
||||
|
||||
var shouldCloseConn bool
|
||||
|
||||
p.connsMu.Lock()
|
||||
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.idleConnsLen++
|
||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||
// put them at the opposite end of the queue
|
||||
if !cn.IsUsable() {
|
||||
if p.cfg.PoolFIFO {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.connsMu.Unlock()
|
||||
} else {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
} else {
|
||||
p.connsMu.Lock()
|
||||
p.idleConns = append(p.idleConns, cn)
|
||||
p.connsMu.Unlock()
|
||||
}
|
||||
p.idleConnsLen.Add(1)
|
||||
} else {
|
||||
p.removeConn(cn)
|
||||
p.removeConnWithLock(cn)
|
||||
shouldCloseConn = true
|
||||
}
|
||||
|
||||
p.connsMu.Unlock()
|
||||
|
||||
p.freeTurn()
|
||||
|
||||
if shouldCloseConn {
|
||||
@@ -449,6 +620,9 @@ func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
p.removeConnWithLock(cn)
|
||||
p.freeTurn()
|
||||
_ = p.closeConn(cn)
|
||||
|
||||
// Check if we need to create new idle connections to maintain MinIdleConns
|
||||
p.checkMinIdleConns()
|
||||
}
|
||||
|
||||
func (p *ConnPool) CloseConn(cn *Conn) error {
|
||||
@@ -463,17 +637,13 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
|
||||
}
|
||||
|
||||
func (p *ConnPool) removeConn(cn *Conn) {
|
||||
for i, c := range p.conns {
|
||||
if c == cn {
|
||||
p.conns = append(p.conns[:i], p.conns[i+1:]...)
|
||||
if cn.pooled {
|
||||
p.poolSize--
|
||||
p.checkMinIdleConns()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
delete(p.conns, cn.GetID())
|
||||
atomic.AddUint32(&p.stats.StaleConns, 1)
|
||||
|
||||
// Decrement pool size counter when removing a connection
|
||||
if cn.pooled {
|
||||
p.poolSize.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool) closeConn(cn *Conn) error {
|
||||
@@ -491,9 +661,9 @@ func (p *ConnPool) Len() int {
|
||||
// IdleLen returns number of idle connections.
|
||||
func (p *ConnPool) IdleLen() int {
|
||||
p.connsMu.Lock()
|
||||
n := p.idleConnsLen
|
||||
n := p.idleConnsLen.Load()
|
||||
p.connsMu.Unlock()
|
||||
return n
|
||||
return int(n)
|
||||
}
|
||||
|
||||
func (p *ConnPool) Stats() *Stats {
|
||||
@@ -502,6 +672,7 @@ func (p *ConnPool) Stats() *Stats {
|
||||
Misses: atomic.LoadUint32(&p.stats.Misses),
|
||||
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
|
||||
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
|
||||
Unusable: atomic.LoadUint32(&p.stats.Unusable),
|
||||
WaitDurationNs: p.waitDurationNs.Load(),
|
||||
|
||||
TotalConns: uint32(p.Len()),
|
||||
@@ -542,30 +713,32 @@ func (p *ConnPool) Close() error {
|
||||
}
|
||||
}
|
||||
p.conns = nil
|
||||
p.poolSize = 0
|
||||
p.poolSize.Store(0)
|
||||
p.idleConns = nil
|
||||
p.idleConnsLen = 0
|
||||
p.idleConnsLen.Store(0)
|
||||
p.connsMu.Unlock()
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
||||
now := time.Now()
|
||||
|
||||
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
|
||||
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
|
||||
// slight optimization, check expiresAt first.
|
||||
if cn.expiresAt.Before(now) {
|
||||
return false
|
||||
}
|
||||
|
||||
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check connection health, but be aware of push notifications
|
||||
if err := connCheck(cn.netConn); err != nil {
|
||||
cn.SetUsedAt(now)
|
||||
// Check basic connection health
|
||||
// Use GetNetConn() to safely access netConn and avoid data races
|
||||
if err := connCheck(cn.getNetConn()); err != nil {
|
||||
// If there's unexpected data, it might be push notifications (RESP3)
|
||||
// However, push notification processing is now handled by the client
|
||||
// before WithReader to ensure proper context is available to handlers
|
||||
if err == errUnexpectedRead && p.cfg.Protocol == 3 {
|
||||
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
|
||||
// we know that there is something in the buffer, so peek at the next reply type without
|
||||
// the potential to block
|
||||
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
||||
@@ -579,7 +752,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsedAt(now)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package pool
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type SingleConnPool struct {
|
||||
pool Pooler
|
||||
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
|
||||
func (p *StickyConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
@@ -2,6 +2,7 @@ package pool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -20,7 +21,7 @@ var _ = Describe("ConnPool", func() {
|
||||
BeforeEach(func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Hour,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
@@ -45,11 +46,11 @@ var _ = Describe("ConnPool", func() {
|
||||
<-closedChan
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Hour,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
MinIdleConns: minIdleConns,
|
||||
MinIdleConns: int32(minIdleConns),
|
||||
})
|
||||
wg.Wait()
|
||||
Expect(connPool.Close()).NotTo(HaveOccurred())
|
||||
@@ -105,7 +106,7 @@ var _ = Describe("ConnPool", func() {
|
||||
// ok
|
||||
}
|
||||
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
|
||||
// Check that Get is unblocked.
|
||||
select {
|
||||
@@ -130,8 +131,8 @@ var _ = Describe("MinIdleConns", func() {
|
||||
newConnPool := func() *pool.ConnPool {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: minIdleConns,
|
||||
PoolSize: int32(poolSize),
|
||||
MinIdleConns: int32(minIdleConns),
|
||||
PoolTimeout: 100 * time.Millisecond,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: -1,
|
||||
@@ -168,7 +169,7 @@ var _ = Describe("MinIdleConns", func() {
|
||||
|
||||
Context("after Remove", func() {
|
||||
BeforeEach(func() {
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
})
|
||||
|
||||
It("has idle connections", func() {
|
||||
@@ -245,7 +246,7 @@ var _ = Describe("MinIdleConns", func() {
|
||||
BeforeEach(func() {
|
||||
perform(len(cns), func(i int) {
|
||||
mu.RLock()
|
||||
connPool.Remove(ctx, cns[i], nil)
|
||||
connPool.Remove(ctx, cns[i], errors.New("test"))
|
||||
mu.RUnlock()
|
||||
})
|
||||
|
||||
@@ -309,7 +310,7 @@ var _ = Describe("race", func() {
|
||||
It("does not happen on Get, Put, and Remove", func() {
|
||||
connPool = pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: 10,
|
||||
PoolSize: int32(10),
|
||||
PoolTimeout: time.Minute,
|
||||
DialTimeout: 1 * time.Second,
|
||||
ConnMaxIdleTime: time.Millisecond,
|
||||
@@ -328,7 +329,7 @@ var _ = Describe("race", func() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
connPool.Remove(ctx, cn, nil)
|
||||
connPool.Remove(ctx, cn, errors.New("test"))
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -339,15 +340,15 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1000,
|
||||
MinIdleConns: 50,
|
||||
PoolSize: int32(1000),
|
||||
MinIdleConns: int32(50),
|
||||
PoolTimeout: 3 * time.Second,
|
||||
DialTimeout: 1 * time.Second,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < opt.PoolSize; i++ {
|
||||
for i := int32(0); i < opt.PoolSize; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
@@ -366,8 +367,8 @@ var _ = Describe("race", func() {
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
panic("test panic")
|
||||
},
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 30,
|
||||
PoolSize: int32(100),
|
||||
MinIdleConns: int32(30),
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
||||
@@ -377,14 +378,14 @@ var _ = Describe("race", func() {
|
||||
state := p.Stats()
|
||||
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
|
||||
}, "3s", "50ms").Should(BeTrue())
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
It("wait", func() {
|
||||
opt := &pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: 3 * time.Second,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
@@ -415,7 +416,7 @@ var _ = Describe("race", func() {
|
||||
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
PoolSize: 1,
|
||||
PoolSize: int32(1),
|
||||
PoolTimeout: testPoolTimeout,
|
||||
}
|
||||
p := pool.NewConnPool(opt)
|
||||
|
77
internal/pool/pubsub.go
Normal file
77
internal/pool/pubsub.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type PubSubStats struct {
|
||||
Created uint32
|
||||
Untracked uint32
|
||||
Active uint32
|
||||
}
|
||||
|
||||
// PubSubPool manages a pool of PubSub connections.
|
||||
type PubSubPool struct {
|
||||
opt *Options
|
||||
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Map to track active PubSub connections
|
||||
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
|
||||
closed atomic.Bool
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
opt: opt,
|
||||
netDialer: netDialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
netConn, err := p.netDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
|
||||
atomic.AddUint32(&p.stats.Created, 1)
|
||||
return cn, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *PubSubPool) TrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, 1)
|
||||
p.activeConns.Store(cn.GetID(), cn)
|
||||
}
|
||||
|
||||
func (p *PubSubPool) UntrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, ^uint32(0))
|
||||
atomic.AddUint32(&p.stats.Untracked, 1)
|
||||
p.activeConns.Delete(cn.GetID())
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Close() error {
|
||||
p.closed.Store(true)
|
||||
p.activeConns.Range(func(key, value interface{}) bool {
|
||||
cn := value.(*Conn)
|
||||
_ = cn.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Stats() *PubSubStats {
|
||||
// load stats atomically
|
||||
return &PubSubStats{
|
||||
Created: atomic.LoadUint32(&p.stats.Created),
|
||||
Untracked: atomic.LoadUint32(&p.stats.Untracked),
|
||||
Active: atomic.LoadUint32(&p.stats.Active),
|
||||
}
|
||||
}
|
3
internal/redis.go
Normal file
3
internal/redis.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package internal
|
||||
|
||||
const RedisNull = "null"
|
17
internal/util/math.go
Normal file
17
internal/util/math.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
// Max returns the maximum of two integers
|
||||
func Max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Min returns the minimum of two integers
|
||||
func Min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
81
options.go
81
options.go
@@ -14,9 +14,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Limiter is the interface of a rate limiter or a circuit breaker.
|
||||
@@ -153,6 +154,7 @@ type Options struct {
|
||||
//
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
// default: false
|
||||
PoolFIFO bool
|
||||
|
||||
// PoolSize is the base number of socket connections.
|
||||
@@ -244,8 +246,19 @@ type Options struct {
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
|
||||
// This is an alias to hitless.Config for convenience.
|
||||
type HitlessUpgradeConfig = hitless.Config
|
||||
|
||||
func (opt *Options) init() {
|
||||
if opt.Addr == "" {
|
||||
opt.Addr = "localhost:6379"
|
||||
@@ -320,13 +333,36 @@ func (opt *Options) init() {
|
||||
case 0:
|
||||
opt.MaxRetryBackoff = 512 * time.Millisecond
|
||||
}
|
||||
|
||||
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize)
|
||||
|
||||
// auto-detect endpoint type if not specified
|
||||
endpointType := opt.HitlessUpgradeConfig.EndpointType
|
||||
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
|
||||
// Auto-detect endpoint type if not specified
|
||||
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
|
||||
}
|
||||
opt.HitlessUpgradeConfig.EndpointType = endpointType
|
||||
}
|
||||
|
||||
func (opt *Options) clone() *Options {
|
||||
clone := *opt
|
||||
|
||||
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
|
||||
if opt.HitlessUpgradeConfig != nil {
|
||||
configClone := *opt.HitlessUpgradeConfig
|
||||
clone.HitlessUpgradeConfig = &configClone
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
|
||||
return NewDialer(opt)
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
||||
@@ -617,18 +653,35 @@ func newConnPool(
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer(ctx, opt.Network, opt.Addr)
|
||||
},
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: opt.PoolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: opt.MinIdleConns,
|
||||
MaxIdleConns: opt.MaxIdleConns,
|
||||
MaxActiveConns: opt.MaxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
// Pass protocol version for push notification optimization
|
||||
Protocol: opt.Protocol,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: int32(opt.PoolSize),
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: int32(opt.MinIdleConns),
|
||||
MaxIdleConns: int32(opt.MaxIdleConns),
|
||||
MaxActiveConns: int32(opt.MaxActiveConns),
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
})
|
||||
}
|
||||
|
||||
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) *pool.PubSubPool {
|
||||
return pool.NewPubSubPool(&pool.Options{
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: int32(opt.PoolSize),
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: int32(opt.MinIdleConns),
|
||||
MaxIdleConns: int32(opt.MaxIdleConns),
|
||||
MaxActiveConns: int32(opt.MaxActiveConns),
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}, dialer)
|
||||
}
|
||||
|
@@ -38,6 +38,7 @@ type ClusterOptions struct {
|
||||
ClientName string
|
||||
|
||||
// NewClient creates a cluster node client with provided name and options.
|
||||
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
|
||||
NewClient func(opt *Options) *Client
|
||||
|
||||
// The maximum number of retries before giving up. Command is retried
|
||||
@@ -129,6 +130,14 @@ type ClusterOptions struct {
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) init() {
|
||||
@@ -319,6 +328,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
|
||||
var hitlessConfig *HitlessUpgradeConfig
|
||||
if opt.HitlessUpgradeConfig != nil {
|
||||
configClone := *opt.HitlessUpgradeConfig
|
||||
hitlessConfig = &configClone
|
||||
}
|
||||
|
||||
return &Options{
|
||||
ClientName: opt.ClientName,
|
||||
Dialer: opt.Dialer,
|
||||
@@ -360,8 +376,9 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// much use for ClusterSlots config). This means we cannot execute the
|
||||
// READONLY command against that node -- setting readOnly to false in such
|
||||
// situations in the options below will prevent that from happening.
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
HitlessUpgradeConfig: hitlessConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1830,12 +1847,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
|
||||
return err
|
||||
}
|
||||
|
||||
// hitless won't work here for now
|
||||
func (c *ClusterClient) pubSub() *PubSub {
|
||||
var node *clusterNode
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt.clientOptions(),
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
if node != nil {
|
||||
panic("node != nil")
|
||||
}
|
||||
@@ -1850,18 +1867,25 @@ func (c *ClusterClient) pubSub() *PubSub {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn, err := node.Client.newConn(context.TODO())
|
||||
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
|
||||
if err != nil {
|
||||
node = nil
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// will return nil if already initialized
|
||||
err = node.Client.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
node = nil
|
||||
return nil, err
|
||||
}
|
||||
node.Client.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
err := node.Client.connPool.CloseConn(cn)
|
||||
// Untrack connection from PubSubPool
|
||||
node.Client.pubSubPool.UntrackConn(cn)
|
||||
err := cn.Close()
|
||||
node = nil
|
||||
return err
|
||||
},
|
||||
|
375
pool_pubsub_bench_test.go
Normal file
375
pool_pubsub_bench_test.go
Normal file
@@ -0,0 +1,375 @@
|
||||
// Pool and PubSub Benchmark Suite
|
||||
//
|
||||
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
|
||||
// It's designed to be run against different branches to compare performance.
|
||||
//
|
||||
// Usage Examples:
|
||||
// # Run all benchmarks
|
||||
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run only pool benchmarks
|
||||
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run only PubSub benchmarks
|
||||
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Compare between branches
|
||||
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
|
||||
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
|
||||
// benchcmp branch1.txt branch2.txt
|
||||
//
|
||||
// # Run with memory profiling
|
||||
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
|
||||
//
|
||||
// # Run with CPU profiling
|
||||
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// dummyDialer creates a mock connection for benchmarking
|
||||
func dummyDialer(ctx context.Context) (net.Conn, error) {
|
||||
return &dummyConn{}, nil
|
||||
}
|
||||
|
||||
// dummyConn implements net.Conn for benchmarking
|
||||
type dummyConn struct{}
|
||||
|
||||
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Close() error { return nil }
|
||||
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
|
||||
func (c *dummyConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
|
||||
}
|
||||
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// =============================================================================
|
||||
// POOL BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
|
||||
func BenchmarkPoolGetPut(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(poolSize),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
MinIdleConns: int32(0), // Start with no idle connections
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
|
||||
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
configs := []struct {
|
||||
poolSize int
|
||||
minIdleConns int
|
||||
}{
|
||||
{8, 2},
|
||||
{16, 4},
|
||||
{32, 8},
|
||||
{64, 16},
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(config.poolSize),
|
||||
MinIdleConns: int32(config.minIdleConns),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
|
||||
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
connPool := pool.NewConnPool(&pool.Options{
|
||||
Dialer: dummyDialer,
|
||||
PoolSize: int32(32),
|
||||
PoolTimeout: time.Second,
|
||||
DialTimeout: time.Second,
|
||||
ConnMaxIdleTime: time.Hour,
|
||||
MinIdleConns: int32(0),
|
||||
})
|
||||
defer connPool.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Test with different levels of concurrency
|
||||
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||
b.SetParallelism(concurrency)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
connPool.Put(ctx, cn)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// PUBSUB BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// benchmarkClient creates a Redis client for benchmarking with mock dialer
|
||||
func benchmarkClient(poolSize int) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379", // Mock address
|
||||
DialTimeout: time.Second,
|
||||
ReadTimeout: time.Second,
|
||||
WriteTimeout: time.Second,
|
||||
PoolSize: poolSize,
|
||||
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
|
||||
func BenchmarkPubSubCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 4, 8, 16, 32}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
client := benchmarkClient(poolSize)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
|
||||
func BenchmarkPubSubPatternCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
poolSizes := []int{1, 4, 8, 16, 32}
|
||||
|
||||
for _, poolSize := range poolSizes {
|
||||
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||
client := benchmarkClient(poolSize)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.PSubscribe(ctx, "test-*")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
|
||||
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
concurrencyLevels := []int{1, 2, 4, 8, 16}
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, concurrency)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
|
||||
func BenchmarkPubSubMultipleChannels(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
channelCounts := []int{1, 5, 10, 25, 50, 100}
|
||||
|
||||
for _, channelCount := range channelCounts {
|
||||
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
|
||||
// Prepare channel names
|
||||
channels := make([]string, channelCount)
|
||||
for i := 0; i < channelCount; i++ {
|
||||
channels[i] = fmt.Sprintf("channel-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
pubsub := client.Subscribe(ctx, channels...)
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
|
||||
func BenchmarkPubSubReuse(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Benchmark just the creation and closing of PubSub connections
|
||||
// This simulates reuse patterns without requiring actual Redis operations
|
||||
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
|
||||
pubsub.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// COMBINED BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
|
||||
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// Mix of pool stats collection and PubSub creation
|
||||
if pb.Next() {
|
||||
// Pool stats operation
|
||||
stats := client.PoolStats()
|
||||
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
|
||||
}
|
||||
|
||||
if pb.Next() {
|
||||
// PubSub operation
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
|
||||
func BenchmarkPoolStatsCollection(b *testing.B) {
|
||||
client := benchmarkClient(16)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stats := client.PoolStats()
|
||||
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPoolHighContention tests pool performance under high contention
|
||||
func BenchmarkPoolHighContention(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
client := benchmarkClient(32)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// High contention Get/Put operations
|
||||
pubsub := client.Subscribe(ctx, "test-channel")
|
||||
pubsub.Close()
|
||||
}
|
||||
})
|
||||
}
|
40
pubsub.go
40
pubsub.go
@@ -22,7 +22,7 @@ import (
|
||||
type PubSub struct {
|
||||
opt *Options
|
||||
|
||||
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
|
||||
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
|
||||
closeConn func(*pool.Conn) error
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -42,6 +42,9 @@ type PubSub struct {
|
||||
|
||||
// Push notification processor for handling generic push notifications
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// Cleanup callback for hitless upgrade tracking
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func (c *PubSub) init() {
|
||||
@@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
|
||||
return c.cn, nil
|
||||
}
|
||||
|
||||
if c.opt.Addr == "" {
|
||||
// TODO(hitless):
|
||||
// this is probably cluster client
|
||||
// c.newConn will ignore the addr argument
|
||||
// will be changed when we have hitless upgrades for cluster clients
|
||||
c.opt.Addr = internal.RedisNull
|
||||
}
|
||||
|
||||
channels := mapKeys(c.channels)
|
||||
channels = append(channels, newChannels...)
|
||||
|
||||
cn, err := c.newConn(ctx, channels)
|
||||
cn, err := c.newConn(ctx, c.opt.Addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -157,12 +168,28 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
|
||||
if c.cn != cn {
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.IsUsable() || cn.ShouldHandoff() {
|
||||
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
|
||||
}
|
||||
|
||||
if isBadConn(err, allowTimeout, c.opt.Addr) {
|
||||
c.reconnect(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PubSub) reconnect(ctx context.Context, reason error) {
|
||||
if c.cn != nil && c.cn.ShouldHandoff() {
|
||||
newEndpoint := c.cn.GetHandoffEndpoint()
|
||||
// If new endpoint is NULL, use the original address
|
||||
if newEndpoint == internal.RedisNull {
|
||||
newEndpoint = c.opt.Addr
|
||||
}
|
||||
|
||||
if newEndpoint != "" {
|
||||
c.opt.Addr = newEndpoint
|
||||
}
|
||||
}
|
||||
_ = c.closeTheCn(reason)
|
||||
_, _ = c.conn(ctx, nil)
|
||||
}
|
||||
@@ -189,6 +216,11 @@ func (c *PubSub) Close() error {
|
||||
c.closed = true
|
||||
close(c.exit)
|
||||
|
||||
// Call cleanup callback if set
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
|
||||
return c.closeTheCn(pool.ErrClosed)
|
||||
}
|
||||
|
||||
@@ -461,6 +493,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
||||
// Receive returns a message as a Subscription, Message, Pong or error.
|
||||
// See PubSub example for details. This is low-level API and in most cases
|
||||
// Channel should be used instead.
|
||||
// This will block until a message is received.
|
||||
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
|
||||
return c.ReceiveTimeout(ctx, 0)
|
||||
}
|
||||
@@ -543,7 +576,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
|
||||
}
|
||||
|
||||
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
|
||||
if c.pushProcessor == nil {
|
||||
// Only process push notifications for RESP3 connections with a processor
|
||||
if c.opt.Protocol != 3 || c.pushProcessor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -1,8 +1,6 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
// No imports needed for this file
|
||||
|
||||
// NotificationHandlerContext provides context information about where a push notification was received.
|
||||
// This struct allows handlers to make informed decisions based on the source of the notification
|
||||
@@ -35,7 +33,12 @@ type NotificationHandlerContext struct {
|
||||
PubSub interface{}
|
||||
|
||||
// Conn is the specific connection on which the notification was received.
|
||||
Conn *pool.Conn
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *pool.Conn
|
||||
// - *connectionAdapter (for hitless upgrades)
|
||||
Conn interface{}
|
||||
|
||||
// IsBlocking indicates if the notification was received on a blocking connection.
|
||||
IsBlocking bool
|
||||
|
315
push/processor_unit_test.go
Normal file
315
push/processor_unit_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProcessorCreation tests processor creation and initialization
|
||||
func TestProcessorCreation(t *testing.T) {
|
||||
t.Run("NewProcessor", func(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
if processor == nil {
|
||||
t.Fatal("NewProcessor should not return nil")
|
||||
}
|
||||
if processor.registry == nil {
|
||||
t.Error("Processor should have a registry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewVoidProcessor", func(t *testing.T) {
|
||||
voidProcessor := NewVoidProcessor()
|
||||
if voidProcessor == nil {
|
||||
t.Fatal("NewVoidProcessor should not return nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessorHandlerManagement tests handler registration and retrieval
|
||||
func TestProcessorHandlerManagement(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
handler := &UnitTestHandler{name: "test-handler"}
|
||||
|
||||
t.Run("RegisterHandler", func(t *testing.T) {
|
||||
err := processor.RegisterHandler("TEST", handler, false)
|
||||
if err != nil {
|
||||
t.Errorf("RegisterHandler should not error: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler is registered
|
||||
retrievedHandler := processor.GetHandler("TEST")
|
||||
if retrievedHandler != handler {
|
||||
t.Error("GetHandler should return the registered handler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterProtectedHandler", func(t *testing.T) {
|
||||
protectedHandler := &UnitTestHandler{name: "protected-handler"}
|
||||
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
|
||||
if err != nil {
|
||||
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler is registered
|
||||
retrievedHandler := processor.GetHandler("PROTECTED")
|
||||
if retrievedHandler != protectedHandler {
|
||||
t.Error("GetHandler should return the protected handler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetNonExistentHandler", func(t *testing.T) {
|
||||
handler := processor.GetHandler("NONEXISTENT")
|
||||
if handler != nil {
|
||||
t.Error("GetHandler should return nil for non-existent handler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterHandler", func(t *testing.T) {
|
||||
err := processor.UnregisterHandler("TEST")
|
||||
if err != nil {
|
||||
t.Errorf("UnregisterHandler should not error: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler is removed
|
||||
retrievedHandler := processor.GetHandler("TEST")
|
||||
if retrievedHandler != nil {
|
||||
t.Error("GetHandler should return nil after unregistering")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
|
||||
err := processor.UnregisterHandler("PROTECTED")
|
||||
if err == nil {
|
||||
t.Error("UnregisterHandler should error for protected handler")
|
||||
}
|
||||
|
||||
// Verify handler is still there
|
||||
retrievedHandler := processor.GetHandler("PROTECTED")
|
||||
if retrievedHandler == nil {
|
||||
t.Error("Protected handler should not be removed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestVoidProcessorBehavior tests void processor behavior
|
||||
func TestVoidProcessorBehavior(t *testing.T) {
|
||||
voidProcessor := NewVoidProcessor()
|
||||
handler := &UnitTestHandler{name: "test-handler"}
|
||||
|
||||
t.Run("GetHandler", func(t *testing.T) {
|
||||
retrievedHandler := voidProcessor.GetHandler("ANY")
|
||||
if retrievedHandler != nil {
|
||||
t.Error("VoidProcessor GetHandler should always return nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterHandler", func(t *testing.T) {
|
||||
err := voidProcessor.RegisterHandler("TEST", handler, false)
|
||||
if err == nil {
|
||||
t.Error("VoidProcessor RegisterHandler should return error")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if !IsVoidProcessorError(err) {
|
||||
t.Error("Error should be a VoidProcessorError")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterHandler", func(t *testing.T) {
|
||||
err := voidProcessor.UnregisterHandler("TEST")
|
||||
if err == nil {
|
||||
t.Error("VoidProcessor UnregisterHandler should return error")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if !IsVoidProcessorError(err) {
|
||||
t.Error("Error should be a VoidProcessorError")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessPendingNotificationsNilReader tests handling of nil reader
|
||||
func TestProcessPendingNotificationsNilReader(t *testing.T) {
|
||||
t.Run("ProcessorWithNilReader", func(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
ctx := context.Background()
|
||||
handlerCtx := NotificationHandlerContext{}
|
||||
|
||||
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
|
||||
if err != nil {
|
||||
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
|
||||
voidProcessor := NewVoidProcessor()
|
||||
ctx := context.Background()
|
||||
handlerCtx := NotificationHandlerContext{}
|
||||
|
||||
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
|
||||
if err != nil {
|
||||
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWillHandleNotificationInClient tests the notification filtering logic
|
||||
func TestWillHandleNotificationInClient(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
notificationType string
|
||||
shouldHandle bool
|
||||
}{
|
||||
// Pub/Sub notifications (should be handled in client)
|
||||
{"message", "message", true},
|
||||
{"pmessage", "pmessage", true},
|
||||
{"subscribe", "subscribe", true},
|
||||
{"unsubscribe", "unsubscribe", true},
|
||||
{"psubscribe", "psubscribe", true},
|
||||
{"punsubscribe", "punsubscribe", true},
|
||||
{"smessage", "smessage", true},
|
||||
{"ssubscribe", "ssubscribe", true},
|
||||
{"sunsubscribe", "sunsubscribe", true},
|
||||
|
||||
// Push notifications (should be handled by processor)
|
||||
{"MOVING", "MOVING", false},
|
||||
{"MIGRATING", "MIGRATING", false},
|
||||
{"MIGRATED", "MIGRATED", false},
|
||||
{"FAILING_OVER", "FAILING_OVER", false},
|
||||
{"FAILED_OVER", "FAILED_OVER", false},
|
||||
{"custom", "custom", false},
|
||||
{"unknown", "unknown", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := willHandleNotificationInClient(tc.notificationType)
|
||||
if result != tc.shouldHandle {
|
||||
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessorErrorHandlingUnit tests error handling scenarios
|
||||
func TestProcessorErrorHandlingUnit(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
|
||||
t.Run("RegisterNilHandler", func(t *testing.T) {
|
||||
err := processor.RegisterHandler("TEST", nil, false)
|
||||
if err == nil {
|
||||
t.Error("RegisterHandler should error with nil handler")
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if !IsHandlerNilError(err) {
|
||||
t.Error("Error should be a HandlerNilError")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
|
||||
handler1 := &UnitTestHandler{name: "handler1"}
|
||||
handler2 := &UnitTestHandler{name: "handler2"}
|
||||
|
||||
// Register first handler
|
||||
err := processor.RegisterHandler("DUPLICATE", handler1, false)
|
||||
if err != nil {
|
||||
t.Errorf("First RegisterHandler should not error: %v", err)
|
||||
}
|
||||
|
||||
// Try to register second handler with same name
|
||||
err = processor.RegisterHandler("DUPLICATE", handler2, false)
|
||||
if err == nil {
|
||||
t.Error("RegisterHandler should error when registering duplicate handler")
|
||||
}
|
||||
|
||||
// Verify original handler is still there
|
||||
retrievedHandler := processor.GetHandler("DUPLICATE")
|
||||
if retrievedHandler != handler1 {
|
||||
t.Error("Original handler should remain after failed duplicate registration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
|
||||
err := processor.UnregisterHandler("NONEXISTENT")
|
||||
if err != nil {
|
||||
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessorConcurrentAccess tests concurrent access to processor
|
||||
func TestProcessorConcurrentAccess(t *testing.T) {
|
||||
processor := NewProcessor()
|
||||
|
||||
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Goroutine 1: Register handlers
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for i := 0; i < 100; i++ {
|
||||
handler := &UnitTestHandler{name: "concurrent-handler"}
|
||||
processor.RegisterHandler("CONCURRENT", handler, false)
|
||||
processor.UnregisterHandler("CONCURRENT")
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine 2: Get handlers
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for i := 0; i < 100; i++ {
|
||||
processor.GetHandler("CONCURRENT")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for both goroutines to complete
|
||||
<-done
|
||||
<-done
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessorInterfaceCompliance tests interface compliance
|
||||
func TestProcessorInterfaceCompliance(t *testing.T) {
|
||||
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
|
||||
var _ NotificationProcessor = (*Processor)(nil)
|
||||
})
|
||||
|
||||
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
|
||||
var _ NotificationProcessor = (*VoidProcessor)(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// UnitTestHandler is a test implementation of NotificationHandler
|
||||
type UnitTestHandler struct {
|
||||
name string
|
||||
lastNotification []interface{}
|
||||
errorToReturn error
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
|
||||
h.callCount++
|
||||
h.lastNotification = notification
|
||||
return h.errorToReturn
|
||||
}
|
||||
|
||||
// Helper methods for UnitTestHandler
|
||||
func (h *UnitTestHandler) GetCallCount() int {
|
||||
return h.callCount
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) GetLastNotification() []interface{} {
|
||||
return h.lastNotification
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) SetErrorToReturn(err error) {
|
||||
h.errorToReturn = err
|
||||
}
|
||||
|
||||
func (h *UnitTestHandler) Reset() {
|
||||
h.callCount = 0
|
||||
h.lastNotification = nil
|
||||
h.errorToReturn = nil
|
||||
}
|
@@ -4,24 +4,6 @@ import (
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Push notification constants for cluster operations
|
||||
const (
|
||||
// MOVING indicates a slot is being moved to a different node
|
||||
PushNotificationMoving = "MOVING"
|
||||
|
||||
// MIGRATING indicates a slot is being migrated from this node
|
||||
PushNotificationMigrating = "MIGRATING"
|
||||
|
||||
// MIGRATED indicates a slot has been migrated to this node
|
||||
PushNotificationMigrated = "MIGRATED"
|
||||
|
||||
// FAILING_OVER indicates a failover is starting
|
||||
PushNotificationFailingOver = "FAILING_OVER"
|
||||
|
||||
// FAILED_OVER indicates a failover has completed
|
||||
PushNotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// NewPushNotificationProcessor creates a new push notification processor
|
||||
// This processor maintains a registry of handlers and processes push notifications
|
||||
// It is used for RESP3 connections where push notifications are available
|
||||
|
211
redis.go
211
redis.go
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/hitless"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/hscan"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
@@ -204,19 +205,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type baseClient struct {
|
||||
opt *Options
|
||||
connPool pool.Pooler
|
||||
opt *Options
|
||||
optLock sync.RWMutex
|
||||
connPool pool.Pooler
|
||||
pubSubPool *pool.PubSubPool
|
||||
hooksMixin
|
||||
|
||||
onClose func() error // hook called when client is closed
|
||||
|
||||
// Push notification processing
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// Hitless upgrade manager
|
||||
hitlessManager *hitless.HitlessManager
|
||||
hitlessManagerLock sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *baseClient) clone() *baseClient {
|
||||
clone := *c
|
||||
return &clone
|
||||
c.hitlessManagerLock.RLock()
|
||||
hitlessManager := c.hitlessManager
|
||||
c.hitlessManagerLock.RUnlock()
|
||||
|
||||
clone := &baseClient{
|
||||
opt: c.opt,
|
||||
connPool: c.connPool,
|
||||
onClose: c.onClose,
|
||||
pushProcessor: c.pushProcessor,
|
||||
hitlessManager: hitlessManager,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
|
||||
@@ -234,21 +251,6 @@ func (c *baseClient) String() string {
|
||||
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
|
||||
}
|
||||
|
||||
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
|
||||
cn, err := c.connPool.NewConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = c.connPool.CloseConn(cn)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
if c.opt.Limiter != nil {
|
||||
err := c.opt.Limiter.Allow()
|
||||
@@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cn.Inited {
|
||||
if cn.IsInited() {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
@@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if cn.Inited {
|
||||
if !cn.Inited.CompareAndSwap(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
cn.Inited = true
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
@@ -430,6 +430,50 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
}
|
||||
|
||||
// Enable maintenance notifications if hitless upgrades are configured
|
||||
c.optLock.RLock()
|
||||
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
|
||||
protocol := c.opt.Protocol
|
||||
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
|
||||
c.optLock.RUnlock()
|
||||
var hitlessHandshakeErr error
|
||||
if hitlessEnabled && protocol == 3 {
|
||||
hitlessHandshakeErr = conn.ClientMaintNotifications(
|
||||
ctx,
|
||||
true,
|
||||
endpointType.String(),
|
||||
).Err()
|
||||
if hitlessHandshakeErr != nil {
|
||||
if !isRedisError(hitlessHandshakeErr) {
|
||||
// if not redis error, fail the connection
|
||||
return hitlessHandshakeErr
|
||||
}
|
||||
c.optLock.Lock()
|
||||
// handshake failed - check and modify config atomically
|
||||
switch c.opt.HitlessUpgradeConfig.Mode {
|
||||
case hitless.MaintNotificationsEnabled:
|
||||
// enabled mode, fail the connection
|
||||
c.optLock.Unlock()
|
||||
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
|
||||
default: // will handle auto and any other
|
||||
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
|
||||
c.optLock.Unlock()
|
||||
// auto mode, disable hitless upgrades and continue
|
||||
if err := c.disableHitlessUpgrades(); err != nil {
|
||||
// Log error but continue - auto mode should be resilient
|
||||
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// handshake was executed successfully
|
||||
// to make sure that the handshake will be executed on other connections as well if it was successfully
|
||||
// executed on this connection, we will force the handshake to be executed on all connections
|
||||
c.optLock.Lock()
|
||||
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
|
||||
c.optLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||
libName := ""
|
||||
libVer := Version()
|
||||
@@ -446,6 +490,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
}
|
||||
}
|
||||
|
||||
cn.SetUsable(true)
|
||||
cn.Inited.Store(true)
|
||||
|
||||
// Set the connection initialization function for potential reconnections
|
||||
cn.SetInitConnFunc(c.createInitConnFunc())
|
||||
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
}
|
||||
@@ -593,19 +643,76 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
|
||||
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
|
||||
return func(ctx context.Context, cn *pool.Conn) error {
|
||||
return c.initConn(ctx, cn)
|
||||
}
|
||||
}
|
||||
|
||||
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
|
||||
// This function is called during client initialization.
|
||||
// will register push notification handlers for all hitless upgrade events.
|
||||
// will start background workers for handoff processing in the pool hook.
|
||||
func (c *baseClient) enableHitlessUpgrades() error {
|
||||
// Create client adapter
|
||||
clientAdapterInstance := newClientAdapter(c)
|
||||
|
||||
// Create hitless manager directly
|
||||
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Set the manager reference and initialize pool hook
|
||||
c.hitlessManagerLock.Lock()
|
||||
c.hitlessManager = manager
|
||||
c.hitlessManagerLock.Unlock()
|
||||
|
||||
// Initialize pool hook (safe to call without lock since manager is now set)
|
||||
manager.InitPoolHook(c.dialHook)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseClient) disableHitlessUpgrades() error {
|
||||
c.hitlessManagerLock.Lock()
|
||||
defer c.hitlessManagerLock.Unlock()
|
||||
|
||||
// Close the hitless manager
|
||||
if c.hitlessManager != nil {
|
||||
// Closing the manager will also shutdown the pool hook
|
||||
// and remove it from the pool
|
||||
c.hitlessManager.Close()
|
||||
c.hitlessManager = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the client, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Client, as the Client is meant to be
|
||||
// long-lived and shared between many goroutines.
|
||||
func (c *baseClient) Close() error {
|
||||
var firstErr error
|
||||
|
||||
// Close hitless manager first
|
||||
if err := c.disableHitlessUpgrades(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
|
||||
if c.onClose != nil {
|
||||
if err := c.onClose(); err != nil {
|
||||
if err := c.onClose(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
if c.connPool != nil {
|
||||
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if c.pubSubPool != nil {
|
||||
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
@@ -810,11 +917,24 @@ func NewClient(opt *Options) *Client {
|
||||
// Initialize push notification processor using shared helper
|
||||
// Use void processor for RESP2 connections (push notifications not available)
|
||||
c.pushProcessor = initializePushProcessor(opt)
|
||||
|
||||
// Update options with the initialized push processor for connection pool
|
||||
// Update options with the initialized push processor
|
||||
opt.PushNotificationProcessor = c.pushProcessor
|
||||
|
||||
// Create connection pools
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
c.pubSubPool = newPubSubPool(opt, c.dialHook)
|
||||
|
||||
// Initialize hitless upgrades first if enabled and protocol is RESP3
|
||||
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
|
||||
err := c.enableHitlessUpgrades()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
|
||||
// panic so we fail fast without breaking existing clients api
|
||||
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
@@ -851,6 +971,14 @@ func (c *Client) Options() *Options {
|
||||
return c.opt
|
||||
}
|
||||
|
||||
// GetHitlessManager returns the hitless manager instance for monitoring and control.
|
||||
// Returns nil if hitless upgrades are not enabled.
|
||||
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
|
||||
c.hitlessManagerLock.RLock()
|
||||
defer c.hitlessManagerLock.RUnlock()
|
||||
return c.hitlessManager
|
||||
}
|
||||
|
||||
// initializePushProcessor initializes the push notification processor for any client type.
|
||||
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
|
||||
func initializePushProcessor(opt *Options) push.NotificationProcessor {
|
||||
@@ -887,6 +1015,7 @@ type PoolStats pool.Stats
|
||||
// PoolStats returns connection pool stats.
|
||||
func (c *Client) PoolStats() *PoolStats {
|
||||
stats := c.connPool.Stats()
|
||||
stats.PubSubStats = *(c.pubSubPool.Stats())
|
||||
return (*PoolStats)(stats)
|
||||
}
|
||||
|
||||
@@ -921,11 +1050,27 @@ func (c *Client) TxPipeline() Pipeliner {
|
||||
func (c *Client) pubSub() *PubSub {
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt,
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
return c.newConn(ctx)
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// will return nil if already initialized
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Track connection in PubSubPool
|
||||
c.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
// Untrack connection from PubSubPool
|
||||
c.pubSubPool.UntrackConn(cn)
|
||||
_ = cn.Close()
|
||||
return nil
|
||||
},
|
||||
closeConn: c.connPool.CloseConn,
|
||||
pushProcessor: c.pushProcessor,
|
||||
}
|
||||
pubsub.init()
|
||||
@@ -1113,6 +1258,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica
|
||||
return push.NotificationHandlerContext{
|
||||
Client: c,
|
||||
ConnPool: c.connPool,
|
||||
Conn: cn,
|
||||
Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access
|
||||
}
|
||||
}
|
||||
|
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
52
sentinel.go
52
sentinel.go
@@ -16,8 +16,8 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -139,6 +139,14 @@ type FailoverOptions struct {
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
UnstableResp3 bool
|
||||
|
||||
// Hitless is not supported for FailoverClients at the moment
|
||||
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||
// upgrade notifications gracefully and manage connection/pool state transitions
|
||||
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, hitless upgrades are disabled.
|
||||
//HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
func (opt *FailoverOptions) clientOptions() *Options {
|
||||
@@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
opt.Dialer = masterReplicaDialer(failover)
|
||||
opt.init()
|
||||
|
||||
var connPool *pool.ConnPool
|
||||
|
||||
rdb := &Client{
|
||||
baseClient: &baseClient{
|
||||
opt: opt,
|
||||
@@ -469,15 +475,18 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
// Use void processor by default for RESP2 connections
|
||||
rdb.pushProcessor = initializePushProcessor(opt)
|
||||
|
||||
connPool = newConnPool(opt, rdb.dialHook)
|
||||
rdb.connPool = connPool
|
||||
rdb.connPool = newConnPool(opt, rdb.dialHook)
|
||||
rdb.pubSubPool = newPubSubPool(opt, rdb.dialHook)
|
||||
|
||||
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||
|
||||
failover.mu.Lock()
|
||||
failover.onFailover = func(ctx context.Context, addr string) {
|
||||
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
||||
return cn.RemoteAddr().String() != addr
|
||||
})
|
||||
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
|
||||
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
||||
return cn.RemoteAddr().String() != addr
|
||||
})
|
||||
}
|
||||
}
|
||||
failover.mu.Unlock()
|
||||
|
||||
@@ -544,6 +553,7 @@ func NewSentinelClient(opt *Options) *SentinelClient {
|
||||
process: c.baseClient.process,
|
||||
})
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
c.pubSubPool = newPubSubPool(opt, c.dialHook)
|
||||
|
||||
return c
|
||||
}
|
||||
@@ -570,13 +580,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
|
||||
func (c *SentinelClient) pubSub() *PubSub {
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt,
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
return c.newConn(ctx)
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// will return nil if already initialized
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Track connection in PubSubPool
|
||||
c.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: c.connPool.CloseConn,
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
// Untrack connection from PubSubPool
|
||||
c.pubSubPool.UntrackConn(cn)
|
||||
_ = cn.Close()
|
||||
return nil
|
||||
},
|
||||
pushProcessor: c.pushProcessor,
|
||||
}
|
||||
pubsub.init()
|
||||
|
||||
return pubsub
|
||||
}
|
||||
|
||||
|
2
tx.go
2
tx.go
@@ -24,7 +24,7 @@ type Tx struct {
|
||||
func (c *Client) newTx() *Tx {
|
||||
tx := Tx{
|
||||
baseClient: baseClient{
|
||||
opt: c.opt,
|
||||
opt: c.opt.clone(), // Clone options to avoid sharing HitlessUpgradeConfig
|
||||
connPool: pool.NewStickyConnPool(c.connPool),
|
||||
hooksMixin: c.hooksMixin.clone(),
|
||||
pushProcessor: c.pushProcessor, // Copy push processor from parent client
|
||||
|
14
universal.go
14
universal.go
@@ -122,6 +122,9 @@ type UniversalOptions struct {
|
||||
|
||||
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
|
||||
IsClusterMode bool
|
||||
|
||||
// HitlessUpgradeConfig provides configuration for hitless upgrades.
|
||||
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||
}
|
||||
|
||||
// Cluster returns cluster options created from the universal options.
|
||||
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
// Note: HitlessUpgradeConfig not supported for FailoverOptions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,10 +289,11 @@ func (o *UniversalOptions) Simple() *Options {
|
||||
|
||||
TLSConfig: o.TLSConfig,
|
||||
|
||||
DisableIdentity: o.DisableIdentity,
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
DisableIdentity: o.DisableIdentity,
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user