mirror of
https://github.com/redis/go-redis.git
synced 2025-09-05 20:24:00 +03:00
feat(hitless): Introduce handlers for hitless upgrades
This commit includes all the work on hitless upgrades with the addition of: - Pubsub Pool - Examples - Refactor of push - Refactor of pool (using atomics for most things) - Introducing of hooks in pool
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,3 +9,6 @@ coverage.txt
|
|||||||
**/coverage.txt
|
**/coverage.txt
|
||||||
.vscode
|
.vscode
|
||||||
tmp/*
|
tmp/*
|
||||||
|
|
||||||
|
# Hitless upgrade documentation (temporary)
|
||||||
|
hitless/docs/
|
||||||
|
149
adapters.go
Normal file
149
adapters.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
|
||||||
|
var ErrInvalidCommand = errors.New("invalid command type")
|
||||||
|
|
||||||
|
// ErrInvalidPool is returned when the pool type is not supported.
|
||||||
|
var ErrInvalidPool = errors.New("invalid pool type")
|
||||||
|
|
||||||
|
// newClientAdapter creates a new client adapter for regular Redis clients.
|
||||||
|
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
|
||||||
|
return &clientAdapter{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
|
||||||
|
type clientAdapter struct {
|
||||||
|
client *baseClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOptions returns the client options.
|
||||||
|
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
|
||||||
|
return &optionsAdapter{options: ca.client.opt}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPushProcessor returns the client's push notification processor.
|
||||||
|
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
|
||||||
|
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
|
||||||
|
type optionsAdapter struct {
|
||||||
|
options *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReadTimeout returns the read timeout.
|
||||||
|
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
|
||||||
|
return oa.options.ReadTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWriteTimeout returns the write timeout.
|
||||||
|
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
|
||||||
|
return oa.options.WriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNetwork returns the network type.
|
||||||
|
func (oa *optionsAdapter) GetNetwork() string {
|
||||||
|
return oa.options.Network
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAddr returns the connection address.
|
||||||
|
func (oa *optionsAdapter) GetAddr() string {
|
||||||
|
return oa.options.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTLSEnabled returns true if TLS is enabled.
|
||||||
|
func (oa *optionsAdapter) IsTLSEnabled() bool {
|
||||||
|
return oa.options.TLSConfig != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProtocol returns the protocol version.
|
||||||
|
func (oa *optionsAdapter) GetProtocol() int {
|
||||||
|
return oa.options.Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPoolSize returns the connection pool size.
|
||||||
|
func (oa *optionsAdapter) GetPoolSize() int {
|
||||||
|
return oa.options.PoolSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDialer returns a new dialer function for the connection.
|
||||||
|
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
|
||||||
|
baseDialer := oa.options.NewDialer()
|
||||||
|
return func(ctx context.Context) (net.Conn, error) {
|
||||||
|
// Extract network and address from the options
|
||||||
|
network := oa.options.Network
|
||||||
|
addr := oa.options.Addr
|
||||||
|
return baseDialer(ctx, network, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout
|
||||||
|
type connectionAdapter struct {
|
||||||
|
conn *pool.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (ca *connectionAdapter) Close() error {
|
||||||
|
return ca.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUsable returns true if the connection is safe to use for new commands.
|
||||||
|
func (ca *connectionAdapter) IsUsable() bool {
|
||||||
|
return ca.conn.IsUsable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPoolConnection returns the underlying pool connection.
|
||||||
|
func (ca *connectionAdapter) GetPoolConnection() *pool.Conn {
|
||||||
|
return ca.conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||||
|
// These timeouts remain active until explicitly cleared.
|
||||||
|
func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||||
|
ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||||
|
// After the deadline, timeouts automatically revert to normal values.
|
||||||
|
func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||||
|
ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
|
||||||
|
func (ca *connectionAdapter) ClearRelaxedTimeout() {
|
||||||
|
ca.conn.ClearRelaxedTimeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
|
||||||
|
type pushProcessorAdapter struct {
|
||||||
|
processor push.NotificationProcessor
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterHandler registers a handler for a specific push notification name.
|
||||||
|
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
|
||||||
|
if pushHandler, ok := handler.(push.NotificationHandler); ok {
|
||||||
|
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
|
||||||
|
}
|
||||||
|
return errors.New("handler must implement push.NotificationHandler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterHandler removes a handler for a specific push notification name.
|
||||||
|
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
|
||||||
|
return ppa.processor.UnregisterHandler(pushNotificationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandler returns the handler for a specific push notification name.
|
||||||
|
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
|
||||||
|
return ppa.processor.GetHandler(pushNotificationName)
|
||||||
|
}
|
348
async_handoff_integration_test.go
Normal file
348
async_handoff_integration_test.go
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/hitless"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockNetConn implements net.Conn for testing
|
||||||
|
type mockNetConn struct {
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||||
|
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||||
|
func (m *mockNetConn) Close() error { return nil }
|
||||||
|
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||||
|
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||||
|
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
type mockAddr struct {
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAddr) Network() string { return "tcp" }
|
||||||
|
func (m *mockAddr) String() string { return m.addr }
|
||||||
|
|
||||||
|
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
|
||||||
|
func TestEventDrivenHandoffIntegration(t *testing.T) {
|
||||||
|
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
|
||||||
|
// Create a base dialer for testing
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create processor with event-driven handoff support
|
||||||
|
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create a test pool with hooks
|
||||||
|
hookManager := pool.NewPoolHookManager()
|
||||||
|
hookManager.AddHook(processor)
|
||||||
|
|
||||||
|
testPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: "original:6379"}, nil
|
||||||
|
},
|
||||||
|
PoolSize: int32(5),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add the hook to the pool after creation
|
||||||
|
testPool.AddPoolHook(processor)
|
||||||
|
defer testPool.Close()
|
||||||
|
|
||||||
|
// Set the pool reference in the processor for connection removal on handoff failure
|
||||||
|
processor.SetPool(testPool)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get a connection and mark it for handoff
|
||||||
|
conn, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set initialization function with a small delay to ensure handoff is pending
|
||||||
|
initConnCalled := false
|
||||||
|
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||||
|
initConnCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conn.SetInitConnFunc(initConnFunc)
|
||||||
|
|
||||||
|
// Mark connection for handoff
|
||||||
|
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return connection to pool - this should queue handoff
|
||||||
|
testPool.Put(ctx, conn)
|
||||||
|
|
||||||
|
// Give the on-demand worker a moment to start processing
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify handoff was queued
|
||||||
|
if !processor.IsHandoffPending(conn) {
|
||||||
|
t.Error("Handoff should be queued in pending map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the same connection - should be skipped due to pending handoff
|
||||||
|
conn2, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get second connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should get a different connection (the pending one should be skipped)
|
||||||
|
if conn == conn2 {
|
||||||
|
t.Error("Should have gotten a different connection while handoff is pending")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the second connection
|
||||||
|
testPool.Put(ctx, conn2)
|
||||||
|
|
||||||
|
// Wait for handoff to complete
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify handoff completed (removed from pending map)
|
||||||
|
if processor.IsHandoffPending(conn) {
|
||||||
|
t.Error("Handoff should have completed and been removed from pending map")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initConnCalled {
|
||||||
|
t.Error("InitConn should have been called during handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now the original connection should be available again
|
||||||
|
conn3, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get third connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Could be the original connection (now handed off) or a new one
|
||||||
|
testPool.Put(ctx, conn3)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ConcurrentHandoffs", func(t *testing.T) {
|
||||||
|
// Create a base dialer that simulates slow handoffs
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
time.Sleep(50 * time.Millisecond) // Simulate network delay
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create hooks manager and add processor as hook
|
||||||
|
hookManager := pool.NewPoolHookManager()
|
||||||
|
hookManager.AddHook(processor)
|
||||||
|
|
||||||
|
testPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: "original:6379"}, nil
|
||||||
|
},
|
||||||
|
|
||||||
|
PoolSize: int32(10),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
})
|
||||||
|
defer testPool.Close()
|
||||||
|
|
||||||
|
// Add the hook to the pool after creation
|
||||||
|
testPool.AddPoolHook(processor)
|
||||||
|
|
||||||
|
// Set the pool reference in the processor
|
||||||
|
processor.SetPool(testPool)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Start multiple concurrent handoffs
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Get connection
|
||||||
|
conn, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get connection %d: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set initialization function
|
||||||
|
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conn.SetInitConnFunc(initConnFunc)
|
||||||
|
|
||||||
|
// Mark for handoff
|
||||||
|
conn.MarkForHandoff("new-endpoint:6379", int64(id))
|
||||||
|
|
||||||
|
// Return to pool (starts async handoff)
|
||||||
|
testPool.Put(ctx, conn)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Wait for all handoffs to complete
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify pool is still functional
|
||||||
|
conn, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
|
||||||
|
}
|
||||||
|
testPool.Put(ctx, conn)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HandoffFailureRecovery", func(t *testing.T) {
|
||||||
|
// Create a failing base dialer
|
||||||
|
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create hooks manager and add processor as hook
|
||||||
|
hookManager := pool.NewPoolHookManager()
|
||||||
|
hookManager.AddHook(processor)
|
||||||
|
|
||||||
|
testPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: "original:6379"}, nil
|
||||||
|
},
|
||||||
|
|
||||||
|
PoolSize: int32(3),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
})
|
||||||
|
defer testPool.Close()
|
||||||
|
|
||||||
|
// Add the hook to the pool after creation
|
||||||
|
testPool.AddPoolHook(processor)
|
||||||
|
|
||||||
|
// Set the pool reference in the processor
|
||||||
|
processor.SetPool(testPool)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get connection and mark for handoff
|
||||||
|
conn, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
|
||||||
|
|
||||||
|
// Return to pool (starts async handoff that will fail)
|
||||||
|
testPool.Put(ctx, conn)
|
||||||
|
|
||||||
|
// Wait for handoff to fail
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Connection should be removed from pending map after failed handoff
|
||||||
|
if processor.IsHandoffPending(conn) {
|
||||||
|
t.Error("Connection should be removed from pending map after failed handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pool should still be functional
|
||||||
|
conn2, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pool should still be functional: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In event-driven approach, the original connection remains in pool
|
||||||
|
// even after failed handoff (it's still a valid connection)
|
||||||
|
// We might get the same connection or a different one
|
||||||
|
testPool.Put(ctx, conn2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GracefulShutdown", func(t *testing.T) {
|
||||||
|
// Create a slow base dialer
|
||||||
|
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create hooks manager and add processor as hook
|
||||||
|
hookManager := pool.NewPoolHookManager()
|
||||||
|
hookManager.AddHook(processor)
|
||||||
|
|
||||||
|
testPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: "original:6379"}, nil
|
||||||
|
},
|
||||||
|
|
||||||
|
PoolSize: int32(2),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
})
|
||||||
|
defer testPool.Close()
|
||||||
|
|
||||||
|
// Add the hook to the pool after creation
|
||||||
|
testPool.AddPoolHook(processor)
|
||||||
|
|
||||||
|
// Set the pool reference in the processor
|
||||||
|
processor.SetPool(testPool)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Start a handoff
|
||||||
|
conn, err := testPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a mock initialization function with delay to ensure handoff is pending
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
testPool.Put(ctx, conn)
|
||||||
|
|
||||||
|
// Give the on-demand worker a moment to start and begin processing
|
||||||
|
// The handoff should be pending because the slowDialer takes 100ms
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify handoff was queued and is being processed
|
||||||
|
if !processor.IsHandoffPending(conn) {
|
||||||
|
t.Error("Handoff should be queued in pending map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give the handoff a moment to start processing
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Shutdown processor gracefully
|
||||||
|
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = processor.Shutdown(shutdownCtx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Graceful shutdown should succeed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handoff should have completed (removed from pending map)
|
||||||
|
if processor.IsHandoffPending(conn) {
|
||||||
|
t.Error("Handoff should have completed and been removed from pending map after shutdown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
18
commands.go
18
commands.go
@@ -193,6 +193,7 @@ type Cmdable interface {
|
|||||||
ClientID(ctx context.Context) *IntCmd
|
ClientID(ctx context.Context) *IntCmd
|
||||||
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
||||||
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
||||||
|
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
|
||||||
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
||||||
ConfigResetStat(ctx context.Context) *StatusCmd
|
ConfigResetStat(ctx context.Context) *StatusCmd
|
||||||
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
||||||
@@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
|
|||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
|
||||||
|
// When enabled, the client will receive push notifications about Redis maintenance events.
|
||||||
|
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
|
||||||
|
args := []interface{}{"client", "maint_notifications"}
|
||||||
|
if enabled {
|
||||||
|
if endpointType == "" {
|
||||||
|
endpointType = "none"
|
||||||
|
}
|
||||||
|
args = append(args, "on", "moving-endpoint-type", endpointType)
|
||||||
|
} else {
|
||||||
|
args = append(args, "off")
|
||||||
|
}
|
||||||
|
cmd := NewStatusCmd(ctx, args...)
|
||||||
|
_ = c(ctx, cmd)
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------------------------------------
|
// ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
||||||
|
12
example/pubsub/go.mod
Normal file
12
example/pubsub/go.mod
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
module github.com/redis/go-redis/example/pubsub
|
||||||
|
|
||||||
|
go 1.18
|
||||||
|
|
||||||
|
replace github.com/redis/go-redis/v9 => ../..
|
||||||
|
|
||||||
|
require github.com/redis/go-redis/v9 v9.11.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
)
|
6
example/pubsub/go.sum
Normal file
6
example/pubsub/go.sum
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
146
example/pubsub/main.go
Normal file
146
example/pubsub/main.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/redis/go-redis/v9/hitless"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ctx = context.Background()
|
||||||
|
|
||||||
|
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
|
||||||
|
// It was used to find regressions in pool management in hitless mode.
|
||||||
|
// Please don't use it as a reference for how to use pubsub.
|
||||||
|
func main() {
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: ":6379",
|
||||||
|
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
|
||||||
|
Mode: hitless.MaintNotificationsEnabled,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
_ = rdb.FlushDB(ctx).Err()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err := rdb.Ping(ctx).Err()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Println("published", rdb.Get(ctx, "published").Val())
|
||||||
|
fmt.Println("received", rdb.Get(ctx, "received").Val())
|
||||||
|
subCtx, cancelSubCtx := context.WithCancel(ctx)
|
||||||
|
pubCtx, cancelPublishers := context.WithCancel(ctx)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go subscribe(subCtx, rdb, "test", i, wg)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
cancelSubCtx()
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
subCtx, cancelSubCtx = context.WithCancel(ctx)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go floodThePool(pubCtx, rdb, wg)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 500; i++ {
|
||||||
|
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go subscribe(subCtx, rdb, "test2", i, wg)
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
fmt.Println("canceling publishers")
|
||||||
|
cancelPublishers()
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
fmt.Println("canceling subscribers")
|
||||||
|
cancelSubCtx()
|
||||||
|
wg.Wait()
|
||||||
|
published, err := rdb.Get(ctx, "published").Result()
|
||||||
|
received, err := rdb.Get(ctx, "received").Result()
|
||||||
|
publishers, err := rdb.Get(ctx, "publishers").Result()
|
||||||
|
subscribers, err := rdb.Get(ctx, "subscribers").Result()
|
||||||
|
fmt.Printf("publishers: %s\n", publishers)
|
||||||
|
fmt.Printf("published: %s\n", published)
|
||||||
|
fmt.Printf("subscribers: %s\n", subscribers)
|
||||||
|
fmt.Printf("received: %s\n", received)
|
||||||
|
publishedInt, err := rdb.Get(ctx, "published").Int()
|
||||||
|
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
|
||||||
|
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
err := rdb.Publish(ctx, "test2", "hello").Err()
|
||||||
|
if err != nil {
|
||||||
|
// noop
|
||||||
|
//log.Println("publish error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rdb.Incr(ctx, "published").Err()
|
||||||
|
if err != nil {
|
||||||
|
// noop
|
||||||
|
//log.Println("incr error:", err)
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Nanosecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
rec := rdb.Subscribe(ctx, topic)
|
||||||
|
recChan := rec.Channel()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
rec.Close()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
rec.Close()
|
||||||
|
return
|
||||||
|
case msg := <-recChan:
|
||||||
|
err := rdb.Incr(ctx, "received").Err()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("incr error:", err)
|
||||||
|
}
|
||||||
|
_ = msg // Use the message to avoid unused variable warning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -57,6 +57,8 @@ func Example_instrumentation() {
|
|||||||
// finished dialing tcp :6379
|
// finished dialing tcp :6379
|
||||||
// starting processing: <[hello 3]>
|
// starting processing: <[hello 3]>
|
||||||
// finished processing: <[hello 3]>
|
// finished processing: <[hello 3]>
|
||||||
|
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||||
|
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||||
// finished processing: <[ping]>
|
// finished processing: <[ping]>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
|
|||||||
// finished dialing tcp :6379
|
// finished dialing tcp :6379
|
||||||
// starting processing: <[hello 3]>
|
// starting processing: <[hello 3]>
|
||||||
// finished processing: <[hello 3]>
|
// finished processing: <[hello 3]>
|
||||||
|
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||||
|
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||||
// pipeline finished processing: [[ping] [ping]]
|
// pipeline finished processing: [[ping] [ping]]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
|
|||||||
// finished dialing tcp :6379
|
// finished dialing tcp :6379
|
||||||
// starting processing: <[hello 3]>
|
// starting processing: <[hello 3]>
|
||||||
// finished processing: <[hello 3]>
|
// finished processing: <[hello 3]>
|
||||||
|
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||||
|
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
|
||||||
// finished processing: <[watch foo]>
|
// finished processing: <[watch foo]>
|
||||||
// starting processing: <[ping]>
|
// starting processing: <[ping]>
|
||||||
// finished processing: <[ping]>
|
// finished processing: <[ping]>
|
||||||
|
72
hitless/README.md
Normal file
72
hitless/README.md
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Hitless Upgrades
|
||||||
|
|
||||||
|
Seamless Redis connection handoffs during topology changes without interrupting operations.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/redis/go-redis/v9/hitless"
|
||||||
|
|
||||||
|
opt := &redis.Options{
|
||||||
|
Addr: "localhost:6379",
|
||||||
|
Protocol: 3, // RESP3 required
|
||||||
|
HitlessUpgrades: &redis.HitlessUpgradeConfig{
|
||||||
|
Mode: hitless.MaintNotificationsEnabled, // or MaintNotificationsAuto
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := redis.NewClient(opt)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Modes
|
||||||
|
|
||||||
|
- **`MaintNotificationsDisabled`**: Hitless upgrades are completely disabled
|
||||||
|
- **`MaintNotificationsEnabled`**: Hitless upgrades are forcefully enabled (fails if server doesn't support it)
|
||||||
|
- **`MaintNotificationsAuto`**: Hitless upgrades are enabled if server supports it (default)
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/redis/go-redis/v9/hitless"
|
||||||
|
|
||||||
|
Config: &hitless.Config{
|
||||||
|
Mode: hitless.MaintNotificationsAuto, // Notification mode
|
||||||
|
MaxHandoffRetries: 3, // Retry failed handoffs
|
||||||
|
HandoffTimeout: 15 * time.Second, // Handoff operation timeout
|
||||||
|
RelaxedTimeout: 10 * time.Second, // Extended timeout during migrations
|
||||||
|
PostHandoffRelaxedDuration: 20 * time.Second, // Keep relaxed timeout after handoff
|
||||||
|
LogLevel: 1, // 0=errors, 1=warnings, 2=info, 3=debug
|
||||||
|
MaxWorkers: 15, // Concurrent handoff workers
|
||||||
|
HandoffQueueSize: 50, // Handoff request queue size
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Worker Scaling
|
||||||
|
- **Auto-calculated**: `min(10, PoolSize/3)` - scales with pool size, capped at 10
|
||||||
|
- **Explicit values**: `max(10, set_value)` - enforces minimum 10 workers
|
||||||
|
- **On-demand**: Workers created when needed, cleaned up when idle
|
||||||
|
|
||||||
|
### Queue Sizing
|
||||||
|
- **Auto-calculated**: `10 × MaxWorkers`, capped by pool size
|
||||||
|
- **Always capped**: Queue size never exceeds pool size
|
||||||
|
|
||||||
|
## Metrics Hook Example
|
||||||
|
|
||||||
|
A metrics collection hook is available in `example_hooks.go` that demonstrates how to monitor hitless upgrade operations:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/redis/go-redis/v9/hitless"
|
||||||
|
|
||||||
|
metricsHook := hitless.NewMetricsHook()
|
||||||
|
// Use with your monitoring system
|
||||||
|
```
|
||||||
|
|
||||||
|
The metrics hook tracks:
|
||||||
|
- Handoff success/failure rates
|
||||||
|
- Handoff duration
|
||||||
|
- Queue depth
|
||||||
|
- Worker utilization
|
||||||
|
- Connection lifecycle events
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- **RESP3 Protocol**: Required for push notifications
|
377
hitless/config.go
Normal file
377
hitless/config.go
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaintNotificationsMode represents the maintenance notifications mode
|
||||||
|
type MaintNotificationsMode string
|
||||||
|
|
||||||
|
// Constants for maintenance push notifications modes
|
||||||
|
const (
|
||||||
|
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
|
||||||
|
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
|
||||||
|
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsValid returns true if the maintenance notifications mode is valid
|
||||||
|
func (m MaintNotificationsMode) IsValid() bool {
|
||||||
|
switch m {
|
||||||
|
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the mode
|
||||||
|
func (m MaintNotificationsMode) String() string {
|
||||||
|
return string(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndpointType represents the type of endpoint to request in MOVING notifications
|
||||||
|
type EndpointType string
|
||||||
|
|
||||||
|
// Constants for endpoint types
|
||||||
|
const (
|
||||||
|
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
|
||||||
|
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
|
||||||
|
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
|
||||||
|
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
|
||||||
|
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
|
||||||
|
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsValid returns true if the endpoint type is valid
|
||||||
|
func (e EndpointType) IsValid() bool {
|
||||||
|
switch e {
|
||||||
|
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||||
|
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the endpoint type
|
||||||
|
func (e EndpointType) String() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config provides configuration options for hitless upgrades.
|
||||||
|
type Config struct {
|
||||||
|
// Mode controls how client maintenance notifications are handled.
|
||||||
|
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
|
||||||
|
// Default: MaintNotificationsAuto
|
||||||
|
Mode MaintNotificationsMode
|
||||||
|
|
||||||
|
// EndpointType specifies the type of endpoint to request in MOVING notifications.
|
||||||
|
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||||
|
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
|
||||||
|
// Default: EndpointTypeAuto
|
||||||
|
EndpointType EndpointType
|
||||||
|
|
||||||
|
// RelaxedTimeout is the concrete timeout value to use during
|
||||||
|
// MIGRATING/FAILING_OVER states to accommodate increased latency.
|
||||||
|
// This applies to both read and write timeouts.
|
||||||
|
// Default: 10 seconds
|
||||||
|
RelaxedTimeout time.Duration
|
||||||
|
|
||||||
|
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
|
||||||
|
// If handoff takes longer than this, the old connection will be forcibly closed.
|
||||||
|
// Default: 15 seconds (matches server-side eviction timeout)
|
||||||
|
HandoffTimeout time.Duration
|
||||||
|
|
||||||
|
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
|
||||||
|
// Workers are created on-demand and automatically cleaned up when idle.
|
||||||
|
// If zero, defaults to min(10, PoolSize/3) to handle bursts effectively.
|
||||||
|
// If explicitly set, enforces minimum of 10 workers.
|
||||||
|
//
|
||||||
|
// Default: min(10, PoolSize/3), Minimum when set: 10
|
||||||
|
MaxWorkers int
|
||||||
|
|
||||||
|
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
|
||||||
|
// If the queue is full, new handoff requests will be rejected.
|
||||||
|
// Always capped by pool size since you can't handoff more connections than exist.
|
||||||
|
//
|
||||||
|
// Default: 10x max workers, capped by pool size, min 2
|
||||||
|
HandoffQueueSize int
|
||||||
|
|
||||||
|
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
|
||||||
|
// after a handoff completes. This provides additional resilience during cluster transitions.
|
||||||
|
// Default: 2 * RelaxedTimeout
|
||||||
|
PostHandoffRelaxedDuration time.Duration
|
||||||
|
|
||||||
|
// ScaleDownDelay is the delay before checking if workers should be scaled down.
|
||||||
|
// This prevents expensive checks on every handoff completion and avoids rapid scaling cycles.
|
||||||
|
// Default: 2 seconds
|
||||||
|
ScaleDownDelay time.Duration
|
||||||
|
|
||||||
|
// LogLevel controls the verbosity of hitless upgrade logging.
|
||||||
|
// 0 = errors only, 1 = warnings, 2 = info, 3 = debug
|
||||||
|
// Default: 1 (warnings)
|
||||||
|
LogLevel int
|
||||||
|
|
||||||
|
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
|
||||||
|
// After this many retries, the connection will be removed from the pool.
|
||||||
|
// Default: 3
|
||||||
|
MaxHandoffRetries int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) IsEnabled() bool {
|
||||||
|
return c != nil && c.Mode != MaintNotificationsDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a Config with sensible defaults.
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
|
||||||
|
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
|
||||||
|
RelaxedTimeout: 10 * time.Second,
|
||||||
|
HandoffTimeout: 15 * time.Second,
|
||||||
|
MaxWorkers: 0, // Auto-calculated based on pool size
|
||||||
|
HandoffQueueSize: 0, // Auto-calculated based on max workers
|
||||||
|
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
|
||||||
|
ScaleDownDelay: 2 * time.Second,
|
||||||
|
LogLevel: 1,
|
||||||
|
|
||||||
|
// Connection Handoff Configuration
|
||||||
|
MaxHandoffRetries: 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks if the configuration is valid.
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
if c.RelaxedTimeout <= 0 {
|
||||||
|
return ErrInvalidRelaxedTimeout
|
||||||
|
}
|
||||||
|
if c.HandoffTimeout <= 0 {
|
||||||
|
return ErrInvalidHandoffTimeout
|
||||||
|
}
|
||||||
|
// Validate worker configuration
|
||||||
|
// Allow 0 for auto-calculation, but negative values are invalid
|
||||||
|
if c.MaxWorkers < 0 {
|
||||||
|
return ErrInvalidHandoffWorkers
|
||||||
|
}
|
||||||
|
// HandoffQueueSize validation - allow 0 for auto-calculation
|
||||||
|
if c.HandoffQueueSize < 0 {
|
||||||
|
return ErrInvalidHandoffQueueSize
|
||||||
|
}
|
||||||
|
if c.PostHandoffRelaxedDuration < 0 {
|
||||||
|
return ErrInvalidPostHandoffRelaxedDuration
|
||||||
|
}
|
||||||
|
if c.LogLevel < 0 || c.LogLevel > 3 {
|
||||||
|
return ErrInvalidLogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Mode (maintenance notifications mode)
|
||||||
|
if !c.Mode.IsValid() {
|
||||||
|
return ErrInvalidMaintNotifications
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate EndpointType
|
||||||
|
if !c.EndpointType.IsValid() {
|
||||||
|
return ErrInvalidEndpointType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate configuration fields
|
||||||
|
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
|
||||||
|
return ErrInvalidHandoffRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaults applies default values to any zero-value fields in the configuration.
|
||||||
|
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||||
|
func (c *Config) ApplyDefaults() *Config {
|
||||||
|
return c.ApplyDefaultsWithPoolSize(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
|
||||||
|
// using the provided pool size to calculate worker defaults.
|
||||||
|
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||||
|
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
|
||||||
|
if c == nil {
|
||||||
|
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaults := DefaultConfig()
|
||||||
|
result := &Config{}
|
||||||
|
|
||||||
|
// Apply defaults for enum fields (empty/zero means not set)
|
||||||
|
if c.Mode == "" {
|
||||||
|
result.Mode = defaults.Mode
|
||||||
|
} else {
|
||||||
|
result.Mode = c.Mode
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.EndpointType == "" {
|
||||||
|
result.EndpointType = defaults.EndpointType
|
||||||
|
} else {
|
||||||
|
result.EndpointType = c.EndpointType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults for duration fields (zero means not set)
|
||||||
|
if c.RelaxedTimeout <= 0 {
|
||||||
|
result.RelaxedTimeout = defaults.RelaxedTimeout
|
||||||
|
} else {
|
||||||
|
result.RelaxedTimeout = c.RelaxedTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.HandoffTimeout <= 0 {
|
||||||
|
result.HandoffTimeout = defaults.HandoffTimeout
|
||||||
|
} else {
|
||||||
|
result.HandoffTimeout = c.HandoffTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults for integer fields (zero means not set)
|
||||||
|
if c.HandoffQueueSize <= 0 {
|
||||||
|
result.HandoffQueueSize = defaults.HandoffQueueSize
|
||||||
|
} else {
|
||||||
|
result.HandoffQueueSize = c.HandoffQueueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy worker configuration
|
||||||
|
result.MaxWorkers = c.MaxWorkers
|
||||||
|
|
||||||
|
// Apply worker defaults based on pool size
|
||||||
|
result.applyWorkerDefaults(poolSize)
|
||||||
|
|
||||||
|
// Apply queue size defaults based on max workers, capped by pool size
|
||||||
|
if c.HandoffQueueSize <= 0 {
|
||||||
|
// Queue size: 10x max workers, but never more than pool size
|
||||||
|
workerBasedSize := result.MaxWorkers * 10
|
||||||
|
result.HandoffQueueSize = util.Min(workerBasedSize, poolSize)
|
||||||
|
} else {
|
||||||
|
result.HandoffQueueSize = c.HandoffQueueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always cap queue size by pool size - no point having more queue slots than connections
|
||||||
|
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, poolSize)
|
||||||
|
|
||||||
|
// Ensure minimum queue size of 2
|
||||||
|
if result.HandoffQueueSize < 2 {
|
||||||
|
result.HandoffQueueSize = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.PostHandoffRelaxedDuration <= 0 {
|
||||||
|
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
|
||||||
|
} else {
|
||||||
|
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ScaleDownDelay <= 0 {
|
||||||
|
result.ScaleDownDelay = defaults.ScaleDownDelay
|
||||||
|
} else {
|
||||||
|
result.ScaleDownDelay = c.ScaleDownDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
|
||||||
|
// We'll use the provided value as-is, since 0 is valid
|
||||||
|
result.LogLevel = c.LogLevel
|
||||||
|
|
||||||
|
// Apply defaults for configuration fields
|
||||||
|
if c.MaxHandoffRetries <= 0 {
|
||||||
|
result.MaxHandoffRetries = defaults.MaxHandoffRetries
|
||||||
|
} else {
|
||||||
|
result.MaxHandoffRetries = c.MaxHandoffRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone creates a deep copy of the configuration.
|
||||||
|
func (c *Config) Clone() *Config {
|
||||||
|
if c == nil {
|
||||||
|
return DefaultConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Config{
|
||||||
|
Mode: c.Mode,
|
||||||
|
EndpointType: c.EndpointType,
|
||||||
|
RelaxedTimeout: c.RelaxedTimeout,
|
||||||
|
HandoffTimeout: c.HandoffTimeout,
|
||||||
|
MaxWorkers: c.MaxWorkers,
|
||||||
|
HandoffQueueSize: c.HandoffQueueSize,
|
||||||
|
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
|
||||||
|
ScaleDownDelay: c.ScaleDownDelay,
|
||||||
|
LogLevel: c.LogLevel,
|
||||||
|
|
||||||
|
// Configuration fields
|
||||||
|
MaxHandoffRetries: c.MaxHandoffRetries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyWorkerDefaults calculates and applies worker defaults based on pool size
|
||||||
|
func (c *Config) applyWorkerDefaults(poolSize int) {
|
||||||
|
// Calculate defaults based on pool size
|
||||||
|
if poolSize <= 0 {
|
||||||
|
poolSize = 10 * runtime.GOMAXPROCS(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MaxWorkers == 0 {
|
||||||
|
// When not set: min(10, poolSize/3) - don't exceed 10 workers for small pools
|
||||||
|
c.MaxWorkers = util.Min(10, poolSize/3)
|
||||||
|
} else {
|
||||||
|
// When explicitly set: max(10, set_value) - ensure at least 10 workers
|
||||||
|
c.MaxWorkers = util.Max(10, c.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure minimum of 1 worker (fallback for very small pools)
|
||||||
|
if c.MaxWorkers < 1 {
|
||||||
|
c.MaxWorkers = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectEndpointType automatically detects the appropriate endpoint type
|
||||||
|
// based on the connection address and TLS configuration.
|
||||||
|
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
|
||||||
|
// Parse the address to determine if it's an IP or hostname
|
||||||
|
isPrivate := isPrivateIP(addr)
|
||||||
|
|
||||||
|
var endpointType EndpointType
|
||||||
|
|
||||||
|
if tlsEnabled {
|
||||||
|
// TLS requires FQDN for certificate validation
|
||||||
|
if isPrivate {
|
||||||
|
endpointType = EndpointTypeInternalFQDN
|
||||||
|
} else {
|
||||||
|
endpointType = EndpointTypeExternalFQDN
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No TLS, can use IP addresses
|
||||||
|
if isPrivate {
|
||||||
|
endpointType = EndpointTypeInternalIP
|
||||||
|
} else {
|
||||||
|
endpointType = EndpointTypeExternalIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return endpointType
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrivateIP checks if the given address is in a private IP range.
|
||||||
|
func isPrivateIP(addr string) bool {
|
||||||
|
// Extract host from "host:port" format
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr // Assume no port
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip == nil {
|
||||||
|
return false // Not an IP address (likely hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for private/loopback ranges
|
||||||
|
return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||||
|
}
|
427
hitless/config_test.go
Normal file
427
hitless/config_test.go
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig(t *testing.T) {
|
||||||
|
t.Run("DefaultConfig", func(t *testing.T) {
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
// MaxWorkers should be 0 in default config (auto-calculated)
|
||||||
|
if config.MaxWorkers != 0 {
|
||||||
|
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandoffQueueSize should be 0 in default config (auto-calculated)
|
||||||
|
if config.HandoffQueueSize != 0 {
|
||||||
|
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.RelaxedTimeout != 10*time.Second {
|
||||||
|
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test configuration fields have proper defaults
|
||||||
|
if config.MaxHandoffRetries != 3 {
|
||||||
|
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HandoffTimeout != 15*time.Second {
|
||||||
|
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.PostHandoffRelaxedDuration != 0 {
|
||||||
|
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that defaults are applied correctly
|
||||||
|
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
|
||||||
|
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
|
||||||
|
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ConfigValidation", func(t *testing.T) {
|
||||||
|
// Valid config with applied defaults
|
||||||
|
config := DefaultConfig().ApplyDefaults()
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
t.Errorf("Default config with applied defaults should be valid: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid worker configuration (negative MaxWorkers)
|
||||||
|
config = &Config{
|
||||||
|
RelaxedTimeout: 30 * time.Second,
|
||||||
|
HandoffTimeout: 15 * time.Second,
|
||||||
|
MaxWorkers: -1, // This should be invalid
|
||||||
|
HandoffQueueSize: 100,
|
||||||
|
PostHandoffRelaxedDuration: 10 * time.Second,
|
||||||
|
LogLevel: 1,
|
||||||
|
MaxHandoffRetries: 3, // Add required field
|
||||||
|
}
|
||||||
|
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
|
||||||
|
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid HandoffQueueSize
|
||||||
|
config = DefaultConfig().ApplyDefaults()
|
||||||
|
config.HandoffQueueSize = -1
|
||||||
|
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
|
||||||
|
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid PostHandoffRelaxedDuration
|
||||||
|
config = DefaultConfig().ApplyDefaults()
|
||||||
|
config.PostHandoffRelaxedDuration = -1 * time.Second
|
||||||
|
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
|
||||||
|
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ConfigClone", func(t *testing.T) {
|
||||||
|
original := DefaultConfig()
|
||||||
|
original.MaxWorkers = 20
|
||||||
|
original.HandoffQueueSize = 200
|
||||||
|
|
||||||
|
cloned := original.Clone()
|
||||||
|
|
||||||
|
if cloned.MaxWorkers != 20 {
|
||||||
|
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cloned.HandoffQueueSize != 200 {
|
||||||
|
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify original to ensure clone is independent
|
||||||
|
original.MaxWorkers = 2
|
||||||
|
if cloned.MaxWorkers != 20 {
|
||||||
|
t.Error("Clone should be independent of original")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyDefaults(t *testing.T) {
|
||||||
|
t.Run("NilConfig", func(t *testing.T) {
|
||||||
|
var config *Config
|
||||||
|
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||||
|
|
||||||
|
// With nil config, should get default config with auto-calculated workers
|
||||||
|
if result.MaxWorkers <= 0 {
|
||||||
|
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
|
||||||
|
workerBasedSize := result.MaxWorkers * 10
|
||||||
|
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||||
|
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||||
|
if result.HandoffQueueSize != expectedQueueSize {
|
||||||
|
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||||
|
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PartialConfig", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 12, // Set this field explicitly
|
||||||
|
// Leave other fields as zero values
|
||||||
|
}
|
||||||
|
|
||||||
|
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||||
|
|
||||||
|
// Should keep the explicitly set values
|
||||||
|
if result.MaxWorkers != 12 {
|
||||||
|
t.Errorf("Expected MaxWorkers to be 12 (explicitly set), got %d", result.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should apply default for unset fields (auto-calculated queue size, capped by pool size)
|
||||||
|
workerBasedSize := result.MaxWorkers * 10
|
||||||
|
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||||
|
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||||
|
if result.HandoffQueueSize != expectedQueueSize {
|
||||||
|
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||||
|
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test explicit queue size capping by pool size
|
||||||
|
configWithLargeQueue := &Config{
|
||||||
|
MaxWorkers: 5,
|
||||||
|
HandoffQueueSize: 1000, // Much larger than pool size
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
|
||||||
|
if resultCapped.HandoffQueueSize != 20 {
|
||||||
|
t.Errorf("Expected HandoffQueueSize to be capped by pool size (20), got %d", resultCapped.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RelaxedTimeout != 10*time.Second {
|
||||||
|
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.HandoffTimeout != 15*time.Second {
|
||||||
|
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ZeroValues", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 0, // Zero value should get auto-calculated defaults
|
||||||
|
HandoffQueueSize: 0, // Zero value should get default
|
||||||
|
RelaxedTimeout: 0, // Zero value should get default
|
||||||
|
LogLevel: 0, // Zero is valid for LogLevel (errors only)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||||
|
|
||||||
|
// Zero values should get auto-calculated defaults
|
||||||
|
if result.MaxWorkers <= 0 {
|
||||||
|
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
|
||||||
|
workerBasedSize := result.MaxWorkers * 10
|
||||||
|
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||||
|
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||||
|
if result.HandoffQueueSize != expectedQueueSize {
|
||||||
|
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||||
|
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RelaxedTimeout != 10*time.Second {
|
||||||
|
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogLevel 0 should be preserved (it's a valid value)
|
||||||
|
if result.LogLevel != 0 {
|
||||||
|
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessorWithConfig(t *testing.T) {
|
||||||
|
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 5,
|
||||||
|
HandoffQueueSize: 50,
|
||||||
|
RelaxedTimeout: 10 * time.Second,
|
||||||
|
HandoffTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// The processor should be created successfully with custom config
|
||||||
|
if processor == nil {
|
||||||
|
t.Error("Processor should be created with custom config")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 7, // Only set worker field
|
||||||
|
// Other fields will get defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Should work with partial config (defaults applied)
|
||||||
|
if processor == nil {
|
||||||
|
t.Error("Processor should be created with partial config")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Should use default config when nil is passed
|
||||||
|
if processor == nil {
|
||||||
|
t.Error("Processor should be created with nil config (using defaults)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationWithApplyDefaults(t *testing.T) {
|
||||||
|
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
|
||||||
|
// Create a partial config with only some fields set
|
||||||
|
partialConfig := &Config{
|
||||||
|
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
|
||||||
|
LogLevel: 2, // Custom value
|
||||||
|
// Other fields left as zero values - should get defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create processor - should apply defaults to missing fields
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Processor should be created successfully
|
||||||
|
if processor == nil {
|
||||||
|
t.Error("Processor should be created with partial config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the ApplyDefaults method worked correctly by creating the same config
|
||||||
|
// and applying defaults manually
|
||||||
|
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
|
||||||
|
|
||||||
|
// Should preserve custom values (when >= 10)
|
||||||
|
if expectedConfig.MaxWorkers != 15 {
|
||||||
|
t.Errorf("Expected MaxWorkers to be 15, got %d", expectedConfig.MaxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedConfig.LogLevel != 2 {
|
||||||
|
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should apply defaults for missing fields (auto-calculated queue size, capped by pool size)
|
||||||
|
workerBasedSize := expectedConfig.MaxWorkers * 10
|
||||||
|
poolSize := 100 // Default pool size used in ApplyDefaults
|
||||||
|
expectedQueueSize := util.Min(workerBasedSize, poolSize)
|
||||||
|
if expectedConfig.HandoffQueueSize != expectedQueueSize {
|
||||||
|
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
|
||||||
|
expectedQueueSize, workerBasedSize, poolSize, expectedConfig.HandoffQueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that queue size is always capped by pool size
|
||||||
|
if expectedConfig.HandoffQueueSize > poolSize {
|
||||||
|
t.Errorf("HandoffQueueSize (%d) should never exceed pool size (%d)",
|
||||||
|
expectedConfig.HandoffQueueSize, poolSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedConfig.RelaxedTimeout != 10*time.Second {
|
||||||
|
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedConfig.HandoffTimeout != 15*time.Second {
|
||||||
|
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
|
||||||
|
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnhancedConfigValidation(t *testing.T) {
|
||||||
|
t.Run("ValidateFields", func(t *testing.T) {
|
||||||
|
config := DefaultConfig()
|
||||||
|
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
|
||||||
|
|
||||||
|
// Should pass validation with default values
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
t.Errorf("Default config should be valid, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid MaxHandoffRetries
|
||||||
|
config.MaxHandoffRetries = 0
|
||||||
|
if err := config.Validate(); err == nil {
|
||||||
|
t.Error("Expected validation error for MaxHandoffRetries = 0")
|
||||||
|
}
|
||||||
|
config.MaxHandoffRetries = 11
|
||||||
|
if err := config.Validate(); err == nil {
|
||||||
|
t.Error("Expected validation error for MaxHandoffRetries = 11")
|
||||||
|
}
|
||||||
|
config.MaxHandoffRetries = 3 // Reset to valid value
|
||||||
|
|
||||||
|
// Should pass validation again
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
t.Errorf("Config should be valid after reset, got error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigClone(t *testing.T) {
|
||||||
|
original := DefaultConfig()
|
||||||
|
original.MaxHandoffRetries = 7
|
||||||
|
original.HandoffTimeout = 8 * time.Second
|
||||||
|
|
||||||
|
cloned := original.Clone()
|
||||||
|
|
||||||
|
// Test that values are copied
|
||||||
|
if cloned.MaxHandoffRetries != 7 {
|
||||||
|
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
|
||||||
|
}
|
||||||
|
if cloned.HandoffTimeout != 8*time.Second {
|
||||||
|
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that modifying clone doesn't affect original
|
||||||
|
cloned.MaxHandoffRetries = 10
|
||||||
|
if original.MaxHandoffRetries != 7 {
|
||||||
|
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxWorkersLogic(t *testing.T) {
|
||||||
|
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
poolSize int
|
||||||
|
expectedWorkers int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{6, 2, "Small pool: min(10, 6/3) = min(10, 2) = 2"},
|
||||||
|
{15, 5, "Medium pool: min(10, 15/3) = min(10, 5) = 5"},
|
||||||
|
{30, 10, "Large pool: min(10, 30/3) = min(10, 10) = 10"},
|
||||||
|
{60, 10, "Very large pool: min(10, 60/3) = min(10, 20) = 10"},
|
||||||
|
{120, 10, "Huge pool: min(10, 120/3) = min(10, 40) = 10"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
config := &Config{} // MaxWorkers = 0 (not set)
|
||||||
|
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
|
||||||
|
|
||||||
|
if result.MaxWorkers != tc.expectedWorkers {
|
||||||
|
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
|
||||||
|
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
setValue int
|
||||||
|
expectedWorkers int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{1, 10, "Set 1: max(10, 1) = 10 (enforced minimum)"},
|
||||||
|
{5, 10, "Set 5: max(10, 5) = 10 (enforced minimum)"},
|
||||||
|
{8, 10, "Set 8: max(10, 8) = 10 (enforced minimum)"},
|
||||||
|
{10, 10, "Set 10: max(10, 10) = 10 (exact minimum)"},
|
||||||
|
{15, 15, "Set 15: max(10, 15) = 15 (respects user choice)"},
|
||||||
|
{20, 20, "Set 20: max(10, 20) = 20 (respects user choice)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: tc.setValue, // Explicitly set
|
||||||
|
}
|
||||||
|
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
|
||||||
|
|
||||||
|
if result.MaxWorkers != tc.expectedWorkers {
|
||||||
|
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
|
||||||
|
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
76
hitless/errors.go
Normal file
76
hitless/errors.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Configuration errors
|
||||||
|
var (
|
||||||
|
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
|
||||||
|
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
|
||||||
|
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
|
||||||
|
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
|
||||||
|
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
|
||||||
|
ErrInvalidLogLevel = errors.New("hitless: log level must be between 0 and 3")
|
||||||
|
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
|
||||||
|
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
|
||||||
|
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
|
||||||
|
|
||||||
|
// Configuration validation errors
|
||||||
|
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
|
||||||
|
ErrInvalidConnectionValidationTimeout = errors.New("hitless: ConnectionValidationTimeout must be greater than 0 and less than 30 seconds")
|
||||||
|
ErrInvalidConnectionHealthCheckInterval = errors.New("hitless: ConnectionHealthCheckInterval must be between 0 and 1 hour")
|
||||||
|
ErrInvalidOperationCleanupInterval = errors.New("hitless: OperationCleanupInterval must be greater than 0 and less than 1 hour")
|
||||||
|
ErrInvalidMaxActiveOperations = errors.New("hitless: MaxActiveOperations must be between 100 and 100000")
|
||||||
|
ErrInvalidNotificationBufferSize = errors.New("hitless: NotificationBufferSize must be between 10 and 10000")
|
||||||
|
ErrInvalidNotificationTimeout = errors.New("hitless: NotificationTimeout must be greater than 0 and less than 30 seconds")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Integration errors
|
||||||
|
var (
|
||||||
|
ErrInvalidClient = errors.New("hitless: invalid client type")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handoff errors
|
||||||
|
var (
|
||||||
|
ErrHandoffInProgress = errors.New("hitless: handoff already in progress")
|
||||||
|
ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress")
|
||||||
|
ErrConnectionFailed = errors.New("hitless: failed to establish new connection")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dead error variables removed - unused in simplified architecture
|
||||||
|
|
||||||
|
// Notification errors
|
||||||
|
var (
|
||||||
|
ErrInvalidNotification = errors.New("hitless: invalid notification format")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dead error variables removed - unused in simplified architecture
|
||||||
|
|
||||||
|
// HandoffError represents an error that occurred during connection handoff.
|
||||||
|
type HandoffError struct {
|
||||||
|
Operation string
|
||||||
|
Endpoint string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *HandoffError) Error() string {
|
||||||
|
if e.Cause != nil {
|
||||||
|
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *HandoffError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandoffError creates a new HandoffError.
|
||||||
|
func NewHandoffError(operation, endpoint string, cause error) *HandoffError {
|
||||||
|
return &HandoffError{
|
||||||
|
Operation: operation,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Cause: cause,
|
||||||
|
}
|
||||||
|
}
|
63
hitless/example_hooks.go
Normal file
63
hitless/example_hooks.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// contextKey is a custom type for context keys to avoid collisions
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
startTimeKey contextKey = "notif_hitless_start_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetricsHook collects metrics about notification processing.
|
||||||
|
type MetricsHook struct {
|
||||||
|
NotificationCounts map[string]int64
|
||||||
|
ProcessingTimes map[string]time.Duration
|
||||||
|
ErrorCounts map[string]int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetricsHook creates a new metrics collection hook.
|
||||||
|
func NewMetricsHook() *MetricsHook {
|
||||||
|
return &MetricsHook{
|
||||||
|
NotificationCounts: make(map[string]int64),
|
||||||
|
ProcessingTimes: make(map[string]time.Duration),
|
||||||
|
ErrorCounts: make(map[string]int64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreHook records the start time for processing metrics.
|
||||||
|
func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||||
|
mh.NotificationCounts[notificationType]++
|
||||||
|
|
||||||
|
// Store start time in context for duration calculation
|
||||||
|
startTime := time.Now()
|
||||||
|
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
|
||||||
|
|
||||||
|
return notification, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostHook records processing completion and any errors.
|
||||||
|
func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||||
|
// Calculate processing duration
|
||||||
|
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
mh.ProcessingTimes[notificationType] = duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record errors
|
||||||
|
if result != nil {
|
||||||
|
mh.ErrorCounts[notificationType]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns a summary of collected metrics.
|
||||||
|
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"notification_counts": mh.NotificationCounts,
|
||||||
|
"processing_times": mh.ProcessingTimes,
|
||||||
|
"error_counts": mh.ErrorCounts,
|
||||||
|
}
|
||||||
|
}
|
299
hitless/hitless_manager.go
Normal file
299
hitless/hitless_manager.go
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Push notification type constants for hitless upgrades
|
||||||
|
const (
|
||||||
|
NotificationMoving = "MOVING"
|
||||||
|
NotificationMigrating = "MIGRATING"
|
||||||
|
NotificationMigrated = "MIGRATED"
|
||||||
|
NotificationFailingOver = "FAILING_OVER"
|
||||||
|
NotificationFailedOver = "FAILED_OVER"
|
||||||
|
)
|
||||||
|
|
||||||
|
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
|
||||||
|
var hitlessNotificationTypes = []string{
|
||||||
|
NotificationMoving,
|
||||||
|
NotificationMigrating,
|
||||||
|
NotificationMigrated,
|
||||||
|
NotificationFailingOver,
|
||||||
|
NotificationFailedOver,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotificationHook is called before and after notification processing
|
||||||
|
// PreHook can modify the notification and return false to skip processing
|
||||||
|
// PostHook is called after successful processing
|
||||||
|
type NotificationHook interface {
|
||||||
|
PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool)
|
||||||
|
PostHook(ctx context.Context, notificationType string, notification []interface{}, result error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MovingOperationKey provides a unique key for tracking MOVING operations
|
||||||
|
// that combines sequence ID with connection identifier to handle duplicate
|
||||||
|
// sequence IDs across multiple connections to the same node.
|
||||||
|
type MovingOperationKey struct {
|
||||||
|
SeqID int64 // Sequence ID from MOVING notification
|
||||||
|
ConnID uint64 // Unique connection identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the key for debugging
|
||||||
|
func (k MovingOperationKey) String() string {
|
||||||
|
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
|
||||||
|
type HitlessManager struct {
|
||||||
|
client interfaces.ClientInterface
|
||||||
|
config *Config
|
||||||
|
options interfaces.OptionsInterface
|
||||||
|
pool pool.Pooler
|
||||||
|
|
||||||
|
// MOVING operation tracking - using sync.Map for better concurrent performance
|
||||||
|
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
|
||||||
|
|
||||||
|
// Atomic state tracking - no locks needed for state queries
|
||||||
|
activeOperationCount atomic.Int64 // Number of active operations
|
||||||
|
closed atomic.Bool // Manager closed state
|
||||||
|
|
||||||
|
// Notification hooks for extensibility
|
||||||
|
hooks []NotificationHook
|
||||||
|
hooksMu sync.RWMutex // Protects hooks slice
|
||||||
|
poolHooksRef *PoolHook
|
||||||
|
}
|
||||||
|
|
||||||
|
// MovingOperation tracks an active MOVING operation.
|
||||||
|
type MovingOperation struct {
|
||||||
|
SeqID int64
|
||||||
|
NewEndpoint string
|
||||||
|
StartTime time.Time
|
||||||
|
Deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHitlessManager creates a new simplified hitless manager.
|
||||||
|
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
|
||||||
|
if client == nil {
|
||||||
|
return nil, ErrInvalidClient
|
||||||
|
}
|
||||||
|
|
||||||
|
hm := &HitlessManager{
|
||||||
|
client: client,
|
||||||
|
pool: pool,
|
||||||
|
options: client.GetOptions(),
|
||||||
|
config: config.Clone(),
|
||||||
|
hooks: make([]NotificationHook, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up push notification handling
|
||||||
|
if err := hm.setupPushNotifications(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return hm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPoolHook creates a pool hook with a custom dialer.
|
||||||
|
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
|
||||||
|
poolHook := hm.createPoolHook(baseDialer)
|
||||||
|
hm.pool.AddPoolHook(poolHook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupPushNotifications sets up push notification handling by registering with the client's processor.
|
||||||
|
func (hm *HitlessManager) setupPushNotifications() error {
|
||||||
|
processor := hm.client.GetPushProcessor()
|
||||||
|
if processor == nil {
|
||||||
|
return ErrInvalidClient // Client doesn't support push notifications
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create our notification handler
|
||||||
|
handler := &NotificationHandler{manager: hm}
|
||||||
|
|
||||||
|
// Register handlers for all hitless upgrade notifications with the client's processor
|
||||||
|
for _, notificationType := range hitlessNotificationTypes {
|
||||||
|
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
|
||||||
|
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
|
||||||
|
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
|
||||||
|
// Create composite key
|
||||||
|
key := MovingOperationKey{
|
||||||
|
SeqID: seqID,
|
||||||
|
ConnID: connID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create MOVING operation record
|
||||||
|
movingOp := &MovingOperation{
|
||||||
|
SeqID: seqID,
|
||||||
|
NewEndpoint: newEndpoint,
|
||||||
|
StartTime: time.Now(),
|
||||||
|
Deadline: deadline,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use LoadOrStore for atomic check-and-set operation
|
||||||
|
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
|
||||||
|
// Duplicate MOVING notification, ignore
|
||||||
|
internal.Logger.Printf(ctx, "Duplicate MOVING operation ignored: %s", key.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment active operation count atomically
|
||||||
|
hm.activeOperationCount.Add(1)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
|
||||||
|
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
|
||||||
|
// Create composite key
|
||||||
|
key := MovingOperationKey{
|
||||||
|
SeqID: seqID,
|
||||||
|
ConnID: connID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from active operations atomically
|
||||||
|
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
|
||||||
|
// Decrement active operation count only if operation existed
|
||||||
|
hm.activeOperationCount.Add(-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveMovingOperations returns active operations with composite keys.
|
||||||
|
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
|
||||||
|
result := make(map[MovingOperationKey]*MovingOperation)
|
||||||
|
|
||||||
|
// Iterate over sync.Map to build result
|
||||||
|
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||||
|
k := key.(MovingOperationKey)
|
||||||
|
op := value.(*MovingOperation)
|
||||||
|
|
||||||
|
// Create a copy to avoid sharing references
|
||||||
|
result[k] = &MovingOperation{
|
||||||
|
SeqID: op.SeqID,
|
||||||
|
NewEndpoint: op.NewEndpoint,
|
||||||
|
StartTime: op.StartTime,
|
||||||
|
Deadline: op.Deadline,
|
||||||
|
}
|
||||||
|
return true // Continue iteration
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHandoffInProgress returns true if any handoff is in progress.
|
||||||
|
// Uses atomic counter for lock-free operation.
|
||||||
|
func (hm *HitlessManager) IsHandoffInProgress() bool {
|
||||||
|
return hm.activeOperationCount.Load() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveOperationCount returns the number of active operations.
|
||||||
|
// Uses atomic counter for lock-free operation.
|
||||||
|
func (hm *HitlessManager) GetActiveOperationCount() int64 {
|
||||||
|
return hm.activeOperationCount.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the hitless manager.
|
||||||
|
func (hm *HitlessManager) Close() error {
|
||||||
|
// Use atomic operation for thread-safe close check
|
||||||
|
if !hm.closed.CompareAndSwap(false, true) {
|
||||||
|
return nil // Already closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the pool hook if it exists
|
||||||
|
if hm.poolHooksRef != nil {
|
||||||
|
// Use a timeout to prevent hanging indefinitely
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := hm.poolHooksRef.Shutdown(shutdownCtx)
|
||||||
|
if err != nil {
|
||||||
|
// was not able to close pool hook, keep closed state false
|
||||||
|
hm.closed.Store(false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Remove the pool hook from the pool
|
||||||
|
if hm.pool != nil {
|
||||||
|
hm.pool.RemovePoolHook(hm.poolHooksRef)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear all active operations
|
||||||
|
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||||
|
hm.activeMovingOps.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset counter
|
||||||
|
hm.activeOperationCount.Store(0)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetState returns current state using atomic counter for lock-free operation.
|
||||||
|
func (hm *HitlessManager) GetState() State {
|
||||||
|
if hm.activeOperationCount.Load() > 0 {
|
||||||
|
return StateMoving
|
||||||
|
}
|
||||||
|
return StateIdle
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
|
||||||
|
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||||
|
hm.hooksMu.RLock()
|
||||||
|
defer hm.hooksMu.RUnlock()
|
||||||
|
|
||||||
|
currentNotification := notification
|
||||||
|
|
||||||
|
for _, hook := range hm.hooks {
|
||||||
|
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationType, currentNotification)
|
||||||
|
if !shouldContinue {
|
||||||
|
return modifiedNotification, false
|
||||||
|
}
|
||||||
|
currentNotification = modifiedNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentNotification, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPostHooks calls all post-hooks with the processing result.
|
||||||
|
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||||
|
hm.hooksMu.RLock()
|
||||||
|
defer hm.hooksMu.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range hm.hooks {
|
||||||
|
hook.PostHook(ctx, notificationType, notification, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createPoolHook creates a pool hook with this manager already set.
|
||||||
|
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
|
||||||
|
if hm.poolHooksRef != nil {
|
||||||
|
return hm.poolHooksRef
|
||||||
|
}
|
||||||
|
// Get pool size from client options for better worker defaults
|
||||||
|
poolSize := 0
|
||||||
|
if hm.options != nil {
|
||||||
|
poolSize = hm.options.GetPoolSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
|
||||||
|
hm.poolHooksRef.SetPool(hm.pool)
|
||||||
|
|
||||||
|
return hm.poolHooksRef
|
||||||
|
}
|
260
hitless/hitless_manager_test.go
Normal file
260
hitless/hitless_manager_test.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockClient implements interfaces.ClientInterface for testing
|
||||||
|
type MockClient struct {
|
||||||
|
options interfaces.OptionsInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
|
||||||
|
return mc.options
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
|
||||||
|
return &MockPushProcessor{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPushProcessor implements interfaces.NotificationProcessor for testing
|
||||||
|
type MockPushProcessor struct{}
|
||||||
|
|
||||||
|
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockOptions implements interfaces.OptionsInterface for testing
|
||||||
|
type MockOptions struct{}
|
||||||
|
|
||||||
|
func (mo *MockOptions) GetReadTimeout() time.Duration {
|
||||||
|
return 5 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) GetWriteTimeout() time.Duration {
|
||||||
|
return 5 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) GetAddr() string {
|
||||||
|
return "localhost:6379"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) IsTLSEnabled() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) GetProtocol() int {
|
||||||
|
return 3 // RESP3
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) GetPoolSize() int {
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) GetNetwork() string {
|
||||||
|
return "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
|
||||||
|
return func(ctx context.Context) (net.Conn, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHitlessManagerRefactoring(t *testing.T) {
|
||||||
|
t.Run("AtomicStateTracking", func(t *testing.T) {
|
||||||
|
config := DefaultConfig()
|
||||||
|
client := &MockClient{options: &MockOptions{}}
|
||||||
|
|
||||||
|
manager, err := NewHitlessManager(client, nil, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||||
|
}
|
||||||
|
defer manager.Close()
|
||||||
|
|
||||||
|
// Test initial state
|
||||||
|
if manager.IsHandoffInProgress() {
|
||||||
|
t.Error("Expected no handoff in progress initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetActiveOperationCount() != 0 {
|
||||||
|
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetState() != StateIdle {
|
||||||
|
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an operation
|
||||||
|
ctx := context.Background()
|
||||||
|
deadline := time.Now().Add(30 * time.Second)
|
||||||
|
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to track operation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test state after adding operation
|
||||||
|
if !manager.IsHandoffInProgress() {
|
||||||
|
t.Error("Expected handoff in progress after adding operation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetActiveOperationCount() != 1 {
|
||||||
|
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetState() != StateMoving {
|
||||||
|
t.Errorf("Expected StateMoving, got %v", manager.GetState())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the operation
|
||||||
|
manager.UntrackOperationWithConnID(12345, 1)
|
||||||
|
|
||||||
|
// Test state after removing operation
|
||||||
|
if manager.IsHandoffInProgress() {
|
||||||
|
t.Error("Expected no handoff in progress after removing operation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetActiveOperationCount() != 0 {
|
||||||
|
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetState() != StateIdle {
|
||||||
|
t.Errorf("Expected StateIdle, got %v", manager.GetState())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SyncMapPerformance", func(t *testing.T) {
|
||||||
|
config := DefaultConfig()
|
||||||
|
client := &MockClient{options: &MockOptions{}}
|
||||||
|
|
||||||
|
manager, err := NewHitlessManager(client, nil, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||||
|
}
|
||||||
|
defer manager.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
deadline := time.Now().Add(30 * time.Second)
|
||||||
|
|
||||||
|
// Test concurrent operations
|
||||||
|
const numOps = 100
|
||||||
|
for i := 0; i < numOps; i++ {
|
||||||
|
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to track operation %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetActiveOperationCount() != numOps {
|
||||||
|
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetActiveMovingOperations
|
||||||
|
operations := manager.GetActiveMovingOperations()
|
||||||
|
if len(operations) != numOps {
|
||||||
|
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all operations
|
||||||
|
for i := 0; i < numOps; i++ {
|
||||||
|
manager.UntrackOperationWithConnID(int64(i), uint64(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetActiveOperationCount() != 0 {
|
||||||
|
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DuplicateOperationHandling", func(t *testing.T) {
|
||||||
|
config := DefaultConfig()
|
||||||
|
client := &MockClient{options: &MockOptions{}}
|
||||||
|
|
||||||
|
manager, err := NewHitlessManager(client, nil, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create hitless manager: %v", err)
|
||||||
|
}
|
||||||
|
defer manager.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
deadline := time.Now().Add(30 * time.Second)
|
||||||
|
|
||||||
|
// Add operation
|
||||||
|
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to track operation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to add duplicate operation
|
||||||
|
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Duplicate operation should not return error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should still have only 1 operation
|
||||||
|
if manager.GetActiveOperationCount() != 1 {
|
||||||
|
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NotificationTypeConstants", func(t *testing.T) {
|
||||||
|
// Test that constants are properly defined
|
||||||
|
expectedTypes := []string{
|
||||||
|
NotificationMoving,
|
||||||
|
NotificationMigrating,
|
||||||
|
NotificationMigrated,
|
||||||
|
NotificationFailingOver,
|
||||||
|
NotificationFailedOver,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hitlessNotificationTypes) != len(expectedTypes) {
|
||||||
|
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that all expected types are present
|
||||||
|
typeMap := make(map[string]bool)
|
||||||
|
for _, t := range hitlessNotificationTypes {
|
||||||
|
typeMap[t] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedTypes {
|
||||||
|
if !typeMap[expected] {
|
||||||
|
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that hitlessNotificationTypes contains all expected constants
|
||||||
|
expectedConstants := []string{
|
||||||
|
NotificationMoving,
|
||||||
|
NotificationMigrating,
|
||||||
|
NotificationMigrated,
|
||||||
|
NotificationFailingOver,
|
||||||
|
NotificationFailedOver,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedConstants {
|
||||||
|
found := false
|
||||||
|
for _, actual := range hitlessNotificationTypes {
|
||||||
|
if actual == expected {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
48
hitless/hooks.go
Normal file
48
hitless/hooks.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoggingHook is an example hook implementation that logs all notifications.
|
||||||
|
type LoggingHook struct {
|
||||||
|
LogLevel int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreHook logs the notification before processing and allows modification.
|
||||||
|
func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||||
|
if lh.LogLevel >= 2 { // Info level
|
||||||
|
internal.Logger.Printf(ctx, "hitless: processing %s notification: %v", notificationType, notification)
|
||||||
|
}
|
||||||
|
return notification, true // Continue processing with unmodified notification
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostHook logs the result after processing.
|
||||||
|
func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||||
|
if result != nil && lh.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v", notificationType, result)
|
||||||
|
} else if lh.LogLevel >= 3 { // Debug level
|
||||||
|
internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterHook is an example hook that can filter out certain notifications.
|
||||||
|
type FilterHook struct {
|
||||||
|
BlockedTypes map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreHook filters notifications based on type.
|
||||||
|
func (fh *FilterHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||||
|
if fh.BlockedTypes[notificationType] {
|
||||||
|
internal.Logger.Printf(ctx, "hitless: filtering out %s notification", notificationType)
|
||||||
|
return notification, false // Skip processing
|
||||||
|
}
|
||||||
|
return notification, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostHook does nothing for filter hook.
|
||||||
|
func (fh *FilterHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
|
||||||
|
// No post-processing needed for filter hook
|
||||||
|
}
|
247
hitless/notification_handler.go
Normal file
247
hitless/notification_handler.go
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationHandler handles push notifications for the simplified manager.
|
||||||
|
type NotificationHandler struct {
|
||||||
|
manager *HitlessManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlePushNotification processes push notifications with hook support.
|
||||||
|
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) == 0 {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
notificationType, ok := notification[0].(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process pre-hooks - they can modify the notification or skip processing
|
||||||
|
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, notificationType, notification)
|
||||||
|
if !shouldContinue {
|
||||||
|
return nil // Hooks decided to skip processing
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch notificationType {
|
||||||
|
case NotificationMoving:
|
||||||
|
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
|
||||||
|
case NotificationMigrating:
|
||||||
|
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
|
||||||
|
case NotificationMigrated:
|
||||||
|
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
|
||||||
|
case NotificationFailingOver:
|
||||||
|
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
|
||||||
|
case NotificationFailedOver:
|
||||||
|
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
|
||||||
|
default:
|
||||||
|
// Ignore other notification types (e.g., pub/sub messages)
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process post-hooks with the result
|
||||||
|
snh.manager.processPostHooks(ctx, notificationType, modifiedNotification, err)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMoving processes MOVING notifications.
|
||||||
|
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
|
||||||
|
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
if len(notification) < 3 {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
seqIDStr, ok := notification[1].(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
seqID, err := strconv.ParseInt(seqIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract timeS
|
||||||
|
timeSStr, ok := notification[2].(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
timeS, err := strconv.ParseInt(timeSStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
newEndpoint := ""
|
||||||
|
if len(notification) > 3 {
|
||||||
|
// Extract new endpoint
|
||||||
|
newEndpoint, ok = notification[3].(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the connection that received this notification
|
||||||
|
conn := handlerCtx.Conn
|
||||||
|
if conn == nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type assert to get the underlying pool connection
|
||||||
|
var poolConn *pool.Conn
|
||||||
|
if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok {
|
||||||
|
poolConn = connAdapter.GetPoolConn()
|
||||||
|
} else if pc, ok := conn.(*pool.Conn); ok {
|
||||||
|
poolConn = pc
|
||||||
|
} else {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
|
||||||
|
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
|
||||||
|
if newEndpoint == "" || newEndpoint == internal.RedisNull {
|
||||||
|
// same as current endpoint
|
||||||
|
newEndpoint = snh.manager.options.GetAddr()
|
||||||
|
// delay the handoff for timeS/2 seconds to the same endpoint
|
||||||
|
// do this in a goroutine to avoid blocking the notification handler
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Duration(timeS/2) * time.Second)
|
||||||
|
if poolConn == nil || poolConn.IsClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
|
||||||
|
// Log error but don't fail the goroutine
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
|
||||||
|
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
|
||||||
|
// Connection is already marked for handoff, which is acceptable
|
||||||
|
// This can happen if multiple MOVING notifications are received for the same connection
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Optionally track in hitless manager for monitoring/debugging
|
||||||
|
if snh.manager != nil {
|
||||||
|
connID := conn.GetID()
|
||||||
|
|
||||||
|
// Track the operation (ignore errors since this is optional)
|
||||||
|
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("hitless: manager not initialized")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMigrating processes MIGRATING notifications.
|
||||||
|
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
// MIGRATING notifications indicate that a connection is about to be migrated
|
||||||
|
// Apply relaxed timeouts to the specific connection that received this notification
|
||||||
|
if len(notification) < 2 {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the connection from handler context and type assert to connectionAdapter
|
||||||
|
if handlerCtx.Conn == nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||||
|
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply relaxed timeout to this specific connection
|
||||||
|
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMigrated processes MIGRATED notifications.
|
||||||
|
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
// MIGRATED notifications indicate that a connection migration has completed
|
||||||
|
// Restore normal timeouts for the specific connection that received this notification
|
||||||
|
if len(notification) < 2 {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the connection from handler context and type assert to connectionAdapter
|
||||||
|
if handlerCtx.Conn == nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||||
|
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear relaxed timeout for this specific connection
|
||||||
|
connAdapter.ClearRelaxedTimeout()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailingOver processes FAILING_OVER notifications.
|
||||||
|
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
// FAILING_OVER notifications indicate that a connection is about to failover
|
||||||
|
// Apply relaxed timeouts to the specific connection that received this notification
|
||||||
|
if len(notification) < 2 {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the connection from handler context and type assert to connectionAdapter
|
||||||
|
if handlerCtx.Conn == nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||||
|
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply relaxed timeout to this specific connection
|
||||||
|
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFailedOver processes FAILED_OVER notifications.
|
||||||
|
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
// FAILED_OVER notifications indicate that a connection failover has completed
|
||||||
|
// Restore normal timeouts for the specific connection that received this notification
|
||||||
|
if len(notification) < 2 {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the connection from handler context and type assert to connectionAdapter
|
||||||
|
if handlerCtx.Conn == nil {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
|
||||||
|
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear relaxed timeout for this specific connection
|
||||||
|
connAdapter.ClearRelaxedTimeout()
|
||||||
|
return nil
|
||||||
|
}
|
477
hitless/pool_hook.go
Normal file
477
hitless/pool_hook.go
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HitlessManagerInterface defines the interface for completing handoff operations
|
||||||
|
type HitlessManagerInterface interface {
|
||||||
|
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
|
||||||
|
UntrackOperationWithConnID(seqID int64, connID uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandoffRequest represents a request to handoff a connection to a new endpoint
|
||||||
|
type HandoffRequest struct {
|
||||||
|
Conn *pool.Conn
|
||||||
|
ConnID uint64 // Unique connection identifier
|
||||||
|
Endpoint string
|
||||||
|
SeqID int64
|
||||||
|
Pool pool.Pooler // Pool to remove connection from on failure
|
||||||
|
}
|
||||||
|
|
||||||
|
// PoolHook implements pool.PoolHook for Redis-specific connection handling
|
||||||
|
// with hitless upgrade support.
|
||||||
|
type PoolHook struct {
|
||||||
|
// Base dialer for creating connections to new endpoints during handoffs
|
||||||
|
// args are network and address
|
||||||
|
baseDialer func(context.Context, string, string) (net.Conn, error)
|
||||||
|
|
||||||
|
// Network type (e.g., "tcp", "unix")
|
||||||
|
network string
|
||||||
|
|
||||||
|
// Event-driven handoff support
|
||||||
|
handoffQueue chan HandoffRequest // Queue for handoff requests
|
||||||
|
shutdown chan struct{} // Shutdown signal
|
||||||
|
shutdownOnce sync.Once // Ensure clean shutdown
|
||||||
|
workerWg sync.WaitGroup // Track worker goroutines
|
||||||
|
|
||||||
|
// On-demand worker management
|
||||||
|
maxWorkers int
|
||||||
|
activeWorkers int32 // Atomic counter for active workers
|
||||||
|
workerTimeout time.Duration // How long workers wait for work before exiting
|
||||||
|
|
||||||
|
// Simple state tracking
|
||||||
|
pending sync.Map // map[uint64]int64 (connID -> seqID)
|
||||||
|
|
||||||
|
// Configuration for the hitless upgrade
|
||||||
|
config *Config
|
||||||
|
|
||||||
|
// Hitless manager for operation completion tracking
|
||||||
|
hitlessManager HitlessManagerInterface
|
||||||
|
|
||||||
|
// Pool interface for removing connections on handoff failure
|
||||||
|
pool pool.Pooler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolHook creates a new pool hook
|
||||||
|
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
|
||||||
|
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
|
||||||
|
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
|
||||||
|
// Apply defaults to any missing configuration fields, using pool size for worker calculations
|
||||||
|
config = config.ApplyDefaultsWithPoolSize(poolSize)
|
||||||
|
|
||||||
|
ph := &PoolHook{
|
||||||
|
// baseDialer is used to create connections to new endpoints during handoffs
|
||||||
|
baseDialer: baseDialer,
|
||||||
|
network: network,
|
||||||
|
// handoffQueue is a buffered channel for queuing handoff requests
|
||||||
|
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
|
||||||
|
// shutdown is a channel for signaling shutdown
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
maxWorkers: config.MaxWorkers,
|
||||||
|
activeWorkers: 0, // Start with no workers - create on demand
|
||||||
|
workerTimeout: 30 * time.Second, // Workers exit after 30s of inactivity
|
||||||
|
config: config,
|
||||||
|
// Hitless manager for operation completion tracking
|
||||||
|
hitlessManager: hitlessManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
// No upfront worker creation - workers are created on demand
|
||||||
|
|
||||||
|
return ph
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPool sets the pool interface for removing connections on handoff failure
|
||||||
|
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
|
||||||
|
ph.pool = pooler
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentWorkers returns the current number of active workers (for testing)
|
||||||
|
func (ph *PoolHook) GetCurrentWorkers() int {
|
||||||
|
return int(atomic.LoadInt32(&ph.activeWorkers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetScaleLevel returns 1 if workers are active, 0 if none (for testing compatibility)
|
||||||
|
func (ph *PoolHook) GetScaleLevel() int {
|
||||||
|
if atomic.LoadInt32(&ph.activeWorkers) > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHandoffPending returns true if the given connection has a pending handoff
|
||||||
|
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
|
||||||
|
_, pending := ph.pending.Load(conn.GetID())
|
||||||
|
return pending
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnGet is called when a connection is retrieved from the pool
|
||||||
|
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) error {
|
||||||
|
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
|
||||||
|
// in a handoff state at the moment.
|
||||||
|
|
||||||
|
// Check if connection is usable (not in a handoff state)
|
||||||
|
// Should not happen since the pool will not return a connection that is not usable.
|
||||||
|
if !conn.IsUsable() {
|
||||||
|
return ErrConnectionMarkedForHandoff
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
|
||||||
|
if conn.ShouldHandoff() {
|
||||||
|
return ErrConnectionMarkedForHandoff
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPut is called when a connection is returned to the pool
|
||||||
|
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||||
|
// first check if we should handoff for faster rejection
|
||||||
|
if conn.ShouldHandoff() {
|
||||||
|
// check pending handoff to not queue the same connection twice
|
||||||
|
_, hasPendingHandoff := ph.pending.Load(conn.GetID())
|
||||||
|
if !hasPendingHandoff {
|
||||||
|
// Check for empty endpoint first (synchronous check)
|
||||||
|
if conn.GetHandoffEndpoint() == "" {
|
||||||
|
conn.ClearHandoffState()
|
||||||
|
} else {
|
||||||
|
if err := ph.queueHandoff(conn); err != nil {
|
||||||
|
// Failed to queue handoff, remove the connection
|
||||||
|
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
|
||||||
|
return false, true, nil // Don't pool, remove connection, no error to caller
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if handoff was already processed by a worker before we can mark it as queued
|
||||||
|
if !conn.ShouldHandoff() {
|
||||||
|
// Handoff was already processed - this is normal and the connection should be pooled
|
||||||
|
return true, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.MarkQueuedForHandoff(); err != nil {
|
||||||
|
// If marking fails, check if handoff was processed in the meantime
|
||||||
|
if !conn.ShouldHandoff() {
|
||||||
|
// Handoff was processed - this is normal, pool the connection
|
||||||
|
return true, false, nil
|
||||||
|
}
|
||||||
|
// Other error - remove the connection
|
||||||
|
return false, true, nil
|
||||||
|
}
|
||||||
|
return true, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Default: pool the connection
|
||||||
|
return true, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureWorkerAvailable ensures at least one worker is available to process requests
|
||||||
|
// Creates a new worker if needed and under the max limit
|
||||||
|
func (ph *PoolHook) ensureWorkerAvailable() {
|
||||||
|
select {
|
||||||
|
case <-ph.shutdown:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Check if we need a new worker
|
||||||
|
currentWorkers := atomic.LoadInt32(&ph.activeWorkers)
|
||||||
|
if currentWorkers < int32(ph.maxWorkers) {
|
||||||
|
// Try to create a new worker (atomic increment to prevent race)
|
||||||
|
if atomic.CompareAndSwapInt32(&ph.activeWorkers, currentWorkers, currentWorkers+1) {
|
||||||
|
ph.workerWg.Add(1)
|
||||||
|
go ph.onDemandWorker()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// onDemandWorker processes handoff requests and exits when idle
|
||||||
|
func (ph *PoolHook) onDemandWorker() {
|
||||||
|
defer func() {
|
||||||
|
// Decrement active worker count when exiting
|
||||||
|
atomic.AddInt32(&ph.activeWorkers, -1)
|
||||||
|
ph.workerWg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case request := <-ph.handoffQueue:
|
||||||
|
// Check for shutdown before processing
|
||||||
|
select {
|
||||||
|
case <-ph.shutdown:
|
||||||
|
// Clean up the request before exiting
|
||||||
|
ph.pending.Delete(request.ConnID)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Process the request
|
||||||
|
ph.processHandoffRequest(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(ph.workerTimeout):
|
||||||
|
// Worker has been idle for too long, exit to save resources
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 3 { // Debug level
|
||||||
|
internal.Logger.Printf(context.Background(),
|
||||||
|
"hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ph.shutdown:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processHandoffRequest processes a single handoff request
|
||||||
|
func (ph *PoolHook) processHandoffRequest(request HandoffRequest) {
|
||||||
|
// Remove from pending map
|
||||||
|
defer ph.pending.Delete(request.Conn.GetID())
|
||||||
|
|
||||||
|
// Create a context with handoff timeout from config
|
||||||
|
handoffTimeout := 30 * time.Second // Default fallback
|
||||||
|
if ph.config != nil && ph.config.HandoffTimeout > 0 {
|
||||||
|
handoffTimeout = ph.config.HandoffTimeout
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create a context that also respects the shutdown signal
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
// Monitor shutdown signal in a separate goroutine
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ph.shutdown:
|
||||||
|
shutdownCancel()
|
||||||
|
case <-shutdownCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Perform the handoff with cancellable context
|
||||||
|
err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool)
|
||||||
|
|
||||||
|
// If handoff failed, restore the handoff state for potential retry
|
||||||
|
if err != nil {
|
||||||
|
request.Conn.RestoreHandoffState()
|
||||||
|
internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need for scale down scheduling with on-demand workers
|
||||||
|
// Workers automatically exit when idle
|
||||||
|
}
|
||||||
|
|
||||||
|
// queueHandoff queues a handoff request for processing
|
||||||
|
// if err is returned, connection will be removed from pool
|
||||||
|
func (ph *PoolHook) queueHandoff(conn *pool.Conn) error {
|
||||||
|
// Create handoff request
|
||||||
|
request := HandoffRequest{
|
||||||
|
Conn: conn,
|
||||||
|
ConnID: conn.GetID(),
|
||||||
|
Endpoint: conn.GetHandoffEndpoint(),
|
||||||
|
SeqID: conn.GetMovingSeqID(),
|
||||||
|
Pool: ph.pool, // Include pool for connection removal on failure
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
// priority to shutdown
|
||||||
|
case <-ph.shutdown:
|
||||||
|
return errors.New("shutdown")
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-ph.shutdown:
|
||||||
|
return errors.New("shutdown")
|
||||||
|
case ph.handoffQueue <- request:
|
||||||
|
// Store in pending map
|
||||||
|
ph.pending.Store(request.ConnID, request.SeqID)
|
||||||
|
// Ensure we have a worker to process this request
|
||||||
|
ph.ensureWorkerAvailable()
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
// Queue is full - log and attempt scaling
|
||||||
|
queueLen := len(ph.handoffQueue)
|
||||||
|
queueCap := cap(ph.handoffQueue)
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(context.Background(),
|
||||||
|
"hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers",
|
||||||
|
queueLen, queueCap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have workers available to handle the load
|
||||||
|
ph.ensureWorkerAvailable()
|
||||||
|
return errors.New("queue full")
|
||||||
|
}
|
||||||
|
|
||||||
|
// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure
|
||||||
|
// if err is returned, connection will be removed from pool
|
||||||
|
func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) error {
|
||||||
|
// Clear handoff state after successful handoff
|
||||||
|
seqID := conn.GetMovingSeqID()
|
||||||
|
connID := conn.GetID()
|
||||||
|
|
||||||
|
// Notify hitless manager of completion if available
|
||||||
|
if ph.hitlessManager != nil {
|
||||||
|
defer ph.hitlessManager.UntrackOperationWithConnID(seqID, connID)
|
||||||
|
}
|
||||||
|
|
||||||
|
newEndpoint := conn.GetHandoffEndpoint()
|
||||||
|
if newEndpoint == "" {
|
||||||
|
// TODO(hitless): Handle by performing the handoff to the current endpoint in N seconds,
|
||||||
|
// Where N is the time in the moving notification...
|
||||||
|
// For now, clear the handoff state and return
|
||||||
|
conn.ClearHandoffState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||||
|
maxRetries := 3 // Default fallback
|
||||||
|
if ph.config != nil {
|
||||||
|
maxRetries = ph.config.MaxHandoffRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
if retries > maxRetries {
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(ctx,
|
||||||
|
"hitless: reached max retries (%d) for handoff of connection %d to %s",
|
||||||
|
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
|
||||||
|
}
|
||||||
|
err := ErrMaxHandoffRetriesReached
|
||||||
|
if pooler != nil {
|
||||||
|
go pooler.Remove(ctx, conn, err)
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(ctx,
|
||||||
|
"hitless: removed connection %d from pool due to max handoff retries reached",
|
||||||
|
conn.GetID())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
go conn.Close()
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(ctx,
|
||||||
|
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
|
||||||
|
conn.GetID(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create endpoint-specific dialer
|
||||||
|
endpointDialer := ph.createEndpointDialer(newEndpoint)
|
||||||
|
|
||||||
|
// Create new connection to the new endpoint
|
||||||
|
newNetConn, err := endpointDialer(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// TODO(hitless): retry
|
||||||
|
// This is the only case where we should retry the handoff request
|
||||||
|
// Should we do anything else other than return the error?
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the old connection
|
||||||
|
oldConn := conn.GetNetConn()
|
||||||
|
|
||||||
|
// Replace the connection and execute initialization
|
||||||
|
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
|
||||||
|
if err != nil {
|
||||||
|
// Remove the connection from the pool since it's in a bad state
|
||||||
|
if pooler != nil {
|
||||||
|
// Use pool.Pooler interface directly - no adapter needed
|
||||||
|
go pooler.Remove(ctx, conn, err)
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(ctx,
|
||||||
|
"hitless: removed connection %d from pool due to handoff initialization failure: %v",
|
||||||
|
conn.GetID(), err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
go conn.Close()
|
||||||
|
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
|
||||||
|
internal.Logger.Printf(ctx,
|
||||||
|
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
|
||||||
|
conn.GetID(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep the handoff state for retry
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if oldConn != nil {
|
||||||
|
oldConn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn.ClearHandoffState()
|
||||||
|
|
||||||
|
// Apply relaxed timeout to the new connection for the configured post-handoff duration
|
||||||
|
// This gives the new connection more time to handle operations during cluster transition
|
||||||
|
if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 {
|
||||||
|
relaxedTimeout := ph.config.RelaxedTimeout
|
||||||
|
postHandoffDuration := ph.config.PostHandoffRelaxedDuration
|
||||||
|
|
||||||
|
// Set relaxed timeout with deadline - no background goroutine needed
|
||||||
|
deadline := time.Now().Add(postHandoffDuration)
|
||||||
|
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
|
||||||
|
|
||||||
|
if ph.config.LogLevel >= 2 { // Info level
|
||||||
|
internal.Logger.Printf(context.Background(),
|
||||||
|
"hitless: applied post-handoff relaxed timeout (%v) until %v for connection %d",
|
||||||
|
relaxedTimeout, deadline.Format("15:04:05.000"), connID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createEndpointDialer creates a dialer function that connects to a specific endpoint
|
||||||
|
func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
|
||||||
|
return func(ctx context.Context) (net.Conn, error) {
|
||||||
|
// Parse endpoint to extract host and port
|
||||||
|
host, port, err := net.SplitHostPort(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
// If no port specified, assume default Redis port
|
||||||
|
host = endpoint
|
||||||
|
if port == "" {
|
||||||
|
port = "6379"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the base dialer to connect to the new endpoint
|
||||||
|
return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||||
|
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||||
|
ph.shutdownOnce.Do(func() {
|
||||||
|
close(ph.shutdown)
|
||||||
|
|
||||||
|
// No timers to clean up with on-demand workers
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for workers to complete
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ph.workerWg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||||
|
// and should not be used until the handoff is complete
|
||||||
|
var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff")
|
959
hitless/pool_hook_test.go
Normal file
959
hitless/pool_hook_test.go
Normal file
@@ -0,0 +1,959 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockNetConn implements net.Conn for testing
|
||||||
|
type mockNetConn struct {
|
||||||
|
addr string
|
||||||
|
shouldFailInit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||||
|
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||||
|
func (m *mockNetConn) Close() error { return nil }
|
||||||
|
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
|
||||||
|
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
|
||||||
|
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
type mockAddr struct {
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAddr) Network() string { return "tcp" }
|
||||||
|
func (m *mockAddr) String() string { return m.addr }
|
||||||
|
|
||||||
|
// createMockPoolConnection creates a mock pool connection for testing
|
||||||
|
func createMockPoolConnection() *pool.Conn {
|
||||||
|
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||||
|
conn := pool.NewConn(mockNetConn)
|
||||||
|
conn.SetUsable(true) // Make connection usable for testing
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockPool implements pool.Pooler for testing
|
||||||
|
type mockPool struct {
|
||||||
|
removedConnections map[uint64]bool
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
|
||||||
|
// Not implemented for testing
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Use pool.Conn directly - no adapter needed
|
||||||
|
mp.removedConnections[conn.GetID()] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// WasRemoved safely checks if a connection was removed from the pool
|
||||||
|
func (mp *mockPool) WasRemoved(connID uint64) bool {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
return mp.removedConnections[connID]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) Len() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) IdleLen() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) Stats() *pool.Stats {
|
||||||
|
return &pool.Stats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
|
||||||
|
// Mock implementation - do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
|
||||||
|
// Mock implementation - do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockPool) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnectionHook tests the Redis connection processor functionality
|
||||||
|
func TestConnectionHook(t *testing.T) {
|
||||||
|
// Create a base dialer for testing
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
Mode: MaintNotificationsAuto,
|
||||||
|
EndpointType: EndpointTypeAuto,
|
||||||
|
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
|
||||||
|
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
|
||||||
|
MaxHandoffRetries: 3,
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify connection is marked for handoff
|
||||||
|
if !conn.ShouldHandoff() {
|
||||||
|
t.Fatal("Connection should be marked for handoff")
|
||||||
|
}
|
||||||
|
// Set a mock initialization function with synchronization
|
||||||
|
initConnCalled := make(chan bool, 1)
|
||||||
|
proceedWithInit := make(chan bool, 1)
|
||||||
|
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
select {
|
||||||
|
case initConnCalled <- true:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
// Wait for test to proceed
|
||||||
|
<-proceedWithInit
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conn.SetInitConnFunc(initConnFunc)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pool the connection immediately (handoff queued)
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Connection should be pooled immediately with event-driven handoff")
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Connection should not be removed when queuing handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for initialization to be called (indicates handoff started)
|
||||||
|
select {
|
||||||
|
case <-initConnCalled:
|
||||||
|
// Good, initialization was called
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("Timeout waiting for initialization function to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection should be in pending map while initialization is blocked
|
||||||
|
if _, pending := processor.pending.Load(conn.GetID()); !pending {
|
||||||
|
t.Error("Connection should be in pending handoffs map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow initialization to proceed
|
||||||
|
proceedWithInit <- true
|
||||||
|
|
||||||
|
// Wait for handoff to complete with proper timeout and polling
|
||||||
|
timeout := time.After(2 * time.Second)
|
||||||
|
ticker := time.NewTicker(10 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
handoffCompleted := false
|
||||||
|
for !handoffCompleted {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("Timeout waiting for handoff to complete")
|
||||||
|
case <-ticker.C:
|
||||||
|
if _, pending := processor.pending.Load(conn); !pending {
|
||||||
|
handoffCompleted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handoff completed (removed from pending map)
|
||||||
|
if _, pending := processor.pending.Load(conn); pending {
|
||||||
|
t.Error("Connection should be removed from pending map after handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify connection is usable again
|
||||||
|
if !conn.IsUsable() {
|
||||||
|
t.Error("Connection should be usable after successful handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handoff state is cleared
|
||||||
|
if conn.ShouldHandoff() {
|
||||||
|
t.Error("Connection should not be marked for handoff after completion")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HandoffNotNeeded", func(t *testing.T) {
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
// Don't mark for handoff
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not error when handoff not needed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pool the connection normally
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Connection should be pooled when no handoff needed")
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Connection should not be removed when no handoff needed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EmptyEndpoint", func(t *testing.T) {
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not error with empty endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pool the connection (empty endpoint clears state)
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Connection should be pooled after clearing empty endpoint")
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Connection should not be removed after clearing empty endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
// State should be cleared
|
||||||
|
if conn.ShouldHandoff() {
|
||||||
|
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
|
||||||
|
// Create a failing base dialer
|
||||||
|
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("dial failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Mode: MaintNotificationsAuto,
|
||||||
|
EndpointType: EndpointTypeAuto,
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
MaxHandoffRetries: 3,
|
||||||
|
HandoffTimeout: 1 * time.Second, // Shorter timeout for faster test
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not return error to caller: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pool the connection initially (handoff queued)
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Connection should be pooled initially with event-driven handoff")
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Connection should not be removed when queuing handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for handoff to complete and fail with proper timeout and polling
|
||||||
|
// Use longer timeout to account for handoff timeout + processing time
|
||||||
|
timeout := time.After(5 * time.Second)
|
||||||
|
ticker := time.NewTicker(10 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// wait for handoff to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
handoffCompleted := false
|
||||||
|
for !handoffCompleted {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("Timeout waiting for failed handoff to complete")
|
||||||
|
case <-ticker.C:
|
||||||
|
if _, pending := processor.pending.Load(conn.GetID()); !pending {
|
||||||
|
handoffCompleted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection should be removed from pending map after failed handoff
|
||||||
|
if _, pending := processor.pending.Load(conn.GetID()); pending {
|
||||||
|
t.Error("Connection should be removed from pending map after failed handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handoff state should still be set (since handoff failed)
|
||||||
|
if !conn.ShouldHandoff() {
|
||||||
|
t.Error("Connection should still be marked for handoff after failed handoff")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BufferedDataRESP2", func(t *testing.T) {
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
|
||||||
|
// For this test, we'll just verify the logic works for connections without buffered data
|
||||||
|
// The actual buffered data detection is handled by the pool's connection health check
|
||||||
|
// which is outside the scope of the Redis connection processor
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pool the connection normally (no buffered data in mock)
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Connection should be pooled when no buffered data")
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Connection should not be removed when no buffered data")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OnGet", func(t *testing.T) {
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := processor.OnGet(ctx, conn, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnGet should not error for normal connection: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
Mode: MaintNotificationsAuto,
|
||||||
|
EndpointType: EndpointTypeAuto,
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
|
||||||
|
// Simulate a pending handoff by marking for handoff and queuing
|
||||||
|
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||||
|
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||||
|
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := processor.OnGet(ctx, conn, false)
|
||||||
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
|
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
processor.pending.Delete(conn)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EventDrivenStateManagement", func(t *testing.T) {
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
|
||||||
|
// Test initial state - no pending handoffs
|
||||||
|
if _, pending := processor.pending.Load(conn); pending {
|
||||||
|
t.Error("New connection should not have pending handoffs")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test adding to pending map
|
||||||
|
conn.MarkForHandoff("new-endpoint:6379", 12345)
|
||||||
|
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
|
||||||
|
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
|
||||||
|
|
||||||
|
if _, pending := processor.pending.Load(conn.GetID()); !pending {
|
||||||
|
t.Error("Connection should be in pending map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test OnGet with pending handoff
|
||||||
|
ctx := context.Background()
|
||||||
|
err := processor.OnGet(ctx, conn, false)
|
||||||
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
|
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test removing from pending map and clearing handoff state
|
||||||
|
processor.pending.Delete(conn)
|
||||||
|
if _, pending := processor.pending.Load(conn); pending {
|
||||||
|
t.Error("Connection should be removed from pending map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear handoff state to simulate completed handoff
|
||||||
|
conn.ClearHandoffState()
|
||||||
|
conn.SetUsable(true) // Make connection usable again
|
||||||
|
|
||||||
|
// Test OnGet without pending handoff
|
||||||
|
err = processor.OnGet(ctx, conn, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should not return error for non-pending connection: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
|
||||||
|
// Create processor with small queue to test optimization features
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 3,
|
||||||
|
HandoffQueueSize: 2,
|
||||||
|
MaxHandoffRetries: 3, // Small queue to trigger optimizations
|
||||||
|
LogLevel: 3, // Debug level to see optimization logs
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// Add small delay to simulate network latency
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create multiple connections that need handoff to fill the queue
|
||||||
|
connections := make([]*pool.Conn, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
connections[i] = createMockPoolConnection()
|
||||||
|
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
|
||||||
|
}
|
||||||
|
// Set a mock initialization function
|
||||||
|
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
successCount := 0
|
||||||
|
|
||||||
|
// Process connections - should trigger scaling and timeout logic
|
||||||
|
for _, conn := range connections {
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("OnPut returned error (expected with timeout): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldPool && !shouldRemove {
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// With timeout and scaling, most handoffs should eventually succeed
|
||||||
|
if successCount == 0 {
|
||||||
|
t.Error("Should have queued some handoffs with timeout and scaling")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
|
||||||
|
|
||||||
|
// Give time for workers to process and scaling to occur
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WorkerScalingBehavior", func(t *testing.T) {
|
||||||
|
// Create processor with small queue to test scaling behavior
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
|
||||||
|
HandoffQueueSize: 1,
|
||||||
|
MaxHandoffRetries: 3, // Very small queue to force scaling
|
||||||
|
LogLevel: 2, // Info level to see scaling logs
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Verify initial worker count (should be 0 with on-demand workers)
|
||||||
|
if processor.GetCurrentWorkers() != 0 {
|
||||||
|
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
|
||||||
|
}
|
||||||
|
if processor.GetScaleLevel() != 0 {
|
||||||
|
t.Errorf("Processor should be at scale level 0 initially, got %d", processor.GetScaleLevel())
|
||||||
|
}
|
||||||
|
if processor.maxWorkers != 15 {
|
||||||
|
t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The on-demand worker behavior creates workers only when needed
|
||||||
|
// This test just verifies the basic configuration is correct
|
||||||
|
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
|
||||||
|
processor.maxWorkers, processor.GetCurrentWorkers())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
|
||||||
|
// Create processor with fast post-handoff duration for testing
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
|
||||||
|
RelaxedTimeout: 5 * time.Second,
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a connection and trigger handoff
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a mock initialization function
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Process the connection to trigger handoff
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Handoff should succeed: %v", err)
|
||||||
|
}
|
||||||
|
if !shouldPool || shouldRemove {
|
||||||
|
t.Error("Connection should be pooled after handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for handoff to complete with proper timeout and polling
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
ticker := time.NewTicker(5 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
handoffCompleted := false
|
||||||
|
for !handoffCompleted {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("Timeout waiting for handoff to complete")
|
||||||
|
case <-ticker.C:
|
||||||
|
if _, pending := processor.pending.Load(conn); !pending {
|
||||||
|
handoffCompleted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify relaxed timeout is set with deadline
|
||||||
|
if !conn.HasRelaxedTimeout() {
|
||||||
|
t.Error("Connection should have relaxed timeout after handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that timeout is still active before deadline
|
||||||
|
// We'll use HasRelaxedTimeout which internally checks the deadline
|
||||||
|
if !conn.HasRelaxedTimeout() {
|
||||||
|
t.Error("Connection should still have active relaxed timeout before deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for deadline to pass
|
||||||
|
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
|
||||||
|
|
||||||
|
// Test that timeout is automatically restored after deadline
|
||||||
|
// HasRelaxedTimeout should return false after deadline passes
|
||||||
|
if conn.HasRelaxedTimeout() {
|
||||||
|
t.Error("Connection should not have active relaxed timeout after deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional verification: calling HasRelaxedTimeout again should still return false
|
||||||
|
// and should have cleared the internal timeout values
|
||||||
|
if conn.HasRelaxedTimeout() {
|
||||||
|
t.Error("Connection should not have relaxed timeout after deadline (second check)")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Passive timeout restoration test completed successfully")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UsableFlagBehavior", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
MaxHandoffRetries: 3,
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a new connection without setting it usable
|
||||||
|
mockNetConn := &mockNetConn{addr: "test:6379"}
|
||||||
|
conn := pool.NewConn(mockNetConn)
|
||||||
|
|
||||||
|
// Initially, connection should not be usable (not initialized)
|
||||||
|
if conn.IsUsable() {
|
||||||
|
t.Error("New connection should not be usable before initialization")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate initialization by setting usable to true
|
||||||
|
conn.SetUsable(true)
|
||||||
|
if !conn.IsUsable() {
|
||||||
|
t.Error("Connection should be usable after initialization")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnGet should succeed for usable connection
|
||||||
|
err := processor.OnGet(ctx, conn, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnGet should succeed for usable connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark connection for handoff
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a mock initialization function
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Connection should still be usable until queued, but marked for handoff
|
||||||
|
if !conn.IsUsable() {
|
||||||
|
t.Error("Connection should still be usable after being marked for handoff (until queued)")
|
||||||
|
}
|
||||||
|
if !conn.ShouldHandoff() {
|
||||||
|
t.Error("Connection should be marked for handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnGet should fail for connection marked for handoff
|
||||||
|
err = processor.OnGet(ctx, conn, false)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("OnGet should fail for connection marked for handoff")
|
||||||
|
}
|
||||||
|
if err != ErrConnectionMarkedForHandoff {
|
||||||
|
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the connection to trigger handoff
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should succeed: %v", err)
|
||||||
|
}
|
||||||
|
if !shouldPool || shouldRemove {
|
||||||
|
t.Error("Connection should be pooled after handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for handoff to complete
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// After handoff completion, connection should be usable again
|
||||||
|
if !conn.IsUsable() {
|
||||||
|
t.Error("Connection should be usable after handoff completion")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnGet should succeed again
|
||||||
|
err = processor.OnGet(ctx, conn, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnGet should succeed after handoff completion: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Usable flag behavior test completed successfully")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("StaticQueueBehavior", func(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 3,
|
||||||
|
HandoffQueueSize: 50,
|
||||||
|
MaxHandoffRetries: 3, // Explicit static queue size
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Verify queue capacity matches configured size
|
||||||
|
queueCapacity := cap(processor.handoffQueue)
|
||||||
|
if queueCapacity != 50 {
|
||||||
|
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that queue size is static regardless of pool size
|
||||||
|
// (No dynamic resizing should occur)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Fill part of the queue
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
|
||||||
|
}
|
||||||
|
// Set a mock initialization function
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to queue handoff %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldPool || shouldRemove {
|
||||||
|
t.Errorf("Connection %d should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
|
||||||
|
i, shouldPool, shouldRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify queue capacity remains static (the main purpose of this test)
|
||||||
|
finalCapacity := cap(processor.handoffQueue)
|
||||||
|
|
||||||
|
if finalCapacity != 50 {
|
||||||
|
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: We don't check queue size here because workers process items quickly
|
||||||
|
// The important thing is that the capacity remains static regardless of pool size
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
|
||||||
|
// Create a failing dialer that will cause handoff initialization to fail
|
||||||
|
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// Return a connection that will fail during initialization
|
||||||
|
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
MaxHandoffRetries: 3,
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(failingDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create a mock pool that tracks removals
|
||||||
|
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
|
||||||
|
processor.SetPool(mockPool)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a connection and mark it for handoff
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a failing initialization function
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return fmt.Errorf("initialization failed")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Process the connection - handoff should fail and connection should be removed
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not error: %v", err)
|
||||||
|
}
|
||||||
|
if !shouldPool || shouldRemove {
|
||||||
|
t.Error("Connection should be pooled after failed handoff attempt")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for handoff to be attempted and fail
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify that the connection was removed from the pool
|
||||||
|
if !mockPool.WasRemoved(conn.GetID()) {
|
||||||
|
t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID())
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Connection removal on handoff failure test completed successfully")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
|
||||||
|
// Create config with short post-handoff duration for testing
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
RelaxedTimeout: 5 * time.Second,
|
||||||
|
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
|
||||||
|
}
|
||||||
|
|
||||||
|
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return &mockNetConn{addr: addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a mock initialization function
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OnPut failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Connection should be pooled after successful handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Connection should not be removed after successful handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the handoff to complete (it happens asynchronously)
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
ticker := time.NewTicker(5 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
handoffCompleted := false
|
||||||
|
for !handoffCompleted {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("Timeout waiting for handoff to complete")
|
||||||
|
case <-ticker.C:
|
||||||
|
if _, pending := processor.pending.Load(conn); !pending {
|
||||||
|
handoffCompleted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that relaxed timeout was applied to the new connection
|
||||||
|
if !conn.HasRelaxedTimeout() {
|
||||||
|
t.Error("New connection should have relaxed timeout applied after handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the post-handoff duration to expire
|
||||||
|
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
|
||||||
|
|
||||||
|
// Verify that relaxed timeout was automatically cleared
|
||||||
|
if conn.HasRelaxedTimeout() {
|
||||||
|
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
|
||||||
|
// First mark should succeed
|
||||||
|
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
|
||||||
|
t.Fatalf("First MarkForHandoff should succeed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second mark should fail
|
||||||
|
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
|
||||||
|
t.Fatal("Second MarkForHandoff should return error")
|
||||||
|
} else if err.Error() != "connection is already marked for handoff" {
|
||||||
|
t.Fatalf("Expected specific error message, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify original handoff data is preserved
|
||||||
|
if !conn.ShouldHandoff() {
|
||||||
|
t.Fatal("Connection should still be marked for handoff")
|
||||||
|
}
|
||||||
|
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
|
||||||
|
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
|
||||||
|
}
|
||||||
|
if conn.GetMovingSeqID() != 1 {
|
||||||
|
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
|
||||||
|
// Test that HandoffTimeout from config is actually used
|
||||||
|
customTimeout := 2 * time.Second
|
||||||
|
config := &Config{
|
||||||
|
MaxWorkers: 2,
|
||||||
|
HandoffQueueSize: 10,
|
||||||
|
HandoffTimeout: customTimeout, // Custom timeout
|
||||||
|
MaxHandoffRetries: 1, // Single retry to speed up test
|
||||||
|
LogLevel: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := NewPoolHook(baseDialer, "tcp", config, nil)
|
||||||
|
defer processor.Shutdown(context.Background())
|
||||||
|
|
||||||
|
// Create a connection that will test the timeout
|
||||||
|
conn := createMockPoolConnection()
|
||||||
|
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
|
||||||
|
t.Fatalf("Failed to mark connection for handoff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a dialer that will check the context timeout
|
||||||
|
var timeoutVerified int32 // Use atomic for thread safety
|
||||||
|
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
// Check that the context has the expected timeout
|
||||||
|
deadline, ok := ctx.Deadline()
|
||||||
|
if !ok {
|
||||||
|
t.Error("Context should have a deadline")
|
||||||
|
return errors.New("no deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The deadline should be approximately customTimeout from now
|
||||||
|
expectedDeadline := time.Now().Add(customTimeout)
|
||||||
|
timeDiff := deadline.Sub(expectedDeadline)
|
||||||
|
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
|
||||||
|
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
|
||||||
|
expectedDeadline, deadline, timeDiff)
|
||||||
|
} else {
|
||||||
|
atomic.StoreInt32(&timeoutVerified, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil // Successful handoff
|
||||||
|
})
|
||||||
|
|
||||||
|
// Trigger handoff
|
||||||
|
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("OnPut should not return error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection should be queued for handoff
|
||||||
|
if !shouldPool || shouldRemove {
|
||||||
|
t.Errorf("Connection should be pooled for handoff processing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for handoff to complete
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&timeoutVerified) == 0 {
|
||||||
|
t.Error("HandoffTimeout was not properly applied to context")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("HandoffTimeout configuration test completed successfully")
|
||||||
|
})
|
||||||
|
}
|
24
hitless/state.go
Normal file
24
hitless/state.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package hitless
|
||||||
|
|
||||||
|
// State represents the current state of a hitless upgrade operation.
|
||||||
|
type State int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// StateIdle indicates no upgrade is in progress
|
||||||
|
StateIdle State = iota
|
||||||
|
|
||||||
|
// StateHandoff indicates a connection handoff is in progress
|
||||||
|
StateMoving
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns a string representation of the state.
|
||||||
|
func (s State) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateIdle:
|
||||||
|
return "idle"
|
||||||
|
case StateMoving:
|
||||||
|
return "moving"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
67
internal/interfaces/interfaces.go
Normal file
67
internal/interfaces/interfaces.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Package interfaces provides shared interfaces used by both the main redis package
|
||||||
|
// and the hitless upgrade package to avoid circular dependencies.
|
||||||
|
package interfaces
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Forward declaration to avoid circular imports
|
||||||
|
type NotificationProcessor interface {
|
||||||
|
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
|
||||||
|
UnregisterHandler(pushNotificationName string) error
|
||||||
|
GetHandler(pushNotificationName string) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientInterface defines the interface that clients must implement for hitless upgrades.
|
||||||
|
type ClientInterface interface {
|
||||||
|
// GetOptions returns the client options.
|
||||||
|
GetOptions() OptionsInterface
|
||||||
|
|
||||||
|
// GetPushProcessor returns the client's push notification processor.
|
||||||
|
GetPushProcessor() NotificationProcessor
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionsInterface defines the interface for client options.
|
||||||
|
type OptionsInterface interface {
|
||||||
|
// GetReadTimeout returns the read timeout.
|
||||||
|
GetReadTimeout() time.Duration
|
||||||
|
|
||||||
|
// GetWriteTimeout returns the write timeout.
|
||||||
|
GetWriteTimeout() time.Duration
|
||||||
|
|
||||||
|
// GetNetwork returns the network type.
|
||||||
|
GetNetwork() string
|
||||||
|
|
||||||
|
// GetAddr returns the connection address.
|
||||||
|
GetAddr() string
|
||||||
|
|
||||||
|
// IsTLSEnabled returns true if TLS is enabled.
|
||||||
|
IsTLSEnabled() bool
|
||||||
|
|
||||||
|
// GetProtocol returns the protocol version.
|
||||||
|
GetProtocol() int
|
||||||
|
|
||||||
|
// GetPoolSize returns the connection pool size.
|
||||||
|
GetPoolSize() int
|
||||||
|
|
||||||
|
// NewDialer returns a new dialer function for the connection.
|
||||||
|
NewDialer() func(context.Context) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionWithRelaxedTimeout defines the interface for connections that support relaxed timeout adjustment.
|
||||||
|
// This is used by the hitless upgrade system for per-connection timeout management.
|
||||||
|
type ConnectionWithRelaxedTimeout interface {
|
||||||
|
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||||
|
// These timeouts remain active until explicitly cleared.
|
||||||
|
SetRelaxedTimeout(readTimeout, writeTimeout time.Duration)
|
||||||
|
|
||||||
|
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||||
|
// After the deadline, timeouts automatically revert to normal values.
|
||||||
|
SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time)
|
||||||
|
|
||||||
|
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
|
||||||
|
ClearRelaxedTimeout()
|
||||||
|
}
|
@@ -2,6 +2,7 @@ package pool_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
|
|||||||
b.Run(bm.String(), func(b *testing.B) {
|
b.Run(bm.String(), func(b *testing.B) {
|
||||||
connPool := pool.NewConnPool(&pool.Options{
|
connPool := pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: bm.poolSize,
|
PoolSize: int32(bm.poolSize),
|
||||||
PoolTimeout: time.Second,
|
PoolTimeout: time.Second,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
ConnMaxIdleTime: time.Hour,
|
ConnMaxIdleTime: time.Hour,
|
||||||
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
|||||||
b.Run(bm.String(), func(b *testing.B) {
|
b.Run(bm.String(), func(b *testing.B) {
|
||||||
connPool := pool.NewConnPool(&pool.Options{
|
connPool := pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: bm.poolSize,
|
PoolSize: int32(bm.poolSize),
|
||||||
PoolTimeout: time.Second,
|
PoolTimeout: time.Second,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
ConnMaxIdleTime: time.Hour,
|
ConnMaxIdleTime: time.Hour,
|
||||||
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
connPool.Remove(ctx, cn, nil)
|
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
|||||||
It("should use default buffer sizes when not specified", func() {
|
It("should use default buffer sizes when not specified", func() {
|
||||||
connPool = pool.NewConnPool(&pool.Options{
|
connPool = pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: 1,
|
PoolSize: int32(1),
|
||||||
PoolTimeout: 1000,
|
PoolTimeout: 1000,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
|||||||
|
|
||||||
connPool = pool.NewConnPool(&pool.Options{
|
connPool = pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: 1,
|
PoolSize: int32(1),
|
||||||
PoolTimeout: 1000,
|
PoolTimeout: 1000,
|
||||||
ReadBufferSize: customReadSize,
|
ReadBufferSize: customReadSize,
|
||||||
WriteBufferSize: customWriteSize,
|
WriteBufferSize: customWriteSize,
|
||||||
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
|||||||
It("should handle zero buffer sizes by using defaults", func() {
|
It("should handle zero buffer sizes by using defaults", func() {
|
||||||
connPool = pool.NewConnPool(&pool.Options{
|
connPool = pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: 1,
|
PoolSize: int32(1),
|
||||||
PoolTimeout: 1000,
|
PoolTimeout: 1000,
|
||||||
ReadBufferSize: 0, // Should use default
|
ReadBufferSize: 0, // Should use default
|
||||||
WriteBufferSize: 0, // Should use default
|
WriteBufferSize: 0, // Should use default
|
||||||
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
|
|||||||
// without setting ReadBufferSize and WriteBufferSize
|
// without setting ReadBufferSize and WriteBufferSize
|
||||||
connPool = pool.NewConnPool(&pool.Options{
|
connPool = pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: 1,
|
PoolSize: int32(1),
|
||||||
PoolTimeout: 1000,
|
PoolTimeout: 1000,
|
||||||
// ReadBufferSize and WriteBufferSize are not set (will be 0)
|
// ReadBufferSize and WriteBufferSize are not set (will be 0)
|
||||||
})
|
})
|
||||||
|
@@ -3,7 +3,10 @@ package pool
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -12,17 +15,64 @@ import (
|
|||||||
|
|
||||||
var noDeadline = time.Time{}
|
var noDeadline = time.Time{}
|
||||||
|
|
||||||
|
// Global atomic counter for connection IDs
|
||||||
|
var connIDCounter uint64
|
||||||
|
|
||||||
|
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
|
||||||
|
type atomicNetConn struct {
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateConnID generates a fast unique identifier for a connection with zero allocations
|
||||||
|
func generateConnID() uint64 {
|
||||||
|
return atomic.AddUint64(&connIDCounter, 1)
|
||||||
|
}
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
usedAt int64 // atomic
|
usedAt int64 // atomic
|
||||||
netConn net.Conn
|
|
||||||
|
// Lock-free netConn access using atomic.Value
|
||||||
|
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||||
|
netConnAtomic atomic.Value // stores *atomicNetConn
|
||||||
|
|
||||||
rd *proto.Reader
|
rd *proto.Reader
|
||||||
bw *bufio.Writer
|
bw *bufio.Writer
|
||||||
wr *proto.Writer
|
wr *proto.Writer
|
||||||
|
|
||||||
Inited bool
|
// Lightweight mutex to protect reader operations during handoff
|
||||||
|
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||||
|
readerMu sync.RWMutex
|
||||||
|
|
||||||
|
Inited atomic.Bool
|
||||||
pooled bool
|
pooled bool
|
||||||
|
closed atomic.Bool
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
|
expiresAt time.Time
|
||||||
|
|
||||||
|
// Hitless upgrade support: relaxed timeouts during migrations/failovers
|
||||||
|
// Using atomic operations for lock-free access to avoid mutex contention
|
||||||
|
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||||
|
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||||
|
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
|
||||||
|
|
||||||
|
// Counter to track multiple relaxed timeout setters if we have nested calls
|
||||||
|
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
|
||||||
|
// if counter reaches 0, we clear the relaxed timeouts
|
||||||
|
relaxedCounter atomic.Int32
|
||||||
|
|
||||||
|
// Connection initialization function for reconnections
|
||||||
|
initConnFunc func(context.Context, *Conn) error
|
||||||
|
|
||||||
|
// Connection identifier for unique tracking across handoffs
|
||||||
|
id uint64 // Unique numeric identifier for this connection
|
||||||
|
|
||||||
|
// Handoff state - using atomic operations for lock-free access
|
||||||
|
usableAtomic atomic.Bool // Connection usability state
|
||||||
|
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
|
||||||
|
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
|
||||||
|
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
|
||||||
|
// newEndpointAtomic needs special handling as it's a string
|
||||||
|
newEndpointAtomic atomic.Value // stores string
|
||||||
|
|
||||||
onClose func() error
|
onClose func() error
|
||||||
}
|
}
|
||||||
@@ -33,8 +83,8 @@ func NewConn(netConn net.Conn) *Conn {
|
|||||||
|
|
||||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||||
cn := &Conn{
|
cn := &Conn{
|
||||||
netConn: netConn,
|
|
||||||
createdAt: time.Now(),
|
createdAt: time.Now(),
|
||||||
|
id: generateConnID(), // Generate unique ID for this connection
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||||
@@ -50,6 +100,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
|||||||
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store netConn atomically for lock-free access using wrapper
|
||||||
|
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||||
|
|
||||||
|
// Initialize atomic handoff state
|
||||||
|
cn.usableAtomic.Store(false) // false initially, set to true after initialization
|
||||||
|
cn.shouldHandoffAtomic.Store(false) // false initially
|
||||||
|
cn.movingSeqIDAtomic.Store(0) // 0 initially
|
||||||
|
cn.handoffRetriesAtomic.Store(0) // 0 initially
|
||||||
|
cn.newEndpointAtomic.Store("") // empty string initially
|
||||||
|
|
||||||
cn.wr = proto.NewWriter(cn.bw)
|
cn.wr = proto.NewWriter(cn.bw)
|
||||||
cn.SetUsedAt(time.Now())
|
cn.SetUsedAt(time.Now())
|
||||||
return cn
|
return cn
|
||||||
@@ -64,23 +124,368 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
|
|||||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||||
|
// This is the fast path for accessing netConn without mutex overhead.
|
||||||
|
func (cn *Conn) getNetConn() net.Conn {
|
||||||
|
if v := cn.netConnAtomic.Load(); v != nil {
|
||||||
|
if wrapper, ok := v.(*atomicNetConn); ok {
|
||||||
|
return wrapper.conn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNetConn stores the network connection atomically (lock-free).
|
||||||
|
// This is used for the fast path of connection replacement.
|
||||||
|
func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||||
|
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock-free helper methods for handoff state management
|
||||||
|
|
||||||
|
// isUsable returns true if the connection is safe to use (lock-free).
|
||||||
|
func (cn *Conn) isUsable() bool {
|
||||||
|
return cn.usableAtomic.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setUsable sets the usable flag atomically (lock-free).
|
||||||
|
func (cn *Conn) setUsable(usable bool) {
|
||||||
|
cn.usableAtomic.Store(usable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldHandoff returns true if connection needs handoff (lock-free).
|
||||||
|
func (cn *Conn) shouldHandoff() bool {
|
||||||
|
return cn.shouldHandoffAtomic.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setShouldHandoff sets the handoff flag atomically (lock-free).
|
||||||
|
func (cn *Conn) setShouldHandoff(should bool) {
|
||||||
|
cn.shouldHandoffAtomic.Store(should)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMovingSeqID returns the sequence ID atomically (lock-free).
|
||||||
|
func (cn *Conn) getMovingSeqID() int64 {
|
||||||
|
return cn.movingSeqIDAtomic.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setMovingSeqID sets the sequence ID atomically (lock-free).
|
||||||
|
func (cn *Conn) setMovingSeqID(seqID int64) {
|
||||||
|
cn.movingSeqIDAtomic.Store(seqID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNewEndpoint returns the new endpoint atomically (lock-free).
|
||||||
|
func (cn *Conn) getNewEndpoint() string {
|
||||||
|
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
|
||||||
|
return endpoint.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNewEndpoint sets the new endpoint atomically (lock-free).
|
||||||
|
func (cn *Conn) setNewEndpoint(endpoint string) {
|
||||||
|
cn.newEndpointAtomic.Store(endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setHandoffRetries sets the retry count atomically (lock-free).
|
||||||
|
func (cn *Conn) setHandoffRetries(retries int) {
|
||||||
|
cn.handoffRetriesAtomic.Store(uint32(retries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
|
||||||
|
func (cn *Conn) incrementHandoffRetries(delta int) int {
|
||||||
|
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||||
|
func (cn *Conn) IsUsable() bool {
|
||||||
|
return cn.isUsable()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cn *Conn) IsInited() bool {
|
||||||
|
return cn.Inited.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUsable sets the usable flag for the connection (lock-free).
|
||||||
|
func (cn *Conn) SetUsable(usable bool) {
|
||||||
|
cn.setUsable(usable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
|
||||||
|
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||||
|
// Uses atomic operations for lock-free access.
|
||||||
|
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||||
|
cn.relaxedCounter.Add(1)
|
||||||
|
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||||
|
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||||
|
// After the deadline, timeouts automatically revert to normal values.
|
||||||
|
// Uses atomic operations for lock-free access.
|
||||||
|
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||||
|
cn.relaxedCounter.Add(1)
|
||||||
|
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||||
|
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||||
|
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
|
||||||
|
// Uses atomic operations for lock-free access.
|
||||||
|
func (cn *Conn) ClearRelaxedTimeout() {
|
||||||
|
// Atomically decrement counter and check if we should clear
|
||||||
|
newCount := cn.relaxedCounter.Add(-1)
|
||||||
|
if newCount <= 0 {
|
||||||
|
// Use compare-and-swap to ensure only one goroutine clears
|
||||||
|
if cn.relaxedCounter.CompareAndSwap(newCount, 0) {
|
||||||
|
cn.clearRelaxedTimeout()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cn *Conn) clearRelaxedTimeout() {
|
||||||
|
cn.relaxedReadTimeoutNs.Store(0)
|
||||||
|
cn.relaxedWriteTimeoutNs.Store(0)
|
||||||
|
cn.relaxedDeadlineNs.Store(0)
|
||||||
|
cn.relaxedCounter.Store(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
|
||||||
|
// This checks both the timeout values and the deadline (if set).
|
||||||
|
// Uses atomic operations for lock-free access.
|
||||||
|
func (cn *Conn) HasRelaxedTimeout() bool {
|
||||||
|
// Fast path: no relaxed timeouts are set
|
||||||
|
if cn.relaxedCounter.Load() <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||||
|
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||||
|
|
||||||
|
// If no relaxed timeouts are set, return false
|
||||||
|
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||||
|
// If no deadline is set, relaxed timeouts are active
|
||||||
|
if deadlineNs == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If deadline is set, check if it's still in the future
|
||||||
|
return time.Now().UnixNano() < deadlineNs
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEffectiveReadTimeout returns the timeout to use for read operations.
|
||||||
|
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||||
|
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||||
|
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
|
||||||
|
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||||
|
|
||||||
|
// Fast path: no relaxed timeout set
|
||||||
|
if readTimeoutNs <= 0 {
|
||||||
|
return normalTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||||
|
// If no deadline is set, use relaxed timeout
|
||||||
|
if deadlineNs == 0 {
|
||||||
|
return time.Duration(readTimeoutNs)
|
||||||
|
}
|
||||||
|
|
||||||
|
nowNs := time.Now().UnixNano()
|
||||||
|
// Check if deadline has passed
|
||||||
|
if nowNs < deadlineNs {
|
||||||
|
// Deadline is in the future, use relaxed timeout
|
||||||
|
return time.Duration(readTimeoutNs)
|
||||||
|
} else {
|
||||||
|
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||||
|
cn.relaxedCounter.Add(-1)
|
||||||
|
if cn.relaxedCounter.Load() <= 0 {
|
||||||
|
cn.clearRelaxedTimeout()
|
||||||
|
}
|
||||||
|
return normalTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEffectiveWriteTimeout returns the timeout to use for write operations.
|
||||||
|
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||||
|
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||||
|
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
|
||||||
|
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||||
|
|
||||||
|
// Fast path: no relaxed timeout set
|
||||||
|
if writeTimeoutNs <= 0 {
|
||||||
|
return normalTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||||
|
// If no deadline is set, use relaxed timeout
|
||||||
|
if deadlineNs == 0 {
|
||||||
|
return time.Duration(writeTimeoutNs)
|
||||||
|
}
|
||||||
|
|
||||||
|
nowNs := time.Now().UnixNano()
|
||||||
|
// Check if deadline has passed
|
||||||
|
if nowNs < deadlineNs {
|
||||||
|
// Deadline is in the future, use relaxed timeout
|
||||||
|
return time.Duration(writeTimeoutNs)
|
||||||
|
} else {
|
||||||
|
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||||
|
cn.relaxedCounter.Add(-1)
|
||||||
|
if cn.relaxedCounter.Load() <= 0 {
|
||||||
|
cn.clearRelaxedTimeout()
|
||||||
|
}
|
||||||
|
return normalTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (cn *Conn) SetOnClose(fn func() error) {
|
func (cn *Conn) SetOnClose(fn func() error) {
|
||||||
cn.onClose = fn
|
cn.onClose = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
|
||||||
|
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
|
||||||
|
cn.initConnFunc = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteInitConn runs the stored connection initialization function if available.
|
||||||
|
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
|
||||||
|
if cn.initConnFunc != nil {
|
||||||
|
return cn.initConnFunc(ctx, cn)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID())
|
||||||
|
}
|
||||||
|
|
||||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||||
cn.netConn = netConn
|
// Store the new connection atomically first (lock-free)
|
||||||
|
cn.setNetConn(netConn)
|
||||||
|
// Clear relaxed timeouts when connection is replaced
|
||||||
|
cn.clearRelaxedTimeout()
|
||||||
|
|
||||||
|
// Protect reader reset operations to avoid data races
|
||||||
|
// Use write lock since we're modifying the reader state
|
||||||
|
cn.readerMu.Lock()
|
||||||
cn.rd.Reset(netConn)
|
cn.rd.Reset(netConn)
|
||||||
|
cn.readerMu.Unlock()
|
||||||
|
|
||||||
cn.bw.Reset(netConn)
|
cn.bw.Reset(netConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNetConn safely returns the current network connection using atomic load (lock-free).
|
||||||
|
// This method is used by the pool for health checks and provides better performance.
|
||||||
|
func (cn *Conn) GetNetConn() net.Conn {
|
||||||
|
return cn.getNetConn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||||
|
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||||
|
// New connection is not initialized yet
|
||||||
|
cn.Inited.Store(false)
|
||||||
|
// Replace the underlying connection
|
||||||
|
cn.SetNetConn(netConn)
|
||||||
|
return cn.ExecuteInitConn(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
|
||||||
|
// Returns an error if the connection is already marked for handoff.
|
||||||
|
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||||
|
// Use single atomic CAS operation for state transition
|
||||||
|
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
|
||||||
|
return errors.New("connection is already marked for handoff")
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.setNewEndpoint(newEndpoint)
|
||||||
|
cn.setMovingSeqID(seqID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||||
|
// Use single atomic CAS operation for state transition
|
||||||
|
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
|
||||||
|
return errors.New("connection was not marked for handoff")
|
||||||
|
}
|
||||||
|
cn.setUsable(false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreHandoffState restores the handoff state after a failed handoff (lock-free).
|
||||||
|
func (cn *Conn) RestoreHandoffState() {
|
||||||
|
// Restore shouldHandoff flag for retry
|
||||||
|
cn.shouldHandoffAtomic.Store(true)
|
||||||
|
// Keep usable=false to prevent the connection from being used until handoff succeeds
|
||||||
|
cn.setUsable(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
|
||||||
|
func (cn *Conn) ShouldHandoff() bool {
|
||||||
|
return cn.shouldHandoff()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||||
|
func (cn *Conn) GetHandoffEndpoint() string {
|
||||||
|
return cn.getNewEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||||
|
func (cn *Conn) GetMovingSeqID() int64 {
|
||||||
|
return cn.getMovingSeqID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the unique identifier for this connection.
|
||||||
|
func (cn *Conn) GetID() uint64 {
|
||||||
|
return cn.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
|
||||||
|
func (cn *Conn) ClearHandoffState() {
|
||||||
|
// clear handoff state
|
||||||
|
cn.setShouldHandoff(false)
|
||||||
|
cn.setNewEndpoint("")
|
||||||
|
cn.setMovingSeqID(0)
|
||||||
|
cn.setHandoffRetries(0)
|
||||||
|
cn.setUsable(true) // Connection is safe to use again after handoff completes
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||||
|
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||||
|
return cn.incrementHandoffRetries(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasBufferedData safely checks if the connection has buffered data.
|
||||||
|
// This method is used to avoid data races when checking for push notifications.
|
||||||
|
func (cn *Conn) HasBufferedData() bool {
|
||||||
|
// Use read lock for concurrent access to reader state
|
||||||
|
cn.readerMu.RLock()
|
||||||
|
defer cn.readerMu.RUnlock()
|
||||||
|
return cn.rd.Buffered() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekReplyTypeSafe safely peeks at the reply type.
|
||||||
|
// This method is used to avoid data races when checking for push notifications.
|
||||||
|
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
|
||||||
|
// Use read lock for concurrent access to reader state
|
||||||
|
cn.readerMu.RLock()
|
||||||
|
defer cn.readerMu.RUnlock()
|
||||||
|
|
||||||
|
if cn.rd.Buffered() <= 0 {
|
||||||
|
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
|
||||||
|
}
|
||||||
|
return cn.rd.PeekReplyType()
|
||||||
|
}
|
||||||
|
|
||||||
func (cn *Conn) Write(b []byte) (int, error) {
|
func (cn *Conn) Write(b []byte) (int, error) {
|
||||||
return cn.netConn.Write(b)
|
// Lock-free netConn access for better performance
|
||||||
|
if netConn := cn.getNetConn(); netConn != nil {
|
||||||
|
return netConn.Write(b)
|
||||||
|
}
|
||||||
|
return 0, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cn *Conn) RemoteAddr() net.Addr {
|
func (cn *Conn) RemoteAddr() net.Addr {
|
||||||
if cn.netConn != nil {
|
// Lock-free netConn access for better performance
|
||||||
return cn.netConn.RemoteAddr()
|
if netConn := cn.getNetConn(); netConn != nil {
|
||||||
|
return netConn.RemoteAddr()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -89,7 +494,16 @@ func (cn *Conn) WithReader(
|
|||||||
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
||||||
) error {
|
) error {
|
||||||
if timeout >= 0 {
|
if timeout >= 0 {
|
||||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
|
// Use relaxed timeout if set, otherwise use provided timeout
|
||||||
|
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
|
||||||
|
|
||||||
|
// Get the connection directly from atomic storage
|
||||||
|
netConn := cn.getNetConn()
|
||||||
|
if netConn == nil {
|
||||||
|
return fmt.Errorf("redis: connection not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -100,13 +514,26 @@ func (cn *Conn) WithWriter(
|
|||||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
||||||
) error {
|
) error {
|
||||||
if timeout >= 0 {
|
if timeout >= 0 {
|
||||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
|
// Use relaxed timeout if set, otherwise use provided timeout
|
||||||
return err
|
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||||
|
|
||||||
|
// Always set write deadline, even if getNetConn() returns nil
|
||||||
|
// This prevents write operations from hanging indefinitely
|
||||||
|
if netConn := cn.getNetConn(); netConn != nil {
|
||||||
|
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If getNetConn() returns nil, we still need to respect the timeout
|
||||||
|
// Return an error to prevent indefinite blocking
|
||||||
|
return fmt.Errorf("redis: connection not available for write operation")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cn.bw.Buffered() > 0 {
|
if cn.bw.Buffered() > 0 {
|
||||||
cn.bw.Reset(cn.netConn)
|
if netConn := cn.getNetConn(); netConn != nil {
|
||||||
|
cn.bw.Reset(netConn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fn(cn.wr); err != nil {
|
if err := fn(cn.wr); err != nil {
|
||||||
@@ -116,19 +543,33 @@ func (cn *Conn) WithWriter(
|
|||||||
return cn.bw.Flush()
|
return cn.bw.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cn *Conn) IsClosed() bool {
|
||||||
|
return cn.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
func (cn *Conn) Close() error {
|
func (cn *Conn) Close() error {
|
||||||
|
cn.closed.Store(true)
|
||||||
if cn.onClose != nil {
|
if cn.onClose != nil {
|
||||||
// ignore error
|
// ignore error
|
||||||
_ = cn.onClose()
|
_ = cn.onClose()
|
||||||
}
|
}
|
||||||
return cn.netConn.Close()
|
|
||||||
|
// Lock-free netConn access for better performance
|
||||||
|
if netConn := cn.getNetConn(); netConn != nil {
|
||||||
|
return netConn.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
||||||
// This is used to check if there are push notifications available
|
// This is used to check if there are push notifications available
|
||||||
// Important: This will work on Linux, but not on Windows
|
// Important: This will work on Linux, but not on Windows
|
||||||
func (cn *Conn) MaybeHasData() bool {
|
func (cn *Conn) MaybeHasData() bool {
|
||||||
return maybeHasData(cn.netConn)
|
// Lock-free netConn access for better performance
|
||||||
|
if netConn := cn.getNetConn(); netConn != nil {
|
||||||
|
return maybeHasData(netConn)
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||||
|
@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cn *Conn) NetConn() net.Conn {
|
func (cn *Conn) NetConn() net.Conn {
|
||||||
return cn.netConn
|
return cn.getNetConn()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) CheckMinIdleConns() {
|
func (p *ConnPool) CheckMinIdleConns() {
|
||||||
|
114
internal/pool/hooks.go
Normal file
114
internal/pool/hooks.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PoolHook defines the interface for connection lifecycle hooks.
|
||||||
|
type PoolHook interface {
|
||||||
|
// OnGet is called when a connection is retrieved from the pool.
|
||||||
|
// It can modify the connection or return an error to prevent its use.
|
||||||
|
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||||
|
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||||
|
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
|
||||||
|
|
||||||
|
// OnPut is called when a connection is returned to the pool.
|
||||||
|
// It returns whether the connection should be pooled and whether it should be removed.
|
||||||
|
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PoolHookManager manages multiple pool hooks.
|
||||||
|
type PoolHookManager struct {
|
||||||
|
hooks []PoolHook
|
||||||
|
hooksMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolHookManager creates a new pool hook manager.
|
||||||
|
func NewPoolHookManager() *PoolHookManager {
|
||||||
|
return &PoolHookManager{
|
||||||
|
hooks: make([]PoolHook, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHook adds a pool hook to the manager.
|
||||||
|
// Hooks are called in the order they were added.
|
||||||
|
func (phm *PoolHookManager) AddHook(hook PoolHook) {
|
||||||
|
phm.hooksMu.Lock()
|
||||||
|
defer phm.hooksMu.Unlock()
|
||||||
|
phm.hooks = append(phm.hooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveHook removes a pool hook from the manager.
|
||||||
|
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||||
|
phm.hooksMu.Lock()
|
||||||
|
defer phm.hooksMu.Unlock()
|
||||||
|
|
||||||
|
for i, h := range phm.hooks {
|
||||||
|
if h == hook {
|
||||||
|
// Remove hook by swapping with last element and truncating
|
||||||
|
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
|
||||||
|
phm.hooks = phm.hooks[:len(phm.hooks)-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessOnGet calls all OnGet hooks in order.
|
||||||
|
// If any hook returns an error, processing stops and the error is returned.
|
||||||
|
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||||
|
phm.hooksMu.RLock()
|
||||||
|
defer phm.hooksMu.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range phm.hooks {
|
||||||
|
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessOnPut calls all OnPut hooks in order.
|
||||||
|
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||||
|
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||||
|
phm.hooksMu.RLock()
|
||||||
|
defer phm.hooksMu.RUnlock()
|
||||||
|
|
||||||
|
shouldPool = true // Default to pooling the connection
|
||||||
|
|
||||||
|
for _, hook := range phm.hooks {
|
||||||
|
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||||
|
|
||||||
|
if hookErr != nil {
|
||||||
|
return false, true, hookErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any hook says to remove or not pool, respect that decision
|
||||||
|
if hookShouldRemove {
|
||||||
|
return false, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hookShouldPool {
|
||||||
|
shouldPool = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return shouldPool, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHookCount returns the number of registered hooks (for testing).
|
||||||
|
func (phm *PoolHookManager) GetHookCount() int {
|
||||||
|
phm.hooksMu.RLock()
|
||||||
|
defer phm.hooksMu.RUnlock()
|
||||||
|
return len(phm.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHooks returns a copy of all registered hooks.
|
||||||
|
func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||||
|
phm.hooksMu.RLock()
|
||||||
|
defer phm.hooksMu.RUnlock()
|
||||||
|
|
||||||
|
hooks := make([]PoolHook, len(phm.hooks))
|
||||||
|
copy(hooks, phm.hooks)
|
||||||
|
return hooks
|
||||||
|
}
|
213
internal/pool/hooks_test.go
Normal file
213
internal/pool/hooks_test.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHook for testing hook functionality
|
||||||
|
type TestHook struct {
|
||||||
|
OnGetCalled int
|
||||||
|
OnPutCalled int
|
||||||
|
GetError error
|
||||||
|
PutError error
|
||||||
|
ShouldPool bool
|
||||||
|
ShouldRemove bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
|
||||||
|
th.OnGetCalled++
|
||||||
|
return th.GetError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||||
|
th.OnPutCalled++
|
||||||
|
return th.ShouldPool, th.ShouldRemove, th.PutError
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolHookManager(t *testing.T) {
|
||||||
|
manager := NewPoolHookManager()
|
||||||
|
|
||||||
|
// Test initial state
|
||||||
|
if manager.GetHookCount() != 0 {
|
||||||
|
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add hooks
|
||||||
|
hook1 := &TestHook{ShouldPool: true}
|
||||||
|
hook2 := &TestHook{ShouldPool: true}
|
||||||
|
|
||||||
|
manager.AddHook(hook1)
|
||||||
|
manager.AddHook(hook2)
|
||||||
|
|
||||||
|
if manager.GetHookCount() != 2 {
|
||||||
|
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ProcessOnGet
|
||||||
|
ctx := context.Background()
|
||||||
|
conn := &Conn{} // Mock connection
|
||||||
|
|
||||||
|
err := manager.ProcessOnGet(ctx, conn, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ProcessOnGet should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hook1.OnGetCalled != 1 {
|
||||||
|
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hook2.OnGetCalled != 1 {
|
||||||
|
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ProcessOnPut
|
||||||
|
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldPool {
|
||||||
|
t.Error("Expected shouldPool to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldRemove {
|
||||||
|
t.Error("Expected shouldRemove to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if hook1.OnPutCalled != 1 {
|
||||||
|
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hook2.OnPutCalled != 1 {
|
||||||
|
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove a hook
|
||||||
|
manager.RemoveHook(hook1)
|
||||||
|
|
||||||
|
if manager.GetHookCount() != 1 {
|
||||||
|
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHookErrorHandling(t *testing.T) {
|
||||||
|
manager := NewPoolHookManager()
|
||||||
|
|
||||||
|
// Hook that returns error on Get
|
||||||
|
errorHook := &TestHook{
|
||||||
|
GetError: errors.New("test error"),
|
||||||
|
ShouldPool: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
normalHook := &TestHook{ShouldPool: true}
|
||||||
|
|
||||||
|
manager.AddHook(errorHook)
|
||||||
|
manager.AddHook(normalHook)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
conn := &Conn{}
|
||||||
|
|
||||||
|
// Test that error stops processing
|
||||||
|
err := manager.ProcessOnGet(ctx, conn, false)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from ProcessOnGet")
|
||||||
|
}
|
||||||
|
|
||||||
|
if errorHook.OnGetCalled != 1 {
|
||||||
|
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalHook should not be called due to error
|
||||||
|
if normalHook.OnGetCalled != 0 {
|
||||||
|
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHookShouldRemove(t *testing.T) {
|
||||||
|
manager := NewPoolHookManager()
|
||||||
|
|
||||||
|
// Hook that says to remove connection
|
||||||
|
removeHook := &TestHook{
|
||||||
|
ShouldPool: false,
|
||||||
|
ShouldRemove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
normalHook := &TestHook{ShouldPool: true}
|
||||||
|
|
||||||
|
manager.AddHook(removeHook)
|
||||||
|
manager.AddHook(normalHook)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
conn := &Conn{}
|
||||||
|
|
||||||
|
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ProcessOnPut should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldPool {
|
||||||
|
t.Error("Expected shouldPool to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldRemove {
|
||||||
|
t.Error("Expected shouldRemove to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if removeHook.OnPutCalled != 1 {
|
||||||
|
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalHook should not be called due to early return
|
||||||
|
if normalHook.OnPutCalled != 0 {
|
||||||
|
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolWithHooks(t *testing.T) {
|
||||||
|
// Create a pool with hooks
|
||||||
|
hookManager := NewPoolHookManager()
|
||||||
|
testHook := &TestHook{ShouldPool: true}
|
||||||
|
hookManager.AddHook(testHook)
|
||||||
|
|
||||||
|
opt := &Options{
|
||||||
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
|
return &net.TCPConn{}, nil // Mock connection
|
||||||
|
},
|
||||||
|
PoolSize: 1,
|
||||||
|
DialTimeout: time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := NewConnPool(opt)
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
// Add hook to pool after creation
|
||||||
|
pool.AddPoolHook(testHook)
|
||||||
|
|
||||||
|
// Verify hooks are initialized
|
||||||
|
if pool.hookManager == nil {
|
||||||
|
t.Error("Expected hookManager to be initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pool.hookManager.GetHookCount() != 1 {
|
||||||
|
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test adding hook to pool
|
||||||
|
additionalHook := &TestHook{ShouldPool: true}
|
||||||
|
pool.AddPoolHook(additionalHook)
|
||||||
|
|
||||||
|
if pool.hookManager.GetHookCount() != 2 {
|
||||||
|
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test removing hook from pool
|
||||||
|
pool.RemovePoolHook(additionalHook)
|
||||||
|
|
||||||
|
if pool.hookManager.GetHookCount() != 1 {
|
||||||
|
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
|
||||||
|
}
|
||||||
|
}
|
@@ -3,6 +3,7 @@ package pool
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -22,6 +23,12 @@ var (
|
|||||||
|
|
||||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
|
||||||
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
ErrPoolTimeout = errors.New("redis: connection pool timeout")
|
||||||
|
|
||||||
|
popAttempts = 10
|
||||||
|
getAttempts = 3
|
||||||
|
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
|
||||||
|
maxTime = minTime.Add(1<<63 - 1)
|
||||||
|
noExpiration = maxTime
|
||||||
)
|
)
|
||||||
|
|
||||||
var timers = sync.Pool{
|
var timers = sync.Pool{
|
||||||
@@ -38,11 +45,14 @@ type Stats struct {
|
|||||||
Misses uint32 // number of times free connection was NOT found in the pool
|
Misses uint32 // number of times free connection was NOT found in the pool
|
||||||
Timeouts uint32 // number of times a wait timeout occurred
|
Timeouts uint32 // number of times a wait timeout occurred
|
||||||
WaitCount uint32 // number of times a connection was waited
|
WaitCount uint32 // number of times a connection was waited
|
||||||
|
Unusable uint32 // number of times a connection was found to be unusable
|
||||||
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
|
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
|
||||||
|
|
||||||
TotalConns uint32 // number of total connections in the pool
|
TotalConns uint32 // number of total connections in the pool
|
||||||
IdleConns uint32 // number of idle connections in the pool
|
IdleConns uint32 // number of idle connections in the pool
|
||||||
StaleConns uint32 // number of stale connections removed from the pool
|
StaleConns uint32 // number of stale connections removed from the pool
|
||||||
|
|
||||||
|
PubSubStats PubSubStats
|
||||||
}
|
}
|
||||||
|
|
||||||
type Pooler interface {
|
type Pooler interface {
|
||||||
@@ -57,29 +67,27 @@ type Pooler interface {
|
|||||||
IdleLen() int
|
IdleLen() int
|
||||||
Stats() *Stats
|
Stats() *Stats
|
||||||
|
|
||||||
|
AddPoolHook(hook PoolHook)
|
||||||
|
RemovePoolHook(hook PoolHook)
|
||||||
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Dialer func(context.Context) (net.Conn, error)
|
Dialer func(context.Context) (net.Conn, error)
|
||||||
|
|
||||||
PoolFIFO bool
|
|
||||||
PoolSize int
|
|
||||||
DialTimeout time.Duration
|
|
||||||
PoolTimeout time.Duration
|
|
||||||
MinIdleConns int
|
|
||||||
MaxIdleConns int
|
|
||||||
MaxActiveConns int
|
|
||||||
ConnMaxIdleTime time.Duration
|
|
||||||
ConnMaxLifetime time.Duration
|
|
||||||
|
|
||||||
|
|
||||||
// Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without)
|
|
||||||
Protocol int
|
|
||||||
|
|
||||||
ReadBufferSize int
|
ReadBufferSize int
|
||||||
WriteBufferSize int
|
WriteBufferSize int
|
||||||
|
|
||||||
|
PoolFIFO bool
|
||||||
|
PoolSize int32
|
||||||
|
DialTimeout time.Duration
|
||||||
|
PoolTimeout time.Duration
|
||||||
|
MinIdleConns int32
|
||||||
|
MaxIdleConns int32
|
||||||
|
MaxActiveConns int32
|
||||||
|
ConnMaxIdleTime time.Duration
|
||||||
|
ConnMaxLifetime time.Duration
|
||||||
|
PushNotificationsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type lastDialErrorWrap struct {
|
type lastDialErrorWrap struct {
|
||||||
@@ -95,16 +103,21 @@ type ConnPool struct {
|
|||||||
queue chan struct{}
|
queue chan struct{}
|
||||||
|
|
||||||
connsMu sync.Mutex
|
connsMu sync.Mutex
|
||||||
conns []*Conn
|
conns map[uint64]*Conn
|
||||||
idleConns []*Conn
|
idleConns []*Conn
|
||||||
|
|
||||||
poolSize int
|
poolSize atomic.Int32
|
||||||
idleConnsLen int
|
idleConnsLen atomic.Int32
|
||||||
|
idleCheckInProgress atomic.Bool
|
||||||
|
|
||||||
stats Stats
|
stats Stats
|
||||||
waitDurationNs atomic.Int64
|
waitDurationNs atomic.Int64
|
||||||
|
|
||||||
_closed uint32 // atomic
|
_closed uint32 // atomic
|
||||||
|
|
||||||
|
// Pool hooks manager for flexible connection processing
|
||||||
|
hookManagerMu sync.RWMutex
|
||||||
|
hookManager *PoolHookManager
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Pooler = (*ConnPool)(nil)
|
var _ Pooler = (*ConnPool)(nil)
|
||||||
@@ -114,34 +127,69 @@ func NewConnPool(opt *Options) *ConnPool {
|
|||||||
cfg: opt,
|
cfg: opt,
|
||||||
|
|
||||||
queue: make(chan struct{}, opt.PoolSize),
|
queue: make(chan struct{}, opt.PoolSize),
|
||||||
conns: make([]*Conn, 0, opt.PoolSize),
|
conns: make(map[uint64]*Conn),
|
||||||
idleConns: make([]*Conn, 0, opt.PoolSize),
|
idleConns: make([]*Conn, 0, opt.PoolSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
p.connsMu.Lock()
|
// Only create MinIdleConns if explicitly requested (> 0)
|
||||||
p.checkMinIdleConns()
|
// This avoids creating connections during pool initialization for tests
|
||||||
p.connsMu.Unlock()
|
if opt.MinIdleConns > 0 {
|
||||||
|
p.connsMu.Lock()
|
||||||
|
p.checkMinIdleConns()
|
||||||
|
p.connsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initializeHooks sets up the pool hooks system.
|
||||||
|
func (p *ConnPool) initializeHooks() {
|
||||||
|
p.hookManager = NewPoolHookManager()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPoolHook adds a pool hook to the pool.
|
||||||
|
func (p *ConnPool) AddPoolHook(hook PoolHook) {
|
||||||
|
p.hookManagerMu.Lock()
|
||||||
|
defer p.hookManagerMu.Unlock()
|
||||||
|
|
||||||
|
if p.hookManager == nil {
|
||||||
|
p.initializeHooks()
|
||||||
|
}
|
||||||
|
p.hookManager.AddHook(hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePoolHook removes a pool hook from the pool.
|
||||||
|
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
|
||||||
|
p.hookManagerMu.Lock()
|
||||||
|
defer p.hookManagerMu.Unlock()
|
||||||
|
|
||||||
|
if p.hookManager != nil {
|
||||||
|
p.hookManager.RemoveHook(hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ConnPool) checkMinIdleConns() {
|
func (p *ConnPool) checkMinIdleConns() {
|
||||||
|
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer p.idleCheckInProgress.Store(false)
|
||||||
|
|
||||||
if p.cfg.MinIdleConns == 0 {
|
if p.cfg.MinIdleConns == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
|
|
||||||
|
// Only create idle connections if we haven't reached the total pool size limit
|
||||||
|
// MinIdleConns should be a subset of PoolSize, not additional connections
|
||||||
|
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
|
||||||
select {
|
select {
|
||||||
case p.queue <- struct{}{}:
|
case p.queue <- struct{}{}:
|
||||||
p.poolSize++
|
p.poolSize.Add(1)
|
||||||
p.idleConnsLen++
|
p.idleConnsLen.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
p.connsMu.Lock()
|
p.poolSize.Add(-1)
|
||||||
p.poolSize--
|
p.idleConnsLen.Add(-1)
|
||||||
p.idleConnsLen--
|
|
||||||
p.connsMu.Unlock()
|
|
||||||
|
|
||||||
p.freeTurn()
|
p.freeTurn()
|
||||||
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
|
||||||
@@ -150,12 +198,9 @@ func (p *ConnPool) checkMinIdleConns() {
|
|||||||
|
|
||||||
err := p.addIdleConn()
|
err := p.addIdleConn()
|
||||||
if err != nil && err != ErrClosed {
|
if err != nil && err != ErrClosed {
|
||||||
p.connsMu.Lock()
|
p.poolSize.Add(-1)
|
||||||
p.poolSize--
|
p.idleConnsLen.Add(-1)
|
||||||
p.idleConnsLen--
|
|
||||||
p.connsMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
p.freeTurn()
|
p.freeTurn()
|
||||||
}()
|
}()
|
||||||
default:
|
default:
|
||||||
@@ -172,6 +217,9 @@ func (p *ConnPool) addIdleConn() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// Mark connection as usable after successful creation
|
||||||
|
// This is essential for normal pool operations
|
||||||
|
cn.SetUsable(true)
|
||||||
|
|
||||||
p.connsMu.Lock()
|
p.connsMu.Lock()
|
||||||
defer p.connsMu.Unlock()
|
defer p.connsMu.Unlock()
|
||||||
@@ -182,11 +230,15 @@ func (p *ConnPool) addIdleConn() error {
|
|||||||
return ErrClosed
|
return ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
p.conns = append(p.conns, cn)
|
p.conns[cn.GetID()] = cn
|
||||||
p.idleConns = append(p.idleConns, cn)
|
p.idleConns = append(p.idleConns, cn)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewConn creates a new connection and returns it to the user.
|
||||||
|
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
|
||||||
|
//
|
||||||
|
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
|
||||||
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
|
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
|
||||||
return p.newConn(ctx, false)
|
return p.newConn(ctx, false)
|
||||||
}
|
}
|
||||||
@@ -196,33 +248,42 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
|||||||
return nil, ErrClosed
|
return nil, ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
p.connsMu.Lock()
|
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
|
||||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
|
||||||
p.connsMu.Unlock()
|
|
||||||
return nil, ErrPoolExhausted
|
return nil, ErrPoolExhausted
|
||||||
}
|
}
|
||||||
p.connsMu.Unlock()
|
|
||||||
|
|
||||||
cn, err := p.dialConn(ctx, pooled)
|
cn, err := p.dialConn(ctx, pooled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Mark connection as usable after successful creation
|
||||||
|
// This is essential for normal pool operations
|
||||||
|
cn.SetUsable(true)
|
||||||
|
|
||||||
p.connsMu.Lock()
|
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
|
||||||
defer p.connsMu.Unlock()
|
|
||||||
|
|
||||||
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
|
|
||||||
_ = cn.Close()
|
_ = cn.Close()
|
||||||
return nil, ErrPoolExhausted
|
return nil, ErrPoolExhausted
|
||||||
}
|
}
|
||||||
|
|
||||||
p.conns = append(p.conns, cn)
|
p.connsMu.Lock()
|
||||||
|
defer p.connsMu.Unlock()
|
||||||
|
if p.closed() {
|
||||||
|
_ = cn.Close()
|
||||||
|
return nil, ErrClosed
|
||||||
|
}
|
||||||
|
// Check if pool was closed while we were waiting for the lock
|
||||||
|
if p.conns == nil {
|
||||||
|
p.conns = make(map[uint64]*Conn)
|
||||||
|
}
|
||||||
|
p.conns[cn.GetID()] = cn
|
||||||
|
|
||||||
if pooled {
|
if pooled {
|
||||||
// If pool is full remove the cn on next Put.
|
// If pool is full remove the cn on next Put.
|
||||||
if p.poolSize >= p.cfg.PoolSize {
|
currentPoolSize := p.poolSize.Load()
|
||||||
|
if currentPoolSize >= int32(p.cfg.PoolSize) {
|
||||||
cn.pooled = false
|
cn.pooled = false
|
||||||
} else {
|
} else {
|
||||||
p.poolSize++
|
p.poolSize.Add(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,6 +310,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
|
|||||||
|
|
||||||
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
|
||||||
cn.pooled = pooled
|
cn.pooled = pooled
|
||||||
|
if p.cfg.ConnMaxLifetime > 0 {
|
||||||
|
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
|
||||||
|
} else {
|
||||||
|
cn.expiresAt = noExpiration
|
||||||
|
}
|
||||||
|
|
||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,6 +356,14 @@ func (p *ConnPool) getLastDialError() error {
|
|||||||
|
|
||||||
// Get returns existed connection from the pool or creates a new one.
|
// Get returns existed connection from the pool or creates a new one.
|
||||||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||||
|
return p.getConn(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConn returns a connection from the pool.
|
||||||
|
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
|
||||||
|
var cn *Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
if p.closed() {
|
if p.closed() {
|
||||||
return nil, ErrClosed
|
return nil, ErrClosed
|
||||||
}
|
}
|
||||||
@@ -297,9 +372,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
attempts := 0
|
||||||
for {
|
for {
|
||||||
|
if attempts >= getAttempts {
|
||||||
|
log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attempts++
|
||||||
|
|
||||||
p.connsMu.Lock()
|
p.connsMu.Lock()
|
||||||
cn, err := p.popIdle()
|
cn, err = p.popIdle()
|
||||||
p.connsMu.Unlock()
|
p.connsMu.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -311,11 +394,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if !p.isHealthyConn(cn) {
|
if !p.isHealthyConn(cn, now) {
|
||||||
_ = p.CloseConn(cn)
|
_ = p.CloseConn(cn)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process connection using the hooks system
|
||||||
|
p.hookManagerMu.RLock()
|
||||||
|
hookManager := p.hookManager
|
||||||
|
p.hookManagerMu.RUnlock()
|
||||||
|
|
||||||
|
if hookManager != nil {
|
||||||
|
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
|
||||||
|
log.Printf("redis: connection pool: failed to process idle connection by hook: %v", err)
|
||||||
|
// Failed to process connection, discard it
|
||||||
|
_ = p.CloseConn(cn)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
atomic.AddUint32(&p.stats.Hits, 1)
|
atomic.AddUint32(&p.stats.Hits, 1)
|
||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
@@ -328,6 +425,20 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process connection using the hooks system
|
||||||
|
p.hookManagerMu.RLock()
|
||||||
|
hookManager := p.hookManager
|
||||||
|
p.hookManagerMu.RUnlock()
|
||||||
|
|
||||||
|
if hookManager != nil {
|
||||||
|
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
|
||||||
|
// Failed to process connection, discard it
|
||||||
|
log.Printf("redis: connection pool: failed to process new connection by hook: %v", err)
|
||||||
|
_ = p.CloseConn(newcn)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return newcn, nil
|
return newcn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,7 +467,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case p.queue <- struct{}{}:
|
case p.queue <- struct{}{}:
|
||||||
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
|
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
|
||||||
atomic.AddUint32(&p.stats.WaitCount, 1)
|
atomic.AddUint32(&p.stats.WaitCount, 1)
|
||||||
if !timer.Stop() {
|
if !timer.Stop() {
|
||||||
<-timer.C
|
<-timer.C
|
||||||
@@ -376,68 +487,128 @@ func (p *ConnPool) popIdle() (*Conn, error) {
|
|||||||
if p.closed() {
|
if p.closed() {
|
||||||
return nil, ErrClosed
|
return nil, ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
n := len(p.idleConns)
|
n := len(p.idleConns)
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var cn *Conn
|
var cn *Conn
|
||||||
if p.cfg.PoolFIFO {
|
attempts := 0
|
||||||
cn = p.idleConns[0]
|
|
||||||
copy(p.idleConns, p.idleConns[1:])
|
for attempts < popAttempts {
|
||||||
p.idleConns = p.idleConns[:n-1]
|
if len(p.idleConns) == 0 {
|
||||||
} else {
|
return nil, nil
|
||||||
idx := n - 1
|
}
|
||||||
cn = p.idleConns[idx]
|
|
||||||
p.idleConns = p.idleConns[:idx]
|
if p.cfg.PoolFIFO {
|
||||||
|
cn = p.idleConns[0]
|
||||||
|
copy(p.idleConns, p.idleConns[1:])
|
||||||
|
p.idleConns = p.idleConns[:len(p.idleConns)-1]
|
||||||
|
} else {
|
||||||
|
idx := len(p.idleConns) - 1
|
||||||
|
cn = p.idleConns[idx]
|
||||||
|
p.idleConns = p.idleConns[:idx]
|
||||||
|
}
|
||||||
|
attempts++
|
||||||
|
|
||||||
|
if cn.IsUsable() {
|
||||||
|
p.idleConnsLen.Add(-1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection is not usable, put it back in the pool
|
||||||
|
if p.cfg.PoolFIFO {
|
||||||
|
// FIFO: put at end (will be picked up last since we pop from front)
|
||||||
|
p.idleConns = append(p.idleConns, cn)
|
||||||
|
} else {
|
||||||
|
// LIFO: put at beginning (will be picked up last since we pop from end)
|
||||||
|
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
p.idleConnsLen--
|
|
||||||
|
// If we exhausted all attempts without finding a usable connection, return nil
|
||||||
|
if attempts >= popAttempts {
|
||||||
|
log.Printf("redis: connection pool: failed to get an usable connection after %d attempts", popAttempts)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
p.checkMinIdleConns()
|
p.checkMinIdleConns()
|
||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
|
||||||
|
// Process connection using the hooks system
|
||||||
|
shouldPool := true
|
||||||
shouldRemove := false
|
shouldRemove := false
|
||||||
if cn.rd.Buffered() > 0 {
|
var err error
|
||||||
// Check if this might be push notification data
|
|
||||||
if p.cfg.Protocol == 3 {
|
|
||||||
// we know that there is something in the buffer, so peek at the next reply type without
|
|
||||||
// the potential to block and check if it's a push notification
|
|
||||||
if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush {
|
|
||||||
shouldRemove = true
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// not a push notification since protocol 2 doesn't support them
|
|
||||||
shouldRemove = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldRemove {
|
if cn.HasBufferedData() {
|
||||||
// For non-RESP3 or data that is not a push notification, buffered data is unexpected
|
// Peek at the reply type to check if it's a push notification
|
||||||
internal.Logger.Printf(ctx, "Conn has unread data, closing it")
|
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
|
||||||
p.Remove(ctx, cn, BadConnError{})
|
// Not a push notification or error peeking, remove connection
|
||||||
|
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
|
||||||
|
p.Remove(ctx, cn, err)
|
||||||
|
}
|
||||||
|
// It's a push notification, allow pooling (client will handle it)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.hookManagerMu.RLock()
|
||||||
|
hookManager := p.hookManager
|
||||||
|
p.hookManagerMu.RUnlock()
|
||||||
|
|
||||||
|
if hookManager != nil {
|
||||||
|
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
|
||||||
|
p.Remove(ctx, cn, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If hooks say to remove the connection, do so
|
||||||
|
if shouldRemove {
|
||||||
|
p.Remove(ctx, cn, errors.New("hook requested removal"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If processor says not to pool the connection, remove it
|
||||||
|
if !shouldPool {
|
||||||
|
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !cn.pooled {
|
if !cn.pooled {
|
||||||
p.Remove(ctx, cn, nil)
|
p.Remove(ctx, cn, errors.New("connection not pooled"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var shouldCloseConn bool
|
var shouldCloseConn bool
|
||||||
|
|
||||||
p.connsMu.Lock()
|
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
|
||||||
|
// unusable conns are expected to become usable at some point (background process is reconnecting them)
|
||||||
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
|
// put them at the opposite end of the queue
|
||||||
p.idleConns = append(p.idleConns, cn)
|
if !cn.IsUsable() {
|
||||||
p.idleConnsLen++
|
if p.cfg.PoolFIFO {
|
||||||
|
p.connsMu.Lock()
|
||||||
|
p.idleConns = append(p.idleConns, cn)
|
||||||
|
p.connsMu.Unlock()
|
||||||
|
} else {
|
||||||
|
p.connsMu.Lock()
|
||||||
|
p.idleConns = append([]*Conn{cn}, p.idleConns...)
|
||||||
|
p.connsMu.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.connsMu.Lock()
|
||||||
|
p.idleConns = append(p.idleConns, cn)
|
||||||
|
p.connsMu.Unlock()
|
||||||
|
}
|
||||||
|
p.idleConnsLen.Add(1)
|
||||||
} else {
|
} else {
|
||||||
p.removeConn(cn)
|
p.removeConnWithLock(cn)
|
||||||
shouldCloseConn = true
|
shouldCloseConn = true
|
||||||
}
|
}
|
||||||
|
|
||||||
p.connsMu.Unlock()
|
|
||||||
|
|
||||||
p.freeTurn()
|
p.freeTurn()
|
||||||
|
|
||||||
if shouldCloseConn {
|
if shouldCloseConn {
|
||||||
@@ -449,6 +620,9 @@ func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
|||||||
p.removeConnWithLock(cn)
|
p.removeConnWithLock(cn)
|
||||||
p.freeTurn()
|
p.freeTurn()
|
||||||
_ = p.closeConn(cn)
|
_ = p.closeConn(cn)
|
||||||
|
|
||||||
|
// Check if we need to create new idle connections to maintain MinIdleConns
|
||||||
|
p.checkMinIdleConns()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) CloseConn(cn *Conn) error {
|
func (p *ConnPool) CloseConn(cn *Conn) error {
|
||||||
@@ -463,17 +637,13 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) removeConn(cn *Conn) {
|
func (p *ConnPool) removeConn(cn *Conn) {
|
||||||
for i, c := range p.conns {
|
delete(p.conns, cn.GetID())
|
||||||
if c == cn {
|
|
||||||
p.conns = append(p.conns[:i], p.conns[i+1:]...)
|
|
||||||
if cn.pooled {
|
|
||||||
p.poolSize--
|
|
||||||
p.checkMinIdleConns()
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
atomic.AddUint32(&p.stats.StaleConns, 1)
|
atomic.AddUint32(&p.stats.StaleConns, 1)
|
||||||
|
|
||||||
|
// Decrement pool size counter when removing a connection
|
||||||
|
if cn.pooled {
|
||||||
|
p.poolSize.Add(-1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) closeConn(cn *Conn) error {
|
func (p *ConnPool) closeConn(cn *Conn) error {
|
||||||
@@ -491,9 +661,9 @@ func (p *ConnPool) Len() int {
|
|||||||
// IdleLen returns number of idle connections.
|
// IdleLen returns number of idle connections.
|
||||||
func (p *ConnPool) IdleLen() int {
|
func (p *ConnPool) IdleLen() int {
|
||||||
p.connsMu.Lock()
|
p.connsMu.Lock()
|
||||||
n := p.idleConnsLen
|
n := p.idleConnsLen.Load()
|
||||||
p.connsMu.Unlock()
|
p.connsMu.Unlock()
|
||||||
return n
|
return int(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Stats() *Stats {
|
func (p *ConnPool) Stats() *Stats {
|
||||||
@@ -502,6 +672,7 @@ func (p *ConnPool) Stats() *Stats {
|
|||||||
Misses: atomic.LoadUint32(&p.stats.Misses),
|
Misses: atomic.LoadUint32(&p.stats.Misses),
|
||||||
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
|
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
|
||||||
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
|
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
|
||||||
|
Unusable: atomic.LoadUint32(&p.stats.Unusable),
|
||||||
WaitDurationNs: p.waitDurationNs.Load(),
|
WaitDurationNs: p.waitDurationNs.Load(),
|
||||||
|
|
||||||
TotalConns: uint32(p.Len()),
|
TotalConns: uint32(p.Len()),
|
||||||
@@ -542,30 +713,32 @@ func (p *ConnPool) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.conns = nil
|
p.conns = nil
|
||||||
p.poolSize = 0
|
p.poolSize.Store(0)
|
||||||
p.idleConns = nil
|
p.idleConns = nil
|
||||||
p.idleConnsLen = 0
|
p.idleConnsLen.Store(0)
|
||||||
p.connsMu.Unlock()
|
p.connsMu.Unlock()
|
||||||
|
|
||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
|
||||||
now := time.Now()
|
// slight optimization, check expiresAt first.
|
||||||
|
if cn.expiresAt.Before(now) {
|
||||||
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check connection health, but be aware of push notifications
|
cn.SetUsedAt(now)
|
||||||
if err := connCheck(cn.netConn); err != nil {
|
// Check basic connection health
|
||||||
|
// Use GetNetConn() to safely access netConn and avoid data races
|
||||||
|
if err := connCheck(cn.getNetConn()); err != nil {
|
||||||
// If there's unexpected data, it might be push notifications (RESP3)
|
// If there's unexpected data, it might be push notifications (RESP3)
|
||||||
// However, push notification processing is now handled by the client
|
// However, push notification processing is now handled by the client
|
||||||
// before WithReader to ensure proper context is available to handlers
|
// before WithReader to ensure proper context is available to handlers
|
||||||
if err == errUnexpectedRead && p.cfg.Protocol == 3 {
|
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
|
||||||
// we know that there is something in the buffer, so peek at the next reply type without
|
// we know that there is something in the buffer, so peek at the next reply type without
|
||||||
// the potential to block
|
// the potential to block
|
||||||
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
|
||||||
@@ -579,7 +752,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cn.SetUsedAt(now)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
package pool
|
package pool
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
type SingleConnPool struct {
|
type SingleConnPool struct {
|
||||||
pool Pooler
|
pool Pooler
|
||||||
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
|
|||||||
func (p *SingleConnPool) Stats() *Stats {
|
func (p *SingleConnPool) Stats() *Stats {
|
||||||
return &Stats{}
|
return &Stats{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
|
||||||
|
|
||||||
|
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}
|
||||||
|
@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
|
|||||||
func (p *StickyConnPool) Stats() *Stats {
|
func (p *StickyConnPool) Stats() *Stats {
|
||||||
return &Stats{}
|
return &Stats{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
|
||||||
|
|
||||||
|
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
|
||||||
|
@@ -2,6 +2,7 @@ package pool_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -20,7 +21,7 @@ var _ = Describe("ConnPool", func() {
|
|||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
connPool = pool.NewConnPool(&pool.Options{
|
connPool = pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: 10,
|
PoolSize: int32(10),
|
||||||
PoolTimeout: time.Hour,
|
PoolTimeout: time.Hour,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
ConnMaxIdleTime: time.Millisecond,
|
ConnMaxIdleTime: time.Millisecond,
|
||||||
@@ -45,11 +46,11 @@ var _ = Describe("ConnPool", func() {
|
|||||||
<-closedChan
|
<-closedChan
|
||||||
return &net.TCPConn{}, nil
|
return &net.TCPConn{}, nil
|
||||||
},
|
},
|
||||||
PoolSize: 10,
|
PoolSize: int32(10),
|
||||||
PoolTimeout: time.Hour,
|
PoolTimeout: time.Hour,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
ConnMaxIdleTime: time.Millisecond,
|
ConnMaxIdleTime: time.Millisecond,
|
||||||
MinIdleConns: minIdleConns,
|
MinIdleConns: int32(minIdleConns),
|
||||||
})
|
})
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
Expect(connPool.Close()).NotTo(HaveOccurred())
|
Expect(connPool.Close()).NotTo(HaveOccurred())
|
||||||
@@ -105,7 +106,7 @@ var _ = Describe("ConnPool", func() {
|
|||||||
// ok
|
// ok
|
||||||
}
|
}
|
||||||
|
|
||||||
connPool.Remove(ctx, cn, nil)
|
connPool.Remove(ctx, cn, errors.New("test"))
|
||||||
|
|
||||||
// Check that Get is unblocked.
|
// Check that Get is unblocked.
|
||||||
select {
|
select {
|
||||||
@@ -130,8 +131,8 @@ var _ = Describe("MinIdleConns", func() {
|
|||||||
newConnPool := func() *pool.ConnPool {
|
newConnPool := func() *pool.ConnPool {
|
||||||
connPool := pool.NewConnPool(&pool.Options{
|
connPool := pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: poolSize,
|
PoolSize: int32(poolSize),
|
||||||
MinIdleConns: minIdleConns,
|
MinIdleConns: int32(minIdleConns),
|
||||||
PoolTimeout: 100 * time.Millisecond,
|
PoolTimeout: 100 * time.Millisecond,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
ConnMaxIdleTime: -1,
|
ConnMaxIdleTime: -1,
|
||||||
@@ -168,7 +169,7 @@ var _ = Describe("MinIdleConns", func() {
|
|||||||
|
|
||||||
Context("after Remove", func() {
|
Context("after Remove", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
connPool.Remove(ctx, cn, nil)
|
connPool.Remove(ctx, cn, errors.New("test"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("has idle connections", func() {
|
It("has idle connections", func() {
|
||||||
@@ -245,7 +246,7 @@ var _ = Describe("MinIdleConns", func() {
|
|||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
perform(len(cns), func(i int) {
|
perform(len(cns), func(i int) {
|
||||||
mu.RLock()
|
mu.RLock()
|
||||||
connPool.Remove(ctx, cns[i], nil)
|
connPool.Remove(ctx, cns[i], errors.New("test"))
|
||||||
mu.RUnlock()
|
mu.RUnlock()
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -309,7 +310,7 @@ var _ = Describe("race", func() {
|
|||||||
It("does not happen on Get, Put, and Remove", func() {
|
It("does not happen on Get, Put, and Remove", func() {
|
||||||
connPool = pool.NewConnPool(&pool.Options{
|
connPool = pool.NewConnPool(&pool.Options{
|
||||||
Dialer: dummyDialer,
|
Dialer: dummyDialer,
|
||||||
PoolSize: 10,
|
PoolSize: int32(10),
|
||||||
PoolTimeout: time.Minute,
|
PoolTimeout: time.Minute,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
ConnMaxIdleTime: time.Millisecond,
|
ConnMaxIdleTime: time.Millisecond,
|
||||||
@@ -328,7 +329,7 @@ var _ = Describe("race", func() {
|
|||||||
cn, err := connPool.Get(ctx)
|
cn, err := connPool.Get(ctx)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
connPool.Remove(ctx, cn, nil)
|
connPool.Remove(ctx, cn, errors.New("test"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -339,15 +340,15 @@ var _ = Describe("race", func() {
|
|||||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
return &net.TCPConn{}, nil
|
return &net.TCPConn{}, nil
|
||||||
},
|
},
|
||||||
PoolSize: 1000,
|
PoolSize: int32(1000),
|
||||||
MinIdleConns: 50,
|
MinIdleConns: int32(50),
|
||||||
PoolTimeout: 3 * time.Second,
|
PoolTimeout: 3 * time.Second,
|
||||||
DialTimeout: 1 * time.Second,
|
DialTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
p := pool.NewConnPool(opt)
|
p := pool.NewConnPool(opt)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < opt.PoolSize; i++ {
|
for i := int32(0); i < opt.PoolSize; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@@ -366,8 +367,8 @@ var _ = Describe("race", func() {
|
|||||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
panic("test panic")
|
panic("test panic")
|
||||||
},
|
},
|
||||||
PoolSize: 100,
|
PoolSize: int32(100),
|
||||||
MinIdleConns: 30,
|
MinIdleConns: int32(30),
|
||||||
}
|
}
|
||||||
p := pool.NewConnPool(opt)
|
p := pool.NewConnPool(opt)
|
||||||
|
|
||||||
@@ -377,14 +378,14 @@ var _ = Describe("race", func() {
|
|||||||
state := p.Stats()
|
state := p.Stats()
|
||||||
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
|
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
|
||||||
}, "3s", "50ms").Should(BeTrue())
|
}, "3s", "50ms").Should(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("wait", func() {
|
It("wait", func() {
|
||||||
opt := &pool.Options{
|
opt := &pool.Options{
|
||||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
return &net.TCPConn{}, nil
|
return &net.TCPConn{}, nil
|
||||||
},
|
},
|
||||||
PoolSize: 1,
|
PoolSize: int32(1),
|
||||||
PoolTimeout: 3 * time.Second,
|
PoolTimeout: 3 * time.Second,
|
||||||
}
|
}
|
||||||
p := pool.NewConnPool(opt)
|
p := pool.NewConnPool(opt)
|
||||||
@@ -415,7 +416,7 @@ var _ = Describe("race", func() {
|
|||||||
|
|
||||||
return &net.TCPConn{}, nil
|
return &net.TCPConn{}, nil
|
||||||
},
|
},
|
||||||
PoolSize: 1,
|
PoolSize: int32(1),
|
||||||
PoolTimeout: testPoolTimeout,
|
PoolTimeout: testPoolTimeout,
|
||||||
}
|
}
|
||||||
p := pool.NewConnPool(opt)
|
p := pool.NewConnPool(opt)
|
||||||
|
77
internal/pool/pubsub.go
Normal file
77
internal/pool/pubsub.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PubSubStats struct {
|
||||||
|
Created uint32
|
||||||
|
Untracked uint32
|
||||||
|
Active uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// PubSubPool manages a pool of PubSub connections.
|
||||||
|
type PubSubPool struct {
|
||||||
|
opt *Options
|
||||||
|
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
|
// Map to track active PubSub connections
|
||||||
|
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
|
||||||
|
closed atomic.Bool
|
||||||
|
stats PubSubStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||||
|
return &PubSubPool{
|
||||||
|
opt: opt,
|
||||||
|
netDialer: netDialer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
|
||||||
|
if p.closed.Load() {
|
||||||
|
return nil, ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
netConn, err := p.netDialer(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
|
||||||
|
atomic.AddUint32(&p.stats.Created, 1)
|
||||||
|
return cn, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PubSubPool) TrackConn(cn *Conn) {
|
||||||
|
atomic.AddUint32(&p.stats.Active, 1)
|
||||||
|
p.activeConns.Store(cn.GetID(), cn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PubSubPool) UntrackConn(cn *Conn) {
|
||||||
|
atomic.AddUint32(&p.stats.Active, ^uint32(0))
|
||||||
|
atomic.AddUint32(&p.stats.Untracked, 1)
|
||||||
|
p.activeConns.Delete(cn.GetID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PubSubPool) Close() error {
|
||||||
|
p.closed.Store(true)
|
||||||
|
p.activeConns.Range(func(key, value interface{}) bool {
|
||||||
|
cn := value.(*Conn)
|
||||||
|
_ = cn.Close()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PubSubPool) Stats() *PubSubStats {
|
||||||
|
// load stats atomically
|
||||||
|
return &PubSubStats{
|
||||||
|
Created: atomic.LoadUint32(&p.stats.Created),
|
||||||
|
Untracked: atomic.LoadUint32(&p.stats.Untracked),
|
||||||
|
Active: atomic.LoadUint32(&p.stats.Active),
|
||||||
|
}
|
||||||
|
}
|
3
internal/redis.go
Normal file
3
internal/redis.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
const RedisNull = "null"
|
17
internal/util/math.go
Normal file
17
internal/util/math.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
// Max returns the maximum of two integers
|
||||||
|
func Max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min returns the minimum of two integers
|
||||||
|
func Min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
81
options.go
81
options.go
@@ -14,9 +14,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9/auth"
|
"github.com/redis/go-redis/v9/auth"
|
||||||
|
"github.com/redis/go-redis/v9/hitless"
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
"github.com/redis/go-redis/v9/push"
|
|
||||||
"github.com/redis/go-redis/v9/internal/proto"
|
"github.com/redis/go-redis/v9/internal/proto"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Limiter is the interface of a rate limiter or a circuit breaker.
|
// Limiter is the interface of a rate limiter or a circuit breaker.
|
||||||
@@ -153,6 +154,7 @@ type Options struct {
|
|||||||
//
|
//
|
||||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||||
// but it helps closing idle connections faster reducing the pool size.
|
// but it helps closing idle connections faster reducing the pool size.
|
||||||
|
// default: false
|
||||||
PoolFIFO bool
|
PoolFIFO bool
|
||||||
|
|
||||||
// PoolSize is the base number of socket connections.
|
// PoolSize is the base number of socket connections.
|
||||||
@@ -244,8 +246,19 @@ type Options struct {
|
|||||||
// When a node is marked as failing, it will be avoided for this duration.
|
// When a node is marked as failing, it will be avoided for this duration.
|
||||||
// Default is 15 seconds.
|
// Default is 15 seconds.
|
||||||
FailingTimeoutSeconds int
|
FailingTimeoutSeconds int
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||||
|
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||||
|
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||||
|
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||||
|
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||||
|
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
|
||||||
|
// This is an alias to hitless.Config for convenience.
|
||||||
|
type HitlessUpgradeConfig = hitless.Config
|
||||||
|
|
||||||
func (opt *Options) init() {
|
func (opt *Options) init() {
|
||||||
if opt.Addr == "" {
|
if opt.Addr == "" {
|
||||||
opt.Addr = "localhost:6379"
|
opt.Addr = "localhost:6379"
|
||||||
@@ -320,13 +333,36 @@ func (opt *Options) init() {
|
|||||||
case 0:
|
case 0:
|
||||||
opt.MaxRetryBackoff = 512 * time.Millisecond
|
opt.MaxRetryBackoff = 512 * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize)
|
||||||
|
|
||||||
|
// auto-detect endpoint type if not specified
|
||||||
|
endpointType := opt.HitlessUpgradeConfig.EndpointType
|
||||||
|
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
|
||||||
|
// Auto-detect endpoint type if not specified
|
||||||
|
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
|
||||||
|
}
|
||||||
|
opt.HitlessUpgradeConfig.EndpointType = endpointType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *Options) clone() *Options {
|
func (opt *Options) clone() *Options {
|
||||||
clone := *opt
|
clone := *opt
|
||||||
|
|
||||||
|
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
|
||||||
|
if opt.HitlessUpgradeConfig != nil {
|
||||||
|
configClone := *opt.HitlessUpgradeConfig
|
||||||
|
clone.HitlessUpgradeConfig = &configClone
|
||||||
|
}
|
||||||
|
|
||||||
return &clone
|
return &clone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewDialer returns a function that will be used as the default dialer
|
||||||
|
// when none is specified in Options.Dialer.
|
||||||
|
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
|
||||||
|
return NewDialer(opt)
|
||||||
|
}
|
||||||
|
|
||||||
// NewDialer returns a function that will be used as the default dialer
|
// NewDialer returns a function that will be used as the default dialer
|
||||||
// when none is specified in Options.Dialer.
|
// when none is specified in Options.Dialer.
|
||||||
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
||||||
@@ -617,18 +653,35 @@ func newConnPool(
|
|||||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||||
return dialer(ctx, opt.Network, opt.Addr)
|
return dialer(ctx, opt.Network, opt.Addr)
|
||||||
},
|
},
|
||||||
PoolFIFO: opt.PoolFIFO,
|
PoolFIFO: opt.PoolFIFO,
|
||||||
PoolSize: opt.PoolSize,
|
PoolSize: int32(opt.PoolSize),
|
||||||
PoolTimeout: opt.PoolTimeout,
|
PoolTimeout: opt.PoolTimeout,
|
||||||
DialTimeout: opt.DialTimeout,
|
DialTimeout: opt.DialTimeout,
|
||||||
MinIdleConns: opt.MinIdleConns,
|
MinIdleConns: int32(opt.MinIdleConns),
|
||||||
MaxIdleConns: opt.MaxIdleConns,
|
MaxIdleConns: int32(opt.MaxIdleConns),
|
||||||
MaxActiveConns: opt.MaxActiveConns,
|
MaxActiveConns: int32(opt.MaxActiveConns),
|
||||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||||
// Pass protocol version for push notification optimization
|
ReadBufferSize: opt.ReadBufferSize,
|
||||||
Protocol: opt.Protocol,
|
WriteBufferSize: opt.WriteBufferSize,
|
||||||
ReadBufferSize: opt.ReadBufferSize,
|
PushNotificationsEnabled: opt.Protocol == 3,
|
||||||
WriteBufferSize: opt.WriteBufferSize,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||||
|
) *pool.PubSubPool {
|
||||||
|
return pool.NewPubSubPool(&pool.Options{
|
||||||
|
PoolFIFO: opt.PoolFIFO,
|
||||||
|
PoolSize: int32(opt.PoolSize),
|
||||||
|
PoolTimeout: opt.PoolTimeout,
|
||||||
|
DialTimeout: opt.DialTimeout,
|
||||||
|
MinIdleConns: int32(opt.MinIdleConns),
|
||||||
|
MaxIdleConns: int32(opt.MaxIdleConns),
|
||||||
|
MaxActiveConns: int32(opt.MaxActiveConns),
|
||||||
|
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||||
|
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||||
|
ReadBufferSize: 32 * 1024,
|
||||||
|
WriteBufferSize: 32 * 1024,
|
||||||
|
PushNotificationsEnabled: opt.Protocol == 3,
|
||||||
|
}, dialer)
|
||||||
|
}
|
||||||
|
@@ -38,6 +38,7 @@ type ClusterOptions struct {
|
|||||||
ClientName string
|
ClientName string
|
||||||
|
|
||||||
// NewClient creates a cluster node client with provided name and options.
|
// NewClient creates a cluster node client with provided name and options.
|
||||||
|
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
|
||||||
NewClient func(opt *Options) *Client
|
NewClient func(opt *Options) *Client
|
||||||
|
|
||||||
// The maximum number of retries before giving up. Command is retried
|
// The maximum number of retries before giving up. Command is retried
|
||||||
@@ -129,6 +130,14 @@ type ClusterOptions struct {
|
|||||||
// When a node is marked as failing, it will be avoided for this duration.
|
// When a node is marked as failing, it will be avoided for this duration.
|
||||||
// Default is 15 seconds.
|
// Default is 15 seconds.
|
||||||
FailingTimeoutSeconds int
|
FailingTimeoutSeconds int
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||||
|
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||||
|
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||||
|
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||||
|
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||||
|
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
|
||||||
|
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *ClusterOptions) init() {
|
func (opt *ClusterOptions) init() {
|
||||||
@@ -319,6 +328,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (opt *ClusterOptions) clientOptions() *Options {
|
func (opt *ClusterOptions) clientOptions() *Options {
|
||||||
|
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
|
||||||
|
var hitlessConfig *HitlessUpgradeConfig
|
||||||
|
if opt.HitlessUpgradeConfig != nil {
|
||||||
|
configClone := *opt.HitlessUpgradeConfig
|
||||||
|
hitlessConfig = &configClone
|
||||||
|
}
|
||||||
|
|
||||||
return &Options{
|
return &Options{
|
||||||
ClientName: opt.ClientName,
|
ClientName: opt.ClientName,
|
||||||
Dialer: opt.Dialer,
|
Dialer: opt.Dialer,
|
||||||
@@ -360,8 +376,9 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
|||||||
// much use for ClusterSlots config). This means we cannot execute the
|
// much use for ClusterSlots config). This means we cannot execute the
|
||||||
// READONLY command against that node -- setting readOnly to false in such
|
// READONLY command against that node -- setting readOnly to false in such
|
||||||
// situations in the options below will prevent that from happening.
|
// situations in the options below will prevent that from happening.
|
||||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||||
UnstableResp3: opt.UnstableResp3,
|
UnstableResp3: opt.UnstableResp3,
|
||||||
|
HitlessUpgradeConfig: hitlessConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1830,12 +1847,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hitless won't work here for now
|
||||||
func (c *ClusterClient) pubSub() *PubSub {
|
func (c *ClusterClient) pubSub() *PubSub {
|
||||||
var node *clusterNode
|
var node *clusterNode
|
||||||
pubsub := &PubSub{
|
pubsub := &PubSub{
|
||||||
opt: c.opt.clientOptions(),
|
opt: c.opt.clientOptions(),
|
||||||
|
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
|
||||||
if node != nil {
|
if node != nil {
|
||||||
panic("node != nil")
|
panic("node != nil")
|
||||||
}
|
}
|
||||||
@@ -1850,18 +1867,25 @@ func (c *ClusterClient) pubSub() *PubSub {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
|
||||||
cn, err := node.Client.newConn(context.TODO())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node = nil
|
node = nil
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// will return nil if already initialized
|
||||||
|
err = node.Client.initConn(ctx, cn)
|
||||||
|
if err != nil {
|
||||||
|
_ = cn.Close()
|
||||||
|
node = nil
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
node.Client.pubSubPool.TrackConn(cn)
|
||||||
return cn, nil
|
return cn, nil
|
||||||
},
|
},
|
||||||
closeConn: func(cn *pool.Conn) error {
|
closeConn: func(cn *pool.Conn) error {
|
||||||
err := node.Client.connPool.CloseConn(cn)
|
// Untrack connection from PubSubPool
|
||||||
|
node.Client.pubSubPool.UntrackConn(cn)
|
||||||
|
err := cn.Close()
|
||||||
node = nil
|
node = nil
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
|
375
pool_pubsub_bench_test.go
Normal file
375
pool_pubsub_bench_test.go
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
// Pool and PubSub Benchmark Suite
|
||||||
|
//
|
||||||
|
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
|
||||||
|
// It's designed to be run against different branches to compare performance.
|
||||||
|
//
|
||||||
|
// Usage Examples:
|
||||||
|
// # Run all benchmarks
|
||||||
|
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
|
||||||
|
//
|
||||||
|
// # Run only pool benchmarks
|
||||||
|
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
|
||||||
|
//
|
||||||
|
// # Run only PubSub benchmarks
|
||||||
|
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
|
||||||
|
//
|
||||||
|
// # Compare between branches
|
||||||
|
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
|
||||||
|
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
|
||||||
|
// benchcmp branch1.txt branch2.txt
|
||||||
|
//
|
||||||
|
// # Run with memory profiling
|
||||||
|
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
|
||||||
|
//
|
||||||
|
// # Run with CPU profiling
|
||||||
|
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
|
||||||
|
|
||||||
|
package redis_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dummyDialer creates a mock connection for benchmarking
|
||||||
|
func dummyDialer(ctx context.Context) (net.Conn, error) {
|
||||||
|
return &dummyConn{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dummyConn implements net.Conn for benchmarking
|
||||||
|
type dummyConn struct{}
|
||||||
|
|
||||||
|
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
|
||||||
|
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||||
|
func (c *dummyConn) Close() error { return nil }
|
||||||
|
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
|
||||||
|
func (c *dummyConn) RemoteAddr() net.Addr {
|
||||||
|
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
|
||||||
|
}
|
||||||
|
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// POOL BENCHMARKS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
|
||||||
|
func BenchmarkPoolGetPut(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
|
||||||
|
|
||||||
|
for _, poolSize := range poolSizes {
|
||||||
|
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||||
|
connPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: dummyDialer,
|
||||||
|
PoolSize: int32(poolSize),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
DialTimeout: time.Second,
|
||||||
|
ConnMaxIdleTime: time.Hour,
|
||||||
|
MinIdleConns: int32(0), // Start with no idle connections
|
||||||
|
})
|
||||||
|
defer connPool.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
cn, err := connPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
connPool.Put(ctx, cn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
|
||||||
|
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
configs := []struct {
|
||||||
|
poolSize int
|
||||||
|
minIdleConns int
|
||||||
|
}{
|
||||||
|
{8, 2},
|
||||||
|
{16, 4},
|
||||||
|
{32, 8},
|
||||||
|
{64, 16},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, config := range configs {
|
||||||
|
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
|
||||||
|
connPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: dummyDialer,
|
||||||
|
PoolSize: int32(config.poolSize),
|
||||||
|
MinIdleConns: int32(config.minIdleConns),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
DialTimeout: time.Second,
|
||||||
|
ConnMaxIdleTime: time.Hour,
|
||||||
|
})
|
||||||
|
defer connPool.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
cn, err := connPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
connPool.Put(ctx, cn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
|
||||||
|
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
connPool := pool.NewConnPool(&pool.Options{
|
||||||
|
Dialer: dummyDialer,
|
||||||
|
PoolSize: int32(32),
|
||||||
|
PoolTimeout: time.Second,
|
||||||
|
DialTimeout: time.Second,
|
||||||
|
ConnMaxIdleTime: time.Hour,
|
||||||
|
MinIdleConns: int32(0),
|
||||||
|
})
|
||||||
|
defer connPool.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
// Test with different levels of concurrency
|
||||||
|
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
|
||||||
|
|
||||||
|
for _, concurrency := range concurrencyLevels {
|
||||||
|
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||||
|
b.SetParallelism(concurrency)
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
cn, err := connPool.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
connPool.Put(ctx, cn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// PUBSUB BENCHMARKS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// benchmarkClient creates a Redis client for benchmarking with mock dialer
|
||||||
|
func benchmarkClient(poolSize int) *redis.Client {
|
||||||
|
return redis.NewClient(&redis.Options{
|
||||||
|
Addr: "localhost:6379", // Mock address
|
||||||
|
DialTimeout: time.Second,
|
||||||
|
ReadTimeout: time.Second,
|
||||||
|
WriteTimeout: time.Second,
|
||||||
|
PoolSize: poolSize,
|
||||||
|
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
|
||||||
|
func BenchmarkPubSubCreation(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
poolSizes := []int{1, 4, 8, 16, 32}
|
||||||
|
|
||||||
|
for _, poolSize := range poolSizes {
|
||||||
|
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||||
|
client := benchmarkClient(poolSize)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pubsub := client.Subscribe(ctx, "test-channel")
|
||||||
|
pubsub.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
|
||||||
|
func BenchmarkPubSubPatternCreation(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
poolSizes := []int{1, 4, 8, 16, 32}
|
||||||
|
|
||||||
|
for _, poolSize := range poolSizes {
|
||||||
|
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
|
||||||
|
client := benchmarkClient(poolSize)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pubsub := client.PSubscribe(ctx, "test-*")
|
||||||
|
pubsub.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
|
||||||
|
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := benchmarkClient(32)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
concurrencyLevels := []int{1, 2, 4, 8, 16}
|
||||||
|
|
||||||
|
for _, concurrency := range concurrencyLevels {
|
||||||
|
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
semaphore := make(chan struct{}, concurrency)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
semaphore <- struct{}{}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
|
pubsub := client.Subscribe(ctx, "test-channel")
|
||||||
|
pubsub.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
|
||||||
|
func BenchmarkPubSubMultipleChannels(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := benchmarkClient(16)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
channelCounts := []int{1, 5, 10, 25, 50, 100}
|
||||||
|
|
||||||
|
for _, channelCount := range channelCounts {
|
||||||
|
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
|
||||||
|
// Prepare channel names
|
||||||
|
channels := make([]string, channelCount)
|
||||||
|
for i := 0; i < channelCount; i++ {
|
||||||
|
channels[i] = fmt.Sprintf("channel-%d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pubsub := client.Subscribe(ctx, channels...)
|
||||||
|
pubsub.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
|
||||||
|
func BenchmarkPubSubReuse(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := benchmarkClient(16)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Benchmark just the creation and closing of PubSub connections
|
||||||
|
// This simulates reuse patterns without requiring actual Redis operations
|
||||||
|
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
|
||||||
|
pubsub.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// COMBINED BENCHMARKS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
|
||||||
|
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := benchmarkClient(32)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
// Mix of pool stats collection and PubSub creation
|
||||||
|
if pb.Next() {
|
||||||
|
// Pool stats operation
|
||||||
|
stats := client.PoolStats()
|
||||||
|
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
|
||||||
|
}
|
||||||
|
|
||||||
|
if pb.Next() {
|
||||||
|
// PubSub operation
|
||||||
|
pubsub := client.Subscribe(ctx, "test-channel")
|
||||||
|
pubsub.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
|
||||||
|
func BenchmarkPoolStatsCollection(b *testing.B) {
|
||||||
|
client := benchmarkClient(16)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
stats := client.PoolStats()
|
||||||
|
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPoolHighContention tests pool performance under high contention
|
||||||
|
func BenchmarkPoolHighContention(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := benchmarkClient(32)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
// High contention Get/Put operations
|
||||||
|
pubsub := client.Subscribe(ctx, "test-channel")
|
||||||
|
pubsub.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
40
pubsub.go
40
pubsub.go
@@ -22,7 +22,7 @@ import (
|
|||||||
type PubSub struct {
|
type PubSub struct {
|
||||||
opt *Options
|
opt *Options
|
||||||
|
|
||||||
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
|
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
|
||||||
closeConn func(*pool.Conn) error
|
closeConn func(*pool.Conn) error
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -42,6 +42,9 @@ type PubSub struct {
|
|||||||
|
|
||||||
// Push notification processor for handling generic push notifications
|
// Push notification processor for handling generic push notifications
|
||||||
pushProcessor push.NotificationProcessor
|
pushProcessor push.NotificationProcessor
|
||||||
|
|
||||||
|
// Cleanup callback for hitless upgrade tracking
|
||||||
|
onClose func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PubSub) init() {
|
func (c *PubSub) init() {
|
||||||
@@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
|
|||||||
return c.cn, nil
|
return c.cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.opt.Addr == "" {
|
||||||
|
// TODO(hitless):
|
||||||
|
// this is probably cluster client
|
||||||
|
// c.newConn will ignore the addr argument
|
||||||
|
// will be changed when we have hitless upgrades for cluster clients
|
||||||
|
c.opt.Addr = internal.RedisNull
|
||||||
|
}
|
||||||
|
|
||||||
channels := mapKeys(c.channels)
|
channels := mapKeys(c.channels)
|
||||||
channels = append(channels, newChannels...)
|
channels = append(channels, newChannels...)
|
||||||
|
|
||||||
cn, err := c.newConn(ctx, channels)
|
cn, err := c.newConn(ctx, c.opt.Addr, channels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -157,12 +168,28 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
|
|||||||
if c.cn != cn {
|
if c.cn != cn {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cn.IsUsable() || cn.ShouldHandoff() {
|
||||||
|
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
|
||||||
|
}
|
||||||
|
|
||||||
if isBadConn(err, allowTimeout, c.opt.Addr) {
|
if isBadConn(err, allowTimeout, c.opt.Addr) {
|
||||||
c.reconnect(ctx, err)
|
c.reconnect(ctx, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PubSub) reconnect(ctx context.Context, reason error) {
|
func (c *PubSub) reconnect(ctx context.Context, reason error) {
|
||||||
|
if c.cn != nil && c.cn.ShouldHandoff() {
|
||||||
|
newEndpoint := c.cn.GetHandoffEndpoint()
|
||||||
|
// If new endpoint is NULL, use the original address
|
||||||
|
if newEndpoint == internal.RedisNull {
|
||||||
|
newEndpoint = c.opt.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
if newEndpoint != "" {
|
||||||
|
c.opt.Addr = newEndpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
_ = c.closeTheCn(reason)
|
_ = c.closeTheCn(reason)
|
||||||
_, _ = c.conn(ctx, nil)
|
_, _ = c.conn(ctx, nil)
|
||||||
}
|
}
|
||||||
@@ -189,6 +216,11 @@ func (c *PubSub) Close() error {
|
|||||||
c.closed = true
|
c.closed = true
|
||||||
close(c.exit)
|
close(c.exit)
|
||||||
|
|
||||||
|
// Call cleanup callback if set
|
||||||
|
if c.onClose != nil {
|
||||||
|
c.onClose()
|
||||||
|
}
|
||||||
|
|
||||||
return c.closeTheCn(pool.ErrClosed)
|
return c.closeTheCn(pool.ErrClosed)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,6 +493,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
|||||||
// Receive returns a message as a Subscription, Message, Pong or error.
|
// Receive returns a message as a Subscription, Message, Pong or error.
|
||||||
// See PubSub example for details. This is low-level API and in most cases
|
// See PubSub example for details. This is low-level API and in most cases
|
||||||
// Channel should be used instead.
|
// Channel should be used instead.
|
||||||
|
// This will block until a message is received.
|
||||||
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
|
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
|
||||||
return c.ReceiveTimeout(ctx, 0)
|
return c.ReceiveTimeout(ctx, 0)
|
||||||
}
|
}
|
||||||
@@ -543,7 +576,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
|
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
|
||||||
if c.pushProcessor == nil {
|
// Only process push notifications for RESP3 connections with a processor
|
||||||
|
if c.opt.Protocol != 3 || c.pushProcessor == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
package push
|
package push
|
||||||
|
|
||||||
import (
|
// No imports needed for this file
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NotificationHandlerContext provides context information about where a push notification was received.
|
// NotificationHandlerContext provides context information about where a push notification was received.
|
||||||
// This struct allows handlers to make informed decisions based on the source of the notification
|
// This struct allows handlers to make informed decisions based on the source of the notification
|
||||||
@@ -35,7 +33,12 @@ type NotificationHandlerContext struct {
|
|||||||
PubSub interface{}
|
PubSub interface{}
|
||||||
|
|
||||||
// Conn is the specific connection on which the notification was received.
|
// Conn is the specific connection on which the notification was received.
|
||||||
Conn *pool.Conn
|
// It is interface to both allow for future expansion and to avoid
|
||||||
|
// circular dependencies. The developer is responsible for type assertion.
|
||||||
|
// It can be one of the following types:
|
||||||
|
// - *pool.Conn
|
||||||
|
// - *connectionAdapter (for hitless upgrades)
|
||||||
|
Conn interface{}
|
||||||
|
|
||||||
// IsBlocking indicates if the notification was received on a blocking connection.
|
// IsBlocking indicates if the notification was received on a blocking connection.
|
||||||
IsBlocking bool
|
IsBlocking bool
|
||||||
|
315
push/processor_unit_test.go
Normal file
315
push/processor_unit_test.go
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
package push
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestProcessorCreation tests processor creation and initialization
|
||||||
|
func TestProcessorCreation(t *testing.T) {
|
||||||
|
t.Run("NewProcessor", func(t *testing.T) {
|
||||||
|
processor := NewProcessor()
|
||||||
|
if processor == nil {
|
||||||
|
t.Fatal("NewProcessor should not return nil")
|
||||||
|
}
|
||||||
|
if processor.registry == nil {
|
||||||
|
t.Error("Processor should have a registry")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewVoidProcessor", func(t *testing.T) {
|
||||||
|
voidProcessor := NewVoidProcessor()
|
||||||
|
if voidProcessor == nil {
|
||||||
|
t.Fatal("NewVoidProcessor should not return nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessorHandlerManagement tests handler registration and retrieval
|
||||||
|
func TestProcessorHandlerManagement(t *testing.T) {
|
||||||
|
processor := NewProcessor()
|
||||||
|
handler := &UnitTestHandler{name: "test-handler"}
|
||||||
|
|
||||||
|
t.Run("RegisterHandler", func(t *testing.T) {
|
||||||
|
err := processor.RegisterHandler("TEST", handler, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("RegisterHandler should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler is registered
|
||||||
|
retrievedHandler := processor.GetHandler("TEST")
|
||||||
|
if retrievedHandler != handler {
|
||||||
|
t.Error("GetHandler should return the registered handler")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RegisterProtectedHandler", func(t *testing.T) {
|
||||||
|
protectedHandler := &UnitTestHandler{name: "protected-handler"}
|
||||||
|
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler is registered
|
||||||
|
retrievedHandler := processor.GetHandler("PROTECTED")
|
||||||
|
if retrievedHandler != protectedHandler {
|
||||||
|
t.Error("GetHandler should return the protected handler")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetNonExistentHandler", func(t *testing.T) {
|
||||||
|
handler := processor.GetHandler("NONEXISTENT")
|
||||||
|
if handler != nil {
|
||||||
|
t.Error("GetHandler should return nil for non-existent handler")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterHandler", func(t *testing.T) {
|
||||||
|
err := processor.UnregisterHandler("TEST")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("UnregisterHandler should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler is removed
|
||||||
|
retrievedHandler := processor.GetHandler("TEST")
|
||||||
|
if retrievedHandler != nil {
|
||||||
|
t.Error("GetHandler should return nil after unregistering")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
|
||||||
|
err := processor.UnregisterHandler("PROTECTED")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("UnregisterHandler should error for protected handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler is still there
|
||||||
|
retrievedHandler := processor.GetHandler("PROTECTED")
|
||||||
|
if retrievedHandler == nil {
|
||||||
|
t.Error("Protected handler should not be removed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVoidProcessorBehavior tests void processor behavior
|
||||||
|
func TestVoidProcessorBehavior(t *testing.T) {
|
||||||
|
voidProcessor := NewVoidProcessor()
|
||||||
|
handler := &UnitTestHandler{name: "test-handler"}
|
||||||
|
|
||||||
|
t.Run("GetHandler", func(t *testing.T) {
|
||||||
|
retrievedHandler := voidProcessor.GetHandler("ANY")
|
||||||
|
if retrievedHandler != nil {
|
||||||
|
t.Error("VoidProcessor GetHandler should always return nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RegisterHandler", func(t *testing.T) {
|
||||||
|
err := voidProcessor.RegisterHandler("TEST", handler, false)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("VoidProcessor RegisterHandler should return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error type
|
||||||
|
if !IsVoidProcessorError(err) {
|
||||||
|
t.Error("Error should be a VoidProcessorError")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterHandler", func(t *testing.T) {
|
||||||
|
err := voidProcessor.UnregisterHandler("TEST")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("VoidProcessor UnregisterHandler should return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error type
|
||||||
|
if !IsVoidProcessorError(err) {
|
||||||
|
t.Error("Error should be a VoidProcessorError")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessPendingNotificationsNilReader tests handling of nil reader
|
||||||
|
func TestProcessPendingNotificationsNilReader(t *testing.T) {
|
||||||
|
t.Run("ProcessorWithNilReader", func(t *testing.T) {
|
||||||
|
processor := NewProcessor()
|
||||||
|
ctx := context.Background()
|
||||||
|
handlerCtx := NotificationHandlerContext{}
|
||||||
|
|
||||||
|
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
|
||||||
|
voidProcessor := NewVoidProcessor()
|
||||||
|
ctx := context.Background()
|
||||||
|
handlerCtx := NotificationHandlerContext{}
|
||||||
|
|
||||||
|
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWillHandleNotificationInClient tests the notification filtering logic
|
||||||
|
func TestWillHandleNotificationInClient(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
notificationType string
|
||||||
|
shouldHandle bool
|
||||||
|
}{
|
||||||
|
// Pub/Sub notifications (should be handled in client)
|
||||||
|
{"message", "message", true},
|
||||||
|
{"pmessage", "pmessage", true},
|
||||||
|
{"subscribe", "subscribe", true},
|
||||||
|
{"unsubscribe", "unsubscribe", true},
|
||||||
|
{"psubscribe", "psubscribe", true},
|
||||||
|
{"punsubscribe", "punsubscribe", true},
|
||||||
|
{"smessage", "smessage", true},
|
||||||
|
{"ssubscribe", "ssubscribe", true},
|
||||||
|
{"sunsubscribe", "sunsubscribe", true},
|
||||||
|
|
||||||
|
// Push notifications (should be handled by processor)
|
||||||
|
{"MOVING", "MOVING", false},
|
||||||
|
{"MIGRATING", "MIGRATING", false},
|
||||||
|
{"MIGRATED", "MIGRATED", false},
|
||||||
|
{"FAILING_OVER", "FAILING_OVER", false},
|
||||||
|
{"FAILED_OVER", "FAILED_OVER", false},
|
||||||
|
{"custom", "custom", false},
|
||||||
|
{"unknown", "unknown", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := willHandleNotificationInClient(tc.notificationType)
|
||||||
|
if result != tc.shouldHandle {
|
||||||
|
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessorErrorHandlingUnit tests error handling scenarios
|
||||||
|
func TestProcessorErrorHandlingUnit(t *testing.T) {
|
||||||
|
processor := NewProcessor()
|
||||||
|
|
||||||
|
t.Run("RegisterNilHandler", func(t *testing.T) {
|
||||||
|
err := processor.RegisterHandler("TEST", nil, false)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("RegisterHandler should error with nil handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error type
|
||||||
|
if !IsHandlerNilError(err) {
|
||||||
|
t.Error("Error should be a HandlerNilError")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
|
||||||
|
handler1 := &UnitTestHandler{name: "handler1"}
|
||||||
|
handler2 := &UnitTestHandler{name: "handler2"}
|
||||||
|
|
||||||
|
// Register first handler
|
||||||
|
err := processor.RegisterHandler("DUPLICATE", handler1, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("First RegisterHandler should not error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to register second handler with same name
|
||||||
|
err = processor.RegisterHandler("DUPLICATE", handler2, false)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("RegisterHandler should error when registering duplicate handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify original handler is still there
|
||||||
|
retrievedHandler := processor.GetHandler("DUPLICATE")
|
||||||
|
if retrievedHandler != handler1 {
|
||||||
|
t.Error("Original handler should remain after failed duplicate registration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
|
||||||
|
err := processor.UnregisterHandler("NONEXISTENT")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessorConcurrentAccess tests concurrent access to processor
|
||||||
|
func TestProcessorConcurrentAccess(t *testing.T) {
|
||||||
|
processor := NewProcessor()
|
||||||
|
|
||||||
|
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
|
||||||
|
done := make(chan bool, 2)
|
||||||
|
|
||||||
|
// Goroutine 1: Register handlers
|
||||||
|
go func() {
|
||||||
|
defer func() { done <- true }()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
handler := &UnitTestHandler{name: "concurrent-handler"}
|
||||||
|
processor.RegisterHandler("CONCURRENT", handler, false)
|
||||||
|
processor.UnregisterHandler("CONCURRENT")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Goroutine 2: Get handlers
|
||||||
|
go func() {
|
||||||
|
defer func() { done <- true }()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
processor.GetHandler("CONCURRENT")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for both goroutines to complete
|
||||||
|
<-done
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessorInterfaceCompliance tests interface compliance
|
||||||
|
func TestProcessorInterfaceCompliance(t *testing.T) {
|
||||||
|
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
|
||||||
|
var _ NotificationProcessor = (*Processor)(nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
|
||||||
|
var _ NotificationProcessor = (*VoidProcessor)(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnitTestHandler is a test implementation of NotificationHandler
|
||||||
|
type UnitTestHandler struct {
|
||||||
|
name string
|
||||||
|
lastNotification []interface{}
|
||||||
|
errorToReturn error
|
||||||
|
callCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
|
||||||
|
h.callCount++
|
||||||
|
h.lastNotification = notification
|
||||||
|
return h.errorToReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods for UnitTestHandler
|
||||||
|
func (h *UnitTestHandler) GetCallCount() int {
|
||||||
|
return h.callCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UnitTestHandler) GetLastNotification() []interface{} {
|
||||||
|
return h.lastNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UnitTestHandler) SetErrorToReturn(err error) {
|
||||||
|
h.errorToReturn = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *UnitTestHandler) Reset() {
|
||||||
|
h.callCount = 0
|
||||||
|
h.lastNotification = nil
|
||||||
|
h.errorToReturn = nil
|
||||||
|
}
|
@@ -4,24 +4,6 @@ import (
|
|||||||
"github.com/redis/go-redis/v9/push"
|
"github.com/redis/go-redis/v9/push"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Push notification constants for cluster operations
|
|
||||||
const (
|
|
||||||
// MOVING indicates a slot is being moved to a different node
|
|
||||||
PushNotificationMoving = "MOVING"
|
|
||||||
|
|
||||||
// MIGRATING indicates a slot is being migrated from this node
|
|
||||||
PushNotificationMigrating = "MIGRATING"
|
|
||||||
|
|
||||||
// MIGRATED indicates a slot has been migrated to this node
|
|
||||||
PushNotificationMigrated = "MIGRATED"
|
|
||||||
|
|
||||||
// FAILING_OVER indicates a failover is starting
|
|
||||||
PushNotificationFailingOver = "FAILING_OVER"
|
|
||||||
|
|
||||||
// FAILED_OVER indicates a failover has completed
|
|
||||||
PushNotificationFailedOver = "FAILED_OVER"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewPushNotificationProcessor creates a new push notification processor
|
// NewPushNotificationProcessor creates a new push notification processor
|
||||||
// This processor maintains a registry of handlers and processes push notifications
|
// This processor maintains a registry of handlers and processes push notifications
|
||||||
// It is used for RESP3 connections where push notifications are available
|
// It is used for RESP3 connections where push notifications are available
|
||||||
|
211
redis.go
211
redis.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9/auth"
|
"github.com/redis/go-redis/v9/auth"
|
||||||
|
"github.com/redis/go-redis/v9/hitless"
|
||||||
"github.com/redis/go-redis/v9/internal"
|
"github.com/redis/go-redis/v9/internal"
|
||||||
"github.com/redis/go-redis/v9/internal/hscan"
|
"github.com/redis/go-redis/v9/internal/hscan"
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
@@ -204,19 +205,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
|||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
type baseClient struct {
|
type baseClient struct {
|
||||||
opt *Options
|
opt *Options
|
||||||
connPool pool.Pooler
|
optLock sync.RWMutex
|
||||||
|
connPool pool.Pooler
|
||||||
|
pubSubPool *pool.PubSubPool
|
||||||
hooksMixin
|
hooksMixin
|
||||||
|
|
||||||
onClose func() error // hook called when client is closed
|
onClose func() error // hook called when client is closed
|
||||||
|
|
||||||
// Push notification processing
|
// Push notification processing
|
||||||
pushProcessor push.NotificationProcessor
|
pushProcessor push.NotificationProcessor
|
||||||
|
|
||||||
|
// Hitless upgrade manager
|
||||||
|
hitlessManager *hitless.HitlessManager
|
||||||
|
hitlessManagerLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) clone() *baseClient {
|
func (c *baseClient) clone() *baseClient {
|
||||||
clone := *c
|
c.hitlessManagerLock.RLock()
|
||||||
return &clone
|
hitlessManager := c.hitlessManager
|
||||||
|
c.hitlessManagerLock.RUnlock()
|
||||||
|
|
||||||
|
clone := &baseClient{
|
||||||
|
opt: c.opt,
|
||||||
|
connPool: c.connPool,
|
||||||
|
onClose: c.onClose,
|
||||||
|
pushProcessor: c.pushProcessor,
|
||||||
|
hitlessManager: hitlessManager,
|
||||||
|
}
|
||||||
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
|
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
|
||||||
@@ -234,21 +251,6 @@ func (c *baseClient) String() string {
|
|||||||
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
|
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
|
|
||||||
cn, err := c.connPool.NewConn(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.initConn(ctx, cn)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.connPool.CloseConn(cn)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
|
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
|
||||||
if c.opt.Limiter != nil {
|
if c.opt.Limiter != nil {
|
||||||
err := c.opt.Limiter.Allow()
|
err := c.opt.Limiter.Allow()
|
||||||
@@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cn.Inited {
|
if cn.IsInited() {
|
||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||||
if cn.Inited {
|
if !cn.Inited.CompareAndSwap(false, true) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
cn.Inited = true
|
|
||||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||||
|
|
||||||
@@ -430,6 +430,50 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable maintenance notifications if hitless upgrades are configured
|
||||||
|
c.optLock.RLock()
|
||||||
|
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
|
||||||
|
protocol := c.opt.Protocol
|
||||||
|
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
|
||||||
|
c.optLock.RUnlock()
|
||||||
|
var hitlessHandshakeErr error
|
||||||
|
if hitlessEnabled && protocol == 3 {
|
||||||
|
hitlessHandshakeErr = conn.ClientMaintNotifications(
|
||||||
|
ctx,
|
||||||
|
true,
|
||||||
|
endpointType.String(),
|
||||||
|
).Err()
|
||||||
|
if hitlessHandshakeErr != nil {
|
||||||
|
if !isRedisError(hitlessHandshakeErr) {
|
||||||
|
// if not redis error, fail the connection
|
||||||
|
return hitlessHandshakeErr
|
||||||
|
}
|
||||||
|
c.optLock.Lock()
|
||||||
|
// handshake failed - check and modify config atomically
|
||||||
|
switch c.opt.HitlessUpgradeConfig.Mode {
|
||||||
|
case hitless.MaintNotificationsEnabled:
|
||||||
|
// enabled mode, fail the connection
|
||||||
|
c.optLock.Unlock()
|
||||||
|
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
|
||||||
|
default: // will handle auto and any other
|
||||||
|
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
|
||||||
|
c.optLock.Unlock()
|
||||||
|
// auto mode, disable hitless upgrades and continue
|
||||||
|
if err := c.disableHitlessUpgrades(); err != nil {
|
||||||
|
// Log error but continue - auto mode should be resilient
|
||||||
|
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// handshake was executed successfully
|
||||||
|
// to make sure that the handshake will be executed on other connections as well if it was successfully
|
||||||
|
// executed on this connection, we will force the handshake to be executed on all connections
|
||||||
|
c.optLock.Lock()
|
||||||
|
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
|
||||||
|
c.optLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||||
libName := ""
|
libName := ""
|
||||||
libVer := Version()
|
libVer := Version()
|
||||||
@@ -446,6 +490,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cn.SetUsable(true)
|
||||||
|
cn.Inited.Store(true)
|
||||||
|
|
||||||
|
// Set the connection initialization function for potential reconnections
|
||||||
|
cn.SetInitConnFunc(c.createInitConnFunc())
|
||||||
|
|
||||||
if c.opt.OnConnect != nil {
|
if c.opt.OnConnect != nil {
|
||||||
return c.opt.OnConnect(ctx, conn)
|
return c.opt.OnConnect(ctx, conn)
|
||||||
}
|
}
|
||||||
@@ -593,19 +643,76 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
|||||||
return context.Background()
|
return context.Background()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
|
||||||
|
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
|
||||||
|
return func(ctx context.Context, cn *pool.Conn) error {
|
||||||
|
return c.initConn(ctx, cn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
|
||||||
|
// This function is called during client initialization.
|
||||||
|
// will register push notification handlers for all hitless upgrade events.
|
||||||
|
// will start background workers for handoff processing in the pool hook.
|
||||||
|
func (c *baseClient) enableHitlessUpgrades() error {
|
||||||
|
// Create client adapter
|
||||||
|
clientAdapterInstance := newClientAdapter(c)
|
||||||
|
|
||||||
|
// Create hitless manager directly
|
||||||
|
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Set the manager reference and initialize pool hook
|
||||||
|
c.hitlessManagerLock.Lock()
|
||||||
|
c.hitlessManager = manager
|
||||||
|
c.hitlessManagerLock.Unlock()
|
||||||
|
|
||||||
|
// Initialize pool hook (safe to call without lock since manager is now set)
|
||||||
|
manager.InitPoolHook(c.dialHook)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) disableHitlessUpgrades() error {
|
||||||
|
c.hitlessManagerLock.Lock()
|
||||||
|
defer c.hitlessManagerLock.Unlock()
|
||||||
|
|
||||||
|
// Close the hitless manager
|
||||||
|
if c.hitlessManager != nil {
|
||||||
|
// Closing the manager will also shutdown the pool hook
|
||||||
|
// and remove it from the pool
|
||||||
|
c.hitlessManager.Close()
|
||||||
|
c.hitlessManager = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the client, releasing any open resources.
|
// Close closes the client, releasing any open resources.
|
||||||
//
|
//
|
||||||
// It is rare to Close a Client, as the Client is meant to be
|
// It is rare to Close a Client, as the Client is meant to be
|
||||||
// long-lived and shared between many goroutines.
|
// long-lived and shared between many goroutines.
|
||||||
func (c *baseClient) Close() error {
|
func (c *baseClient) Close() error {
|
||||||
var firstErr error
|
var firstErr error
|
||||||
|
|
||||||
|
// Close hitless manager first
|
||||||
|
if err := c.disableHitlessUpgrades(); err != nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
|
||||||
if c.onClose != nil {
|
if c.onClose != nil {
|
||||||
if err := c.onClose(); err != nil {
|
if err := c.onClose(); err != nil && firstErr == nil {
|
||||||
firstErr = err
|
firstErr = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
if c.connPool != nil {
|
||||||
firstErr = err
|
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.pubSubPool != nil {
|
||||||
|
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
@@ -810,11 +917,24 @@ func NewClient(opt *Options) *Client {
|
|||||||
// Initialize push notification processor using shared helper
|
// Initialize push notification processor using shared helper
|
||||||
// Use void processor for RESP2 connections (push notifications not available)
|
// Use void processor for RESP2 connections (push notifications not available)
|
||||||
c.pushProcessor = initializePushProcessor(opt)
|
c.pushProcessor = initializePushProcessor(opt)
|
||||||
|
// Update options with the initialized push processor
|
||||||
// Update options with the initialized push processor for connection pool
|
|
||||||
opt.PushNotificationProcessor = c.pushProcessor
|
opt.PushNotificationProcessor = c.pushProcessor
|
||||||
|
|
||||||
|
// Create connection pools
|
||||||
c.connPool = newConnPool(opt, c.dialHook)
|
c.connPool = newConnPool(opt, c.dialHook)
|
||||||
|
c.pubSubPool = newPubSubPool(opt, c.dialHook)
|
||||||
|
|
||||||
|
// Initialize hitless upgrades first if enabled and protocol is RESP3
|
||||||
|
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
|
||||||
|
err := c.enableHitlessUpgrades()
|
||||||
|
if err != nil {
|
||||||
|
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
|
||||||
|
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
|
||||||
|
// panic so we fail fast without breaking existing clients api
|
||||||
|
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
@@ -851,6 +971,14 @@ func (c *Client) Options() *Options {
|
|||||||
return c.opt
|
return c.opt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetHitlessManager returns the hitless manager instance for monitoring and control.
|
||||||
|
// Returns nil if hitless upgrades are not enabled.
|
||||||
|
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
|
||||||
|
c.hitlessManagerLock.RLock()
|
||||||
|
defer c.hitlessManagerLock.RUnlock()
|
||||||
|
return c.hitlessManager
|
||||||
|
}
|
||||||
|
|
||||||
// initializePushProcessor initializes the push notification processor for any client type.
|
// initializePushProcessor initializes the push notification processor for any client type.
|
||||||
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
|
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
|
||||||
func initializePushProcessor(opt *Options) push.NotificationProcessor {
|
func initializePushProcessor(opt *Options) push.NotificationProcessor {
|
||||||
@@ -887,6 +1015,7 @@ type PoolStats pool.Stats
|
|||||||
// PoolStats returns connection pool stats.
|
// PoolStats returns connection pool stats.
|
||||||
func (c *Client) PoolStats() *PoolStats {
|
func (c *Client) PoolStats() *PoolStats {
|
||||||
stats := c.connPool.Stats()
|
stats := c.connPool.Stats()
|
||||||
|
stats.PubSubStats = *(c.pubSubPool.Stats())
|
||||||
return (*PoolStats)(stats)
|
return (*PoolStats)(stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -921,11 +1050,27 @@ func (c *Client) TxPipeline() Pipeliner {
|
|||||||
func (c *Client) pubSub() *PubSub {
|
func (c *Client) pubSub() *PubSub {
|
||||||
pubsub := &PubSub{
|
pubsub := &PubSub{
|
||||||
opt: c.opt,
|
opt: c.opt,
|
||||||
|
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||||
return c.newConn(ctx)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// will return nil if already initialized
|
||||||
|
err = c.initConn(ctx, cn)
|
||||||
|
if err != nil {
|
||||||
|
_ = cn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Track connection in PubSubPool
|
||||||
|
c.pubSubPool.TrackConn(cn)
|
||||||
|
return cn, nil
|
||||||
|
},
|
||||||
|
closeConn: func(cn *pool.Conn) error {
|
||||||
|
// Untrack connection from PubSubPool
|
||||||
|
c.pubSubPool.UntrackConn(cn)
|
||||||
|
_ = cn.Close()
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
closeConn: c.connPool.CloseConn,
|
|
||||||
pushProcessor: c.pushProcessor,
|
pushProcessor: c.pushProcessor,
|
||||||
}
|
}
|
||||||
pubsub.init()
|
pubsub.init()
|
||||||
@@ -1113,6 +1258,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica
|
|||||||
return push.NotificationHandlerContext{
|
return push.NotificationHandlerContext{
|
||||||
Client: c,
|
Client: c,
|
||||||
ConnPool: c.connPool,
|
ConnPool: c.connPool,
|
||||||
Conn: cn,
|
Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
. "github.com/bsm/ginkgo/v2"
|
. "github.com/bsm/ginkgo/v2"
|
||||||
. "github.com/bsm/gomega"
|
. "github.com/bsm/gomega"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"github.com/redis/go-redis/v9/auth"
|
"github.com/redis/go-redis/v9/auth"
|
||||||
)
|
)
|
||||||
|
52
sentinel.go
52
sentinel.go
@@ -16,8 +16,8 @@ import (
|
|||||||
"github.com/redis/go-redis/v9/internal"
|
"github.com/redis/go-redis/v9/internal"
|
||||||
"github.com/redis/go-redis/v9/internal/pool"
|
"github.com/redis/go-redis/v9/internal/pool"
|
||||||
"github.com/redis/go-redis/v9/internal/rand"
|
"github.com/redis/go-redis/v9/internal/rand"
|
||||||
"github.com/redis/go-redis/v9/push"
|
|
||||||
"github.com/redis/go-redis/v9/internal/util"
|
"github.com/redis/go-redis/v9/internal/util"
|
||||||
|
"github.com/redis/go-redis/v9/push"
|
||||||
)
|
)
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
@@ -139,6 +139,14 @@ type FailoverOptions struct {
|
|||||||
FailingTimeoutSeconds int
|
FailingTimeoutSeconds int
|
||||||
|
|
||||||
UnstableResp3 bool
|
UnstableResp3 bool
|
||||||
|
|
||||||
|
// Hitless is not supported for FailoverClients at the moment
|
||||||
|
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
|
||||||
|
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
|
||||||
|
// upgrade notifications gracefully and manage connection/pool state transitions
|
||||||
|
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||||
|
// If nil, hitless upgrades are disabled.
|
||||||
|
//HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *FailoverOptions) clientOptions() *Options {
|
func (opt *FailoverOptions) clientOptions() *Options {
|
||||||
@@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
|||||||
opt.Dialer = masterReplicaDialer(failover)
|
opt.Dialer = masterReplicaDialer(failover)
|
||||||
opt.init()
|
opt.init()
|
||||||
|
|
||||||
var connPool *pool.ConnPool
|
|
||||||
|
|
||||||
rdb := &Client{
|
rdb := &Client{
|
||||||
baseClient: &baseClient{
|
baseClient: &baseClient{
|
||||||
opt: opt,
|
opt: opt,
|
||||||
@@ -469,15 +475,18 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
|||||||
// Use void processor by default for RESP2 connections
|
// Use void processor by default for RESP2 connections
|
||||||
rdb.pushProcessor = initializePushProcessor(opt)
|
rdb.pushProcessor = initializePushProcessor(opt)
|
||||||
|
|
||||||
connPool = newConnPool(opt, rdb.dialHook)
|
rdb.connPool = newConnPool(opt, rdb.dialHook)
|
||||||
rdb.connPool = connPool
|
rdb.pubSubPool = newPubSubPool(opt, rdb.dialHook)
|
||||||
|
|
||||||
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||||
|
|
||||||
failover.mu.Lock()
|
failover.mu.Lock()
|
||||||
failover.onFailover = func(ctx context.Context, addr string) {
|
failover.onFailover = func(ctx context.Context, addr string) {
|
||||||
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
|
||||||
return cn.RemoteAddr().String() != addr
|
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
||||||
})
|
return cn.RemoteAddr().String() != addr
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
failover.mu.Unlock()
|
failover.mu.Unlock()
|
||||||
|
|
||||||
@@ -544,6 +553,7 @@ func NewSentinelClient(opt *Options) *SentinelClient {
|
|||||||
process: c.baseClient.process,
|
process: c.baseClient.process,
|
||||||
})
|
})
|
||||||
c.connPool = newConnPool(opt, c.dialHook)
|
c.connPool = newConnPool(opt, c.dialHook)
|
||||||
|
c.pubSubPool = newPubSubPool(opt, c.dialHook)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@@ -570,13 +580,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
|
|||||||
func (c *SentinelClient) pubSub() *PubSub {
|
func (c *SentinelClient) pubSub() *PubSub {
|
||||||
pubsub := &PubSub{
|
pubsub := &PubSub{
|
||||||
opt: c.opt,
|
opt: c.opt,
|
||||||
|
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||||
return c.newConn(ctx)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// will return nil if already initialized
|
||||||
|
err = c.initConn(ctx, cn)
|
||||||
|
if err != nil {
|
||||||
|
_ = cn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Track connection in PubSubPool
|
||||||
|
c.pubSubPool.TrackConn(cn)
|
||||||
|
return cn, nil
|
||||||
},
|
},
|
||||||
closeConn: c.connPool.CloseConn,
|
closeConn: func(cn *pool.Conn) error {
|
||||||
|
// Untrack connection from PubSubPool
|
||||||
|
c.pubSubPool.UntrackConn(cn)
|
||||||
|
_ = cn.Close()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
pushProcessor: c.pushProcessor,
|
||||||
}
|
}
|
||||||
pubsub.init()
|
pubsub.init()
|
||||||
|
|
||||||
return pubsub
|
return pubsub
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
tx.go
2
tx.go
@@ -24,7 +24,7 @@ type Tx struct {
|
|||||||
func (c *Client) newTx() *Tx {
|
func (c *Client) newTx() *Tx {
|
||||||
tx := Tx{
|
tx := Tx{
|
||||||
baseClient: baseClient{
|
baseClient: baseClient{
|
||||||
opt: c.opt,
|
opt: c.opt.clone(), // Clone options to avoid sharing HitlessUpgradeConfig
|
||||||
connPool: pool.NewStickyConnPool(c.connPool),
|
connPool: pool.NewStickyConnPool(c.connPool),
|
||||||
hooksMixin: c.hooksMixin.clone(),
|
hooksMixin: c.hooksMixin.clone(),
|
||||||
pushProcessor: c.pushProcessor, // Copy push processor from parent client
|
pushProcessor: c.pushProcessor, // Copy push processor from parent client
|
||||||
|
14
universal.go
14
universal.go
@@ -122,6 +122,9 @@ type UniversalOptions struct {
|
|||||||
|
|
||||||
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
|
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
|
||||||
IsClusterMode bool
|
IsClusterMode bool
|
||||||
|
|
||||||
|
// HitlessUpgradeConfig provides configuration for hitless upgrades.
|
||||||
|
HitlessUpgradeConfig *HitlessUpgradeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cluster returns cluster options created from the universal options.
|
// Cluster returns cluster options created from the universal options.
|
||||||
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
|
|||||||
IdentitySuffix: o.IdentitySuffix,
|
IdentitySuffix: o.IdentitySuffix,
|
||||||
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
|
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
|
||||||
UnstableResp3: o.UnstableResp3,
|
UnstableResp3: o.UnstableResp3,
|
||||||
|
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
|
|||||||
DisableIndentity: o.DisableIndentity,
|
DisableIndentity: o.DisableIndentity,
|
||||||
IdentitySuffix: o.IdentitySuffix,
|
IdentitySuffix: o.IdentitySuffix,
|
||||||
UnstableResp3: o.UnstableResp3,
|
UnstableResp3: o.UnstableResp3,
|
||||||
|
// Note: HitlessUpgradeConfig not supported for FailoverOptions
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,10 +289,11 @@ func (o *UniversalOptions) Simple() *Options {
|
|||||||
|
|
||||||
TLSConfig: o.TLSConfig,
|
TLSConfig: o.TLSConfig,
|
||||||
|
|
||||||
DisableIdentity: o.DisableIdentity,
|
DisableIdentity: o.DisableIdentity,
|
||||||
DisableIndentity: o.DisableIndentity,
|
DisableIndentity: o.DisableIndentity,
|
||||||
IdentitySuffix: o.IdentitySuffix,
|
IdentitySuffix: o.IdentitySuffix,
|
||||||
UnstableResp3: o.UnstableResp3,
|
UnstableResp3: o.UnstableResp3,
|
||||||
|
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user