1
0
mirror of https://github.com/redis/go-redis.git synced 2025-09-05 20:24:00 +03:00

feat(hitless): Introduce handlers for hitless upgrades

This commit includes all the work on hitless upgrades with the addition
of:

- Pubsub Pool
- Examples
- Refactor of push
- Refactor of pool (using atomics for most things)
- Introducing of hooks in pool
This commit is contained in:
Nedyalko Dyakov
2025-08-18 22:14:06 +03:00
parent 36f9f58c67
commit 5649ffb314
46 changed files with 6347 additions and 249 deletions

3
.gitignore vendored
View File

@@ -9,3 +9,6 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
# Hitless upgrade documentation (temporary)
hitless/docs/

149
adapters.go Normal file
View File

@@ -0,0 +1,149 @@
package redis
import (
"context"
"errors"
"net"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
var ErrInvalidCommand = errors.New("invalid command type")
// ErrInvalidPool is returned when the pool type is not supported.
var ErrInvalidPool = errors.New("invalid pool type")
// newClientAdapter creates a new client adapter for regular Redis clients.
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
return &clientAdapter{client: client}
}
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
type clientAdapter struct {
client *baseClient
}
// GetOptions returns the client options.
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
return &optionsAdapter{options: ca.client.opt}
}
// GetPushProcessor returns the client's push notification processor.
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
}
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
type optionsAdapter struct {
options *Options
}
// GetReadTimeout returns the read timeout.
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
return oa.options.ReadTimeout
}
// GetWriteTimeout returns the write timeout.
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
return oa.options.WriteTimeout
}
// GetNetwork returns the network type.
func (oa *optionsAdapter) GetNetwork() string {
return oa.options.Network
}
// GetAddr returns the connection address.
func (oa *optionsAdapter) GetAddr() string {
return oa.options.Addr
}
// IsTLSEnabled returns true if TLS is enabled.
func (oa *optionsAdapter) IsTLSEnabled() bool {
return oa.options.TLSConfig != nil
}
// GetProtocol returns the protocol version.
func (oa *optionsAdapter) GetProtocol() int {
return oa.options.Protocol
}
// GetPoolSize returns the connection pool size.
func (oa *optionsAdapter) GetPoolSize() int {
return oa.options.PoolSize
}
// NewDialer returns a new dialer function for the connection.
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
baseDialer := oa.options.NewDialer()
return func(ctx context.Context) (net.Conn, error) {
// Extract network and address from the options
network := oa.options.Network
addr := oa.options.Addr
return baseDialer(ctx, network, addr)
}
}
// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout
type connectionAdapter struct {
conn *pool.Conn
}
// Close closes the connection.
func (ca *connectionAdapter) Close() error {
return ca.conn.Close()
}
// IsUsable returns true if the connection is safe to use for new commands.
func (ca *connectionAdapter) IsUsable() bool {
return ca.conn.IsUsable()
}
// GetPoolConnection returns the underlying pool connection.
func (ca *connectionAdapter) GetPoolConnection() *pool.Conn {
return ca.conn
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts remain active until explicitly cleared.
func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout)
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline)
}
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
func (ca *connectionAdapter) ClearRelaxedTimeout() {
ca.conn.ClearRelaxedTimeout()
}
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
type pushProcessorAdapter struct {
processor push.NotificationProcessor
}
// RegisterHandler registers a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
if pushHandler, ok := handler.(push.NotificationHandler); ok {
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
}
return errors.New("handler must implement push.NotificationHandler")
}
// UnregisterHandler removes a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
return ppa.processor.UnregisterHandler(pushNotificationName)
}
// GetHandler returns the handler for a specific push notification name.
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
return ppa.processor.GetHandler(pushNotificationName)
}

View File

@@ -0,0 +1,348 @@
package redis
import (
"context"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
func TestEventDrivenHandoffIntegration(t *testing.T) {
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor with event-driven handoff support
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create a test pool with hooks
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(5),
PoolTimeout: time.Second,
})
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
defer testPool.Close()
// Set the pool reference in the processor for connection removal on handoff failure
processor.SetPool(testPool)
ctx := context.Background()
// Get a connection and mark it for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
// Set initialization function with a small delay to ensure handoff is pending
initConnCalled := false
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
initConnCalled = true
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark connection for handoff
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
if err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Return connection to pool - this should queue handoff
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start processing
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Try to get the same connection - should be skipped due to pending handoff
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get second connection: %v", err)
}
// Should get a different connection (the pending one should be skipped)
if conn == conn2 {
t.Error("Should have gotten a different connection while handoff is pending")
}
// Return the second connection
testPool.Put(ctx, conn2)
// Wait for handoff to complete
time.Sleep(200 * time.Millisecond)
// Verify handoff completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map")
}
if !initConnCalled {
t.Error("InitConn should have been called during handoff")
}
// Now the original connection should be available again
conn3, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get third connection: %v", err)
}
// Could be the original connection (now handed off) or a new one
testPool.Put(ctx, conn3)
})
t.Run("ConcurrentHandoffs", func(t *testing.T) {
// Create a base dialer that simulates slow handoffs
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(50 * time.Millisecond) // Simulate network delay
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(10),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
var wg sync.WaitGroup
// Start multiple concurrent handoffs
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Get connection
conn, err := testPool.Get(ctx)
if err != nil {
t.Errorf("Failed to get connection %d: %v", id, err)
return
}
// Set initialization function
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark for handoff
conn.MarkForHandoff("new-endpoint:6379", int64(id))
// Return to pool (starts async handoff)
testPool.Put(ctx, conn)
}(i)
}
wg.Wait()
// Wait for all handoffs to complete
time.Sleep(300 * time.Millisecond)
// Verify pool is still functional
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
}
testPool.Put(ctx, conn)
})
t.Run("HandoffFailureRecovery", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
}
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(3),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Get connection and mark for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
// Return to pool (starts async handoff that will fail)
testPool.Put(ctx, conn)
// Wait for handoff to fail
time.Sleep(200 * time.Millisecond)
// Connection should be removed from pending map after failed handoff
if processor.IsHandoffPending(conn) {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Pool should still be functional
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional: %v", err)
}
// In event-driven approach, the original connection remains in pool
// even after failed handoff (it's still a valid connection)
// We might get the same connection or a different one
testPool.Put(ctx, conn2)
})
t.Run("GracefulShutdown", func(t *testing.T) {
// Create a slow base dialer
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(100 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(2),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Start a handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function with delay to ensure handoff is pending
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
return nil
})
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start and begin processing
// The handoff should be pending because the slowDialer takes 100ms
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued and is being processed
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Give the handoff a moment to start processing
time.Sleep(50 * time.Millisecond)
// Shutdown processor gracefully
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = processor.Shutdown(shutdownCtx)
if err != nil {
t.Errorf("Graceful shutdown should succeed: %v", err)
}
// Handoff should have completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map after shutdown")
}
})
}

View File

@@ -193,6 +193,7 @@ type Cmdable interface {
ClientID(ctx context.Context) *IntCmd
ClientUnblock(ctx context.Context, id int64) *IntCmd
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
ConfigResetStat(ctx context.Context) *StatusCmd
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
@@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}
if enabled {
if endpointType == "" {
endpointType = "none"
}
args = append(args, "on", "moving-endpoint-type", endpointType)
} else {
args = append(args, "off")
}
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// ------------------------------------------------------------------------------------------------
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {

12
example/pubsub/go.mod Normal file
View File

@@ -0,0 +1,12 @@
module github.com/redis/go-redis/example/pubsub
go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.11.0
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

6
example/pubsub/go.sum Normal file
View File

@@ -0,0 +1,6 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

146
example/pubsub/main.go Normal file
View File

@@ -0,0 +1,146 @@
package main
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
)
var ctx = context.Background()
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
// It was used to find regressions in pool management in hitless mode.
// Please don't use it as a reference for how to use pubsub.
func main() {
wg := &sync.WaitGroup{}
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled,
},
})
_ = rdb.FlushDB(ctx).Err()
go func() {
for {
time.Sleep(2 * time.Second)
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
}
}()
err := rdb.Ping(ctx).Err()
if err != nil {
panic(err)
}
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
panic(err)
}
fmt.Println("published", rdb.Get(ctx, "published").Val())
fmt.Println("received", rdb.Get(ctx, "received").Val())
subCtx, cancelSubCtx := context.WithCancel(ctx)
pubCtx, cancelPublishers := context.WithCancel(ctx)
for i := 0; i < 10; i++ {
wg.Add(1)
go subscribe(subCtx, rdb, "test", i, wg)
}
time.Sleep(time.Second)
cancelSubCtx()
time.Sleep(time.Second)
subCtx, cancelSubCtx = context.WithCancel(ctx)
for i := 0; i < 10; i++ {
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
panic(err)
}
wg.Add(1)
go floodThePool(pubCtx, rdb, wg)
}
for i := 0; i < 500; i++ {
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
panic(err)
}
wg.Add(1)
go subscribe(subCtx, rdb, "test2", i, wg)
}
time.Sleep(5 * time.Second)
fmt.Println("canceling publishers")
cancelPublishers()
time.Sleep(10 * time.Second)
fmt.Println("canceling subscribers")
cancelSubCtx()
wg.Wait()
published, err := rdb.Get(ctx, "published").Result()
received, err := rdb.Get(ctx, "received").Result()
publishers, err := rdb.Get(ctx, "publishers").Result()
subscribers, err := rdb.Get(ctx, "subscribers").Result()
fmt.Printf("publishers: %s\n", publishers)
fmt.Printf("published: %s\n", published)
fmt.Printf("subscribers: %s\n", subscribers)
fmt.Printf("received: %s\n", received)
publishedInt, err := rdb.Get(ctx, "published").Int()
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
time.Sleep(2 * time.Second)
}
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
err := rdb.Publish(ctx, "test2", "hello").Err()
if err != nil {
// noop
//log.Println("publish error:", err)
}
err = rdb.Incr(ctx, "published").Err()
if err != nil {
// noop
//log.Println("incr error:", err)
}
time.Sleep(10 * time.Nanosecond)
}
}
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
defer wg.Done()
rec := rdb.Subscribe(ctx, topic)
recChan := rec.Channel()
for {
select {
case <-ctx.Done():
rec.Close()
return
default:
select {
case <-ctx.Done():
rec.Close()
return
case msg := <-recChan:
err := rdb.Incr(ctx, "received").Err()
if err != nil {
log.Println("incr error:", err)
}
_ = msg // Use the message to avoid unused variable warning
}
}
}
}

View File

@@ -57,6 +57,8 @@ func Example_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[ping]>
}
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// pipeline finished processing: [[ping] [ping]]
}
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[watch foo]>
// starting processing: <[ping]>
// finished processing: <[ping]>

72
hitless/README.md Normal file
View File

@@ -0,0 +1,72 @@
# Hitless Upgrades
Seamless Redis connection handoffs during topology changes without interrupting operations.
## Quick Start
```go
import "github.com/redis/go-redis/v9/hitless"
opt := &redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
HitlessUpgrades: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled, // or MaintNotificationsAuto
},
}
client := redis.NewClient(opt)
```
## Modes
- **`MaintNotificationsDisabled`**: Hitless upgrades are completely disabled
- **`MaintNotificationsEnabled`**: Hitless upgrades are forcefully enabled (fails if server doesn't support it)
- **`MaintNotificationsAuto`**: Hitless upgrades are enabled if server supports it (default)
## Configuration
```go
import "github.com/redis/go-redis/v9/hitless"
Config: &hitless.Config{
Mode: hitless.MaintNotificationsAuto, // Notification mode
MaxHandoffRetries: 3, // Retry failed handoffs
HandoffTimeout: 15 * time.Second, // Handoff operation timeout
RelaxedTimeout: 10 * time.Second, // Extended timeout during migrations
PostHandoffRelaxedDuration: 20 * time.Second, // Keep relaxed timeout after handoff
LogLevel: 1, // 0=errors, 1=warnings, 2=info, 3=debug
MaxWorkers: 15, // Concurrent handoff workers
HandoffQueueSize: 50, // Handoff request queue size
}
```
### Worker Scaling
- **Auto-calculated**: `min(10, PoolSize/3)` - scales with pool size, capped at 10
- **Explicit values**: `max(10, set_value)` - enforces minimum 10 workers
- **On-demand**: Workers created when needed, cleaned up when idle
### Queue Sizing
- **Auto-calculated**: `10 × MaxWorkers`, capped by pool size
- **Always capped**: Queue size never exceeds pool size
## Metrics Hook Example
A metrics collection hook is available in `example_hooks.go` that demonstrates how to monitor hitless upgrade operations:
```go
import "github.com/redis/go-redis/v9/hitless"
metricsHook := hitless.NewMetricsHook()
// Use with your monitoring system
```
The metrics hook tracks:
- Handoff success/failure rates
- Handoff duration
- Queue depth
- Worker utilization
- Connection lifecycle events
## Requirements
- **RESP3 Protocol**: Required for push notifications

377
hitless/config.go Normal file
View File

@@ -0,0 +1,377 @@
package hitless
import (
"net"
"runtime"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
// MaintNotificationsMode represents the maintenance notifications mode
type MaintNotificationsMode string
// Constants for maintenance push notifications modes
const (
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m MaintNotificationsMode) IsValid() bool {
switch m {
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
return true
default:
return false
}
}
// String returns the string representation of the mode
func (m MaintNotificationsMode) String() string {
return string(m)
}
// EndpointType represents the type of endpoint to request in MOVING notifications
type EndpointType string
// Constants for endpoint types
const (
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
)
// IsValid returns true if the endpoint type is valid
func (e EndpointType) IsValid() bool {
switch e {
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
return true
default:
return false
}
}
// String returns the string representation of the endpoint type
func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for hitless upgrades.
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
// Default: MaintNotificationsAuto
Mode MaintNotificationsMode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
// Default: EndpointTypeAuto
EndpointType EndpointType
// RelaxedTimeout is the concrete timeout value to use during
// MIGRATING/FAILING_OVER states to accommodate increased latency.
// This applies to both read and write timeouts.
// Default: 10 seconds
RelaxedTimeout time.Duration
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
// If handoff takes longer than this, the old connection will be forcibly closed.
// Default: 15 seconds (matches server-side eviction timeout)
HandoffTimeout time.Duration
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
// Workers are created on-demand and automatically cleaned up when idle.
// If zero, defaults to min(10, PoolSize/3) to handle bursts effectively.
// If explicitly set, enforces minimum of 10 workers.
//
// Default: min(10, PoolSize/3), Minimum when set: 10
MaxWorkers int
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
// If the queue is full, new handoff requests will be rejected.
// Always capped by pool size since you can't handoff more connections than exist.
//
// Default: 10x max workers, capped by pool size, min 2
HandoffQueueSize int
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
// after a handoff completes. This provides additional resilience during cluster transitions.
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// ScaleDownDelay is the delay before checking if workers should be scaled down.
// This prevents expensive checks on every handoff completion and avoids rapid scaling cycles.
// Default: 2 seconds
ScaleDownDelay time.Duration
// LogLevel controls the verbosity of hitless upgrade logging.
// 0 = errors only, 1 = warnings, 2 = info, 3 = debug
// Default: 1 (warnings)
LogLevel int
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
// After this many retries, the connection will be removed from the pool.
// Default: 3
MaxHandoffRetries int
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != MaintNotificationsDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
ScaleDownDelay: 2 * time.Second,
LogLevel: 1,
// Connection Handoff Configuration
MaxHandoffRetries: 3,
}
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.RelaxedTimeout <= 0 {
return ErrInvalidRelaxedTimeout
}
if c.HandoffTimeout <= 0 {
return ErrInvalidHandoffTimeout
}
// Validate worker configuration
// Allow 0 for auto-calculation, but negative values are invalid
if c.MaxWorkers < 0 {
return ErrInvalidHandoffWorkers
}
// HandoffQueueSize validation - allow 0 for auto-calculation
if c.HandoffQueueSize < 0 {
return ErrInvalidHandoffQueueSize
}
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
if c.LogLevel < 0 || c.LogLevel > 3 {
return ErrInvalidLogLevel
}
// Validate Mode (maintenance notifications mode)
if !c.Mode.IsValid() {
return ErrInvalidMaintNotifications
}
// Validate EndpointType
if !c.EndpointType.IsValid() {
return ErrInvalidEndpointType
}
// Validate configuration fields
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
return ErrInvalidHandoffRetries
}
return nil
}
// ApplyDefaults applies default values to any zero-value fields in the configuration.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaults() *Config {
return c.ApplyDefaultsWithPoolSize(0)
}
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
// using the provided pool size to calculate worker defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
if c == nil {
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
}
defaults := DefaultConfig()
result := &Config{}
// Apply defaults for enum fields (empty/zero means not set)
if c.Mode == "" {
result.Mode = defaults.Mode
} else {
result.Mode = c.Mode
}
if c.EndpointType == "" {
result.EndpointType = defaults.EndpointType
} else {
result.EndpointType = c.EndpointType
}
// Apply defaults for duration fields (zero means not set)
if c.RelaxedTimeout <= 0 {
result.RelaxedTimeout = defaults.RelaxedTimeout
} else {
result.RelaxedTimeout = c.RelaxedTimeout
}
if c.HandoffTimeout <= 0 {
result.HandoffTimeout = defaults.HandoffTimeout
} else {
result.HandoffTimeout = c.HandoffTimeout
}
// Apply defaults for integer fields (zero means not set)
if c.HandoffQueueSize <= 0 {
result.HandoffQueueSize = defaults.HandoffQueueSize
} else {
result.HandoffQueueSize = c.HandoffQueueSize
}
// Copy worker configuration
result.MaxWorkers = c.MaxWorkers
// Apply worker defaults based on pool size
result.applyWorkerDefaults(poolSize)
// Apply queue size defaults based on max workers, capped by pool size
if c.HandoffQueueSize <= 0 {
// Queue size: 10x max workers, but never more than pool size
workerBasedSize := result.MaxWorkers * 10
result.HandoffQueueSize = util.Min(workerBasedSize, poolSize)
} else {
result.HandoffQueueSize = c.HandoffQueueSize
}
// Always cap queue size by pool size - no point having more queue slots than connections
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, poolSize)
// Ensure minimum queue size of 2
if result.HandoffQueueSize < 2 {
result.HandoffQueueSize = 2
}
if c.PostHandoffRelaxedDuration <= 0 {
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
} else {
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
if c.ScaleDownDelay <= 0 {
result.ScaleDownDelay = defaults.ScaleDownDelay
} else {
result.ScaleDownDelay = c.ScaleDownDelay
}
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
// We'll use the provided value as-is, since 0 is valid
result.LogLevel = c.LogLevel
// Apply defaults for configuration fields
if c.MaxHandoffRetries <= 0 {
result.MaxHandoffRetries = defaults.MaxHandoffRetries
} else {
result.MaxHandoffRetries = c.MaxHandoffRetries
}
return result
}
// Clone creates a deep copy of the configuration.
func (c *Config) Clone() *Config {
if c == nil {
return DefaultConfig()
}
return &Config{
Mode: c.Mode,
EndpointType: c.EndpointType,
RelaxedTimeout: c.RelaxedTimeout,
HandoffTimeout: c.HandoffTimeout,
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
ScaleDownDelay: c.ScaleDownDelay,
LogLevel: c.LogLevel,
// Configuration fields
MaxHandoffRetries: c.MaxHandoffRetries,
}
}
// applyWorkerDefaults calculates and applies worker defaults based on pool size
func (c *Config) applyWorkerDefaults(poolSize int) {
// Calculate defaults based on pool size
if poolSize <= 0 {
poolSize = 10 * runtime.GOMAXPROCS(0)
}
if c.MaxWorkers == 0 {
// When not set: min(10, poolSize/3) - don't exceed 10 workers for small pools
c.MaxWorkers = util.Min(10, poolSize/3)
} else {
// When explicitly set: max(10, set_value) - ensure at least 10 workers
c.MaxWorkers = util.Max(10, c.MaxWorkers)
}
// Ensure minimum of 1 worker (fallback for very small pools)
if c.MaxWorkers < 1 {
c.MaxWorkers = 1
}
}
// DetectEndpointType automatically detects the appropriate endpoint type
// based on the connection address and TLS configuration.
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
// Parse the address to determine if it's an IP or hostname
isPrivate := isPrivateIP(addr)
var endpointType EndpointType
if tlsEnabled {
// TLS requires FQDN for certificate validation
if isPrivate {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
} else {
// No TLS, can use IP addresses
if isPrivate {
endpointType = EndpointTypeInternalIP
} else {
endpointType = EndpointTypeExternalIP
}
}
return endpointType
}
// isPrivateIP checks if the given address is in a private IP range.
func isPrivateIP(addr string) bool {
// Extract host from "host:port" format
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // Assume no port
}
ip := net.ParseIP(host)
if ip == nil {
return false // Not an IP address (likely hostname)
}
// Check for private/loopback ranges
return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
}

427
hitless/config_test.go Normal file
View File

@@ -0,0 +1,427 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
func TestConfig(t *testing.T) {
t.Run("DefaultConfig", func(t *testing.T) {
config := DefaultConfig()
// MaxWorkers should be 0 in default config (auto-calculated)
if config.MaxWorkers != 0 {
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
}
// HandoffQueueSize should be 0 in default config (auto-calculated)
if config.HandoffQueueSize != 0 {
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
}
if config.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
}
// Test configuration fields have proper defaults
if config.MaxHandoffRetries != 3 {
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
}
if config.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
}
if config.PostHandoffRelaxedDuration != 0 {
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
}
// Test that defaults are applied correctly
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
}
})
t.Run("ConfigValidation", func(t *testing.T) {
// Valid config with applied defaults
config := DefaultConfig().ApplyDefaults()
if err := config.Validate(); err != nil {
t.Errorf("Default config with applied defaults should be valid: %v", err)
}
// Invalid worker configuration (negative MaxWorkers)
config = &Config{
RelaxedTimeout: 30 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: -1, // This should be invalid
HandoffQueueSize: 100,
PostHandoffRelaxedDuration: 10 * time.Second,
LogLevel: 1,
MaxHandoffRetries: 3, // Add required field
}
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
}
// Invalid HandoffQueueSize
config = DefaultConfig().ApplyDefaults()
config.HandoffQueueSize = -1
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
}
// Invalid PostHandoffRelaxedDuration
config = DefaultConfig().ApplyDefaults()
config.PostHandoffRelaxedDuration = -1 * time.Second
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
}
})
t.Run("ConfigClone", func(t *testing.T) {
original := DefaultConfig()
original.MaxWorkers = 20
original.HandoffQueueSize = 200
cloned := original.Clone()
if cloned.MaxWorkers != 20 {
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
}
if cloned.HandoffQueueSize != 200 {
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
}
// Modify original to ensure clone is independent
original.MaxWorkers = 2
if cloned.MaxWorkers != 20 {
t.Error("Clone should be independent of original")
}
})
}
func TestApplyDefaults(t *testing.T) {
t.Run("NilConfig", func(t *testing.T) {
var config *Config
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// With nil config, should get default config with auto-calculated workers
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
workerBasedSize := result.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
}
})
t.Run("PartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 12, // Set this field explicitly
// Leave other fields as zero values
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should keep the explicitly set values
if result.MaxWorkers != 12 {
t.Errorf("Expected MaxWorkers to be 12 (explicitly set), got %d", result.MaxWorkers)
}
// Should apply default for unset fields (auto-calculated queue size, capped by pool size)
workerBasedSize := result.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
}
// Test explicit queue size capping by pool size
configWithLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than pool size
}
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
if resultCapped.HandoffQueueSize != 20 {
t.Errorf("Expected HandoffQueueSize to be capped by pool size (20), got %d", resultCapped.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
if result.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
}
})
t.Run("ZeroValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 0, // Zero value should get auto-calculated defaults
HandoffQueueSize: 0, // Zero value should get default
RelaxedTimeout: 0, // Zero value should get default
LogLevel: 0, // Zero is valid for LogLevel (errors only)
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Zero values should get auto-calculated defaults
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
workerBasedSize := result.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
// LogLevel 0 should be preserved (it's a valid value)
if result.LogLevel != 0 {
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
}
})
}
func TestProcessorWithConfig(t *testing.T) {
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 5,
HandoffQueueSize: 50,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 5 * time.Second,
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// The processor should be created successfully with custom config
if processor == nil {
t.Error("Processor should be created with custom config")
}
})
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 7, // Only set worker field
// Other fields will get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Should work with partial config (defaults applied)
if processor == nil {
t.Error("Processor should be created with partial config")
}
})
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Should use default config when nil is passed
if processor == nil {
t.Error("Processor should be created with nil config (using defaults)")
}
})
}
func TestIntegrationWithApplyDefaults(t *testing.T) {
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
// Create a partial config with only some fields set
partialConfig := &Config{
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
LogLevel: 2, // Custom value
// Other fields left as zero values - should get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor - should apply defaults to missing fields
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
defer processor.Shutdown(context.Background())
// Processor should be created successfully
if processor == nil {
t.Error("Processor should be created with partial config")
}
// Test that the ApplyDefaults method worked correctly by creating the same config
// and applying defaults manually
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should preserve custom values (when >= 10)
if expectedConfig.MaxWorkers != 15 {
t.Errorf("Expected MaxWorkers to be 15, got %d", expectedConfig.MaxWorkers)
}
if expectedConfig.LogLevel != 2 {
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
}
// Should apply defaults for missing fields (auto-calculated queue size, capped by pool size)
workerBasedSize := expectedConfig.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if expectedConfig.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, expectedConfig.HandoffQueueSize)
}
// Test that queue size is always capped by pool size
if expectedConfig.HandoffQueueSize > poolSize {
t.Errorf("HandoffQueueSize (%d) should never exceed pool size (%d)",
expectedConfig.HandoffQueueSize, poolSize)
}
if expectedConfig.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
}
if expectedConfig.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
}
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
}
})
}
func TestEnhancedConfigValidation(t *testing.T) {
t.Run("ValidateFields", func(t *testing.T) {
config := DefaultConfig()
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
// Should pass validation with default values
if err := config.Validate(); err != nil {
t.Errorf("Default config should be valid, got error: %v", err)
}
// Test invalid MaxHandoffRetries
config.MaxHandoffRetries = 0
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 0")
}
config.MaxHandoffRetries = 11
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 11")
}
config.MaxHandoffRetries = 3 // Reset to valid value
// Should pass validation again
if err := config.Validate(); err != nil {
t.Errorf("Config should be valid after reset, got error: %v", err)
}
})
}
func TestConfigClone(t *testing.T) {
original := DefaultConfig()
original.MaxHandoffRetries = 7
original.HandoffTimeout = 8 * time.Second
cloned := original.Clone()
// Test that values are copied
if cloned.MaxHandoffRetries != 7 {
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
}
if cloned.HandoffTimeout != 8*time.Second {
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
}
// Test that modifying clone doesn't affect original
cloned.MaxHandoffRetries = 10
if original.MaxHandoffRetries != 7 {
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
}
}
func TestMaxWorkersLogic(t *testing.T) {
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
testCases := []struct {
poolSize int
expectedWorkers int
description string
}{
{6, 2, "Small pool: min(10, 6/3) = min(10, 2) = 2"},
{15, 5, "Medium pool: min(10, 15/3) = min(10, 5) = 5"},
{30, 10, "Large pool: min(10, 30/3) = min(10, 10) = 10"},
{60, 10, "Very large pool: min(10, 60/3) = min(10, 20) = 10"},
{120, 10, "Huge pool: min(10, 120/3) = min(10, 40) = 10"},
}
for _, tc := range testCases {
config := &Config{} // MaxWorkers = 0 (not set)
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
testCases := []struct {
setValue int
expectedWorkers int
description string
}{
{1, 10, "Set 1: max(10, 1) = 10 (enforced minimum)"},
{5, 10, "Set 5: max(10, 5) = 10 (enforced minimum)"},
{8, 10, "Set 8: max(10, 8) = 10 (enforced minimum)"},
{10, 10, "Set 10: max(10, 10) = 10 (exact minimum)"},
{15, 15, "Set 15: max(10, 15) = 15 (respects user choice)"},
{20, 20, "Set 20: max(10, 20) = 20 (respects user choice)"},
}
for _, tc := range testCases {
config := &Config{
MaxWorkers: tc.setValue, // Explicitly set
}
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
}

76
hitless/errors.go Normal file
View File

@@ -0,0 +1,76 @@
package hitless
import (
"errors"
"fmt"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
ErrInvalidLogLevel = errors.New("hitless: log level must be between 0 and 3")
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
ErrInvalidConnectionValidationTimeout = errors.New("hitless: ConnectionValidationTimeout must be greater than 0 and less than 30 seconds")
ErrInvalidConnectionHealthCheckInterval = errors.New("hitless: ConnectionHealthCheckInterval must be between 0 and 1 hour")
ErrInvalidOperationCleanupInterval = errors.New("hitless: OperationCleanupInterval must be greater than 0 and less than 1 hour")
ErrInvalidMaxActiveOperations = errors.New("hitless: MaxActiveOperations must be between 100 and 100000")
ErrInvalidNotificationBufferSize = errors.New("hitless: NotificationBufferSize must be between 10 and 10000")
ErrInvalidNotificationTimeout = errors.New("hitless: NotificationTimeout must be greater than 0 and less than 30 seconds")
)
// Integration errors
var (
ErrInvalidClient = errors.New("hitless: invalid client type")
)
// Handoff errors
var (
ErrHandoffInProgress = errors.New("hitless: handoff already in progress")
ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress")
ErrConnectionFailed = errors.New("hitless: failed to establish new connection")
)
// Dead error variables removed - unused in simplified architecture
// Notification errors
var (
ErrInvalidNotification = errors.New("hitless: invalid notification format")
)
// Dead error variables removed - unused in simplified architecture
// HandoffError represents an error that occurred during connection handoff.
type HandoffError struct {
Operation string
Endpoint string
Cause error
}
func (e *HandoffError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause)
}
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint)
}
func (e *HandoffError) Unwrap() error {
return e.Cause
}
// NewHandoffError creates a new HandoffError.
func NewHandoffError(operation, endpoint string, cause error) *HandoffError {
return &HandoffError{
Operation: operation,
Endpoint: endpoint,
Cause: cause,
}
}

63
hitless/example_hooks.go Normal file
View File

@@ -0,0 +1,63 @@
package hitless
import (
"context"
"time"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
startTimeKey contextKey = "notif_hitless_start_time"
)
// MetricsHook collects metrics about notification processing.
type MetricsHook struct {
NotificationCounts map[string]int64
ProcessingTimes map[string]time.Duration
ErrorCounts map[string]int64
}
// NewMetricsHook creates a new metrics collection hook.
func NewMetricsHook() *MetricsHook {
return &MetricsHook{
NotificationCounts: make(map[string]int64),
ProcessingTimes: make(map[string]time.Duration),
ErrorCounts: make(map[string]int64),
}
}
// PreHook records the start time for processing metrics.
func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
mh.NotificationCounts[notificationType]++
// Store start time in context for duration calculation
startTime := time.Now()
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
return notification, true
}
// PostHook records processing completion and any errors.
func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
// Calculate processing duration
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
duration := time.Since(startTime)
mh.ProcessingTimes[notificationType] = duration
}
// Record errors
if result != nil {
mh.ErrorCounts[notificationType]++
}
}
// GetMetrics returns a summary of collected metrics.
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"notification_counts": mh.NotificationCounts,
"processing_times": mh.ProcessingTimes,
"error_counts": mh.ErrorCounts,
}
}

299
hitless/hitless_manager.go Normal file
View File

@@ -0,0 +1,299 @@
package hitless
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
)
// Push notification type constants for hitless upgrades
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
NotificationMigrated = "MIGRATED"
NotificationFailingOver = "FAILING_OVER"
NotificationFailedOver = "FAILED_OVER"
)
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
var hitlessNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
// NotificationHook is called before and after notification processing
// PreHook can modify the notification and return false to skip processing
// PostHook is called after successful processing
type NotificationHook interface {
PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool)
PostHook(ctx context.Context, notificationType string, notification []interface{}, result error)
}
// MovingOperationKey provides a unique key for tracking MOVING operations
// that combines sequence ID with connection identifier to handle duplicate
// sequence IDs across multiple connections to the same node.
type MovingOperationKey struct {
SeqID int64 // Sequence ID from MOVING notification
ConnID uint64 // Unique connection identifier
}
// String returns a string representation of the key for debugging
func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
type HitlessManager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
pool pool.Pooler
// MOVING operation tracking - using sync.Map for better concurrent performance
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
// Atomic state tracking - no locks needed for state queries
activeOperationCount atomic.Int64 // Number of active operations
closed atomic.Bool // Manager closed state
// Notification hooks for extensibility
hooks []NotificationHook
hooksMu sync.RWMutex // Protects hooks slice
poolHooksRef *PoolHook
}
// MovingOperation tracks an active MOVING operation.
type MovingOperation struct {
SeqID int64
NewEndpoint string
StartTime time.Time
Deadline time.Time
}
// NewHitlessManager creates a new simplified hitless manager.
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &HitlessManager{
client: client,
pool: pool,
options: client.GetOptions(),
config: config.Clone(),
hooks: make([]NotificationHook, 0),
}
// Set up push notification handling
if err := hm.setupPushNotifications(); err != nil {
return nil, err
}
return hm, nil
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *HitlessManager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm}
// Register handlers for all hitless upgrade notifications with the client's processor
for _, notificationType := range hitlessNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
}
}
return nil
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Create MOVING operation record
movingOp := &MovingOperation{
SeqID: seqID,
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Deadline: deadline,
}
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
internal.Logger.Printf(ctx, "Duplicate MOVING operation ignored: %s", key.String())
return nil
}
// Increment active operation count atomically
hm.activeOperationCount.Add(1)
return nil
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
}
}
// GetActiveMovingOperations returns active operations with composite keys.
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
hm.activeMovingOps.Range(func(key, value interface{}) bool {
k := key.(MovingOperationKey)
op := value.(*MovingOperation)
// Create a copy to avoid sharing references
result[k] = &MovingOperation{
SeqID: op.SeqID,
NewEndpoint: op.NewEndpoint,
StartTime: op.StartTime,
Deadline: op.Deadline,
}
return true // Continue iteration
})
return result
}
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the hitless manager.
func (hm *HitlessManager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
// Shutdown the pool hook if it exists
if hm.poolHooksRef != nil {
// Use a timeout to prevent hanging indefinitely
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := hm.poolHooksRef.Shutdown(shutdownCtx)
if err != nil {
// was not able to close pool hook, keep closed state false
hm.closed.Store(false)
return err
}
// Remove the pool hook from the pool
if hm.pool != nil {
hm.pool.RemovePoolHook(hm.poolHooksRef)
}
}
// Clear all active operations
hm.activeMovingOps.Range(func(key, value interface{}) bool {
hm.activeMovingOps.Delete(key)
return true
})
// Reset counter
hm.activeOperationCount.Store(0)
return nil
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *HitlessManager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
return StateIdle
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
currentNotification := notification
for _, hook := range hm.hooks {
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationType, currentNotification)
if !shouldContinue {
return modifiedNotification, false
}
currentNotification = modifiedNotification
}
return currentNotification, true
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
for _, hook := range hm.hooks {
hook.PostHook(ctx, notificationType, notification, result)
}
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
// Get pool size from client options for better worker defaults
poolSize := 0
if hm.options != nil {
poolSize = hm.options.GetPoolSize()
}
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
hm.poolHooksRef.SetPool(hm.pool)
return hm.poolHooksRef
}

View File

@@ -0,0 +1,260 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
)
// MockClient implements interfaces.ClientInterface for testing
type MockClient struct {
options interfaces.OptionsInterface
}
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
return mc.options
}
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
return &MockPushProcessor{}
}
// MockPushProcessor implements interfaces.NotificationProcessor for testing
type MockPushProcessor struct{}
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
return nil
}
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
return nil
}
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
return nil
}
// MockOptions implements interfaces.OptionsInterface for testing
type MockOptions struct{}
func (mo *MockOptions) GetReadTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetWriteTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetAddr() string {
return "localhost:6379"
}
func (mo *MockOptions) IsTLSEnabled() bool {
return false
}
func (mo *MockOptions) GetProtocol() int {
return 3 // RESP3
}
func (mo *MockOptions) GetPoolSize() int {
return 10
}
func (mo *MockOptions) GetNetwork() string {
return "tcp"
}
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
return nil, nil
}
}
func TestHitlessManagerRefactoring(t *testing.T) {
t.Run("AtomicStateTracking", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
// Test initial state
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress initially")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
// Add an operation
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Test state after adding operation
if !manager.IsHandoffInProgress() {
t.Error("Expected handoff in progress after adding operation")
}
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateMoving {
t.Errorf("Expected StateMoving, got %v", manager.GetState())
}
// Remove the operation
manager.UntrackOperationWithConnID(12345, 1)
// Test state after removing operation
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress after removing operation")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
})
t.Run("SyncMapPerformance", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Test concurrent operations
const numOps = 100
for i := 0; i < numOps; i++ {
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
if err != nil {
t.Fatalf("Failed to track operation %d: %v", i, err)
}
}
if manager.GetActiveOperationCount() != numOps {
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
}
// Test GetActiveMovingOperations
operations := manager.GetActiveMovingOperations()
if len(operations) != numOps {
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
}
// Remove all operations
for i := 0; i < numOps; i++ {
manager.UntrackOperationWithConnID(int64(i), uint64(i))
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
}
})
t.Run("DuplicateOperationHandling", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Add operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Try to add duplicate operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Duplicate operation should not return error: %v", err)
}
// Should still have only 1 operation
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
}
})
t.Run("NotificationTypeConstants", func(t *testing.T) {
// Test that constants are properly defined
expectedTypes := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
if len(hitlessNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
}
// Test that all expected types are present
typeMap := make(map[string]bool)
for _, t := range hitlessNotificationTypes {
typeMap[t] = true
}
for _, expected := range expectedTypes {
if !typeMap[expected] {
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
}
}
// Test that hitlessNotificationTypes contains all expected constants
expectedConstants := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
for _, expected := range expectedConstants {
found := false
for _, actual := range hitlessNotificationTypes {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
}
}
})
}

48
hitless/hooks.go Normal file
View File

@@ -0,0 +1,48 @@
package hitless
import (
"context"
"github.com/redis/go-redis/v9/internal"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel int
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel >= 2 { // Info level
internal.Logger.Printf(ctx, "hitless: processing %s notification: %v", notificationType, notification)
}
return notification, true // Continue processing with unmodified notification
}
// PostHook logs the result after processing.
func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
if result != nil && lh.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v", notificationType, result)
} else if lh.LogLevel >= 3 { // Debug level
internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType)
}
}
// FilterHook is an example hook that can filter out certain notifications.
type FilterHook struct {
BlockedTypes map[string]bool
}
// PreHook filters notifications based on type.
func (fh *FilterHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
if fh.BlockedTypes[notificationType] {
internal.Logger.Printf(ctx, "hitless: filtering out %s notification", notificationType)
return notification, false // Skip processing
}
return notification, true
}
// PostHook does nothing for filter hook.
func (fh *FilterHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
// No post-processing needed for filter hook
}

View File

@@ -0,0 +1,247 @@
package hitless
import (
"context"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *HitlessManager
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
return ErrInvalidNotification
}
// Process pre-hooks - they can modify the notification or skip processing
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, notificationType, notification)
if !shouldContinue {
return nil // Hooks decided to skip processing
}
var err error
switch notificationType {
case NotificationMoving:
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
case NotificationMigrating:
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
case NotificationMigrated:
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
case NotificationFailingOver:
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
case NotificationFailedOver:
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
default:
// Ignore other notification types (e.g., pub/sub messages)
err = nil
}
// Process post-hooks with the result
snh.manager.processPostHooks(ctx, notificationType, modifiedNotification, err)
return err
}
// handleMoving processes MOVING notifications.
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
return ErrInvalidNotification
}
seqIDStr, ok := notification[1].(string)
if !ok {
return ErrInvalidNotification
}
seqID, err := strconv.ParseInt(seqIDStr, 10, 64)
if err != nil {
return ErrInvalidNotification
}
// Extract timeS
timeSStr, ok := notification[2].(string)
if !ok {
return ErrInvalidNotification
}
timeS, err := strconv.ParseInt(timeSStr, 10, 64)
if err != nil {
return ErrInvalidNotification
}
newEndpoint := ""
if len(notification) > 3 {
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
return ErrInvalidNotification
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
return ErrInvalidNotification
}
// Type assert to get the underlying pool connection
var poolConn *pool.Conn
if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok {
poolConn = connAdapter.GetPoolConn()
} else if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
return ErrInvalidNotification
}
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
// delay the handoff for timeS/2 seconds to the same endpoint
// do this in a goroutine to avoid blocking the notification handler
go func() {
time.Sleep(time.Duration(timeS/2) * time.Second)
if poolConn == nil || poolConn.IsClosed() {
return
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
}
}()
return nil
}
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
}
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in hitless manager for monitoring/debugging
if snh.manager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return fmt.Errorf("hitless: manager not initialized")
}
return nil
}
// handleMigrating processes MIGRATING notifications.
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleMigrated processes MIGRATED notifications.
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
connAdapter.ClearRelaxedTimeout()
return nil
}
// handleFailingOver processes FAILING_OVER notifications.
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleFailedOver processes FAILED_OVER notifications.
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
connAdapter.ClearRelaxedTimeout()
return nil
}

477
hitless/pool_hook.go Normal file
View File

@@ -0,0 +1,477 @@
package hitless
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// HitlessManagerInterface defines the interface for completing handoff operations
type HitlessManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
// HandoffRequest represents a request to handoff a connection to a new endpoint
type HandoffRequest struct {
Conn *pool.Conn
ConnID uint64 // Unique connection identifier
Endpoint string
SeqID int64
Pool pool.Pooler // Pool to remove connection from on failure
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with hitless upgrade support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
baseDialer func(context.Context, string, string) (net.Conn, error)
// Network type (e.g., "tcp", "unix")
network string
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers int32 // Atomic counter for active workers
workerTimeout time.Duration // How long workers wait for work before exiting
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the hitless upgrade
config *Config
// Hitless manager for operation completion tracking
hitlessManager HitlessManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
// Apply defaults to any missing configuration fields, using pool size for worker calculations
config = config.ApplyDefaultsWithPoolSize(poolSize)
ph := &PoolHook{
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
// handoffQueue is a buffered channel for queuing handoff requests
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
// shutdown is a channel for signaling shutdown
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: 0, // Start with no workers - create on demand
workerTimeout: 30 * time.Second, // Workers exit after 30s of inactivity
config: config,
// Hitless manager for operation completion tracking
hitlessManager: hitlessManager,
}
// No upfront worker creation - workers are created on demand
return ph
}
// SetPool sets the pool interface for removing connections on handoff failure
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
ph.pool = pooler
}
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return int(atomic.LoadInt32(&ph.activeWorkers))
}
// GetScaleLevel returns 1 if workers are active, 0 if none (for testing compatibility)
func (ph *PoolHook) GetScaleLevel() int {
if atomic.LoadInt32(&ph.activeWorkers) > 0 {
return 1
}
return 0
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
_, pending := ph.pending.Load(conn.GetID())
return pending
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) error {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
// in a handoff state at the moment.
// Check if connection is usable (not in a handoff state)
// Should not happen since the pool will not return a connection that is not usable.
if !conn.IsUsable() {
return ErrConnectionMarkedForHandoff
}
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
if conn.ShouldHandoff() {
return ErrConnectionMarkedForHandoff
}
return nil
}
// OnPut is called when a connection is returned to the pool
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
// first check if we should handoff for faster rejection
if conn.ShouldHandoff() {
// check pending handoff to not queue the same connection twice
_, hasPendingHandoff := ph.pending.Load(conn.GetID())
if !hasPendingHandoff {
// Check for empty endpoint first (synchronous check)
if conn.GetHandoffEndpoint() == "" {
conn.ClearHandoffState()
} else {
if err := ph.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
return false, true, nil // Don't pool, remove connection, no error to caller
}
// Check if handoff was already processed by a worker before we can mark it as queued
if !conn.ShouldHandoff() {
// Handoff was already processed - this is normal and the connection should be pooled
return true, false, nil
}
if err := conn.MarkQueuedForHandoff(); err != nil {
// If marking fails, check if handoff was processed in the meantime
if !conn.ShouldHandoff() {
// Handoff was processed - this is normal, pool the connection
return true, false, nil
}
// Other error - remove the connection
return false, true, nil
}
return true, false, nil
}
}
}
// Default: pool the connection
return true, false, nil
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (ph *PoolHook) ensureWorkerAvailable() {
select {
case <-ph.shutdown:
return
default:
// Check if we need a new worker
currentWorkers := atomic.LoadInt32(&ph.activeWorkers)
if currentWorkers < int32(ph.maxWorkers) {
// Try to create a new worker (atomic increment to prevent race)
if atomic.CompareAndSwapInt32(&ph.activeWorkers, currentWorkers, currentWorkers+1) {
ph.workerWg.Add(1)
go ph.onDemandWorker()
}
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (ph *PoolHook) onDemandWorker() {
defer func() {
// Decrement active worker count when exiting
atomic.AddInt32(&ph.activeWorkers, -1)
ph.workerWg.Done()
}()
for {
select {
case request := <-ph.handoffQueue:
// Check for shutdown before processing
select {
case <-ph.shutdown:
// Clean up the request before exiting
ph.pending.Delete(request.ConnID)
return
default:
// Process the request
ph.processHandoffRequest(request)
}
case <-time.After(ph.workerTimeout):
// Worker has been idle for too long, exit to save resources
if ph.config != nil && ph.config.LogLevel >= 3 { // Debug level
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout)
}
return
case <-ph.shutdown:
return
}
}
}
// processHandoffRequest processes a single handoff request
func (ph *PoolHook) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer ph.pending.Delete(request.Conn.GetID())
// Create a context with handoff timeout from config
handoffTimeout := 30 * time.Second // Default fallback
if ph.config != nil && ph.config.HandoffTimeout > 0 {
handoffTimeout = ph.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-ph.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool)
// If handoff failed, restore the handoff state for potential retry
if err != nil {
request.Conn.RestoreHandoffState()
internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err)
}
// No need for scale down scheduling with on-demand workers
// Workers automatically exit when idle
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (ph *PoolHook) queueHandoff(conn *pool.Conn) error {
// Create handoff request
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: conn.GetHandoffEndpoint(),
SeqID: conn.GetMovingSeqID(),
Pool: ph.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-ph.shutdown:
return errors.New("shutdown")
default:
select {
case <-ph.shutdown:
return errors.New("shutdown")
case ph.handoffQueue <- request:
// Store in pending map
ph.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
ph.ensureWorkerAvailable()
return nil
default:
// Queue is full - log and attempt scaling
queueLen := len(ph.handoffQueue)
queueCap := cap(ph.handoffQueue)
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers",
queueLen, queueCap)
}
}
}
// Ensure we have workers available to handle the load
ph.ensureWorkerAvailable()
return errors.New("queue full")
}
// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure
// if err is returned, connection will be removed from pool
func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) error {
// Clear handoff state after successful handoff
seqID := conn.GetMovingSeqID()
connID := conn.GetID()
// Notify hitless manager of completion if available
if ph.hitlessManager != nil {
defer ph.hitlessManager.UntrackOperationWithConnID(seqID, connID)
}
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
// TODO(hitless): Handle by performing the handoff to the current endpoint in N seconds,
// Where N is the time in the moving notification...
// For now, clear the handoff state and return
conn.ClearHandoffState()
return nil
}
retries := conn.IncrementAndGetHandoffRetries(1)
maxRetries := 3 // Default fallback
if ph.config != nil {
maxRetries = ph.config.MaxHandoffRetries
}
if retries > maxRetries {
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of connection %d to %s",
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
}
err := ErrMaxHandoffRetriesReached
if pooler != nil {
go pooler.Remove(ctx, conn, err)
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed connection %d from pool due to max handoff retries reached",
conn.GetID())
}
} else {
go conn.Close()
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
conn.GetID(), err)
}
}
return err
}
// Create endpoint-specific dialer
endpointDialer := ph.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
// TODO(hitless): retry
// This is the only case where we should retry the handoff request
// Should we do anything else other than return the error?
return err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// Remove the connection from the pool since it's in a bad state
if pooler != nil {
// Use pool.Pooler interface directly - no adapter needed
go pooler.Remove(ctx, conn, err)
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed connection %d from pool due to handoff initialization failure: %v",
conn.GetID(), err)
}
} else {
go conn.Close()
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
conn.GetID(), err)
}
}
// Keep the handoff state for retry
return err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := ph.config.RelaxedTimeout
postHandoffDuration := ph.config.PostHandoffRelaxedDuration
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(postHandoffDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if ph.config.LogLevel >= 2 { // Info level
internal.Logger.Printf(context.Background(),
"hitless: applied post-handoff relaxed timeout (%v) until %v for connection %d",
relaxedTimeout, deadline.Format("15:04:05.000"), connID)
}
}
return nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port))
}
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
ph.shutdownOnce.Do(func() {
close(ph.shutdown)
// No timers to clean up with on-demand workers
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
ph.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff")

959
hitless/pool_hook_test.go Normal file
View File

@@ -0,0 +1,959 @@
package hitless
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
shouldFailInit bool
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// createMockPoolConnection creates a mock pool connection for testing
func createMockPoolConnection() *pool.Conn {
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
conn.SetUsable(true) // Make connection usable for testing
return conn
}
// mockPool implements pool.Pooler for testing
type mockPool struct {
removedConnections map[uint64]bool
mu sync.Mutex
}
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
return nil
}
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
// Not implemented for testing
}
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Use pool.Conn directly - no adapter needed
mp.removedConnections[conn.GetID()] = true
}
// WasRemoved safely checks if a connection was removed from the pool
func (mp *mockPool) WasRemoved(connID uint64) bool {
mp.mu.Lock()
defer mp.mu.Unlock()
return mp.removedConnections[connID]
}
func (mp *mockPool) Len() int {
return 0
}
func (mp *mockPool) IdleLen() int {
return 0
}
func (mp *mockPool) Stats() *pool.Stats {
return &pool.Stats{}
}
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) Close() error {
return nil
}
// TestConnectionHook tests the Redis connection processor functionality
func TestConnectionHook(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Verify connection is marked for handoff
if !conn.ShouldHandoff() {
t.Fatal("Connection should be marked for handoff")
}
// Set a mock initialization function with synchronization
initConnCalled := make(chan bool, 1)
proceedWithInit := make(chan bool, 1)
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
select {
case initConnCalled <- true:
default:
}
// Wait for test to proceed
<-proceedWithInit
return nil
}
conn.SetInitConnFunc(initConnFunc)
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection immediately (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled immediately with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for initialization to be called (indicates handoff started)
select {
case <-initConnCalled:
// Good, initialization was called
case <-time.After(1 * time.Second):
t.Fatal("Timeout waiting for initialization function to be called")
}
// Connection should be in pending map while initialization is blocked
if _, pending := processor.pending.Load(conn.GetID()); !pending {
t.Error("Connection should be in pending handoffs map")
}
// Allow initialization to proceed
proceedWithInit <- true
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(2 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify handoff completed (removed from pending map)
if _, pending := processor.pending.Load(conn); pending {
t.Error("Connection should be removed from pending map after handoff")
}
// Verify connection is usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after successful handoff")
}
// Verify handoff state is cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after completion")
}
})
t.Run("HandoffNotNeeded", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// Don't mark for handoff
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error when handoff not needed: %v", err)
}
// Should pool the connection normally
if !shouldPool {
t.Error("Connection should be pooled when no handoff needed")
}
if shouldRemove {
t.Error("Connection should not be removed when no handoff needed")
}
})
t.Run("EmptyEndpoint", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error with empty endpoint: %v", err)
}
// Should pool the connection (empty endpoint clears state)
if !shouldPool {
t.Error("Connection should be pooled after clearing empty endpoint")
}
if shouldRemove {
t.Error("Connection should not be removed after clearing empty endpoint")
}
// State should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
}
})
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("dial failed")
}
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
HandoffTimeout: 1 * time.Second, // Shorter timeout for faster test
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not return error to caller: %v", err)
}
// Should pool the connection initially (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled initially with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for handoff to complete and fail with proper timeout and polling
// Use longer timeout to account for handoff timeout + processing time
timeout := time.After(5 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
// wait for handoff to start
time.Sleep(100 * time.Millisecond)
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for failed handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn.GetID()); !pending {
handoffCompleted = true
}
}
}
// Connection should be removed from pending map after failed handoff
if _, pending := processor.pending.Load(conn.GetID()); pending {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Handoff state should still be set (since handoff failed)
if !conn.ShouldHandoff() {
t.Error("Connection should still be marked for handoff after failed handoff")
}
})
t.Run("BufferedDataRESP2", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// For this test, we'll just verify the logic works for connections without buffered data
// The actual buffered data detection is handled by the pool's connection health check
// which is outside the scope of the Redis connection processor
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection normally (no buffered data in mock)
if !shouldPool {
t.Error("Connection should be pooled when no buffered data")
}
if shouldRemove {
t.Error("Connection should not be removed when no buffered data")
}
})
t.Run("OnGet", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should not error for normal connection: %v", err)
}
})
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Simulate a pending handoff by marking for handoff and queuing
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Clean up
processor.pending.Delete(conn)
})
t.Run("EventDrivenStateManagement", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Test initial state - no pending handoffs
if _, pending := processor.pending.Load(conn); pending {
t.Error("New connection should not have pending handoffs")
}
// Test adding to pending map
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
if _, pending := processor.pending.Load(conn.GetID()); !pending {
t.Error("Connection should be in pending map")
}
// Test OnGet with pending handoff
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
}
// Test removing from pending map and clearing handoff state
processor.pending.Delete(conn)
if _, pending := processor.pending.Load(conn); pending {
t.Error("Connection should be removed from pending map")
}
// Clear handoff state to simulate completed handoff
conn.ClearHandoffState()
conn.SetUsable(true) // Make connection usable again
// Test OnGet without pending handoff
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("Should not return error for non-pending connection: %v", err)
}
})
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
// Create processor with small queue to test optimization features
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 2,
MaxHandoffRetries: 3, // Small queue to trigger optimizations
LogLevel: 3, // Debug level to see optimization logs
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Add small delay to simulate network latency
time.Sleep(10 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create multiple connections that need handoff to fill the queue
connections := make([]*pool.Conn, 5)
for i := 0; i < 5; i++ {
connections[i] = createMockPoolConnection()
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
}
// Set a mock initialization function
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
}
ctx := context.Background()
successCount := 0
// Process connections - should trigger scaling and timeout logic
for _, conn := range connections {
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Logf("OnPut returned error (expected with timeout): %v", err)
}
if shouldPool && !shouldRemove {
successCount++
}
}
// With timeout and scaling, most handoffs should eventually succeed
if successCount == 0 {
t.Error("Should have queued some handoffs with timeout and scaling")
}
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
// Give time for workers to process and scaling to occur
time.Sleep(100 * time.Millisecond)
})
t.Run("WorkerScalingBehavior", func(t *testing.T) {
// Create processor with small queue to test scaling behavior
config := &Config{
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
HandoffQueueSize: 1,
MaxHandoffRetries: 3, // Very small queue to force scaling
LogLevel: 2, // Info level to see scaling logs
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Verify initial worker count (should be 0 with on-demand workers)
if processor.GetCurrentWorkers() != 0 {
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
}
if processor.GetScaleLevel() != 0 {
t.Errorf("Processor should be at scale level 0 initially, got %d", processor.GetScaleLevel())
}
if processor.maxWorkers != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers)
}
// The on-demand worker behavior creates workers only when needed
// This test just verifies the basic configuration is correct
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
processor.maxWorkers, processor.GetCurrentWorkers())
})
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
// Create processor with fast post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
RelaxedTimeout: 5 * time.Second,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a connection and trigger handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Handoff should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify relaxed timeout is set with deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should have relaxed timeout after handoff")
}
// Test that timeout is still active before deadline
// We'll use HasRelaxedTimeout which internally checks the deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should still have active relaxed timeout before deadline")
}
// Wait for deadline to pass
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
// Test that timeout is automatically restored after deadline
// HasRelaxedTimeout should return false after deadline passes
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have active relaxed timeout after deadline")
}
// Additional verification: calling HasRelaxedTimeout again should still return false
// and should have cleared the internal timeout values
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have relaxed timeout after deadline (second check)")
}
t.Logf("Passive timeout restoration test completed successfully")
})
t.Run("UsableFlagBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a new connection without setting it usable
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
// Initially, connection should not be usable (not initialized)
if conn.IsUsable() {
t.Error("New connection should not be usable before initialization")
}
// Simulate initialization by setting usable to true
conn.SetUsable(true)
if !conn.IsUsable() {
t.Error("Connection should be usable after initialization")
}
// OnGet should succeed for usable connection
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed for usable connection: %v", err)
}
// Mark connection for handoff
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Connection should still be usable until queued, but marked for handoff
if !conn.IsUsable() {
t.Error("Connection should still be usable after being marked for handoff (until queued)")
}
if !conn.ShouldHandoff() {
t.Error("Connection should be marked for handoff")
}
// OnGet should fail for connection marked for handoff
err = processor.OnGet(ctx, conn, false)
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete
time.Sleep(50 * time.Millisecond)
// After handoff completion, connection should be usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after handoff completion")
}
// OnGet should succeed again
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed after handoff completion: %v", err)
}
t.Logf("Usable flag behavior test completed successfully")
})
t.Run("StaticQueueBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 50,
MaxHandoffRetries: 3, // Explicit static queue size
LogLevel: 2,
}
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
defer processor.Shutdown(context.Background())
// Verify queue capacity matches configured size
queueCapacity := cap(processor.handoffQueue)
if queueCapacity != 50 {
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
}
// Test that queue size is static regardless of pool size
// (No dynamic resizing should occur)
ctx := context.Background()
// Fill part of the queue
for i := 0; i < 10; i++ {
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Failed to queue handoff %d: %v", i, err)
}
if !shouldPool || shouldRemove {
t.Errorf("Connection %d should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
i, shouldPool, shouldRemove)
}
}
// Verify queue capacity remains static (the main purpose of this test)
finalCapacity := cap(processor.handoffQueue)
if finalCapacity != 50 {
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
}
// Note: We don't check queue size here because workers process items quickly
// The important thing is that the capacity remains static regardless of pool size
})
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
// Create a failing dialer that will cause handoff initialization to fail
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Return a connection that will fail during initialization
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
}
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a mock pool that tracks removals
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
processor.SetPool(mockPool)
ctx := context.Background()
// Create a connection and mark it for handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a failing initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("initialization failed")
})
// Process the connection - handoff should fail and connection should be removed
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after failed handoff attempt")
}
// Wait for handoff to be attempted and fail
time.Sleep(100 * time.Millisecond)
// Verify that the connection was removed from the pool
if !mockPool.WasRemoved(conn.GetID()) {
t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID())
}
t.Logf("Connection removal on handoff failure test completed successfully")
})
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
// Create config with short post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
RelaxedTimeout: 5 * time.Second,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Fatalf("OnPut failed: %v", err)
}
if !shouldPool {
t.Error("Connection should be pooled after successful handoff")
}
if shouldRemove {
t.Error("Connection should not be removed after successful handoff")
}
// Wait for the handoff to complete (it happens asynchronously)
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify that relaxed timeout was applied to the new connection
if !conn.HasRelaxedTimeout() {
t.Error("New connection should have relaxed timeout applied after handoff")
}
// Wait for the post-handoff duration to expire
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
// Verify that relaxed timeout was automatically cleared
if conn.HasRelaxedTimeout() {
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
}
})
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
conn := createMockPoolConnection()
// First mark should succeed
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("First MarkForHandoff should succeed: %v", err)
}
// Second mark should fail
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
t.Fatal("Second MarkForHandoff should return error")
} else if err.Error() != "connection is already marked for handoff" {
t.Fatalf("Expected specific error message, got: %v", err)
}
// Verify original handoff data is preserved
if !conn.ShouldHandoff() {
t.Fatal("Connection should still be marked for handoff")
}
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
}
if conn.GetMovingSeqID() != 1 {
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
}
})
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
// Test that HandoffTimeout from config is actually used
customTimeout := 2 * time.Second
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
HandoffTimeout: customTimeout, // Custom timeout
MaxHandoffRetries: 1, // Single retry to speed up test
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a connection that will test the timeout
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a dialer that will check the context timeout
var timeoutVerified int32 // Use atomic for thread safety
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
// Check that the context has the expected timeout
deadline, ok := ctx.Deadline()
if !ok {
t.Error("Context should have a deadline")
return errors.New("no deadline")
}
// The deadline should be approximately customTimeout from now
expectedDeadline := time.Now().Add(customTimeout)
timeDiff := deadline.Sub(expectedDeadline)
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
expectedDeadline, deadline, timeDiff)
} else {
atomic.StoreInt32(&timeoutVerified, 1)
}
return nil // Successful handoff
})
// Trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
if err != nil {
t.Errorf("OnPut should not return error: %v", err)
}
// Connection should be queued for handoff
if !shouldPool || shouldRemove {
t.Errorf("Connection should be pooled for handoff processing")
}
// Wait for handoff to complete
time.Sleep(500 * time.Millisecond)
if atomic.LoadInt32(&timeoutVerified) == 0 {
t.Error("HandoffTimeout was not properly applied to context")
}
t.Logf("HandoffTimeout configuration test completed successfully")
})
}

24
hitless/state.go Normal file
View File

@@ -0,0 +1,24 @@
package hitless
// State represents the current state of a hitless upgrade operation.
type State int
const (
// StateIdle indicates no upgrade is in progress
StateIdle State = iota
// StateHandoff indicates a connection handoff is in progress
StateMoving
)
// String returns a string representation of the state.
func (s State) String() string {
switch s {
case StateIdle:
return "idle"
case StateMoving:
return "moving"
default:
return "unknown"
}
}

View File

@@ -0,0 +1,67 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the hitless upgrade package to avoid circular dependencies.
package interfaces
import (
"context"
"net"
"time"
)
// Forward declaration to avoid circular imports
type NotificationProcessor interface {
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
UnregisterHandler(pushNotificationName string) error
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for hitless upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface
// GetPushProcessor returns the client's push notification processor.
GetPushProcessor() NotificationProcessor
}
// OptionsInterface defines the interface for client options.
type OptionsInterface interface {
// GetReadTimeout returns the read timeout.
GetReadTimeout() time.Duration
// GetWriteTimeout returns the write timeout.
GetWriteTimeout() time.Duration
// GetNetwork returns the network type.
GetNetwork() string
// GetAddr returns the connection address.
GetAddr() string
// IsTLSEnabled returns true if TLS is enabled.
IsTLSEnabled() bool
// GetProtocol returns the protocol version.
GetProtocol() int
// GetPoolSize returns the connection pool size.
GetPoolSize() int
// NewDialer returns a new dialer function for the connection.
NewDialer() func(context.Context) (net.Conn, error)
}
// ConnectionWithRelaxedTimeout defines the interface for connections that support relaxed timeout adjustment.
// This is used by the hitless upgrade system for per-connection timeout management.
type ConnectionWithRelaxedTimeout interface {
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts remain active until explicitly cleared.
SetRelaxedTimeout(readTimeout, writeTimeout time.Duration)
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time)
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
ClearRelaxedTimeout()
}

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
if err != nil {
b.Fatal(err)
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
}
})
})

View File

@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should use default buffer sizes when not specified", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
})
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should handle zero buffer sizes by using defaults", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
// without setting ReadBufferSize and WriteBufferSize
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
// ReadBufferSize and WriteBufferSize are not set (will be 0)
})

View File

@@ -3,7 +3,10 @@ package pool
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
@@ -12,17 +15,64 @@ import (
var noDeadline = time.Time{}
// Global atomic counter for connection IDs
var connIDCounter uint64
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
type atomicNetConn struct {
conn net.Conn
}
// generateConnID generates a fast unique identifier for a connection with zero allocations
func generateConnID() uint64 {
return atomic.AddUint64(&connIDCounter, 1)
}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
// Lock-free netConn access using atomic.Value
// Contains *atomicNetConn wrapper, accessed atomically for better performance
netConnAtomic atomic.Value // stores *atomicNetConn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
// Lightweight mutex to protect reader operations during handoff
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
pooled bool
closed atomic.Bool
createdAt time.Time
expiresAt time.Time
// Hitless upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
// Counter to track multiple relaxed timeout setters if we have nested calls
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
// if counter reaches 0, we clear the relaxed timeouts
relaxedCounter atomic.Int32
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking across handoffs
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// newEndpointAtomic needs special handling as it's a string
newEndpointAtomic atomic.Value // stores string
onClose func() error
}
@@ -33,8 +83,8 @@ func NewConn(netConn net.Conn) *Conn {
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
id: generateConnID(), // Generate unique ID for this connection
}
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
@@ -50,6 +100,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
}
// Store netConn atomically for lock-free access using wrapper
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic handoff state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.shouldHandoffAtomic.Store(false) // false initially
cn.movingSeqIDAtomic.Store(0) // 0 initially
cn.handoffRetriesAtomic.Store(0) // 0 initially
cn.newEndpointAtomic.Store("") // empty string initially
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
@@ -64,23 +124,368 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
if v := cn.netConnAtomic.Load(); v != nil {
if wrapper, ok := v.(*atomicNetConn); ok {
return wrapper.conn
}
}
return nil
}
// setNetConn stores the network connection atomically (lock-free).
// This is used for the fast path of connection replacement.
func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
}
// shouldHandoff returns true if connection needs handoff (lock-free).
func (cn *Conn) shouldHandoff() bool {
return cn.shouldHandoffAtomic.Load()
}
// setShouldHandoff sets the handoff flag atomically (lock-free).
func (cn *Conn) setShouldHandoff(should bool) {
cn.shouldHandoffAtomic.Store(should)
}
// getMovingSeqID returns the sequence ID atomically (lock-free).
func (cn *Conn) getMovingSeqID() int64 {
return cn.movingSeqIDAtomic.Load()
}
// setMovingSeqID sets the sequence ID atomically (lock-free).
func (cn *Conn) setMovingSeqID(seqID int64) {
cn.movingSeqIDAtomic.Store(seqID)
}
// getNewEndpoint returns the new endpoint atomically (lock-free).
func (cn *Conn) getNewEndpoint() string {
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
return endpoint.(string)
}
return ""
}
// setNewEndpoint sets the new endpoint atomically (lock-free).
func (cn *Conn) setNewEndpoint(endpoint string) {
cn.newEndpointAtomic.Store(endpoint)
}
// setHandoffRetries sets the retry count atomically (lock-free).
func (cn *Conn) setHandoffRetries(retries int) {
cn.handoffRetriesAtomic.Store(uint32(retries))
}
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
func (cn *Conn) IsInited() bool {
return cn.Inited.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
}
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
// Uses atomic operations for lock-free access.
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
// Use compare-and-swap to ensure only one goroutine clears
if cn.relaxedCounter.CompareAndSwap(newCount, 0) {
cn.clearRelaxedTimeout()
}
}
}
func (cn *Conn) clearRelaxedTimeout() {
cn.relaxedReadTimeoutNs.Store(0)
cn.relaxedWriteTimeoutNs.Store(0)
cn.relaxedDeadlineNs.Store(0)
cn.relaxedCounter.Store(0)
}
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
// This checks both the timeout values and the deadline (if set).
// Uses atomic operations for lock-free access.
func (cn *Conn) HasRelaxedTimeout() bool {
// Fast path: no relaxed timeouts are set
if cn.relaxedCounter.Load() <= 0 {
return false
}
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// If no relaxed timeouts are set, return false
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
return false
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, relaxed timeouts are active
if deadlineNs == 0 {
return true
}
// If deadline is set, check if it's still in the future
return time.Now().UnixNano() < deadlineNs
}
// getEffectiveReadTimeout returns the timeout to use for read operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
// Fast path: no relaxed timeout set
if readTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(readTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
// getEffectiveWriteTimeout returns the timeout to use for write operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// Fast path: no relaxed timeout set
if writeTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(writeTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
cn.initConnFunc = fn
}
// ExecuteInitConn runs the stored connection initialization function if available.
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
if cn.initConnFunc != nil {
return cn.initConnFunc(ctx, cn)
}
return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
// Store the new connection atomically first (lock-free)
cn.setNetConn(netConn)
// Clear relaxed timeouts when connection is replaced
cn.clearRelaxedTimeout()
// Protect reader reset operations to avoid data races
// Use write lock since we're modifying the reader state
cn.readerMu.Lock()
cn.rd.Reset(netConn)
cn.readerMu.Unlock()
cn.bw.Reset(netConn)
}
// GetNetConn safely returns the current network connection using atomic load (lock-free).
// This method is used by the pool for health checks and provides better performance.
func (cn *Conn) GetNetConn() net.Conn {
return cn.getNetConn()
}
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
// New connection is not initialized yet
cn.Inited.Store(false)
// Replace the underlying connection
cn.SetNetConn(netConn)
return cn.ExecuteInitConn(ctx)
}
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
// Returns an error if the connection is already marked for handoff.
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
// Use single atomic CAS operation for state transition
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
return errors.New("connection is already marked for handoff")
}
cn.setNewEndpoint(newEndpoint)
cn.setMovingSeqID(seqID)
return nil
}
func (cn *Conn) MarkQueuedForHandoff() error {
// Use single atomic CAS operation for state transition
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
return errors.New("connection was not marked for handoff")
}
cn.setUsable(false)
return nil
}
// RestoreHandoffState restores the handoff state after a failed handoff (lock-free).
func (cn *Conn) RestoreHandoffState() {
// Restore shouldHandoff flag for retry
cn.shouldHandoffAtomic.Store(true)
// Keep usable=false to prevent the connection from being used until handoff succeeds
cn.setUsable(false)
}
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
func (cn *Conn) ShouldHandoff() bool {
return cn.shouldHandoff()
}
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
func (cn *Conn) GetHandoffEndpoint() string {
return cn.getNewEndpoint()
}
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
func (cn *Conn) GetMovingSeqID() int64 {
return cn.getMovingSeqID()
}
// GetID returns the unique identifier for this connection.
func (cn *Conn) GetID() uint64 {
return cn.id
}
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
func (cn *Conn) ClearHandoffState() {
// clear handoff state
cn.setShouldHandoff(false)
cn.setNewEndpoint("")
cn.setMovingSeqID(0)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return cn.incrementHandoffRetries(n)
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
return cn.rd.Buffered() > 0
}
// PeekReplyTypeSafe safely peeks at the reply type.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
if cn.rd.Buffered() <= 0 {
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
}
return cn.rd.PeekReplyType()
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Write(b)
}
return 0, net.ErrClosed
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.RemoteAddr()
}
return nil
}
@@ -89,7 +494,16 @@ func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
// Get the connection directly from atomic storage
netConn := cn.getNetConn()
if netConn == nil {
return fmt.Errorf("redis: connection not available")
}
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
}
@@ -100,13 +514,26 @@ func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
// Always set write deadline, even if getNetConn() returns nil
// This prevents write operations from hanging indefinitely
if netConn := cn.getNetConn(); netConn != nil {
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
} else {
// If getNetConn() returns nil, we still need to respect the timeout
// Return an error to prevent indefinite blocking
return fmt.Errorf("redis: connection not available for write operation")
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
if netConn := cn.getNetConn(); netConn != nil {
cn.bw.Reset(netConn)
}
}
if err := fn(cn.wr); err != nil {
@@ -116,19 +543,33 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}
func (cn *Conn) IsClosed() bool {
return cn.closed.Load()
}
func (cn *Conn) Close() error {
cn.closed.Store(true)
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Close()
}
return nil
}
// MaybeHasData tries to peek at the next byte in the socket without consuming it
// This is used to check if there are push notifications available
// Important: This will work on Linux, but not on Windows
func (cn *Conn) MaybeHasData() bool {
return maybeHasData(cn.netConn)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return maybeHasData(netConn)
}
return false
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {

View File

@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
}
func (cn *Conn) NetConn() net.Conn {
return cn.netConn
return cn.getNetConn()
}
func (p *ConnPool) CheckMinIdleConns() {

114
internal/pool/hooks.go Normal file
View File

@@ -0,0 +1,114 @@
package pool
import (
"context"
"sync"
)
// PoolHook defines the interface for connection lifecycle hooks.
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
}
// PoolHookManager manages multiple pool hooks.
type PoolHookManager struct {
hooks []PoolHook
hooksMu sync.RWMutex
}
// NewPoolHookManager creates a new pool hook manager.
func NewPoolHookManager() *PoolHookManager {
return &PoolHookManager{
hooks: make([]PoolHook, 0),
}
}
// AddHook adds a pool hook to the manager.
// Hooks are called in the order they were added.
func (phm *PoolHookManager) AddHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
phm.hooks = append(phm.hooks, hook)
}
// RemoveHook removes a pool hook from the manager.
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
for i, h := range phm.hooks {
if h == hook {
// Remove hook by swapping with last element and truncating
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
phm.hooks = phm.hooks[:len(phm.hooks)-1]
break
}
}
}
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
return err
}
}
return nil
}
// ProcessOnPut calls all OnPut hooks in order.
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
shouldPool = true // Default to pooling the connection
for _, hook := range phm.hooks {
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
if hookErr != nil {
return false, true, hookErr
}
// If any hook says to remove or not pool, respect that decision
if hookShouldRemove {
return false, true, nil
}
if !hookShouldPool {
shouldPool = false
}
}
return shouldPool, false, nil
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
return len(phm.hooks)
}
// GetHooks returns a copy of all registered hooks.
func (phm *PoolHookManager) GetHooks() []PoolHook {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
hooks := make([]PoolHook, len(phm.hooks))
copy(hooks, phm.hooks)
return hooks
}

213
internal/pool/hooks_test.go Normal file
View File

@@ -0,0 +1,213 @@
package pool
import (
"context"
"errors"
"net"
"testing"
"time"
)
// TestHook for testing hook functionality
type TestHook struct {
OnGetCalled int
OnPutCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
}
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
th.OnGetCalled++
return th.GetError
}
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
th.OnPutCalled++
return th.ShouldPool, th.ShouldRemove, th.PutError
}
func TestPoolHookManager(t *testing.T) {
manager := NewPoolHookManager()
// Test initial state
if manager.GetHookCount() != 0 {
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
}
// Add hooks
hook1 := &TestHook{ShouldPool: true}
hook2 := &TestHook{ShouldPool: true}
manager.AddHook(hook1)
manager.AddHook(hook2)
if manager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
}
// Test ProcessOnGet
ctx := context.Background()
conn := &Conn{} // Mock connection
err := manager.ProcessOnGet(ctx, conn, false)
if err != nil {
t.Errorf("ProcessOnGet should not error: %v", err)
}
if hook1.OnGetCalled != 1 {
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
}
if hook2.OnGetCalled != 1 {
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
}
// Test ProcessOnPut
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if !shouldPool {
t.Error("Expected shouldPool to be true")
}
if shouldRemove {
t.Error("Expected shouldRemove to be false")
}
if hook1.OnPutCalled != 1 {
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
}
if hook2.OnPutCalled != 1 {
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
}
// Remove a hook
manager.RemoveHook(hook1)
if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
}
}
func TestHookErrorHandling(t *testing.T) {
manager := NewPoolHookManager()
// Hook that returns error on Get
errorHook := &TestHook{
GetError: errors.New("test error"),
ShouldPool: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(errorHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
// Test that error stops processing
err := manager.ProcessOnGet(ctx, conn, false)
if err == nil {
t.Error("Expected error from ProcessOnGet")
}
if errorHook.OnGetCalled != 1 {
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
}
// normalHook should not be called due to error
if normalHook.OnGetCalled != 0 {
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
}
}
func TestHookShouldRemove(t *testing.T) {
manager := NewPoolHookManager()
// Hook that says to remove connection
removeHook := &TestHook{
ShouldPool: false,
ShouldRemove: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(removeHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if shouldPool {
t.Error("Expected shouldPool to be false")
}
if !shouldRemove {
t.Error("Expected shouldRemove to be true")
}
if removeHook.OnPutCalled != 1 {
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
}
// normalHook should not be called due to early return
if normalHook.OnPutCalled != 0 {
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
}
}
func TestPoolWithHooks(t *testing.T) {
// Create a pool with hooks
hookManager := NewPoolHookManager()
testHook := &TestHook{ShouldPool: true}
hookManager.AddHook(testHook)
opt := &Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // Mock connection
},
PoolSize: 1,
DialTimeout: time.Second,
}
pool := NewConnPool(opt)
defer pool.Close()
// Add hook to pool after creation
pool.AddPoolHook(testHook)
// Verify hooks are initialized
if pool.hookManager == nil {
t.Error("Expected hookManager to be initialized")
}
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
}
// Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true}
pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
}
// Test removing hook from pool
pool.RemovePoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
}
}

View File

@@ -3,6 +3,7 @@ package pool
import (
"context"
"errors"
"log"
"net"
"sync"
"sync/atomic"
@@ -22,6 +23,12 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
popAttempts = 10
getAttempts = 3
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
maxTime = minTime.Add(1<<63 - 1)
noExpiration = maxTime
)
var timers = sync.Pool{
@@ -38,11 +45,14 @@ type Stats struct {
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
WaitCount uint32 // number of times a connection was waited
Unusable uint32 // number of times a connection was found to be unusable
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
PubSubStats PubSubStats
}
type Pooler interface {
@@ -57,29 +67,27 @@ type Pooler interface {
IdleLen() int
Stats() *Stats
AddPoolHook(hook PoolHook)
RemovePoolHook(hook PoolHook)
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
PoolFIFO bool
PoolSize int
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxActiveConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
// Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without)
Protocol int
ReadBufferSize int
WriteBufferSize int
PoolFIFO bool
PoolSize int32
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int32
MaxIdleConns int32
MaxActiveConns int32
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
PushNotificationsEnabled bool
}
type lastDialErrorWrap struct {
@@ -95,16 +103,21 @@ type ConnPool struct {
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
conns map[uint64]*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
poolSize atomic.Int32
idleConnsLen atomic.Int32
idleCheckInProgress atomic.Bool
stats Stats
waitDurationNs atomic.Int64
_closed uint32 // atomic
// Pool hooks manager for flexible connection processing
hookManagerMu sync.RWMutex
hookManager *PoolHookManager
}
var _ Pooler = (*ConnPool)(nil)
@@ -114,34 +127,69 @@ func NewConnPool(opt *Options) *ConnPool {
cfg: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize),
}
// Only create MinIdleConns if explicitly requested (> 0)
// This avoids creating connections during pool initialization for tests
if opt.MinIdleConns > 0 {
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
}
return p
}
// initializeHooks sets up the pool hooks system.
func (p *ConnPool) initializeHooks() {
p.hookManager = NewPoolHookManager()
}
// AddPoolHook adds a pool hook to the pool.
func (p *ConnPool) AddPoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager == nil {
p.initializeHooks()
}
p.hookManager.AddHook(hook)
}
// RemovePoolHook removes a pool hook from the pool.
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager != nil {
p.hookManager.RemoveHook(hook)
}
}
func (p *ConnPool) checkMinIdleConns() {
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
return
}
defer p.idleCheckInProgress.Store(false)
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
// Only create idle connections if we haven't reached the total pool size limit
// MinIdleConns should be a subset of PoolSize, not additional connections
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
select {
case p.queue <- struct{}{}:
p.poolSize++
p.idleConnsLen++
p.poolSize.Add(1)
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
@@ -150,12 +198,9 @@ func (p *ConnPool) checkMinIdleConns() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
default:
@@ -172,6 +217,9 @@ func (p *ConnPool) addIdleConn() error {
if err != nil {
return err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
@@ -182,11 +230,15 @@ func (p *ConnPool) addIdleConn() error {
return ErrClosed
}
p.conns = append(p.conns, cn)
p.conns[cn.GetID()] = cn
p.idleConns = append(p.idleConns, cn)
return nil
}
// NewConn creates a new connection and returns it to the user.
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
//
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
@@ -196,33 +248,42 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, ErrClosed
}
p.connsMu.Lock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
return nil, ErrPoolExhausted
}
p.connsMu.Unlock()
cn, err := p.dialConn(ctx, pooled)
if err != nil {
return nil, err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
_ = cn.Close()
return nil, ErrPoolExhausted
}
p.conns = append(p.conns, cn)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.closed() {
_ = cn.Close()
return nil, ErrClosed
}
// Check if pool was closed while we were waiting for the lock
if p.conns == nil {
p.conns = make(map[uint64]*Conn)
}
p.conns[cn.GetID()] = cn
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
currentPoolSize := p.poolSize.Load()
if currentPoolSize >= int32(p.cfg.PoolSize) {
cn.pooled = false
} else {
p.poolSize++
p.poolSize.Add(1)
}
}
@@ -249,6 +310,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
if p.cfg.ConnMaxLifetime > 0 {
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
} else {
cn.expiresAt = noExpiration
}
return cn, nil
}
@@ -289,6 +356,14 @@ func (p *ConnPool) getLastDialError() error {
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return p.getConn(ctx)
}
// getConn returns a connection from the pool.
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
var cn *Conn
var err error
if p.closed() {
return nil, ErrClosed
}
@@ -297,9 +372,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
now := time.Now()
attempts := 0
for {
if attempts >= getAttempts {
log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts)
break
}
attempts++
p.connsMu.Lock()
cn, err := p.popIdle()
cn, err = p.popIdle()
p.connsMu.Unlock()
if err != nil {
@@ -311,11 +394,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
break
}
if !p.isHealthyConn(cn) {
if !p.isHealthyConn(cn, now) {
_ = p.CloseConn(cn)
continue
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
log.Printf("redis: connection pool: failed to process idle connection by hook: %v", err)
// Failed to process connection, discard it
_ = p.CloseConn(cn)
continue
}
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
@@ -328,6 +425,20 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
// Failed to process connection, discard it
log.Printf("redis: connection pool: failed to process new connection by hook: %v", err)
_ = p.CloseConn(newcn)
return nil, err
}
}
return newcn, nil
}
@@ -356,7 +467,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
}
return ctx.Err()
case p.queue <- struct{}{}:
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
atomic.AddUint32(&p.stats.WaitCount, 1)
if !timer.Stop() {
<-timer.C
@@ -376,68 +487,128 @@ func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
attempts := 0
for attempts < popAttempts {
if len(p.idleConns) == 0 {
return nil, nil
}
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
p.idleConns = p.idleConns[:len(p.idleConns)-1]
} else {
idx := n - 1
idx := len(p.idleConns) - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
p.idleConnsLen--
attempts++
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
}
// Connection is not usable, put it back in the pool
if p.cfg.PoolFIFO {
// FIFO: put at end (will be picked up last since we pop from front)
p.idleConns = append(p.idleConns, cn)
} else {
// LIFO: put at beginning (will be picked up last since we pop from end)
p.idleConns = append([]*Conn{cn}, p.idleConns...)
}
}
// If we exhausted all attempts without finding a usable connection, return nil
if attempts >= popAttempts {
log.Printf("redis: connection pool: failed to get an usable connection after %d attempts", popAttempts)
return nil, nil
}
p.checkMinIdleConns()
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
// Process connection using the hooks system
shouldPool := true
shouldRemove := false
if cn.rd.Buffered() > 0 {
// Check if this might be push notification data
if p.cfg.Protocol == 3 {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block and check if it's a push notification
if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush {
shouldRemove = true
var err error
if cn.HasBufferedData() {
// Peek at the reply type to check if it's a push notification
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
// Not a push notification or error peeking, remove connection
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
p.Remove(ctx, cn, err)
}
} else {
// not a push notification since protocol 2 doesn't support them
shouldRemove = true
// It's a push notification, allow pooling (client will handle it)
}
if shouldRemove {
// For non-RESP3 or data that is not a push notification, buffered data is unexpected
internal.Logger.Printf(ctx, "Conn has unread data, closing it")
p.Remove(ctx, cn, BadConnError{})
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
if err != nil {
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
p.Remove(ctx, cn, err)
return
}
}
// If hooks say to remove the connection, do so
if shouldRemove {
p.Remove(ctx, cn, errors.New("hook requested removal"))
return
}
// If processor says not to pool the connection, remove it
if !shouldPool {
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
p.Remove(ctx, cn, errors.New("connection not pooled"))
return
}
var shouldCloseConn bool
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue
if !cn.IsUsable() {
if p.cfg.PoolFIFO {
p.connsMu.Lock()
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
p.connsMu.Unlock()
} else {
p.removeConn(cn)
p.connsMu.Lock()
p.idleConns = append([]*Conn{cn}, p.idleConns...)
p.connsMu.Unlock()
}
} else {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
}
p.idleConnsLen.Add(1)
} else {
p.removeConnWithLock(cn)
shouldCloseConn = true
}
p.connsMu.Unlock()
p.freeTurn()
if shouldCloseConn {
@@ -449,6 +620,9 @@ func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
// Check if we need to create new idle connections to maintain MinIdleConns
p.checkMinIdleConns()
}
func (p *ConnPool) CloseConn(cn *Conn) error {
@@ -463,17 +637,13 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
}
break
}
}
delete(p.conns, cn.GetID())
atomic.AddUint32(&p.stats.StaleConns, 1)
// Decrement pool size counter when removing a connection
if cn.pooled {
p.poolSize.Add(-1)
}
}
func (p *ConnPool) closeConn(cn *Conn) error {
@@ -491,9 +661,9 @@ func (p *ConnPool) Len() int {
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
n := p.idleConnsLen.Load()
p.connsMu.Unlock()
return n
return int(n)
}
func (p *ConnPool) Stats() *Stats {
@@ -502,6 +672,7 @@ func (p *ConnPool) Stats() *Stats {
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
Unusable: atomic.LoadUint32(&p.stats.Unusable),
WaitDurationNs: p.waitDurationNs.Load(),
TotalConns: uint32(p.Len()),
@@ -542,30 +713,32 @@ func (p *ConnPool) Close() error {
}
}
p.conns = nil
p.poolSize = 0
p.poolSize.Store(0)
p.idleConns = nil
p.idleConnsLen = 0
p.idleConnsLen.Store(0)
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
// slight optimization, check expiresAt first.
if cn.expiresAt.Before(now) {
return false
}
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false
}
// Check connection health, but be aware of push notifications
if err := connCheck(cn.netConn); err != nil {
cn.SetUsedAt(now)
// Check basic connection health
// Use GetNetConn() to safely access netConn and avoid data races
if err := connCheck(cn.getNetConn()); err != nil {
// If there's unexpected data, it might be push notifications (RESP3)
// However, push notification processing is now handled by the client
// before WithReader to ensure proper context is available to handlers
if err == errUnexpectedRead && p.cfg.Protocol == 3 {
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
@@ -579,7 +752,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
return false
}
}
cn.SetUsedAt(now)
return true
}

View File

@@ -1,6 +1,8 @@
package pool
import "context"
import (
"context"
)
type SingleConnPool struct {
pool Pooler
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"net"
"sync"
"testing"
@@ -20,7 +21,7 @@ var _ = Describe("ConnPool", func() {
BeforeEach(func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -45,11 +46,11 @@ var _ = Describe("ConnPool", func() {
<-closedChan
return &net.TCPConn{}, nil
},
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
MinIdleConns: minIdleConns,
MinIdleConns: int32(minIdleConns),
})
wg.Wait()
Expect(connPool.Close()).NotTo(HaveOccurred())
@@ -105,7 +106,7 @@ var _ = Describe("ConnPool", func() {
// ok
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
// Check that Get is unblocked.
select {
@@ -130,8 +131,8 @@ var _ = Describe("MinIdleConns", func() {
newConnPool := func() *pool.ConnPool {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: poolSize,
MinIdleConns: minIdleConns,
PoolSize: int32(poolSize),
MinIdleConns: int32(minIdleConns),
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: -1,
@@ -168,7 +169,7 @@ var _ = Describe("MinIdleConns", func() {
Context("after Remove", func() {
BeforeEach(func() {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
})
It("has idle connections", func() {
@@ -245,7 +246,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Remove(ctx, cns[i], nil)
connPool.Remove(ctx, cns[i], errors.New("test"))
mu.RUnlock()
})
@@ -309,7 +310,7 @@ var _ = Describe("race", func() {
It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Minute,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -328,7 +329,7 @@ var _ = Describe("race", func() {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
}
}
})
@@ -339,15 +340,15 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1000,
MinIdleConns: 50,
PoolSize: int32(1000),
MinIdleConns: int32(50),
PoolTimeout: 3 * time.Second,
DialTimeout: 1 * time.Second,
}
p := pool.NewConnPool(opt)
var wg sync.WaitGroup
for i := 0; i < opt.PoolSize; i++ {
for i := int32(0); i < opt.PoolSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
@@ -366,8 +367,8 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
panic("test panic")
},
PoolSize: 100,
MinIdleConns: 30,
PoolSize: int32(100),
MinIdleConns: int32(30),
}
p := pool.NewConnPool(opt)
@@ -384,7 +385,7 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 3 * time.Second,
}
p := pool.NewConnPool(opt)
@@ -415,7 +416,7 @@ var _ = Describe("race", func() {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: testPoolTimeout,
}
p := pool.NewConnPool(opt)

77
internal/pool/pubsub.go Normal file
View File

@@ -0,0 +1,77 @@
package pool
import (
"context"
"net"
"sync"
"sync/atomic"
)
type PubSubStats struct {
Created uint32
Untracked uint32
Active uint32
}
// PubSubPool manages a pool of PubSub connections.
type PubSubPool struct {
opt *Options
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Map to track active PubSub connections
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
closed atomic.Bool
stats PubSubStats
}
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,
netDialer: netDialer,
}
}
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
if p.closed.Load() {
return nil, ErrClosed
}
netConn, err := p.netDialer(ctx, network, addr)
if err != nil {
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
atomic.AddUint32(&p.stats.Created, 1)
return cn, nil
}
func (p *PubSubPool) TrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, 1)
p.activeConns.Store(cn.GetID(), cn)
}
func (p *PubSubPool) UntrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, ^uint32(0))
atomic.AddUint32(&p.stats.Untracked, 1)
p.activeConns.Delete(cn.GetID())
}
func (p *PubSubPool) Close() error {
p.closed.Store(true)
p.activeConns.Range(func(key, value interface{}) bool {
cn := value.(*Conn)
_ = cn.Close()
return true
})
return nil
}
func (p *PubSubPool) Stats() *PubSubStats {
// load stats atomically
return &PubSubStats{
Created: atomic.LoadUint32(&p.stats.Created),
Untracked: atomic.LoadUint32(&p.stats.Untracked),
Active: atomic.LoadUint32(&p.stats.Active),
}
}

3
internal/redis.go Normal file
View File

@@ -0,0 +1,3 @@
package internal
const RedisNull = "null"

17
internal/util/math.go Normal file
View File

@@ -0,0 +1,17 @@
package util
// Max returns the maximum of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// Min returns the minimum of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -14,9 +14,10 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/push"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
@@ -153,6 +154,7 @@ type Options struct {
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
// default: false
PoolFIFO bool
// PoolSize is the base number of socket connections.
@@ -244,8 +246,19 @@ type Options struct {
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
// This is an alias to hitless.Config for convenience.
type HitlessUpgradeConfig = hitless.Config
func (opt *Options) init() {
if opt.Addr == "" {
opt.Addr = "localhost:6379"
@@ -320,13 +333,36 @@ func (opt *Options) init() {
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize)
// auto-detect endpoint type if not specified
endpointType := opt.HitlessUpgradeConfig.EndpointType
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.HitlessUpgradeConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
clone.HitlessUpgradeConfig = &configClone
}
return &clone
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
return NewDialer(opt)
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
@@ -618,17 +654,34 @@ func newConnPool(
return dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolSize: int32(opt.PoolSize),
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
MinIdleConns: int32(opt.MinIdleConns),
MaxIdleConns: int32(opt.MaxIdleConns),
MaxActiveConns: int32(opt.MaxActiveConns),
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
// Pass protocol version for push notification optimization
Protocol: opt.Protocol,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PushNotificationsEnabled: opt.Protocol == 3,
})
}
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) *pool.PubSubPool {
return pool.NewPubSubPool(&pool.Options{
PoolFIFO: opt.PoolFIFO,
PoolSize: int32(opt.PoolSize),
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: int32(opt.MinIdleConns),
MaxIdleConns: int32(opt.MaxIdleConns),
MaxActiveConns: int32(opt.MaxActiveConns),
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
PushNotificationsEnabled: opt.Protocol == 3,
}, dialer)
}

View File

@@ -38,6 +38,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -129,6 +130,14 @@ type ClusterOptions struct {
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *ClusterOptions) init() {
@@ -319,6 +328,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
var hitlessConfig *HitlessUpgradeConfig
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
hitlessConfig = &configClone
}
return &Options{
ClientName: opt.ClientName,
Dialer: opt.Dialer,
@@ -362,6 +378,7 @@ func (opt *ClusterOptions) clientOptions() *Options {
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
HitlessUpgradeConfig: hitlessConfig,
}
}
@@ -1830,12 +1847,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// hitless won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@@ -1850,18 +1867,25 @@ func (c *ClusterClient) pubSub() *PubSub {
if err != nil {
return nil, err
}
cn, err := node.Client.newConn(context.TODO())
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
if err != nil {
node = nil
return nil, err
}
// will return nil if already initialized
err = node.Client.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
node = nil
return nil, err
}
node.Client.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
err := node.Client.connPool.CloseConn(cn)
// Untrack connection from PubSubPool
node.Client.pubSubPool.UntrackConn(cn)
err := cn.Close()
node = nil
return err
},

375
pool_pubsub_bench_test.go Normal file
View File

@@ -0,0 +1,375 @@
// Pool and PubSub Benchmark Suite
//
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
// It's designed to be run against different branches to compare performance.
//
// Usage Examples:
// # Run all benchmarks
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
//
// # Run only pool benchmarks
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
//
// # Run only PubSub benchmarks
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
//
// # Compare between branches
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
// benchcmp branch1.txt branch2.txt
//
// # Run with memory profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
//
// # Run with CPU profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
package redis_test
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
)
// dummyDialer creates a mock connection for benchmarking
func dummyDialer(ctx context.Context) (net.Conn, error) {
return &dummyConn{}, nil
}
// dummyConn implements net.Conn for benchmarking
type dummyConn struct{}
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Close() error { return nil }
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
func (c *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
}
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
// =============================================================================
// POOL BENCHMARKS
// =============================================================================
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(poolSize),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0), // Start with no idle connections
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
ctx := context.Background()
configs := []struct {
poolSize int
minIdleConns int
}{
{8, 2},
{16, 4},
{32, 8},
{64, 16},
}
for _, config := range configs {
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(config.poolSize),
MinIdleConns: int32(config.minIdleConns),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
ctx := context.Background()
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(32),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0),
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
// Test with different levels of concurrency
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.SetParallelism(concurrency)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// =============================================================================
// PUBSUB BENCHMARKS
// =============================================================================
// benchmarkClient creates a Redis client for benchmarking with mock dialer
func benchmarkClient(poolSize int) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Mock address
DialTimeout: time.Second,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
PoolSize: poolSize,
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
})
}
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
func BenchmarkPubSubCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
func BenchmarkPubSubPatternCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.PSubscribe(ctx, "test-*")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
concurrencyLevels := []int{1, 2, 4, 8, 16}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for i := 0; i < b.N; i++ {
wg.Add(1)
semaphore <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-semaphore }()
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}()
}
wg.Wait()
})
}
}
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
func BenchmarkPubSubMultipleChannels(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
channelCounts := []int{1, 5, 10, 25, 50, 100}
for _, channelCount := range channelCounts {
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
// Prepare channel names
channels := make([]string, channelCount)
for i := 0; i < channelCount; i++ {
channels[i] = fmt.Sprintf("channel-%d", i)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, channels...)
pubsub.Close()
}
})
}
}
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
func BenchmarkPubSubReuse(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Benchmark just the creation and closing of PubSub connections
// This simulates reuse patterns without requiring actual Redis operations
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
pubsub.Close()
}
}
// =============================================================================
// COMBINED BENCHMARKS
// =============================================================================
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// Mix of pool stats collection and PubSub creation
if pb.Next() {
// Pool stats operation
stats := client.PoolStats()
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
}
if pb.Next() {
// PubSub operation
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
}
})
}
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
func BenchmarkPoolStatsCollection(b *testing.B) {
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
stats := client.PoolStats()
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
}
}
// BenchmarkPoolHighContention tests pool performance under high contention
func BenchmarkPoolHighContention(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// High contention Get/Put operations
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}

View File

@@ -22,7 +22,7 @@ import (
type PubSub struct {
opt *Options
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
@@ -42,6 +42,9 @@ type PubSub struct {
// Push notification processor for handling generic push notifications
pushProcessor push.NotificationProcessor
// Cleanup callback for hitless upgrade tracking
onClose func()
}
func (c *PubSub) init() {
@@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
return c.cn, nil
}
if c.opt.Addr == "" {
// TODO(hitless):
// this is probably cluster client
// c.newConn will ignore the addr argument
// will be changed when we have hitless upgrades for cluster clients
c.opt.Addr = internal.RedisNull
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(ctx, channels)
cn, err := c.newConn(ctx, c.opt.Addr, channels)
if err != nil {
return nil, err
}
@@ -157,12 +168,28 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
if c.cn != cn {
return
}
if !cn.IsUsable() || cn.ShouldHandoff() {
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
}
if isBadConn(err, allowTimeout, c.opt.Addr) {
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(ctx context.Context, reason error) {
if c.cn != nil && c.cn.ShouldHandoff() {
newEndpoint := c.cn.GetHandoffEndpoint()
// If new endpoint is NULL, use the original address
if newEndpoint == internal.RedisNull {
newEndpoint = c.opt.Addr
}
if newEndpoint != "" {
c.opt.Addr = newEndpoint
}
}
_ = c.closeTheCn(reason)
_, _ = c.conn(ctx, nil)
}
@@ -189,6 +216,11 @@ func (c *PubSub) Close() error {
c.closed = true
close(c.exit)
// Call cleanup callback if set
if c.onClose != nil {
c.onClose()
}
return c.closeTheCn(pool.ErrClosed)
}
@@ -461,6 +493,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
// This will block until a message is received.
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
@@ -543,7 +576,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
}
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
if c.pushProcessor == nil {
// Only process push notifications for RESP3 connections with a processor
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}

View File

@@ -1,8 +1,6 @@
package push
import (
"github.com/redis/go-redis/v9/internal/pool"
)
// No imports needed for this file
// NotificationHandlerContext provides context information about where a push notification was received.
// This struct allows handlers to make informed decisions based on the source of the notification
@@ -35,7 +33,12 @@ type NotificationHandlerContext struct {
PubSub interface{}
// Conn is the specific connection on which the notification was received.
Conn *pool.Conn
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.Conn
// - *connectionAdapter (for hitless upgrades)
Conn interface{}
// IsBlocking indicates if the notification was received on a blocking connection.
IsBlocking bool

315
push/processor_unit_test.go Normal file
View File

@@ -0,0 +1,315 @@
package push
import (
"context"
"testing"
)
// TestProcessorCreation tests processor creation and initialization
func TestProcessorCreation(t *testing.T) {
t.Run("NewProcessor", func(t *testing.T) {
processor := NewProcessor()
if processor == nil {
t.Fatal("NewProcessor should not return nil")
}
if processor.registry == nil {
t.Error("Processor should have a registry")
}
})
t.Run("NewVoidProcessor", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
if voidProcessor == nil {
t.Fatal("NewVoidProcessor should not return nil")
}
})
}
// TestProcessorHandlerManagement tests handler registration and retrieval
func TestProcessorHandlerManagement(t *testing.T) {
processor := NewProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("RegisterHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", handler, false)
if err != nil {
t.Errorf("RegisterHandler should not error: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != handler {
t.Error("GetHandler should return the registered handler")
}
})
t.Run("RegisterProtectedHandler", func(t *testing.T) {
protectedHandler := &UnitTestHandler{name: "protected-handler"}
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
if err != nil {
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler != protectedHandler {
t.Error("GetHandler should return the protected handler")
}
})
t.Run("GetNonExistentHandler", func(t *testing.T) {
handler := processor.GetHandler("NONEXISTENT")
if handler != nil {
t.Error("GetHandler should return nil for non-existent handler")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := processor.UnregisterHandler("TEST")
if err != nil {
t.Errorf("UnregisterHandler should not error: %v", err)
}
// Verify handler is removed
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != nil {
t.Error("GetHandler should return nil after unregistering")
}
})
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
err := processor.UnregisterHandler("PROTECTED")
if err == nil {
t.Error("UnregisterHandler should error for protected handler")
}
// Verify handler is still there
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler == nil {
t.Error("Protected handler should not be removed")
}
})
}
// TestVoidProcessorBehavior tests void processor behavior
func TestVoidProcessorBehavior(t *testing.T) {
voidProcessor := NewVoidProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("GetHandler", func(t *testing.T) {
retrievedHandler := voidProcessor.GetHandler("ANY")
if retrievedHandler != nil {
t.Error("VoidProcessor GetHandler should always return nil")
}
})
t.Run("RegisterHandler", func(t *testing.T) {
err := voidProcessor.RegisterHandler("TEST", handler, false)
if err == nil {
t.Error("VoidProcessor RegisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := voidProcessor.UnregisterHandler("TEST")
if err == nil {
t.Error("VoidProcessor UnregisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
}
// TestProcessPendingNotificationsNilReader tests handling of nil reader
func TestProcessPendingNotificationsNilReader(t *testing.T) {
t.Run("ProcessorWithNilReader", func(t *testing.T) {
processor := NewProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
}
// TestWillHandleNotificationInClient tests the notification filtering logic
func TestWillHandleNotificationInClient(t *testing.T) {
testCases := []struct {
name string
notificationType string
shouldHandle bool
}{
// Pub/Sub notifications (should be handled in client)
{"message", "message", true},
{"pmessage", "pmessage", true},
{"subscribe", "subscribe", true},
{"unsubscribe", "unsubscribe", true},
{"psubscribe", "psubscribe", true},
{"punsubscribe", "punsubscribe", true},
{"smessage", "smessage", true},
{"ssubscribe", "ssubscribe", true},
{"sunsubscribe", "sunsubscribe", true},
// Push notifications (should be handled by processor)
{"MOVING", "MOVING", false},
{"MIGRATING", "MIGRATING", false},
{"MIGRATED", "MIGRATED", false},
{"FAILING_OVER", "FAILING_OVER", false},
{"FAILED_OVER", "FAILED_OVER", false},
{"custom", "custom", false},
{"unknown", "unknown", false},
{"empty", "", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := willHandleNotificationInClient(tc.notificationType)
if result != tc.shouldHandle {
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
}
})
}
}
// TestProcessorErrorHandlingUnit tests error handling scenarios
func TestProcessorErrorHandlingUnit(t *testing.T) {
processor := NewProcessor()
t.Run("RegisterNilHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", nil, false)
if err == nil {
t.Error("RegisterHandler should error with nil handler")
}
// Check error type
if !IsHandlerNilError(err) {
t.Error("Error should be a HandlerNilError")
}
})
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
handler1 := &UnitTestHandler{name: "handler1"}
handler2 := &UnitTestHandler{name: "handler2"}
// Register first handler
err := processor.RegisterHandler("DUPLICATE", handler1, false)
if err != nil {
t.Errorf("First RegisterHandler should not error: %v", err)
}
// Try to register second handler with same name
err = processor.RegisterHandler("DUPLICATE", handler2, false)
if err == nil {
t.Error("RegisterHandler should error when registering duplicate handler")
}
// Verify original handler is still there
retrievedHandler := processor.GetHandler("DUPLICATE")
if retrievedHandler != handler1 {
t.Error("Original handler should remain after failed duplicate registration")
}
})
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
err := processor.UnregisterHandler("NONEXISTENT")
if err != nil {
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
}
})
}
// TestProcessorConcurrentAccess tests concurrent access to processor
func TestProcessorConcurrentAccess(t *testing.T) {
processor := NewProcessor()
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
done := make(chan bool, 2)
// Goroutine 1: Register handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
handler := &UnitTestHandler{name: "concurrent-handler"}
processor.RegisterHandler("CONCURRENT", handler, false)
processor.UnregisterHandler("CONCURRENT")
}
}()
// Goroutine 2: Get handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
processor.GetHandler("CONCURRENT")
}
}()
// Wait for both goroutines to complete
<-done
<-done
})
}
// TestProcessorInterfaceCompliance tests interface compliance
func TestProcessorInterfaceCompliance(t *testing.T) {
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*Processor)(nil)
})
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*VoidProcessor)(nil)
})
}
// UnitTestHandler is a test implementation of NotificationHandler
type UnitTestHandler struct {
name string
lastNotification []interface{}
errorToReturn error
callCount int
}
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
h.callCount++
h.lastNotification = notification
return h.errorToReturn
}
// Helper methods for UnitTestHandler
func (h *UnitTestHandler) GetCallCount() int {
return h.callCount
}
func (h *UnitTestHandler) GetLastNotification() []interface{} {
return h.lastNotification
}
func (h *UnitTestHandler) SetErrorToReturn(err error) {
h.errorToReturn = err
}
func (h *UnitTestHandler) Reset() {
h.callCount = 0
h.lastNotification = nil
h.errorToReturn = nil
}

View File

@@ -4,24 +4,6 @@ import (
"github.com/redis/go-redis/v9/push"
)
// Push notification constants for cluster operations
const (
// MOVING indicates a slot is being moved to a different node
PushNotificationMoving = "MOVING"
// MIGRATING indicates a slot is being migrated from this node
PushNotificationMigrating = "MIGRATING"
// MIGRATED indicates a slot has been migrated to this node
PushNotificationMigrated = "MIGRATED"
// FAILING_OVER indicates a failover is starting
PushNotificationFailingOver = "FAILING_OVER"
// FAILED_OVER indicates a failover has completed
PushNotificationFailedOver = "FAILED_OVER"
)
// NewPushNotificationProcessor creates a new push notification processor
// This processor maintains a registry of handlers and processes push notifications
// It is used for RESP3 connections where push notifications are available

203
redis.go
View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
@@ -205,18 +206,34 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
type baseClient struct {
opt *Options
optLock sync.RWMutex
connPool pool.Pooler
pubSubPool *pool.PubSubPool
hooksMixin
onClose func() error // hook called when client is closed
// Push notification processing
pushProcessor push.NotificationProcessor
// Hitless upgrade manager
hitlessManager *hitless.HitlessManager
hitlessManagerLock sync.RWMutex
}
func (c *baseClient) clone() *baseClient {
clone := *c
return &clone
c.hitlessManagerLock.RLock()
hitlessManager := c.hitlessManager
c.hitlessManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
hitlessManager: hitlessManager,
}
return clone
}
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
@@ -234,21 +251,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.NewConn(ctx)
if err != nil {
return nil, err
}
err = c.initConn(ctx, cn)
if err != nil {
_ = c.connPool.CloseConn(cn)
return nil, err
}
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
@@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
if cn.Inited {
if cn.IsInited() {
return cn, nil
}
@@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
if !cn.Inited.CompareAndSwap(false, true) {
return nil
}
var err error
cn.Inited = true
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin)
@@ -430,6 +430,50 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("failed to initialize connection options: %w", err)
}
// Enable maintenance notifications if hitless upgrades are configured
c.optLock.RLock()
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
protocol := c.opt.Protocol
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
c.optLock.RUnlock()
var hitlessHandshakeErr error
if hitlessEnabled && protocol == 3 {
hitlessHandshakeErr = conn.ClientMaintNotifications(
ctx,
true,
endpointType.String(),
).Err()
if hitlessHandshakeErr != nil {
if !isRedisError(hitlessHandshakeErr) {
// if not redis error, fail the connection
return hitlessHandshakeErr
}
c.optLock.Lock()
// handshake failed - check and modify config atomically
switch c.opt.HitlessUpgradeConfig.Mode {
case hitless.MaintNotificationsEnabled:
// enabled mode, fail the connection
c.optLock.Unlock()
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
default: // will handle auto and any other
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
c.optLock.Unlock()
// auto mode, disable hitless upgrades and continue
if err := c.disableHitlessUpgrades(); err != nil {
// Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
}
}
} else {
// handshake was executed successfully
// to make sure that the handshake will be executed on other connections as well if it was successfully
// executed on this connection, we will force the handshake to be executed on all connections
c.optLock.Lock()
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
c.optLock.Unlock()
}
}
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
libName := ""
libVer := Version()
@@ -446,6 +490,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
}
}
cn.SetUsable(true)
cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections
cn.SetInitConnFunc(c.createInitConnFunc())
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
@@ -593,20 +643,77 @@ func (c *baseClient) context(ctx context.Context) context.Context {
return context.Background()
}
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
return func(ctx context.Context, cn *pool.Conn) error {
return c.initConn(ctx, cn)
}
}
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
// This function is called during client initialization.
// will register push notification handlers for all hitless upgrade events.
// will start background workers for handoff processing in the pool hook.
func (c *baseClient) enableHitlessUpgrades() error {
// Create client adapter
clientAdapterInstance := newClientAdapter(c)
// Create hitless manager directly
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
if err != nil {
return err
}
// Set the manager reference and initialize pool hook
c.hitlessManagerLock.Lock()
c.hitlessManager = manager
c.hitlessManagerLock.Unlock()
// Initialize pool hook (safe to call without lock since manager is now set)
manager.InitPoolHook(c.dialHook)
return nil
}
func (c *baseClient) disableHitlessUpgrades() error {
c.hitlessManagerLock.Lock()
defer c.hitlessManagerLock.Unlock()
// Close the hitless manager
if c.hitlessManager != nil {
// Closing the manager will also shutdown the pool hook
// and remove it from the pool
c.hitlessManager.Close()
c.hitlessManager = nil
}
return nil
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
// Close hitless manager first
if err := c.disableHitlessUpgrades(); err != nil {
firstErr = err
}
if c.onClose != nil {
if err := c.onClose(); err != nil {
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.connPool != nil {
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.pubSubPool != nil {
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
@@ -810,11 +917,24 @@ func NewClient(opt *Options) *Client {
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
// Update options with the initialized push processor for connection pool
// Update options with the initialized push processor
opt.PushNotificationProcessor = c.pushProcessor
// Create connection pools
c.connPool = newConnPool(opt, c.dialHook)
c.pubSubPool = newPubSubPool(opt, c.dialHook)
// Initialize hitless upgrades first if enabled and protocol is RESP3
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
err := c.enableHitlessUpgrades()
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
// panic so we fail fast without breaking existing clients api
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
}
}
}
return &c
}
@@ -851,6 +971,14 @@ func (c *Client) Options() *Options {
return c.opt
}
// GetHitlessManager returns the hitless manager instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
c.hitlessManagerLock.RLock()
defer c.hitlessManagerLock.RUnlock()
return c.hitlessManager
}
// initializePushProcessor initializes the push notification processor for any client type.
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
func initializePushProcessor(opt *Options) push.NotificationProcessor {
@@ -887,6 +1015,7 @@ type PoolStats pool.Stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
stats := c.connPool.Stats()
stats.PubSubStats = *(c.pubSubPool.Stats())
return (*PoolStats)(stats)
}
@@ -921,11 +1050,27 @@ func (c *Client) TxPipeline() Pipeliner {
func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
closeConn: c.connPool.CloseConn,
pushProcessor: c.pushProcessor,
}
pubsub.init()
@@ -1113,6 +1258,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica
return push.NotificationHandlerContext{
Client: c,
ConnPool: c.connPool,
Conn: cn,
Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access
}
}

View File

@@ -12,7 +12,6 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)

View File

@@ -16,8 +16,8 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/push"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/push"
)
//------------------------------------------------------------------------------
@@ -139,6 +139,14 @@ type FailoverOptions struct {
FailingTimeoutSeconds int
UnstableResp3 bool
// Hitless is not supported for FailoverClients at the moment
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// upgrade notifications gracefully and manage connection/pool state transitions
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are disabled.
//HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *FailoverOptions) clientOptions() *Options {
@@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt.Dialer = masterReplicaDialer(failover)
opt.init()
var connPool *pool.ConnPool
rdb := &Client{
baseClient: &baseClient{
opt: opt,
@@ -469,16 +475,19 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
// Use void processor by default for RESP2 connections
rdb.pushProcessor = initializePushProcessor(opt)
connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool
rdb.connPool = newConnPool(opt, rdb.dialHook)
rdb.pubSubPool = newPubSubPool(opt, rdb.dialHook)
rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
}
}
failover.mu.Unlock()
return rdb
@@ -544,6 +553,7 @@ func NewSentinelClient(opt *Options) *SentinelClient {
process: c.baseClient.process,
})
c.connPool = newConnPool(opt, c.dialHook)
c.pubSubPool = newPubSubPool(opt, c.dialHook)
return c
}
@@ -570,13 +580,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}

2
tx.go
View File

@@ -24,7 +24,7 @@ type Tx struct {
func (c *Client) newTx() *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
opt: c.opt.clone(), // Clone options to avoid sharing HitlessUpgradeConfig
connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
pushProcessor: c.pushProcessor, // Copy push processor from parent client

View File

@@ -122,6 +122,9 @@ type UniversalOptions struct {
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
IsClusterMode bool
// HitlessUpgradeConfig provides configuration for hitless upgrades.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// Cluster returns cluster options created from the universal options.
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
// Note: HitlessUpgradeConfig not supported for FailoverOptions
}
}
@@ -288,6 +293,7 @@ func (o *UniversalOptions) Simple() *Options {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}