1
0
mirror of https://github.com/redis/go-redis.git synced 2025-09-02 22:01:16 +03:00

feat(hitless): Introduce handlers for hitless upgrades

This commit includes all the work on hitless upgrades with the addition
of:

- Pubsub Pool
- Examples
- Refactor of push
- Refactor of pool (using atomics for most things)
- Introducing of hooks in pool
This commit is contained in:
Nedyalko Dyakov
2025-08-18 22:14:06 +03:00
parent 36f9f58c67
commit 5649ffb314
46 changed files with 6347 additions and 249 deletions

3
.gitignore vendored
View File

@@ -9,3 +9,6 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
# Hitless upgrade documentation (temporary)
hitless/docs/

149
adapters.go Normal file
View File

@@ -0,0 +1,149 @@
package redis
import (
"context"
"errors"
"net"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
var ErrInvalidCommand = errors.New("invalid command type")
// ErrInvalidPool is returned when the pool type is not supported.
var ErrInvalidPool = errors.New("invalid pool type")
// newClientAdapter creates a new client adapter for regular Redis clients.
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
return &clientAdapter{client: client}
}
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
type clientAdapter struct {
client *baseClient
}
// GetOptions returns the client options.
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
return &optionsAdapter{options: ca.client.opt}
}
// GetPushProcessor returns the client's push notification processor.
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
}
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
type optionsAdapter struct {
options *Options
}
// GetReadTimeout returns the read timeout.
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
return oa.options.ReadTimeout
}
// GetWriteTimeout returns the write timeout.
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
return oa.options.WriteTimeout
}
// GetNetwork returns the network type.
func (oa *optionsAdapter) GetNetwork() string {
return oa.options.Network
}
// GetAddr returns the connection address.
func (oa *optionsAdapter) GetAddr() string {
return oa.options.Addr
}
// IsTLSEnabled returns true if TLS is enabled.
func (oa *optionsAdapter) IsTLSEnabled() bool {
return oa.options.TLSConfig != nil
}
// GetProtocol returns the protocol version.
func (oa *optionsAdapter) GetProtocol() int {
return oa.options.Protocol
}
// GetPoolSize returns the connection pool size.
func (oa *optionsAdapter) GetPoolSize() int {
return oa.options.PoolSize
}
// NewDialer returns a new dialer function for the connection.
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
baseDialer := oa.options.NewDialer()
return func(ctx context.Context) (net.Conn, error) {
// Extract network and address from the options
network := oa.options.Network
addr := oa.options.Addr
return baseDialer(ctx, network, addr)
}
}
// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout
type connectionAdapter struct {
conn *pool.Conn
}
// Close closes the connection.
func (ca *connectionAdapter) Close() error {
return ca.conn.Close()
}
// IsUsable returns true if the connection is safe to use for new commands.
func (ca *connectionAdapter) IsUsable() bool {
return ca.conn.IsUsable()
}
// GetPoolConnection returns the underlying pool connection.
func (ca *connectionAdapter) GetPoolConnection() *pool.Conn {
return ca.conn
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts remain active until explicitly cleared.
func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout)
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline)
}
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
func (ca *connectionAdapter) ClearRelaxedTimeout() {
ca.conn.ClearRelaxedTimeout()
}
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
type pushProcessorAdapter struct {
processor push.NotificationProcessor
}
// RegisterHandler registers a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
if pushHandler, ok := handler.(push.NotificationHandler); ok {
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
}
return errors.New("handler must implement push.NotificationHandler")
}
// UnregisterHandler removes a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
return ppa.processor.UnregisterHandler(pushNotificationName)
}
// GetHandler returns the handler for a specific push notification name.
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
return ppa.processor.GetHandler(pushNotificationName)
}

View File

@@ -0,0 +1,348 @@
package redis
import (
"context"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
func TestEventDrivenHandoffIntegration(t *testing.T) {
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor with event-driven handoff support
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create a test pool with hooks
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(5),
PoolTimeout: time.Second,
})
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
defer testPool.Close()
// Set the pool reference in the processor for connection removal on handoff failure
processor.SetPool(testPool)
ctx := context.Background()
// Get a connection and mark it for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
// Set initialization function with a small delay to ensure handoff is pending
initConnCalled := false
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
initConnCalled = true
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark connection for handoff
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
if err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Return connection to pool - this should queue handoff
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start processing
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Try to get the same connection - should be skipped due to pending handoff
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get second connection: %v", err)
}
// Should get a different connection (the pending one should be skipped)
if conn == conn2 {
t.Error("Should have gotten a different connection while handoff is pending")
}
// Return the second connection
testPool.Put(ctx, conn2)
// Wait for handoff to complete
time.Sleep(200 * time.Millisecond)
// Verify handoff completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map")
}
if !initConnCalled {
t.Error("InitConn should have been called during handoff")
}
// Now the original connection should be available again
conn3, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get third connection: %v", err)
}
// Could be the original connection (now handed off) or a new one
testPool.Put(ctx, conn3)
})
t.Run("ConcurrentHandoffs", func(t *testing.T) {
// Create a base dialer that simulates slow handoffs
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(50 * time.Millisecond) // Simulate network delay
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(10),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
var wg sync.WaitGroup
// Start multiple concurrent handoffs
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Get connection
conn, err := testPool.Get(ctx)
if err != nil {
t.Errorf("Failed to get connection %d: %v", id, err)
return
}
// Set initialization function
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark for handoff
conn.MarkForHandoff("new-endpoint:6379", int64(id))
// Return to pool (starts async handoff)
testPool.Put(ctx, conn)
}(i)
}
wg.Wait()
// Wait for all handoffs to complete
time.Sleep(300 * time.Millisecond)
// Verify pool is still functional
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
}
testPool.Put(ctx, conn)
})
t.Run("HandoffFailureRecovery", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
}
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(3),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Get connection and mark for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
// Return to pool (starts async handoff that will fail)
testPool.Put(ctx, conn)
// Wait for handoff to fail
time.Sleep(200 * time.Millisecond)
// Connection should be removed from pending map after failed handoff
if processor.IsHandoffPending(conn) {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Pool should still be functional
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional: %v", err)
}
// In event-driven approach, the original connection remains in pool
// even after failed handoff (it's still a valid connection)
// We might get the same connection or a different one
testPool.Put(ctx, conn2)
})
t.Run("GracefulShutdown", func(t *testing.T) {
// Create a slow base dialer
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(100 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(2),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Start a handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function with delay to ensure handoff is pending
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
return nil
})
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start and begin processing
// The handoff should be pending because the slowDialer takes 100ms
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued and is being processed
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Give the handoff a moment to start processing
time.Sleep(50 * time.Millisecond)
// Shutdown processor gracefully
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = processor.Shutdown(shutdownCtx)
if err != nil {
t.Errorf("Graceful shutdown should succeed: %v", err)
}
// Handoff should have completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map after shutdown")
}
})
}

View File

@@ -193,6 +193,7 @@ type Cmdable interface {
ClientID(ctx context.Context) *IntCmd
ClientUnblock(ctx context.Context, id int64) *IntCmd
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
ConfigResetStat(ctx context.Context) *StatusCmd
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
@@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}
if enabled {
if endpointType == "" {
endpointType = "none"
}
args = append(args, "on", "moving-endpoint-type", endpointType)
} else {
args = append(args, "off")
}
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// ------------------------------------------------------------------------------------------------
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {

12
example/pubsub/go.mod Normal file
View File

@@ -0,0 +1,12 @@
module github.com/redis/go-redis/example/pubsub
go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.11.0
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

6
example/pubsub/go.sum Normal file
View File

@@ -0,0 +1,6 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

146
example/pubsub/main.go Normal file
View File

@@ -0,0 +1,146 @@
package main
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
)
var ctx = context.Background()
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
// It was used to find regressions in pool management in hitless mode.
// Please don't use it as a reference for how to use pubsub.
func main() {
wg := &sync.WaitGroup{}
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled,
},
})
_ = rdb.FlushDB(ctx).Err()
go func() {
for {
time.Sleep(2 * time.Second)
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
}
}()
err := rdb.Ping(ctx).Err()
if err != nil {
panic(err)
}
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
panic(err)
}
fmt.Println("published", rdb.Get(ctx, "published").Val())
fmt.Println("received", rdb.Get(ctx, "received").Val())
subCtx, cancelSubCtx := context.WithCancel(ctx)
pubCtx, cancelPublishers := context.WithCancel(ctx)
for i := 0; i < 10; i++ {
wg.Add(1)
go subscribe(subCtx, rdb, "test", i, wg)
}
time.Sleep(time.Second)
cancelSubCtx()
time.Sleep(time.Second)
subCtx, cancelSubCtx = context.WithCancel(ctx)
for i := 0; i < 10; i++ {
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
panic(err)
}
wg.Add(1)
go floodThePool(pubCtx, rdb, wg)
}
for i := 0; i < 500; i++ {
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
panic(err)
}
wg.Add(1)
go subscribe(subCtx, rdb, "test2", i, wg)
}
time.Sleep(5 * time.Second)
fmt.Println("canceling publishers")
cancelPublishers()
time.Sleep(10 * time.Second)
fmt.Println("canceling subscribers")
cancelSubCtx()
wg.Wait()
published, err := rdb.Get(ctx, "published").Result()
received, err := rdb.Get(ctx, "received").Result()
publishers, err := rdb.Get(ctx, "publishers").Result()
subscribers, err := rdb.Get(ctx, "subscribers").Result()
fmt.Printf("publishers: %s\n", publishers)
fmt.Printf("published: %s\n", published)
fmt.Printf("subscribers: %s\n", subscribers)
fmt.Printf("received: %s\n", received)
publishedInt, err := rdb.Get(ctx, "published").Int()
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
time.Sleep(2 * time.Second)
}
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
err := rdb.Publish(ctx, "test2", "hello").Err()
if err != nil {
// noop
//log.Println("publish error:", err)
}
err = rdb.Incr(ctx, "published").Err()
if err != nil {
// noop
//log.Println("incr error:", err)
}
time.Sleep(10 * time.Nanosecond)
}
}
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
defer wg.Done()
rec := rdb.Subscribe(ctx, topic)
recChan := rec.Channel()
for {
select {
case <-ctx.Done():
rec.Close()
return
default:
select {
case <-ctx.Done():
rec.Close()
return
case msg := <-recChan:
err := rdb.Incr(ctx, "received").Err()
if err != nil {
log.Println("incr error:", err)
}
_ = msg // Use the message to avoid unused variable warning
}
}
}
}

View File

@@ -57,6 +57,8 @@ func Example_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[ping]>
}
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// pipeline finished processing: [[ping] [ping]]
}
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[client maint_notifications on moving-endpoint-type external-ip]>
// finished processing: <[watch foo]>
// starting processing: <[ping]>
// finished processing: <[ping]>

72
hitless/README.md Normal file
View File

@@ -0,0 +1,72 @@
# Hitless Upgrades
Seamless Redis connection handoffs during topology changes without interrupting operations.
## Quick Start
```go
import "github.com/redis/go-redis/v9/hitless"
opt := &redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
HitlessUpgrades: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled, // or MaintNotificationsAuto
},
}
client := redis.NewClient(opt)
```
## Modes
- **`MaintNotificationsDisabled`**: Hitless upgrades are completely disabled
- **`MaintNotificationsEnabled`**: Hitless upgrades are forcefully enabled (fails if server doesn't support it)
- **`MaintNotificationsAuto`**: Hitless upgrades are enabled if server supports it (default)
## Configuration
```go
import "github.com/redis/go-redis/v9/hitless"
Config: &hitless.Config{
Mode: hitless.MaintNotificationsAuto, // Notification mode
MaxHandoffRetries: 3, // Retry failed handoffs
HandoffTimeout: 15 * time.Second, // Handoff operation timeout
RelaxedTimeout: 10 * time.Second, // Extended timeout during migrations
PostHandoffRelaxedDuration: 20 * time.Second, // Keep relaxed timeout after handoff
LogLevel: 1, // 0=errors, 1=warnings, 2=info, 3=debug
MaxWorkers: 15, // Concurrent handoff workers
HandoffQueueSize: 50, // Handoff request queue size
}
```
### Worker Scaling
- **Auto-calculated**: `min(10, PoolSize/3)` - scales with pool size, capped at 10
- **Explicit values**: `max(10, set_value)` - enforces minimum 10 workers
- **On-demand**: Workers created when needed, cleaned up when idle
### Queue Sizing
- **Auto-calculated**: `10 × MaxWorkers`, capped by pool size
- **Always capped**: Queue size never exceeds pool size
## Metrics Hook Example
A metrics collection hook is available in `example_hooks.go` that demonstrates how to monitor hitless upgrade operations:
```go
import "github.com/redis/go-redis/v9/hitless"
metricsHook := hitless.NewMetricsHook()
// Use with your monitoring system
```
The metrics hook tracks:
- Handoff success/failure rates
- Handoff duration
- Queue depth
- Worker utilization
- Connection lifecycle events
## Requirements
- **RESP3 Protocol**: Required for push notifications

377
hitless/config.go Normal file
View File

@@ -0,0 +1,377 @@
package hitless
import (
"net"
"runtime"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
// MaintNotificationsMode represents the maintenance notifications mode
type MaintNotificationsMode string
// Constants for maintenance push notifications modes
const (
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m MaintNotificationsMode) IsValid() bool {
switch m {
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
return true
default:
return false
}
}
// String returns the string representation of the mode
func (m MaintNotificationsMode) String() string {
return string(m)
}
// EndpointType represents the type of endpoint to request in MOVING notifications
type EndpointType string
// Constants for endpoint types
const (
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
)
// IsValid returns true if the endpoint type is valid
func (e EndpointType) IsValid() bool {
switch e {
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
return true
default:
return false
}
}
// String returns the string representation of the endpoint type
func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for hitless upgrades.
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
// Default: MaintNotificationsAuto
Mode MaintNotificationsMode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
// Default: EndpointTypeAuto
EndpointType EndpointType
// RelaxedTimeout is the concrete timeout value to use during
// MIGRATING/FAILING_OVER states to accommodate increased latency.
// This applies to both read and write timeouts.
// Default: 10 seconds
RelaxedTimeout time.Duration
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
// If handoff takes longer than this, the old connection will be forcibly closed.
// Default: 15 seconds (matches server-side eviction timeout)
HandoffTimeout time.Duration
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
// Workers are created on-demand and automatically cleaned up when idle.
// If zero, defaults to min(10, PoolSize/3) to handle bursts effectively.
// If explicitly set, enforces minimum of 10 workers.
//
// Default: min(10, PoolSize/3), Minimum when set: 10
MaxWorkers int
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
// If the queue is full, new handoff requests will be rejected.
// Always capped by pool size since you can't handoff more connections than exist.
//
// Default: 10x max workers, capped by pool size, min 2
HandoffQueueSize int
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
// after a handoff completes. This provides additional resilience during cluster transitions.
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// ScaleDownDelay is the delay before checking if workers should be scaled down.
// This prevents expensive checks on every handoff completion and avoids rapid scaling cycles.
// Default: 2 seconds
ScaleDownDelay time.Duration
// LogLevel controls the verbosity of hitless upgrade logging.
// 0 = errors only, 1 = warnings, 2 = info, 3 = debug
// Default: 1 (warnings)
LogLevel int
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
// After this many retries, the connection will be removed from the pool.
// Default: 3
MaxHandoffRetries int
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != MaintNotificationsDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
ScaleDownDelay: 2 * time.Second,
LogLevel: 1,
// Connection Handoff Configuration
MaxHandoffRetries: 3,
}
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.RelaxedTimeout <= 0 {
return ErrInvalidRelaxedTimeout
}
if c.HandoffTimeout <= 0 {
return ErrInvalidHandoffTimeout
}
// Validate worker configuration
// Allow 0 for auto-calculation, but negative values are invalid
if c.MaxWorkers < 0 {
return ErrInvalidHandoffWorkers
}
// HandoffQueueSize validation - allow 0 for auto-calculation
if c.HandoffQueueSize < 0 {
return ErrInvalidHandoffQueueSize
}
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
if c.LogLevel < 0 || c.LogLevel > 3 {
return ErrInvalidLogLevel
}
// Validate Mode (maintenance notifications mode)
if !c.Mode.IsValid() {
return ErrInvalidMaintNotifications
}
// Validate EndpointType
if !c.EndpointType.IsValid() {
return ErrInvalidEndpointType
}
// Validate configuration fields
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
return ErrInvalidHandoffRetries
}
return nil
}
// ApplyDefaults applies default values to any zero-value fields in the configuration.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaults() *Config {
return c.ApplyDefaultsWithPoolSize(0)
}
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
// using the provided pool size to calculate worker defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
if c == nil {
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
}
defaults := DefaultConfig()
result := &Config{}
// Apply defaults for enum fields (empty/zero means not set)
if c.Mode == "" {
result.Mode = defaults.Mode
} else {
result.Mode = c.Mode
}
if c.EndpointType == "" {
result.EndpointType = defaults.EndpointType
} else {
result.EndpointType = c.EndpointType
}
// Apply defaults for duration fields (zero means not set)
if c.RelaxedTimeout <= 0 {
result.RelaxedTimeout = defaults.RelaxedTimeout
} else {
result.RelaxedTimeout = c.RelaxedTimeout
}
if c.HandoffTimeout <= 0 {
result.HandoffTimeout = defaults.HandoffTimeout
} else {
result.HandoffTimeout = c.HandoffTimeout
}
// Apply defaults for integer fields (zero means not set)
if c.HandoffQueueSize <= 0 {
result.HandoffQueueSize = defaults.HandoffQueueSize
} else {
result.HandoffQueueSize = c.HandoffQueueSize
}
// Copy worker configuration
result.MaxWorkers = c.MaxWorkers
// Apply worker defaults based on pool size
result.applyWorkerDefaults(poolSize)
// Apply queue size defaults based on max workers, capped by pool size
if c.HandoffQueueSize <= 0 {
// Queue size: 10x max workers, but never more than pool size
workerBasedSize := result.MaxWorkers * 10
result.HandoffQueueSize = util.Min(workerBasedSize, poolSize)
} else {
result.HandoffQueueSize = c.HandoffQueueSize
}
// Always cap queue size by pool size - no point having more queue slots than connections
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, poolSize)
// Ensure minimum queue size of 2
if result.HandoffQueueSize < 2 {
result.HandoffQueueSize = 2
}
if c.PostHandoffRelaxedDuration <= 0 {
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
} else {
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
if c.ScaleDownDelay <= 0 {
result.ScaleDownDelay = defaults.ScaleDownDelay
} else {
result.ScaleDownDelay = c.ScaleDownDelay
}
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
// We'll use the provided value as-is, since 0 is valid
result.LogLevel = c.LogLevel
// Apply defaults for configuration fields
if c.MaxHandoffRetries <= 0 {
result.MaxHandoffRetries = defaults.MaxHandoffRetries
} else {
result.MaxHandoffRetries = c.MaxHandoffRetries
}
return result
}
// Clone creates a deep copy of the configuration.
func (c *Config) Clone() *Config {
if c == nil {
return DefaultConfig()
}
return &Config{
Mode: c.Mode,
EndpointType: c.EndpointType,
RelaxedTimeout: c.RelaxedTimeout,
HandoffTimeout: c.HandoffTimeout,
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
ScaleDownDelay: c.ScaleDownDelay,
LogLevel: c.LogLevel,
// Configuration fields
MaxHandoffRetries: c.MaxHandoffRetries,
}
}
// applyWorkerDefaults calculates and applies worker defaults based on pool size
func (c *Config) applyWorkerDefaults(poolSize int) {
// Calculate defaults based on pool size
if poolSize <= 0 {
poolSize = 10 * runtime.GOMAXPROCS(0)
}
if c.MaxWorkers == 0 {
// When not set: min(10, poolSize/3) - don't exceed 10 workers for small pools
c.MaxWorkers = util.Min(10, poolSize/3)
} else {
// When explicitly set: max(10, set_value) - ensure at least 10 workers
c.MaxWorkers = util.Max(10, c.MaxWorkers)
}
// Ensure minimum of 1 worker (fallback for very small pools)
if c.MaxWorkers < 1 {
c.MaxWorkers = 1
}
}
// DetectEndpointType automatically detects the appropriate endpoint type
// based on the connection address and TLS configuration.
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
// Parse the address to determine if it's an IP or hostname
isPrivate := isPrivateIP(addr)
var endpointType EndpointType
if tlsEnabled {
// TLS requires FQDN for certificate validation
if isPrivate {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
} else {
// No TLS, can use IP addresses
if isPrivate {
endpointType = EndpointTypeInternalIP
} else {
endpointType = EndpointTypeExternalIP
}
}
return endpointType
}
// isPrivateIP checks if the given address is in a private IP range.
func isPrivateIP(addr string) bool {
// Extract host from "host:port" format
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // Assume no port
}
ip := net.ParseIP(host)
if ip == nil {
return false // Not an IP address (likely hostname)
}
// Check for private/loopback ranges
return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
}

427
hitless/config_test.go Normal file
View File

@@ -0,0 +1,427 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
func TestConfig(t *testing.T) {
t.Run("DefaultConfig", func(t *testing.T) {
config := DefaultConfig()
// MaxWorkers should be 0 in default config (auto-calculated)
if config.MaxWorkers != 0 {
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
}
// HandoffQueueSize should be 0 in default config (auto-calculated)
if config.HandoffQueueSize != 0 {
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
}
if config.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
}
// Test configuration fields have proper defaults
if config.MaxHandoffRetries != 3 {
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
}
if config.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
}
if config.PostHandoffRelaxedDuration != 0 {
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
}
// Test that defaults are applied correctly
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
}
})
t.Run("ConfigValidation", func(t *testing.T) {
// Valid config with applied defaults
config := DefaultConfig().ApplyDefaults()
if err := config.Validate(); err != nil {
t.Errorf("Default config with applied defaults should be valid: %v", err)
}
// Invalid worker configuration (negative MaxWorkers)
config = &Config{
RelaxedTimeout: 30 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: -1, // This should be invalid
HandoffQueueSize: 100,
PostHandoffRelaxedDuration: 10 * time.Second,
LogLevel: 1,
MaxHandoffRetries: 3, // Add required field
}
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
}
// Invalid HandoffQueueSize
config = DefaultConfig().ApplyDefaults()
config.HandoffQueueSize = -1
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
}
// Invalid PostHandoffRelaxedDuration
config = DefaultConfig().ApplyDefaults()
config.PostHandoffRelaxedDuration = -1 * time.Second
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
}
})
t.Run("ConfigClone", func(t *testing.T) {
original := DefaultConfig()
original.MaxWorkers = 20
original.HandoffQueueSize = 200
cloned := original.Clone()
if cloned.MaxWorkers != 20 {
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
}
if cloned.HandoffQueueSize != 200 {
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
}
// Modify original to ensure clone is independent
original.MaxWorkers = 2
if cloned.MaxWorkers != 20 {
t.Error("Clone should be independent of original")
}
})
}
func TestApplyDefaults(t *testing.T) {
t.Run("NilConfig", func(t *testing.T) {
var config *Config
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// With nil config, should get default config with auto-calculated workers
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
workerBasedSize := result.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
}
})
t.Run("PartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 12, // Set this field explicitly
// Leave other fields as zero values
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should keep the explicitly set values
if result.MaxWorkers != 12 {
t.Errorf("Expected MaxWorkers to be 12 (explicitly set), got %d", result.MaxWorkers)
}
// Should apply default for unset fields (auto-calculated queue size, capped by pool size)
workerBasedSize := result.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
}
// Test explicit queue size capping by pool size
configWithLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than pool size
}
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
if resultCapped.HandoffQueueSize != 20 {
t.Errorf("Expected HandoffQueueSize to be capped by pool size (20), got %d", resultCapped.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
if result.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
}
})
t.Run("ZeroValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 0, // Zero value should get auto-calculated defaults
HandoffQueueSize: 0, // Zero value should get default
RelaxedTimeout: 0, // Zero value should get default
LogLevel: 0, // Zero is valid for LogLevel (errors only)
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Zero values should get auto-calculated defaults
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size)
workerBasedSize := result.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
// LogLevel 0 should be preserved (it's a valid value)
if result.LogLevel != 0 {
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
}
})
}
func TestProcessorWithConfig(t *testing.T) {
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 5,
HandoffQueueSize: 50,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 5 * time.Second,
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// The processor should be created successfully with custom config
if processor == nil {
t.Error("Processor should be created with custom config")
}
})
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 7, // Only set worker field
// Other fields will get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Should work with partial config (defaults applied)
if processor == nil {
t.Error("Processor should be created with partial config")
}
})
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Should use default config when nil is passed
if processor == nil {
t.Error("Processor should be created with nil config (using defaults)")
}
})
}
func TestIntegrationWithApplyDefaults(t *testing.T) {
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
// Create a partial config with only some fields set
partialConfig := &Config{
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
LogLevel: 2, // Custom value
// Other fields left as zero values - should get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor - should apply defaults to missing fields
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
defer processor.Shutdown(context.Background())
// Processor should be created successfully
if processor == nil {
t.Error("Processor should be created with partial config")
}
// Test that the ApplyDefaults method worked correctly by creating the same config
// and applying defaults manually
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should preserve custom values (when >= 10)
if expectedConfig.MaxWorkers != 15 {
t.Errorf("Expected MaxWorkers to be 15, got %d", expectedConfig.MaxWorkers)
}
if expectedConfig.LogLevel != 2 {
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
}
// Should apply defaults for missing fields (auto-calculated queue size, capped by pool size)
workerBasedSize := expectedConfig.MaxWorkers * 10
poolSize := 100 // Default pool size used in ApplyDefaults
expectedQueueSize := util.Min(workerBasedSize, poolSize)
if expectedConfig.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d",
expectedQueueSize, workerBasedSize, poolSize, expectedConfig.HandoffQueueSize)
}
// Test that queue size is always capped by pool size
if expectedConfig.HandoffQueueSize > poolSize {
t.Errorf("HandoffQueueSize (%d) should never exceed pool size (%d)",
expectedConfig.HandoffQueueSize, poolSize)
}
if expectedConfig.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
}
if expectedConfig.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
}
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
}
})
}
func TestEnhancedConfigValidation(t *testing.T) {
t.Run("ValidateFields", func(t *testing.T) {
config := DefaultConfig()
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
// Should pass validation with default values
if err := config.Validate(); err != nil {
t.Errorf("Default config should be valid, got error: %v", err)
}
// Test invalid MaxHandoffRetries
config.MaxHandoffRetries = 0
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 0")
}
config.MaxHandoffRetries = 11
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 11")
}
config.MaxHandoffRetries = 3 // Reset to valid value
// Should pass validation again
if err := config.Validate(); err != nil {
t.Errorf("Config should be valid after reset, got error: %v", err)
}
})
}
func TestConfigClone(t *testing.T) {
original := DefaultConfig()
original.MaxHandoffRetries = 7
original.HandoffTimeout = 8 * time.Second
cloned := original.Clone()
// Test that values are copied
if cloned.MaxHandoffRetries != 7 {
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
}
if cloned.HandoffTimeout != 8*time.Second {
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
}
// Test that modifying clone doesn't affect original
cloned.MaxHandoffRetries = 10
if original.MaxHandoffRetries != 7 {
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
}
}
func TestMaxWorkersLogic(t *testing.T) {
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
testCases := []struct {
poolSize int
expectedWorkers int
description string
}{
{6, 2, "Small pool: min(10, 6/3) = min(10, 2) = 2"},
{15, 5, "Medium pool: min(10, 15/3) = min(10, 5) = 5"},
{30, 10, "Large pool: min(10, 30/3) = min(10, 10) = 10"},
{60, 10, "Very large pool: min(10, 60/3) = min(10, 20) = 10"},
{120, 10, "Huge pool: min(10, 120/3) = min(10, 40) = 10"},
}
for _, tc := range testCases {
config := &Config{} // MaxWorkers = 0 (not set)
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
testCases := []struct {
setValue int
expectedWorkers int
description string
}{
{1, 10, "Set 1: max(10, 1) = 10 (enforced minimum)"},
{5, 10, "Set 5: max(10, 5) = 10 (enforced minimum)"},
{8, 10, "Set 8: max(10, 8) = 10 (enforced minimum)"},
{10, 10, "Set 10: max(10, 10) = 10 (exact minimum)"},
{15, 15, "Set 15: max(10, 15) = 15 (respects user choice)"},
{20, 20, "Set 20: max(10, 20) = 20 (respects user choice)"},
}
for _, tc := range testCases {
config := &Config{
MaxWorkers: tc.setValue, // Explicitly set
}
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
}

76
hitless/errors.go Normal file
View File

@@ -0,0 +1,76 @@
package hitless
import (
"errors"
"fmt"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
ErrInvalidLogLevel = errors.New("hitless: log level must be between 0 and 3")
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
ErrInvalidConnectionValidationTimeout = errors.New("hitless: ConnectionValidationTimeout must be greater than 0 and less than 30 seconds")
ErrInvalidConnectionHealthCheckInterval = errors.New("hitless: ConnectionHealthCheckInterval must be between 0 and 1 hour")
ErrInvalidOperationCleanupInterval = errors.New("hitless: OperationCleanupInterval must be greater than 0 and less than 1 hour")
ErrInvalidMaxActiveOperations = errors.New("hitless: MaxActiveOperations must be between 100 and 100000")
ErrInvalidNotificationBufferSize = errors.New("hitless: NotificationBufferSize must be between 10 and 10000")
ErrInvalidNotificationTimeout = errors.New("hitless: NotificationTimeout must be greater than 0 and less than 30 seconds")
)
// Integration errors
var (
ErrInvalidClient = errors.New("hitless: invalid client type")
)
// Handoff errors
var (
ErrHandoffInProgress = errors.New("hitless: handoff already in progress")
ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress")
ErrConnectionFailed = errors.New("hitless: failed to establish new connection")
)
// Dead error variables removed - unused in simplified architecture
// Notification errors
var (
ErrInvalidNotification = errors.New("hitless: invalid notification format")
)
// Dead error variables removed - unused in simplified architecture
// HandoffError represents an error that occurred during connection handoff.
type HandoffError struct {
Operation string
Endpoint string
Cause error
}
func (e *HandoffError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause)
}
return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint)
}
func (e *HandoffError) Unwrap() error {
return e.Cause
}
// NewHandoffError creates a new HandoffError.
func NewHandoffError(operation, endpoint string, cause error) *HandoffError {
return &HandoffError{
Operation: operation,
Endpoint: endpoint,
Cause: cause,
}
}

63
hitless/example_hooks.go Normal file
View File

@@ -0,0 +1,63 @@
package hitless
import (
"context"
"time"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
startTimeKey contextKey = "notif_hitless_start_time"
)
// MetricsHook collects metrics about notification processing.
type MetricsHook struct {
NotificationCounts map[string]int64
ProcessingTimes map[string]time.Duration
ErrorCounts map[string]int64
}
// NewMetricsHook creates a new metrics collection hook.
func NewMetricsHook() *MetricsHook {
return &MetricsHook{
NotificationCounts: make(map[string]int64),
ProcessingTimes: make(map[string]time.Duration),
ErrorCounts: make(map[string]int64),
}
}
// PreHook records the start time for processing metrics.
func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
mh.NotificationCounts[notificationType]++
// Store start time in context for duration calculation
startTime := time.Now()
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
return notification, true
}
// PostHook records processing completion and any errors.
func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
// Calculate processing duration
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
duration := time.Since(startTime)
mh.ProcessingTimes[notificationType] = duration
}
// Record errors
if result != nil {
mh.ErrorCounts[notificationType]++
}
}
// GetMetrics returns a summary of collected metrics.
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"notification_counts": mh.NotificationCounts,
"processing_times": mh.ProcessingTimes,
"error_counts": mh.ErrorCounts,
}
}

299
hitless/hitless_manager.go Normal file
View File

@@ -0,0 +1,299 @@
package hitless
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
)
// Push notification type constants for hitless upgrades
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
NotificationMigrated = "MIGRATED"
NotificationFailingOver = "FAILING_OVER"
NotificationFailedOver = "FAILED_OVER"
)
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
var hitlessNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
// NotificationHook is called before and after notification processing
// PreHook can modify the notification and return false to skip processing
// PostHook is called after successful processing
type NotificationHook interface {
PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool)
PostHook(ctx context.Context, notificationType string, notification []interface{}, result error)
}
// MovingOperationKey provides a unique key for tracking MOVING operations
// that combines sequence ID with connection identifier to handle duplicate
// sequence IDs across multiple connections to the same node.
type MovingOperationKey struct {
SeqID int64 // Sequence ID from MOVING notification
ConnID uint64 // Unique connection identifier
}
// String returns a string representation of the key for debugging
func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
type HitlessManager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
pool pool.Pooler
// MOVING operation tracking - using sync.Map for better concurrent performance
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
// Atomic state tracking - no locks needed for state queries
activeOperationCount atomic.Int64 // Number of active operations
closed atomic.Bool // Manager closed state
// Notification hooks for extensibility
hooks []NotificationHook
hooksMu sync.RWMutex // Protects hooks slice
poolHooksRef *PoolHook
}
// MovingOperation tracks an active MOVING operation.
type MovingOperation struct {
SeqID int64
NewEndpoint string
StartTime time.Time
Deadline time.Time
}
// NewHitlessManager creates a new simplified hitless manager.
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &HitlessManager{
client: client,
pool: pool,
options: client.GetOptions(),
config: config.Clone(),
hooks: make([]NotificationHook, 0),
}
// Set up push notification handling
if err := hm.setupPushNotifications(); err != nil {
return nil, err
}
return hm, nil
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *HitlessManager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm}
// Register handlers for all hitless upgrade notifications with the client's processor
for _, notificationType := range hitlessNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
}
}
return nil
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Create MOVING operation record
movingOp := &MovingOperation{
SeqID: seqID,
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Deadline: deadline,
}
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
internal.Logger.Printf(ctx, "Duplicate MOVING operation ignored: %s", key.String())
return nil
}
// Increment active operation count atomically
hm.activeOperationCount.Add(1)
return nil
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
}
}
// GetActiveMovingOperations returns active operations with composite keys.
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
hm.activeMovingOps.Range(func(key, value interface{}) bool {
k := key.(MovingOperationKey)
op := value.(*MovingOperation)
// Create a copy to avoid sharing references
result[k] = &MovingOperation{
SeqID: op.SeqID,
NewEndpoint: op.NewEndpoint,
StartTime: op.StartTime,
Deadline: op.Deadline,
}
return true // Continue iteration
})
return result
}
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the hitless manager.
func (hm *HitlessManager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
// Shutdown the pool hook if it exists
if hm.poolHooksRef != nil {
// Use a timeout to prevent hanging indefinitely
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := hm.poolHooksRef.Shutdown(shutdownCtx)
if err != nil {
// was not able to close pool hook, keep closed state false
hm.closed.Store(false)
return err
}
// Remove the pool hook from the pool
if hm.pool != nil {
hm.pool.RemovePoolHook(hm.poolHooksRef)
}
}
// Clear all active operations
hm.activeMovingOps.Range(func(key, value interface{}) bool {
hm.activeMovingOps.Delete(key)
return true
})
// Reset counter
hm.activeOperationCount.Store(0)
return nil
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *HitlessManager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
return StateIdle
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
currentNotification := notification
for _, hook := range hm.hooks {
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationType, currentNotification)
if !shouldContinue {
return modifiedNotification, false
}
currentNotification = modifiedNotification
}
return currentNotification, true
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
for _, hook := range hm.hooks {
hook.PostHook(ctx, notificationType, notification, result)
}
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
// Get pool size from client options for better worker defaults
poolSize := 0
if hm.options != nil {
poolSize = hm.options.GetPoolSize()
}
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
hm.poolHooksRef.SetPool(hm.pool)
return hm.poolHooksRef
}

View File

@@ -0,0 +1,260 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
)
// MockClient implements interfaces.ClientInterface for testing
type MockClient struct {
options interfaces.OptionsInterface
}
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
return mc.options
}
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
return &MockPushProcessor{}
}
// MockPushProcessor implements interfaces.NotificationProcessor for testing
type MockPushProcessor struct{}
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
return nil
}
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
return nil
}
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
return nil
}
// MockOptions implements interfaces.OptionsInterface for testing
type MockOptions struct{}
func (mo *MockOptions) GetReadTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetWriteTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetAddr() string {
return "localhost:6379"
}
func (mo *MockOptions) IsTLSEnabled() bool {
return false
}
func (mo *MockOptions) GetProtocol() int {
return 3 // RESP3
}
func (mo *MockOptions) GetPoolSize() int {
return 10
}
func (mo *MockOptions) GetNetwork() string {
return "tcp"
}
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
return nil, nil
}
}
func TestHitlessManagerRefactoring(t *testing.T) {
t.Run("AtomicStateTracking", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
// Test initial state
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress initially")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
// Add an operation
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Test state after adding operation
if !manager.IsHandoffInProgress() {
t.Error("Expected handoff in progress after adding operation")
}
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateMoving {
t.Errorf("Expected StateMoving, got %v", manager.GetState())
}
// Remove the operation
manager.UntrackOperationWithConnID(12345, 1)
// Test state after removing operation
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress after removing operation")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
})
t.Run("SyncMapPerformance", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Test concurrent operations
const numOps = 100
for i := 0; i < numOps; i++ {
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
if err != nil {
t.Fatalf("Failed to track operation %d: %v", i, err)
}
}
if manager.GetActiveOperationCount() != numOps {
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
}
// Test GetActiveMovingOperations
operations := manager.GetActiveMovingOperations()
if len(operations) != numOps {
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
}
// Remove all operations
for i := 0; i < numOps; i++ {
manager.UntrackOperationWithConnID(int64(i), uint64(i))
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
}
})
t.Run("DuplicateOperationHandling", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Add operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Try to add duplicate operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Duplicate operation should not return error: %v", err)
}
// Should still have only 1 operation
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
}
})
t.Run("NotificationTypeConstants", func(t *testing.T) {
// Test that constants are properly defined
expectedTypes := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
if len(hitlessNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
}
// Test that all expected types are present
typeMap := make(map[string]bool)
for _, t := range hitlessNotificationTypes {
typeMap[t] = true
}
for _, expected := range expectedTypes {
if !typeMap[expected] {
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
}
}
// Test that hitlessNotificationTypes contains all expected constants
expectedConstants := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
for _, expected := range expectedConstants {
found := false
for _, actual := range hitlessNotificationTypes {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
}
}
})
}

48
hitless/hooks.go Normal file
View File

@@ -0,0 +1,48 @@
package hitless
import (
"context"
"github.com/redis/go-redis/v9/internal"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel int
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel >= 2 { // Info level
internal.Logger.Printf(ctx, "hitless: processing %s notification: %v", notificationType, notification)
}
return notification, true // Continue processing with unmodified notification
}
// PostHook logs the result after processing.
func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
if result != nil && lh.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v", notificationType, result)
} else if lh.LogLevel >= 3 { // Debug level
internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType)
}
}
// FilterHook is an example hook that can filter out certain notifications.
type FilterHook struct {
BlockedTypes map[string]bool
}
// PreHook filters notifications based on type.
func (fh *FilterHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) {
if fh.BlockedTypes[notificationType] {
internal.Logger.Printf(ctx, "hitless: filtering out %s notification", notificationType)
return notification, false // Skip processing
}
return notification, true
}
// PostHook does nothing for filter hook.
func (fh *FilterHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) {
// No post-processing needed for filter hook
}

View File

@@ -0,0 +1,247 @@
package hitless
import (
"context"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *HitlessManager
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
return ErrInvalidNotification
}
// Process pre-hooks - they can modify the notification or skip processing
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, notificationType, notification)
if !shouldContinue {
return nil // Hooks decided to skip processing
}
var err error
switch notificationType {
case NotificationMoving:
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
case NotificationMigrating:
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
case NotificationMigrated:
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
case NotificationFailingOver:
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
case NotificationFailedOver:
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
default:
// Ignore other notification types (e.g., pub/sub messages)
err = nil
}
// Process post-hooks with the result
snh.manager.processPostHooks(ctx, notificationType, modifiedNotification, err)
return err
}
// handleMoving processes MOVING notifications.
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
return ErrInvalidNotification
}
seqIDStr, ok := notification[1].(string)
if !ok {
return ErrInvalidNotification
}
seqID, err := strconv.ParseInt(seqIDStr, 10, 64)
if err != nil {
return ErrInvalidNotification
}
// Extract timeS
timeSStr, ok := notification[2].(string)
if !ok {
return ErrInvalidNotification
}
timeS, err := strconv.ParseInt(timeSStr, 10, 64)
if err != nil {
return ErrInvalidNotification
}
newEndpoint := ""
if len(notification) > 3 {
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
return ErrInvalidNotification
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
return ErrInvalidNotification
}
// Type assert to get the underlying pool connection
var poolConn *pool.Conn
if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok {
poolConn = connAdapter.GetPoolConn()
} else if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
return ErrInvalidNotification
}
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
// delay the handoff for timeS/2 seconds to the same endpoint
// do this in a goroutine to avoid blocking the notification handler
go func() {
time.Sleep(time.Duration(timeS/2) * time.Second)
if poolConn == nil || poolConn.IsClosed() {
return
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
}
}()
return nil
}
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
}
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in hitless manager for monitoring/debugging
if snh.manager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return fmt.Errorf("hitless: manager not initialized")
}
return nil
}
// handleMigrating processes MIGRATING notifications.
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleMigrated processes MIGRATED notifications.
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
connAdapter.ClearRelaxedTimeout()
return nil
}
// handleFailingOver processes FAILING_OVER notifications.
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleFailedOver processes FAILED_OVER notifications.
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
return ErrInvalidNotification
}
// Get the connection from handler context and type assert to connectionAdapter
if handlerCtx.Conn == nil {
return ErrInvalidNotification
}
// Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout
connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout)
if !ok {
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
connAdapter.ClearRelaxedTimeout()
return nil
}

477
hitless/pool_hook.go Normal file
View File

@@ -0,0 +1,477 @@
package hitless
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// HitlessManagerInterface defines the interface for completing handoff operations
type HitlessManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
// HandoffRequest represents a request to handoff a connection to a new endpoint
type HandoffRequest struct {
Conn *pool.Conn
ConnID uint64 // Unique connection identifier
Endpoint string
SeqID int64
Pool pool.Pooler // Pool to remove connection from on failure
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with hitless upgrade support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
baseDialer func(context.Context, string, string) (net.Conn, error)
// Network type (e.g., "tcp", "unix")
network string
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers int32 // Atomic counter for active workers
workerTimeout time.Duration // How long workers wait for work before exiting
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the hitless upgrade
config *Config
// Hitless manager for operation completion tracking
hitlessManager HitlessManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
// Apply defaults to any missing configuration fields, using pool size for worker calculations
config = config.ApplyDefaultsWithPoolSize(poolSize)
ph := &PoolHook{
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
// handoffQueue is a buffered channel for queuing handoff requests
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
// shutdown is a channel for signaling shutdown
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: 0, // Start with no workers - create on demand
workerTimeout: 30 * time.Second, // Workers exit after 30s of inactivity
config: config,
// Hitless manager for operation completion tracking
hitlessManager: hitlessManager,
}
// No upfront worker creation - workers are created on demand
return ph
}
// SetPool sets the pool interface for removing connections on handoff failure
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
ph.pool = pooler
}
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return int(atomic.LoadInt32(&ph.activeWorkers))
}
// GetScaleLevel returns 1 if workers are active, 0 if none (for testing compatibility)
func (ph *PoolHook) GetScaleLevel() int {
if atomic.LoadInt32(&ph.activeWorkers) > 0 {
return 1
}
return 0
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
_, pending := ph.pending.Load(conn.GetID())
return pending
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) error {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
// in a handoff state at the moment.
// Check if connection is usable (not in a handoff state)
// Should not happen since the pool will not return a connection that is not usable.
if !conn.IsUsable() {
return ErrConnectionMarkedForHandoff
}
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
if conn.ShouldHandoff() {
return ErrConnectionMarkedForHandoff
}
return nil
}
// OnPut is called when a connection is returned to the pool
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
// first check if we should handoff for faster rejection
if conn.ShouldHandoff() {
// check pending handoff to not queue the same connection twice
_, hasPendingHandoff := ph.pending.Load(conn.GetID())
if !hasPendingHandoff {
// Check for empty endpoint first (synchronous check)
if conn.GetHandoffEndpoint() == "" {
conn.ClearHandoffState()
} else {
if err := ph.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
return false, true, nil // Don't pool, remove connection, no error to caller
}
// Check if handoff was already processed by a worker before we can mark it as queued
if !conn.ShouldHandoff() {
// Handoff was already processed - this is normal and the connection should be pooled
return true, false, nil
}
if err := conn.MarkQueuedForHandoff(); err != nil {
// If marking fails, check if handoff was processed in the meantime
if !conn.ShouldHandoff() {
// Handoff was processed - this is normal, pool the connection
return true, false, nil
}
// Other error - remove the connection
return false, true, nil
}
return true, false, nil
}
}
}
// Default: pool the connection
return true, false, nil
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (ph *PoolHook) ensureWorkerAvailable() {
select {
case <-ph.shutdown:
return
default:
// Check if we need a new worker
currentWorkers := atomic.LoadInt32(&ph.activeWorkers)
if currentWorkers < int32(ph.maxWorkers) {
// Try to create a new worker (atomic increment to prevent race)
if atomic.CompareAndSwapInt32(&ph.activeWorkers, currentWorkers, currentWorkers+1) {
ph.workerWg.Add(1)
go ph.onDemandWorker()
}
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (ph *PoolHook) onDemandWorker() {
defer func() {
// Decrement active worker count when exiting
atomic.AddInt32(&ph.activeWorkers, -1)
ph.workerWg.Done()
}()
for {
select {
case request := <-ph.handoffQueue:
// Check for shutdown before processing
select {
case <-ph.shutdown:
// Clean up the request before exiting
ph.pending.Delete(request.ConnID)
return
default:
// Process the request
ph.processHandoffRequest(request)
}
case <-time.After(ph.workerTimeout):
// Worker has been idle for too long, exit to save resources
if ph.config != nil && ph.config.LogLevel >= 3 { // Debug level
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout)
}
return
case <-ph.shutdown:
return
}
}
}
// processHandoffRequest processes a single handoff request
func (ph *PoolHook) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer ph.pending.Delete(request.Conn.GetID())
// Create a context with handoff timeout from config
handoffTimeout := 30 * time.Second // Default fallback
if ph.config != nil && ph.config.HandoffTimeout > 0 {
handoffTimeout = ph.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-ph.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool)
// If handoff failed, restore the handoff state for potential retry
if err != nil {
request.Conn.RestoreHandoffState()
internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err)
}
// No need for scale down scheduling with on-demand workers
// Workers automatically exit when idle
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (ph *PoolHook) queueHandoff(conn *pool.Conn) error {
// Create handoff request
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: conn.GetHandoffEndpoint(),
SeqID: conn.GetMovingSeqID(),
Pool: ph.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-ph.shutdown:
return errors.New("shutdown")
default:
select {
case <-ph.shutdown:
return errors.New("shutdown")
case ph.handoffQueue <- request:
// Store in pending map
ph.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
ph.ensureWorkerAvailable()
return nil
default:
// Queue is full - log and attempt scaling
queueLen := len(ph.handoffQueue)
queueCap := cap(ph.handoffQueue)
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers",
queueLen, queueCap)
}
}
}
// Ensure we have workers available to handle the load
ph.ensureWorkerAvailable()
return errors.New("queue full")
}
// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure
// if err is returned, connection will be removed from pool
func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) error {
// Clear handoff state after successful handoff
seqID := conn.GetMovingSeqID()
connID := conn.GetID()
// Notify hitless manager of completion if available
if ph.hitlessManager != nil {
defer ph.hitlessManager.UntrackOperationWithConnID(seqID, connID)
}
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
// TODO(hitless): Handle by performing the handoff to the current endpoint in N seconds,
// Where N is the time in the moving notification...
// For now, clear the handoff state and return
conn.ClearHandoffState()
return nil
}
retries := conn.IncrementAndGetHandoffRetries(1)
maxRetries := 3 // Default fallback
if ph.config != nil {
maxRetries = ph.config.MaxHandoffRetries
}
if retries > maxRetries {
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of connection %d to %s",
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
}
err := ErrMaxHandoffRetriesReached
if pooler != nil {
go pooler.Remove(ctx, conn, err)
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed connection %d from pool due to max handoff retries reached",
conn.GetID())
}
} else {
go conn.Close()
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
conn.GetID(), err)
}
}
return err
}
// Create endpoint-specific dialer
endpointDialer := ph.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
// TODO(hitless): retry
// This is the only case where we should retry the handoff request
// Should we do anything else other than return the error?
return err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// Remove the connection from the pool since it's in a bad state
if pooler != nil {
// Use pool.Pooler interface directly - no adapter needed
go pooler.Remove(ctx, conn, err)
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed connection %d from pool due to handoff initialization failure: %v",
conn.GetID(), err)
}
} else {
go conn.Close()
if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v",
conn.GetID(), err)
}
}
// Keep the handoff state for retry
return err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := ph.config.RelaxedTimeout
postHandoffDuration := ph.config.PostHandoffRelaxedDuration
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(postHandoffDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if ph.config.LogLevel >= 2 { // Info level
internal.Logger.Printf(context.Background(),
"hitless: applied post-handoff relaxed timeout (%v) until %v for connection %d",
relaxedTimeout, deadline.Format("15:04:05.000"), connID)
}
}
return nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port))
}
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
ph.shutdownOnce.Do(func() {
close(ph.shutdown)
// No timers to clean up with on-demand workers
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
ph.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff")

959
hitless/pool_hook_test.go Normal file
View File

@@ -0,0 +1,959 @@
package hitless
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
shouldFailInit bool
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// createMockPoolConnection creates a mock pool connection for testing
func createMockPoolConnection() *pool.Conn {
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
conn.SetUsable(true) // Make connection usable for testing
return conn
}
// mockPool implements pool.Pooler for testing
type mockPool struct {
removedConnections map[uint64]bool
mu sync.Mutex
}
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
return nil
}
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
// Not implemented for testing
}
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Use pool.Conn directly - no adapter needed
mp.removedConnections[conn.GetID()] = true
}
// WasRemoved safely checks if a connection was removed from the pool
func (mp *mockPool) WasRemoved(connID uint64) bool {
mp.mu.Lock()
defer mp.mu.Unlock()
return mp.removedConnections[connID]
}
func (mp *mockPool) Len() int {
return 0
}
func (mp *mockPool) IdleLen() int {
return 0
}
func (mp *mockPool) Stats() *pool.Stats {
return &pool.Stats{}
}
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) Close() error {
return nil
}
// TestConnectionHook tests the Redis connection processor functionality
func TestConnectionHook(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Verify connection is marked for handoff
if !conn.ShouldHandoff() {
t.Fatal("Connection should be marked for handoff")
}
// Set a mock initialization function with synchronization
initConnCalled := make(chan bool, 1)
proceedWithInit := make(chan bool, 1)
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
select {
case initConnCalled <- true:
default:
}
// Wait for test to proceed
<-proceedWithInit
return nil
}
conn.SetInitConnFunc(initConnFunc)
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection immediately (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled immediately with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for initialization to be called (indicates handoff started)
select {
case <-initConnCalled:
// Good, initialization was called
case <-time.After(1 * time.Second):
t.Fatal("Timeout waiting for initialization function to be called")
}
// Connection should be in pending map while initialization is blocked
if _, pending := processor.pending.Load(conn.GetID()); !pending {
t.Error("Connection should be in pending handoffs map")
}
// Allow initialization to proceed
proceedWithInit <- true
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(2 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify handoff completed (removed from pending map)
if _, pending := processor.pending.Load(conn); pending {
t.Error("Connection should be removed from pending map after handoff")
}
// Verify connection is usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after successful handoff")
}
// Verify handoff state is cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after completion")
}
})
t.Run("HandoffNotNeeded", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// Don't mark for handoff
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error when handoff not needed: %v", err)
}
// Should pool the connection normally
if !shouldPool {
t.Error("Connection should be pooled when no handoff needed")
}
if shouldRemove {
t.Error("Connection should not be removed when no handoff needed")
}
})
t.Run("EmptyEndpoint", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error with empty endpoint: %v", err)
}
// Should pool the connection (empty endpoint clears state)
if !shouldPool {
t.Error("Connection should be pooled after clearing empty endpoint")
}
if shouldRemove {
t.Error("Connection should not be removed after clearing empty endpoint")
}
// State should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
}
})
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("dial failed")
}
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
HandoffTimeout: 1 * time.Second, // Shorter timeout for faster test
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not return error to caller: %v", err)
}
// Should pool the connection initially (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled initially with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for handoff to complete and fail with proper timeout and polling
// Use longer timeout to account for handoff timeout + processing time
timeout := time.After(5 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
// wait for handoff to start
time.Sleep(100 * time.Millisecond)
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for failed handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn.GetID()); !pending {
handoffCompleted = true
}
}
}
// Connection should be removed from pending map after failed handoff
if _, pending := processor.pending.Load(conn.GetID()); pending {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Handoff state should still be set (since handoff failed)
if !conn.ShouldHandoff() {
t.Error("Connection should still be marked for handoff after failed handoff")
}
})
t.Run("BufferedDataRESP2", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// For this test, we'll just verify the logic works for connections without buffered data
// The actual buffered data detection is handled by the pool's connection health check
// which is outside the scope of the Redis connection processor
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection normally (no buffered data in mock)
if !shouldPool {
t.Error("Connection should be pooled when no buffered data")
}
if shouldRemove {
t.Error("Connection should not be removed when no buffered data")
}
})
t.Run("OnGet", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should not error for normal connection: %v", err)
}
})
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Simulate a pending handoff by marking for handoff and queuing
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Clean up
processor.pending.Delete(conn)
})
t.Run("EventDrivenStateManagement", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Test initial state - no pending handoffs
if _, pending := processor.pending.Load(conn); pending {
t.Error("New connection should not have pending handoffs")
}
// Test adding to pending map
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
if _, pending := processor.pending.Load(conn.GetID()); !pending {
t.Error("Connection should be in pending map")
}
// Test OnGet with pending handoff
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
}
// Test removing from pending map and clearing handoff state
processor.pending.Delete(conn)
if _, pending := processor.pending.Load(conn); pending {
t.Error("Connection should be removed from pending map")
}
// Clear handoff state to simulate completed handoff
conn.ClearHandoffState()
conn.SetUsable(true) // Make connection usable again
// Test OnGet without pending handoff
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("Should not return error for non-pending connection: %v", err)
}
})
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
// Create processor with small queue to test optimization features
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 2,
MaxHandoffRetries: 3, // Small queue to trigger optimizations
LogLevel: 3, // Debug level to see optimization logs
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Add small delay to simulate network latency
time.Sleep(10 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create multiple connections that need handoff to fill the queue
connections := make([]*pool.Conn, 5)
for i := 0; i < 5; i++ {
connections[i] = createMockPoolConnection()
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
}
// Set a mock initialization function
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
}
ctx := context.Background()
successCount := 0
// Process connections - should trigger scaling and timeout logic
for _, conn := range connections {
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Logf("OnPut returned error (expected with timeout): %v", err)
}
if shouldPool && !shouldRemove {
successCount++
}
}
// With timeout and scaling, most handoffs should eventually succeed
if successCount == 0 {
t.Error("Should have queued some handoffs with timeout and scaling")
}
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
// Give time for workers to process and scaling to occur
time.Sleep(100 * time.Millisecond)
})
t.Run("WorkerScalingBehavior", func(t *testing.T) {
// Create processor with small queue to test scaling behavior
config := &Config{
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
HandoffQueueSize: 1,
MaxHandoffRetries: 3, // Very small queue to force scaling
LogLevel: 2, // Info level to see scaling logs
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Verify initial worker count (should be 0 with on-demand workers)
if processor.GetCurrentWorkers() != 0 {
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
}
if processor.GetScaleLevel() != 0 {
t.Errorf("Processor should be at scale level 0 initially, got %d", processor.GetScaleLevel())
}
if processor.maxWorkers != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers)
}
// The on-demand worker behavior creates workers only when needed
// This test just verifies the basic configuration is correct
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
processor.maxWorkers, processor.GetCurrentWorkers())
})
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
// Create processor with fast post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
RelaxedTimeout: 5 * time.Second,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a connection and trigger handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Handoff should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify relaxed timeout is set with deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should have relaxed timeout after handoff")
}
// Test that timeout is still active before deadline
// We'll use HasRelaxedTimeout which internally checks the deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should still have active relaxed timeout before deadline")
}
// Wait for deadline to pass
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
// Test that timeout is automatically restored after deadline
// HasRelaxedTimeout should return false after deadline passes
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have active relaxed timeout after deadline")
}
// Additional verification: calling HasRelaxedTimeout again should still return false
// and should have cleared the internal timeout values
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have relaxed timeout after deadline (second check)")
}
t.Logf("Passive timeout restoration test completed successfully")
})
t.Run("UsableFlagBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a new connection without setting it usable
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
// Initially, connection should not be usable (not initialized)
if conn.IsUsable() {
t.Error("New connection should not be usable before initialization")
}
// Simulate initialization by setting usable to true
conn.SetUsable(true)
if !conn.IsUsable() {
t.Error("Connection should be usable after initialization")
}
// OnGet should succeed for usable connection
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed for usable connection: %v", err)
}
// Mark connection for handoff
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Connection should still be usable until queued, but marked for handoff
if !conn.IsUsable() {
t.Error("Connection should still be usable after being marked for handoff (until queued)")
}
if !conn.ShouldHandoff() {
t.Error("Connection should be marked for handoff")
}
// OnGet should fail for connection marked for handoff
err = processor.OnGet(ctx, conn, false)
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete
time.Sleep(50 * time.Millisecond)
// After handoff completion, connection should be usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after handoff completion")
}
// OnGet should succeed again
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed after handoff completion: %v", err)
}
t.Logf("Usable flag behavior test completed successfully")
})
t.Run("StaticQueueBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 50,
MaxHandoffRetries: 3, // Explicit static queue size
LogLevel: 2,
}
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
defer processor.Shutdown(context.Background())
// Verify queue capacity matches configured size
queueCapacity := cap(processor.handoffQueue)
if queueCapacity != 50 {
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
}
// Test that queue size is static regardless of pool size
// (No dynamic resizing should occur)
ctx := context.Background()
// Fill part of the queue
for i := 0; i < 10; i++ {
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
t.Fatalf("Failed to mark connection %d for handoff: %v", i, err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Failed to queue handoff %d: %v", i, err)
}
if !shouldPool || shouldRemove {
t.Errorf("Connection %d should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
i, shouldPool, shouldRemove)
}
}
// Verify queue capacity remains static (the main purpose of this test)
finalCapacity := cap(processor.handoffQueue)
if finalCapacity != 50 {
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
}
// Note: We don't check queue size here because workers process items quickly
// The important thing is that the capacity remains static regardless of pool size
})
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
// Create a failing dialer that will cause handoff initialization to fail
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Return a connection that will fail during initialization
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
}
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a mock pool that tracks removals
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
processor.SetPool(mockPool)
ctx := context.Background()
// Create a connection and mark it for handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a failing initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("initialization failed")
})
// Process the connection - handoff should fail and connection should be removed
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after failed handoff attempt")
}
// Wait for handoff to be attempted and fail
time.Sleep(100 * time.Millisecond)
// Verify that the connection was removed from the pool
if !mockPool.WasRemoved(conn.GetID()) {
t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID())
}
t.Logf("Connection removal on handoff failure test completed successfully")
})
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
// Create config with short post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
RelaxedTimeout: 5 * time.Second,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Fatalf("OnPut failed: %v", err)
}
if !shouldPool {
t.Error("Connection should be pooled after successful handoff")
}
if shouldRemove {
t.Error("Connection should not be removed after successful handoff")
}
// Wait for the handoff to complete (it happens asynchronously)
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.pending.Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify that relaxed timeout was applied to the new connection
if !conn.HasRelaxedTimeout() {
t.Error("New connection should have relaxed timeout applied after handoff")
}
// Wait for the post-handoff duration to expire
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
// Verify that relaxed timeout was automatically cleared
if conn.HasRelaxedTimeout() {
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
}
})
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
conn := createMockPoolConnection()
// First mark should succeed
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("First MarkForHandoff should succeed: %v", err)
}
// Second mark should fail
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
t.Fatal("Second MarkForHandoff should return error")
} else if err.Error() != "connection is already marked for handoff" {
t.Fatalf("Expected specific error message, got: %v", err)
}
// Verify original handoff data is preserved
if !conn.ShouldHandoff() {
t.Fatal("Connection should still be marked for handoff")
}
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
}
if conn.GetMovingSeqID() != 1 {
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
}
})
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
// Test that HandoffTimeout from config is actually used
customTimeout := 2 * time.Second
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
HandoffTimeout: customTimeout, // Custom timeout
MaxHandoffRetries: 1, // Single retry to speed up test
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a connection that will test the timeout
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a dialer that will check the context timeout
var timeoutVerified int32 // Use atomic for thread safety
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
// Check that the context has the expected timeout
deadline, ok := ctx.Deadline()
if !ok {
t.Error("Context should have a deadline")
return errors.New("no deadline")
}
// The deadline should be approximately customTimeout from now
expectedDeadline := time.Now().Add(customTimeout)
timeDiff := deadline.Sub(expectedDeadline)
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
expectedDeadline, deadline, timeDiff)
} else {
atomic.StoreInt32(&timeoutVerified, 1)
}
return nil // Successful handoff
})
// Trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
if err != nil {
t.Errorf("OnPut should not return error: %v", err)
}
// Connection should be queued for handoff
if !shouldPool || shouldRemove {
t.Errorf("Connection should be pooled for handoff processing")
}
// Wait for handoff to complete
time.Sleep(500 * time.Millisecond)
if atomic.LoadInt32(&timeoutVerified) == 0 {
t.Error("HandoffTimeout was not properly applied to context")
}
t.Logf("HandoffTimeout configuration test completed successfully")
})
}

24
hitless/state.go Normal file
View File

@@ -0,0 +1,24 @@
package hitless
// State represents the current state of a hitless upgrade operation.
type State int
const (
// StateIdle indicates no upgrade is in progress
StateIdle State = iota
// StateHandoff indicates a connection handoff is in progress
StateMoving
)
// String returns a string representation of the state.
func (s State) String() string {
switch s {
case StateIdle:
return "idle"
case StateMoving:
return "moving"
default:
return "unknown"
}
}

View File

@@ -0,0 +1,67 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the hitless upgrade package to avoid circular dependencies.
package interfaces
import (
"context"
"net"
"time"
)
// Forward declaration to avoid circular imports
type NotificationProcessor interface {
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
UnregisterHandler(pushNotificationName string) error
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for hitless upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface
// GetPushProcessor returns the client's push notification processor.
GetPushProcessor() NotificationProcessor
}
// OptionsInterface defines the interface for client options.
type OptionsInterface interface {
// GetReadTimeout returns the read timeout.
GetReadTimeout() time.Duration
// GetWriteTimeout returns the write timeout.
GetWriteTimeout() time.Duration
// GetNetwork returns the network type.
GetNetwork() string
// GetAddr returns the connection address.
GetAddr() string
// IsTLSEnabled returns true if TLS is enabled.
IsTLSEnabled() bool
// GetProtocol returns the protocol version.
GetProtocol() int
// GetPoolSize returns the connection pool size.
GetPoolSize() int
// NewDialer returns a new dialer function for the connection.
NewDialer() func(context.Context) (net.Conn, error)
}
// ConnectionWithRelaxedTimeout defines the interface for connections that support relaxed timeout adjustment.
// This is used by the hitless upgrade system for per-connection timeout management.
type ConnectionWithRelaxedTimeout interface {
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts remain active until explicitly cleared.
SetRelaxedTimeout(readTimeout, writeTimeout time.Duration)
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time)
// ClearRelaxedTimeout clears relaxed timeouts for this connection.
ClearRelaxedTimeout()
}

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
if err != nil {
b.Fatal(err)
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
}
})
})

View File

@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should use default buffer sizes when not specified", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
})
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should handle zero buffer sizes by using defaults", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
// without setting ReadBufferSize and WriteBufferSize
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
// ReadBufferSize and WriteBufferSize are not set (will be 0)
})

View File

@@ -3,7 +3,10 @@ package pool
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
@@ -12,17 +15,64 @@ import (
var noDeadline = time.Time{}
// Global atomic counter for connection IDs
var connIDCounter uint64
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
type atomicNetConn struct {
conn net.Conn
}
// generateConnID generates a fast unique identifier for a connection with zero allocations
func generateConnID() uint64 {
return atomic.AddUint64(&connIDCounter, 1)
}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
usedAt int64 // atomic
// Lock-free netConn access using atomic.Value
// Contains *atomicNetConn wrapper, accessed atomically for better performance
netConnAtomic atomic.Value // stores *atomicNetConn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
// Lightweight mutex to protect reader operations during handoff
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
pooled bool
closed atomic.Bool
createdAt time.Time
expiresAt time.Time
// Hitless upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
// Counter to track multiple relaxed timeout setters if we have nested calls
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
// if counter reaches 0, we clear the relaxed timeouts
relaxedCounter atomic.Int32
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking across handoffs
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// newEndpointAtomic needs special handling as it's a string
newEndpointAtomic atomic.Value // stores string
onClose func() error
}
@@ -33,8 +83,8 @@ func NewConn(netConn net.Conn) *Conn {
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
id: generateConnID(), // Generate unique ID for this connection
}
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
@@ -50,6 +100,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
}
// Store netConn atomically for lock-free access using wrapper
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic handoff state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.shouldHandoffAtomic.Store(false) // false initially
cn.movingSeqIDAtomic.Store(0) // 0 initially
cn.handoffRetriesAtomic.Store(0) // 0 initially
cn.newEndpointAtomic.Store("") // empty string initially
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
@@ -64,23 +124,368 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
if v := cn.netConnAtomic.Load(); v != nil {
if wrapper, ok := v.(*atomicNetConn); ok {
return wrapper.conn
}
}
return nil
}
// setNetConn stores the network connection atomically (lock-free).
// This is used for the fast path of connection replacement.
func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
}
// shouldHandoff returns true if connection needs handoff (lock-free).
func (cn *Conn) shouldHandoff() bool {
return cn.shouldHandoffAtomic.Load()
}
// setShouldHandoff sets the handoff flag atomically (lock-free).
func (cn *Conn) setShouldHandoff(should bool) {
cn.shouldHandoffAtomic.Store(should)
}
// getMovingSeqID returns the sequence ID atomically (lock-free).
func (cn *Conn) getMovingSeqID() int64 {
return cn.movingSeqIDAtomic.Load()
}
// setMovingSeqID sets the sequence ID atomically (lock-free).
func (cn *Conn) setMovingSeqID(seqID int64) {
cn.movingSeqIDAtomic.Store(seqID)
}
// getNewEndpoint returns the new endpoint atomically (lock-free).
func (cn *Conn) getNewEndpoint() string {
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
return endpoint.(string)
}
return ""
}
// setNewEndpoint sets the new endpoint atomically (lock-free).
func (cn *Conn) setNewEndpoint(endpoint string) {
cn.newEndpointAtomic.Store(endpoint)
}
// setHandoffRetries sets the retry count atomically (lock-free).
func (cn *Conn) setHandoffRetries(retries int) {
cn.handoffRetriesAtomic.Store(uint32(retries))
}
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
func (cn *Conn) IsInited() bool {
return cn.Inited.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
}
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
// Uses atomic operations for lock-free access.
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
// Use compare-and-swap to ensure only one goroutine clears
if cn.relaxedCounter.CompareAndSwap(newCount, 0) {
cn.clearRelaxedTimeout()
}
}
}
func (cn *Conn) clearRelaxedTimeout() {
cn.relaxedReadTimeoutNs.Store(0)
cn.relaxedWriteTimeoutNs.Store(0)
cn.relaxedDeadlineNs.Store(0)
cn.relaxedCounter.Store(0)
}
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
// This checks both the timeout values and the deadline (if set).
// Uses atomic operations for lock-free access.
func (cn *Conn) HasRelaxedTimeout() bool {
// Fast path: no relaxed timeouts are set
if cn.relaxedCounter.Load() <= 0 {
return false
}
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// If no relaxed timeouts are set, return false
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
return false
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, relaxed timeouts are active
if deadlineNs == 0 {
return true
}
// If deadline is set, check if it's still in the future
return time.Now().UnixNano() < deadlineNs
}
// getEffectiveReadTimeout returns the timeout to use for read operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
// Fast path: no relaxed timeout set
if readTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(readTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
// getEffectiveWriteTimeout returns the timeout to use for write operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// Fast path: no relaxed timeout set
if writeTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(writeTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
cn.initConnFunc = fn
}
// ExecuteInitConn runs the stored connection initialization function if available.
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
if cn.initConnFunc != nil {
return cn.initConnFunc(ctx, cn)
}
return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
// Store the new connection atomically first (lock-free)
cn.setNetConn(netConn)
// Clear relaxed timeouts when connection is replaced
cn.clearRelaxedTimeout()
// Protect reader reset operations to avoid data races
// Use write lock since we're modifying the reader state
cn.readerMu.Lock()
cn.rd.Reset(netConn)
cn.readerMu.Unlock()
cn.bw.Reset(netConn)
}
// GetNetConn safely returns the current network connection using atomic load (lock-free).
// This method is used by the pool for health checks and provides better performance.
func (cn *Conn) GetNetConn() net.Conn {
return cn.getNetConn()
}
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
// New connection is not initialized yet
cn.Inited.Store(false)
// Replace the underlying connection
cn.SetNetConn(netConn)
return cn.ExecuteInitConn(ctx)
}
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
// Returns an error if the connection is already marked for handoff.
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
// Use single atomic CAS operation for state transition
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
return errors.New("connection is already marked for handoff")
}
cn.setNewEndpoint(newEndpoint)
cn.setMovingSeqID(seqID)
return nil
}
func (cn *Conn) MarkQueuedForHandoff() error {
// Use single atomic CAS operation for state transition
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
return errors.New("connection was not marked for handoff")
}
cn.setUsable(false)
return nil
}
// RestoreHandoffState restores the handoff state after a failed handoff (lock-free).
func (cn *Conn) RestoreHandoffState() {
// Restore shouldHandoff flag for retry
cn.shouldHandoffAtomic.Store(true)
// Keep usable=false to prevent the connection from being used until handoff succeeds
cn.setUsable(false)
}
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
func (cn *Conn) ShouldHandoff() bool {
return cn.shouldHandoff()
}
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
func (cn *Conn) GetHandoffEndpoint() string {
return cn.getNewEndpoint()
}
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
func (cn *Conn) GetMovingSeqID() int64 {
return cn.getMovingSeqID()
}
// GetID returns the unique identifier for this connection.
func (cn *Conn) GetID() uint64 {
return cn.id
}
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
func (cn *Conn) ClearHandoffState() {
// clear handoff state
cn.setShouldHandoff(false)
cn.setNewEndpoint("")
cn.setMovingSeqID(0)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return cn.incrementHandoffRetries(n)
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
return cn.rd.Buffered() > 0
}
// PeekReplyTypeSafe safely peeks at the reply type.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
if cn.rd.Buffered() <= 0 {
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
}
return cn.rd.PeekReplyType()
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Write(b)
}
return 0, net.ErrClosed
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.RemoteAddr()
}
return nil
}
@@ -89,7 +494,16 @@ func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
// Get the connection directly from atomic storage
netConn := cn.getNetConn()
if netConn == nil {
return fmt.Errorf("redis: connection not available")
}
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
}
@@ -100,13 +514,26 @@ func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
// Always set write deadline, even if getNetConn() returns nil
// This prevents write operations from hanging indefinitely
if netConn := cn.getNetConn(); netConn != nil {
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
} else {
// If getNetConn() returns nil, we still need to respect the timeout
// Return an error to prevent indefinite blocking
return fmt.Errorf("redis: connection not available for write operation")
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
if netConn := cn.getNetConn(); netConn != nil {
cn.bw.Reset(netConn)
}
}
if err := fn(cn.wr); err != nil {
@@ -116,19 +543,33 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}
func (cn *Conn) IsClosed() bool {
return cn.closed.Load()
}
func (cn *Conn) Close() error {
cn.closed.Store(true)
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Close()
}
return nil
}
// MaybeHasData tries to peek at the next byte in the socket without consuming it
// This is used to check if there are push notifications available
// Important: This will work on Linux, but not on Windows
func (cn *Conn) MaybeHasData() bool {
return maybeHasData(cn.netConn)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return maybeHasData(netConn)
}
return false
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {

View File

@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
}
func (cn *Conn) NetConn() net.Conn {
return cn.netConn
return cn.getNetConn()
}
func (p *ConnPool) CheckMinIdleConns() {

114
internal/pool/hooks.go Normal file
View File

@@ -0,0 +1,114 @@
package pool
import (
"context"
"sync"
)
// PoolHook defines the interface for connection lifecycle hooks.
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
}
// PoolHookManager manages multiple pool hooks.
type PoolHookManager struct {
hooks []PoolHook
hooksMu sync.RWMutex
}
// NewPoolHookManager creates a new pool hook manager.
func NewPoolHookManager() *PoolHookManager {
return &PoolHookManager{
hooks: make([]PoolHook, 0),
}
}
// AddHook adds a pool hook to the manager.
// Hooks are called in the order they were added.
func (phm *PoolHookManager) AddHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
phm.hooks = append(phm.hooks, hook)
}
// RemoveHook removes a pool hook from the manager.
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
for i, h := range phm.hooks {
if h == hook {
// Remove hook by swapping with last element and truncating
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
phm.hooks = phm.hooks[:len(phm.hooks)-1]
break
}
}
}
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
return err
}
}
return nil
}
// ProcessOnPut calls all OnPut hooks in order.
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
shouldPool = true // Default to pooling the connection
for _, hook := range phm.hooks {
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
if hookErr != nil {
return false, true, hookErr
}
// If any hook says to remove or not pool, respect that decision
if hookShouldRemove {
return false, true, nil
}
if !hookShouldPool {
shouldPool = false
}
}
return shouldPool, false, nil
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
return len(phm.hooks)
}
// GetHooks returns a copy of all registered hooks.
func (phm *PoolHookManager) GetHooks() []PoolHook {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
hooks := make([]PoolHook, len(phm.hooks))
copy(hooks, phm.hooks)
return hooks
}

213
internal/pool/hooks_test.go Normal file
View File

@@ -0,0 +1,213 @@
package pool
import (
"context"
"errors"
"net"
"testing"
"time"
)
// TestHook for testing hook functionality
type TestHook struct {
OnGetCalled int
OnPutCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
}
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
th.OnGetCalled++
return th.GetError
}
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
th.OnPutCalled++
return th.ShouldPool, th.ShouldRemove, th.PutError
}
func TestPoolHookManager(t *testing.T) {
manager := NewPoolHookManager()
// Test initial state
if manager.GetHookCount() != 0 {
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
}
// Add hooks
hook1 := &TestHook{ShouldPool: true}
hook2 := &TestHook{ShouldPool: true}
manager.AddHook(hook1)
manager.AddHook(hook2)
if manager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
}
// Test ProcessOnGet
ctx := context.Background()
conn := &Conn{} // Mock connection
err := manager.ProcessOnGet(ctx, conn, false)
if err != nil {
t.Errorf("ProcessOnGet should not error: %v", err)
}
if hook1.OnGetCalled != 1 {
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
}
if hook2.OnGetCalled != 1 {
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
}
// Test ProcessOnPut
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if !shouldPool {
t.Error("Expected shouldPool to be true")
}
if shouldRemove {
t.Error("Expected shouldRemove to be false")
}
if hook1.OnPutCalled != 1 {
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
}
if hook2.OnPutCalled != 1 {
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
}
// Remove a hook
manager.RemoveHook(hook1)
if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
}
}
func TestHookErrorHandling(t *testing.T) {
manager := NewPoolHookManager()
// Hook that returns error on Get
errorHook := &TestHook{
GetError: errors.New("test error"),
ShouldPool: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(errorHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
// Test that error stops processing
err := manager.ProcessOnGet(ctx, conn, false)
if err == nil {
t.Error("Expected error from ProcessOnGet")
}
if errorHook.OnGetCalled != 1 {
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
}
// normalHook should not be called due to error
if normalHook.OnGetCalled != 0 {
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
}
}
func TestHookShouldRemove(t *testing.T) {
manager := NewPoolHookManager()
// Hook that says to remove connection
removeHook := &TestHook{
ShouldPool: false,
ShouldRemove: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(removeHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if shouldPool {
t.Error("Expected shouldPool to be false")
}
if !shouldRemove {
t.Error("Expected shouldRemove to be true")
}
if removeHook.OnPutCalled != 1 {
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
}
// normalHook should not be called due to early return
if normalHook.OnPutCalled != 0 {
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
}
}
func TestPoolWithHooks(t *testing.T) {
// Create a pool with hooks
hookManager := NewPoolHookManager()
testHook := &TestHook{ShouldPool: true}
hookManager.AddHook(testHook)
opt := &Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // Mock connection
},
PoolSize: 1,
DialTimeout: time.Second,
}
pool := NewConnPool(opt)
defer pool.Close()
// Add hook to pool after creation
pool.AddPoolHook(testHook)
// Verify hooks are initialized
if pool.hookManager == nil {
t.Error("Expected hookManager to be initialized")
}
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
}
// Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true}
pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
}
// Test removing hook from pool
pool.RemovePoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
}
}

View File

@@ -3,6 +3,7 @@ package pool
import (
"context"
"errors"
"log"
"net"
"sync"
"sync/atomic"
@@ -22,6 +23,12 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
popAttempts = 10
getAttempts = 3
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
maxTime = minTime.Add(1<<63 - 1)
noExpiration = maxTime
)
var timers = sync.Pool{
@@ -38,11 +45,14 @@ type Stats struct {
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
WaitCount uint32 // number of times a connection was waited
Unusable uint32 // number of times a connection was found to be unusable
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
PubSubStats PubSubStats
}
type Pooler interface {
@@ -57,29 +67,27 @@ type Pooler interface {
IdleLen() int
Stats() *Stats
AddPoolHook(hook PoolHook)
RemovePoolHook(hook PoolHook)
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
PoolFIFO bool
PoolSize int
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxActiveConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
// Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without)
Protocol int
Dialer func(context.Context) (net.Conn, error)
ReadBufferSize int
WriteBufferSize int
PoolFIFO bool
PoolSize int32
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int32
MaxIdleConns int32
MaxActiveConns int32
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
PushNotificationsEnabled bool
}
type lastDialErrorWrap struct {
@@ -95,16 +103,21 @@ type ConnPool struct {
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
conns map[uint64]*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
poolSize atomic.Int32
idleConnsLen atomic.Int32
idleCheckInProgress atomic.Bool
stats Stats
waitDurationNs atomic.Int64
_closed uint32 // atomic
// Pool hooks manager for flexible connection processing
hookManagerMu sync.RWMutex
hookManager *PoolHookManager
}
var _ Pooler = (*ConnPool)(nil)
@@ -114,34 +127,69 @@ func NewConnPool(opt *Options) *ConnPool {
cfg: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize),
}
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
// Only create MinIdleConns if explicitly requested (> 0)
// This avoids creating connections during pool initialization for tests
if opt.MinIdleConns > 0 {
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
}
return p
}
// initializeHooks sets up the pool hooks system.
func (p *ConnPool) initializeHooks() {
p.hookManager = NewPoolHookManager()
}
// AddPoolHook adds a pool hook to the pool.
func (p *ConnPool) AddPoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager == nil {
p.initializeHooks()
}
p.hookManager.AddHook(hook)
}
// RemovePoolHook removes a pool hook from the pool.
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager != nil {
p.hookManager.RemoveHook(hook)
}
}
func (p *ConnPool) checkMinIdleConns() {
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
return
}
defer p.idleCheckInProgress.Store(false)
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
// Only create idle connections if we haven't reached the total pool size limit
// MinIdleConns should be a subset of PoolSize, not additional connections
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
select {
case p.queue <- struct{}{}:
p.poolSize++
p.idleConnsLen++
p.poolSize.Add(1)
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
@@ -150,12 +198,9 @@ func (p *ConnPool) checkMinIdleConns() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
default:
@@ -172,6 +217,9 @@ func (p *ConnPool) addIdleConn() error {
if err != nil {
return err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
@@ -182,11 +230,15 @@ func (p *ConnPool) addIdleConn() error {
return ErrClosed
}
p.conns = append(p.conns, cn)
p.conns[cn.GetID()] = cn
p.idleConns = append(p.idleConns, cn)
return nil
}
// NewConn creates a new connection and returns it to the user.
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
//
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
@@ -196,33 +248,42 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, ErrClosed
}
p.connsMu.Lock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
return nil, ErrPoolExhausted
}
p.connsMu.Unlock()
cn, err := p.dialConn(ctx, pooled)
if err != nil {
return nil, err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
_ = cn.Close()
return nil, ErrPoolExhausted
}
p.conns = append(p.conns, cn)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.closed() {
_ = cn.Close()
return nil, ErrClosed
}
// Check if pool was closed while we were waiting for the lock
if p.conns == nil {
p.conns = make(map[uint64]*Conn)
}
p.conns[cn.GetID()] = cn
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
currentPoolSize := p.poolSize.Load()
if currentPoolSize >= int32(p.cfg.PoolSize) {
cn.pooled = false
} else {
p.poolSize++
p.poolSize.Add(1)
}
}
@@ -249,6 +310,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
if p.cfg.ConnMaxLifetime > 0 {
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
} else {
cn.expiresAt = noExpiration
}
return cn, nil
}
@@ -289,6 +356,14 @@ func (p *ConnPool) getLastDialError() error {
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return p.getConn(ctx)
}
// getConn returns a connection from the pool.
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
var cn *Conn
var err error
if p.closed() {
return nil, ErrClosed
}
@@ -297,9 +372,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
now := time.Now()
attempts := 0
for {
if attempts >= getAttempts {
log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts)
break
}
attempts++
p.connsMu.Lock()
cn, err := p.popIdle()
cn, err = p.popIdle()
p.connsMu.Unlock()
if err != nil {
@@ -311,11 +394,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
break
}
if !p.isHealthyConn(cn) {
if !p.isHealthyConn(cn, now) {
_ = p.CloseConn(cn)
continue
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
log.Printf("redis: connection pool: failed to process idle connection by hook: %v", err)
// Failed to process connection, discard it
_ = p.CloseConn(cn)
continue
}
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
@@ -328,6 +425,20 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
// Failed to process connection, discard it
log.Printf("redis: connection pool: failed to process new connection by hook: %v", err)
_ = p.CloseConn(newcn)
return nil, err
}
}
return newcn, nil
}
@@ -356,7 +467,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
}
return ctx.Err()
case p.queue <- struct{}{}:
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
atomic.AddUint32(&p.stats.WaitCount, 1)
if !timer.Stop() {
<-timer.C
@@ -376,68 +487,128 @@ func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
} else {
idx := n - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
attempts := 0
for attempts < popAttempts {
if len(p.idleConns) == 0 {
return nil, nil
}
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:len(p.idleConns)-1]
} else {
idx := len(p.idleConns) - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
attempts++
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
}
// Connection is not usable, put it back in the pool
if p.cfg.PoolFIFO {
// FIFO: put at end (will be picked up last since we pop from front)
p.idleConns = append(p.idleConns, cn)
} else {
// LIFO: put at beginning (will be picked up last since we pop from end)
p.idleConns = append([]*Conn{cn}, p.idleConns...)
}
}
p.idleConnsLen--
// If we exhausted all attempts without finding a usable connection, return nil
if attempts >= popAttempts {
log.Printf("redis: connection pool: failed to get an usable connection after %d attempts", popAttempts)
return nil, nil
}
p.checkMinIdleConns()
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
// Process connection using the hooks system
shouldPool := true
shouldRemove := false
if cn.rd.Buffered() > 0 {
// Check if this might be push notification data
if p.cfg.Protocol == 3 {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block and check if it's a push notification
if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush {
shouldRemove = true
}
} else {
// not a push notification since protocol 2 doesn't support them
shouldRemove = true
}
var err error
if shouldRemove {
// For non-RESP3 or data that is not a push notification, buffered data is unexpected
internal.Logger.Printf(ctx, "Conn has unread data, closing it")
p.Remove(ctx, cn, BadConnError{})
if cn.HasBufferedData() {
// Peek at the reply type to check if it's a push notification
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
// Not a push notification or error peeking, remove connection
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
p.Remove(ctx, cn, err)
}
// It's a push notification, allow pooling (client will handle it)
}
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
if err != nil {
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
p.Remove(ctx, cn, err)
return
}
}
// If hooks say to remove the connection, do so
if shouldRemove {
p.Remove(ctx, cn, errors.New("hook requested removal"))
return
}
// If processor says not to pool the connection, remove it
if !shouldPool {
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
p.Remove(ctx, cn, errors.New("connection not pooled"))
return
}
var shouldCloseConn bool
p.connsMu.Lock()
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue
if !cn.IsUsable() {
if p.cfg.PoolFIFO {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
} else {
p.connsMu.Lock()
p.idleConns = append([]*Conn{cn}, p.idleConns...)
p.connsMu.Unlock()
}
} else {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
}
p.idleConnsLen.Add(1)
} else {
p.removeConn(cn)
p.removeConnWithLock(cn)
shouldCloseConn = true
}
p.connsMu.Unlock()
p.freeTurn()
if shouldCloseConn {
@@ -449,6 +620,9 @@ func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
// Check if we need to create new idle connections to maintain MinIdleConns
p.checkMinIdleConns()
}
func (p *ConnPool) CloseConn(cn *Conn) error {
@@ -463,17 +637,13 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
}
break
}
}
delete(p.conns, cn.GetID())
atomic.AddUint32(&p.stats.StaleConns, 1)
// Decrement pool size counter when removing a connection
if cn.pooled {
p.poolSize.Add(-1)
}
}
func (p *ConnPool) closeConn(cn *Conn) error {
@@ -491,9 +661,9 @@ func (p *ConnPool) Len() int {
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
n := p.idleConnsLen.Load()
p.connsMu.Unlock()
return n
return int(n)
}
func (p *ConnPool) Stats() *Stats {
@@ -502,6 +672,7 @@ func (p *ConnPool) Stats() *Stats {
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
Unusable: atomic.LoadUint32(&p.stats.Unusable),
WaitDurationNs: p.waitDurationNs.Load(),
TotalConns: uint32(p.Len()),
@@ -542,30 +713,32 @@ func (p *ConnPool) Close() error {
}
}
p.conns = nil
p.poolSize = 0
p.poolSize.Store(0)
p.idleConns = nil
p.idleConnsLen = 0
p.idleConnsLen.Store(0)
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
// slight optimization, check expiresAt first.
if cn.expiresAt.Before(now) {
return false
}
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false
}
// Check connection health, but be aware of push notifications
if err := connCheck(cn.netConn); err != nil {
cn.SetUsedAt(now)
// Check basic connection health
// Use GetNetConn() to safely access netConn and avoid data races
if err := connCheck(cn.getNetConn()); err != nil {
// If there's unexpected data, it might be push notifications (RESP3)
// However, push notification processing is now handled by the client
// before WithReader to ensure proper context is available to handlers
if err == errUnexpectedRead && p.cfg.Protocol == 3 {
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
@@ -579,7 +752,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
return false
}
}
cn.SetUsedAt(now)
return true
}

View File

@@ -1,6 +1,8 @@
package pool
import "context"
import (
"context"
)
type SingleConnPool struct {
pool Pooler
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"net"
"sync"
"testing"
@@ -20,7 +21,7 @@ var _ = Describe("ConnPool", func() {
BeforeEach(func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -45,11 +46,11 @@ var _ = Describe("ConnPool", func() {
<-closedChan
return &net.TCPConn{}, nil
},
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
MinIdleConns: minIdleConns,
MinIdleConns: int32(minIdleConns),
})
wg.Wait()
Expect(connPool.Close()).NotTo(HaveOccurred())
@@ -105,7 +106,7 @@ var _ = Describe("ConnPool", func() {
// ok
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
// Check that Get is unblocked.
select {
@@ -130,8 +131,8 @@ var _ = Describe("MinIdleConns", func() {
newConnPool := func() *pool.ConnPool {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: poolSize,
MinIdleConns: minIdleConns,
PoolSize: int32(poolSize),
MinIdleConns: int32(minIdleConns),
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: -1,
@@ -168,7 +169,7 @@ var _ = Describe("MinIdleConns", func() {
Context("after Remove", func() {
BeforeEach(func() {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
})
It("has idle connections", func() {
@@ -245,7 +246,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Remove(ctx, cns[i], nil)
connPool.Remove(ctx, cns[i], errors.New("test"))
mu.RUnlock()
})
@@ -309,7 +310,7 @@ var _ = Describe("race", func() {
It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Minute,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -328,7 +329,7 @@ var _ = Describe("race", func() {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
}
}
})
@@ -339,15 +340,15 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1000,
MinIdleConns: 50,
PoolSize: int32(1000),
MinIdleConns: int32(50),
PoolTimeout: 3 * time.Second,
DialTimeout: 1 * time.Second,
}
p := pool.NewConnPool(opt)
var wg sync.WaitGroup
for i := 0; i < opt.PoolSize; i++ {
for i := int32(0); i < opt.PoolSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
@@ -366,8 +367,8 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
panic("test panic")
},
PoolSize: 100,
MinIdleConns: 30,
PoolSize: int32(100),
MinIdleConns: int32(30),
}
p := pool.NewConnPool(opt)
@@ -377,14 +378,14 @@ var _ = Describe("race", func() {
state := p.Stats()
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
}, "3s", "50ms").Should(BeTrue())
})
})
It("wait", func() {
opt := &pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 3 * time.Second,
}
p := pool.NewConnPool(opt)
@@ -415,7 +416,7 @@ var _ = Describe("race", func() {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: testPoolTimeout,
}
p := pool.NewConnPool(opt)

77
internal/pool/pubsub.go Normal file
View File

@@ -0,0 +1,77 @@
package pool
import (
"context"
"net"
"sync"
"sync/atomic"
)
type PubSubStats struct {
Created uint32
Untracked uint32
Active uint32
}
// PubSubPool manages a pool of PubSub connections.
type PubSubPool struct {
opt *Options
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Map to track active PubSub connections
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
closed atomic.Bool
stats PubSubStats
}
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,
netDialer: netDialer,
}
}
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
if p.closed.Load() {
return nil, ErrClosed
}
netConn, err := p.netDialer(ctx, network, addr)
if err != nil {
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
atomic.AddUint32(&p.stats.Created, 1)
return cn, nil
}
func (p *PubSubPool) TrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, 1)
p.activeConns.Store(cn.GetID(), cn)
}
func (p *PubSubPool) UntrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, ^uint32(0))
atomic.AddUint32(&p.stats.Untracked, 1)
p.activeConns.Delete(cn.GetID())
}
func (p *PubSubPool) Close() error {
p.closed.Store(true)
p.activeConns.Range(func(key, value interface{}) bool {
cn := value.(*Conn)
_ = cn.Close()
return true
})
return nil
}
func (p *PubSubPool) Stats() *PubSubStats {
// load stats atomically
return &PubSubStats{
Created: atomic.LoadUint32(&p.stats.Created),
Untracked: atomic.LoadUint32(&p.stats.Untracked),
Active: atomic.LoadUint32(&p.stats.Active),
}
}

3
internal/redis.go Normal file
View File

@@ -0,0 +1,3 @@
package internal
const RedisNull = "null"

17
internal/util/math.go Normal file
View File

@@ -0,0 +1,17 @@
package util
// Max returns the maximum of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// Min returns the minimum of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -14,9 +14,10 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/push"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
@@ -153,6 +154,7 @@ type Options struct {
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
// default: false
PoolFIFO bool
// PoolSize is the base number of socket connections.
@@ -244,8 +246,19 @@ type Options struct {
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
// This is an alias to hitless.Config for convenience.
type HitlessUpgradeConfig = hitless.Config
func (opt *Options) init() {
if opt.Addr == "" {
opt.Addr = "localhost:6379"
@@ -320,13 +333,36 @@ func (opt *Options) init() {
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize)
// auto-detect endpoint type if not specified
endpointType := opt.HitlessUpgradeConfig.EndpointType
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.HitlessUpgradeConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
clone.HitlessUpgradeConfig = &configClone
}
return &clone
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
return NewDialer(opt)
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
@@ -617,18 +653,35 @@ func newConnPool(
Dialer: func(ctx context.Context) (net.Conn, error) {
return dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
// Pass protocol version for push notification optimization
Protocol: opt.Protocol,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PoolFIFO: opt.PoolFIFO,
PoolSize: int32(opt.PoolSize),
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: int32(opt.MinIdleConns),
MaxIdleConns: int32(opt.MaxIdleConns),
MaxActiveConns: int32(opt.MaxActiveConns),
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PushNotificationsEnabled: opt.Protocol == 3,
})
}
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) *pool.PubSubPool {
return pool.NewPubSubPool(&pool.Options{
PoolFIFO: opt.PoolFIFO,
PoolSize: int32(opt.PoolSize),
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: int32(opt.MinIdleConns),
MaxIdleConns: int32(opt.MaxIdleConns),
MaxActiveConns: int32(opt.MaxActiveConns),
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
PushNotificationsEnabled: opt.Protocol == 3,
}, dialer)
}

View File

@@ -38,6 +38,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -129,6 +130,14 @@ type ClusterOptions struct {
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *ClusterOptions) init() {
@@ -319,6 +328,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
var hitlessConfig *HitlessUpgradeConfig
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
hitlessConfig = &configClone
}
return &Options{
ClientName: opt.ClientName,
Dialer: opt.Dialer,
@@ -360,8 +376,9 @@ func (opt *ClusterOptions) clientOptions() *Options {
// much use for ClusterSlots config). This means we cannot execute the
// READONLY command against that node -- setting readOnly to false in such
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
HitlessUpgradeConfig: hitlessConfig,
}
}
@@ -1830,12 +1847,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// hitless won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@@ -1850,18 +1867,25 @@ func (c *ClusterClient) pubSub() *PubSub {
if err != nil {
return nil, err
}
cn, err := node.Client.newConn(context.TODO())
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
if err != nil {
node = nil
return nil, err
}
// will return nil if already initialized
err = node.Client.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
node = nil
return nil, err
}
node.Client.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
err := node.Client.connPool.CloseConn(cn)
// Untrack connection from PubSubPool
node.Client.pubSubPool.UntrackConn(cn)
err := cn.Close()
node = nil
return err
},

375
pool_pubsub_bench_test.go Normal file
View File

@@ -0,0 +1,375 @@
// Pool and PubSub Benchmark Suite
//
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
// It's designed to be run against different branches to compare performance.
//
// Usage Examples:
// # Run all benchmarks
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
//
// # Run only pool benchmarks
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
//
// # Run only PubSub benchmarks
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
//
// # Compare between branches
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
// benchcmp branch1.txt branch2.txt
//
// # Run with memory profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
//
// # Run with CPU profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
package redis_test
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
)
// dummyDialer creates a mock connection for benchmarking
func dummyDialer(ctx context.Context) (net.Conn, error) {
return &dummyConn{}, nil
}
// dummyConn implements net.Conn for benchmarking
type dummyConn struct{}
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Close() error { return nil }
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
func (c *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
}
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
// =============================================================================
// POOL BENCHMARKS
// =============================================================================
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(poolSize),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0), // Start with no idle connections
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
ctx := context.Background()
configs := []struct {
poolSize int
minIdleConns int
}{
{8, 2},
{16, 4},
{32, 8},
{64, 16},
}
for _, config := range configs {
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(config.poolSize),
MinIdleConns: int32(config.minIdleConns),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
ctx := context.Background()
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(32),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0),
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
// Test with different levels of concurrency
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.SetParallelism(concurrency)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// =============================================================================
// PUBSUB BENCHMARKS
// =============================================================================
// benchmarkClient creates a Redis client for benchmarking with mock dialer
func benchmarkClient(poolSize int) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Mock address
DialTimeout: time.Second,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
PoolSize: poolSize,
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
})
}
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
func BenchmarkPubSubCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
func BenchmarkPubSubPatternCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.PSubscribe(ctx, "test-*")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
concurrencyLevels := []int{1, 2, 4, 8, 16}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for i := 0; i < b.N; i++ {
wg.Add(1)
semaphore <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-semaphore }()
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}()
}
wg.Wait()
})
}
}
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
func BenchmarkPubSubMultipleChannels(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
channelCounts := []int{1, 5, 10, 25, 50, 100}
for _, channelCount := range channelCounts {
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
// Prepare channel names
channels := make([]string, channelCount)
for i := 0; i < channelCount; i++ {
channels[i] = fmt.Sprintf("channel-%d", i)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, channels...)
pubsub.Close()
}
})
}
}
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
func BenchmarkPubSubReuse(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Benchmark just the creation and closing of PubSub connections
// This simulates reuse patterns without requiring actual Redis operations
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
pubsub.Close()
}
}
// =============================================================================
// COMBINED BENCHMARKS
// =============================================================================
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// Mix of pool stats collection and PubSub creation
if pb.Next() {
// Pool stats operation
stats := client.PoolStats()
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
}
if pb.Next() {
// PubSub operation
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
}
})
}
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
func BenchmarkPoolStatsCollection(b *testing.B) {
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
stats := client.PoolStats()
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
}
}
// BenchmarkPoolHighContention tests pool performance under high contention
func BenchmarkPoolHighContention(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// High contention Get/Put operations
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}

View File

@@ -22,7 +22,7 @@ import (
type PubSub struct {
opt *Options
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
@@ -42,6 +42,9 @@ type PubSub struct {
// Push notification processor for handling generic push notifications
pushProcessor push.NotificationProcessor
// Cleanup callback for hitless upgrade tracking
onClose func()
}
func (c *PubSub) init() {
@@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
return c.cn, nil
}
if c.opt.Addr == "" {
// TODO(hitless):
// this is probably cluster client
// c.newConn will ignore the addr argument
// will be changed when we have hitless upgrades for cluster clients
c.opt.Addr = internal.RedisNull
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(ctx, channels)
cn, err := c.newConn(ctx, c.opt.Addr, channels)
if err != nil {
return nil, err
}
@@ -157,12 +168,28 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
if c.cn != cn {
return
}
if !cn.IsUsable() || cn.ShouldHandoff() {
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
}
if isBadConn(err, allowTimeout, c.opt.Addr) {
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(ctx context.Context, reason error) {
if c.cn != nil && c.cn.ShouldHandoff() {
newEndpoint := c.cn.GetHandoffEndpoint()
// If new endpoint is NULL, use the original address
if newEndpoint == internal.RedisNull {
newEndpoint = c.opt.Addr
}
if newEndpoint != "" {
c.opt.Addr = newEndpoint
}
}
_ = c.closeTheCn(reason)
_, _ = c.conn(ctx, nil)
}
@@ -189,6 +216,11 @@ func (c *PubSub) Close() error {
c.closed = true
close(c.exit)
// Call cleanup callback if set
if c.onClose != nil {
c.onClose()
}
return c.closeTheCn(pool.ErrClosed)
}
@@ -461,6 +493,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
// This will block until a message is received.
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
@@ -543,7 +576,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
}
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
if c.pushProcessor == nil {
// Only process push notifications for RESP3 connections with a processor
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}

View File

@@ -1,8 +1,6 @@
package push
import (
"github.com/redis/go-redis/v9/internal/pool"
)
// No imports needed for this file
// NotificationHandlerContext provides context information about where a push notification was received.
// This struct allows handlers to make informed decisions based on the source of the notification
@@ -35,7 +33,12 @@ type NotificationHandlerContext struct {
PubSub interface{}
// Conn is the specific connection on which the notification was received.
Conn *pool.Conn
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.Conn
// - *connectionAdapter (for hitless upgrades)
Conn interface{}
// IsBlocking indicates if the notification was received on a blocking connection.
IsBlocking bool

315
push/processor_unit_test.go Normal file
View File

@@ -0,0 +1,315 @@
package push
import (
"context"
"testing"
)
// TestProcessorCreation tests processor creation and initialization
func TestProcessorCreation(t *testing.T) {
t.Run("NewProcessor", func(t *testing.T) {
processor := NewProcessor()
if processor == nil {
t.Fatal("NewProcessor should not return nil")
}
if processor.registry == nil {
t.Error("Processor should have a registry")
}
})
t.Run("NewVoidProcessor", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
if voidProcessor == nil {
t.Fatal("NewVoidProcessor should not return nil")
}
})
}
// TestProcessorHandlerManagement tests handler registration and retrieval
func TestProcessorHandlerManagement(t *testing.T) {
processor := NewProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("RegisterHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", handler, false)
if err != nil {
t.Errorf("RegisterHandler should not error: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != handler {
t.Error("GetHandler should return the registered handler")
}
})
t.Run("RegisterProtectedHandler", func(t *testing.T) {
protectedHandler := &UnitTestHandler{name: "protected-handler"}
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
if err != nil {
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler != protectedHandler {
t.Error("GetHandler should return the protected handler")
}
})
t.Run("GetNonExistentHandler", func(t *testing.T) {
handler := processor.GetHandler("NONEXISTENT")
if handler != nil {
t.Error("GetHandler should return nil for non-existent handler")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := processor.UnregisterHandler("TEST")
if err != nil {
t.Errorf("UnregisterHandler should not error: %v", err)
}
// Verify handler is removed
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != nil {
t.Error("GetHandler should return nil after unregistering")
}
})
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
err := processor.UnregisterHandler("PROTECTED")
if err == nil {
t.Error("UnregisterHandler should error for protected handler")
}
// Verify handler is still there
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler == nil {
t.Error("Protected handler should not be removed")
}
})
}
// TestVoidProcessorBehavior tests void processor behavior
func TestVoidProcessorBehavior(t *testing.T) {
voidProcessor := NewVoidProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("GetHandler", func(t *testing.T) {
retrievedHandler := voidProcessor.GetHandler("ANY")
if retrievedHandler != nil {
t.Error("VoidProcessor GetHandler should always return nil")
}
})
t.Run("RegisterHandler", func(t *testing.T) {
err := voidProcessor.RegisterHandler("TEST", handler, false)
if err == nil {
t.Error("VoidProcessor RegisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := voidProcessor.UnregisterHandler("TEST")
if err == nil {
t.Error("VoidProcessor UnregisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
}
// TestProcessPendingNotificationsNilReader tests handling of nil reader
func TestProcessPendingNotificationsNilReader(t *testing.T) {
t.Run("ProcessorWithNilReader", func(t *testing.T) {
processor := NewProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
}
// TestWillHandleNotificationInClient tests the notification filtering logic
func TestWillHandleNotificationInClient(t *testing.T) {
testCases := []struct {
name string
notificationType string
shouldHandle bool
}{
// Pub/Sub notifications (should be handled in client)
{"message", "message", true},
{"pmessage", "pmessage", true},
{"subscribe", "subscribe", true},
{"unsubscribe", "unsubscribe", true},
{"psubscribe", "psubscribe", true},
{"punsubscribe", "punsubscribe", true},
{"smessage", "smessage", true},
{"ssubscribe", "ssubscribe", true},
{"sunsubscribe", "sunsubscribe", true},
// Push notifications (should be handled by processor)
{"MOVING", "MOVING", false},
{"MIGRATING", "MIGRATING", false},
{"MIGRATED", "MIGRATED", false},
{"FAILING_OVER", "FAILING_OVER", false},
{"FAILED_OVER", "FAILED_OVER", false},
{"custom", "custom", false},
{"unknown", "unknown", false},
{"empty", "", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := willHandleNotificationInClient(tc.notificationType)
if result != tc.shouldHandle {
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
}
})
}
}
// TestProcessorErrorHandlingUnit tests error handling scenarios
func TestProcessorErrorHandlingUnit(t *testing.T) {
processor := NewProcessor()
t.Run("RegisterNilHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", nil, false)
if err == nil {
t.Error("RegisterHandler should error with nil handler")
}
// Check error type
if !IsHandlerNilError(err) {
t.Error("Error should be a HandlerNilError")
}
})
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
handler1 := &UnitTestHandler{name: "handler1"}
handler2 := &UnitTestHandler{name: "handler2"}
// Register first handler
err := processor.RegisterHandler("DUPLICATE", handler1, false)
if err != nil {
t.Errorf("First RegisterHandler should not error: %v", err)
}
// Try to register second handler with same name
err = processor.RegisterHandler("DUPLICATE", handler2, false)
if err == nil {
t.Error("RegisterHandler should error when registering duplicate handler")
}
// Verify original handler is still there
retrievedHandler := processor.GetHandler("DUPLICATE")
if retrievedHandler != handler1 {
t.Error("Original handler should remain after failed duplicate registration")
}
})
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
err := processor.UnregisterHandler("NONEXISTENT")
if err != nil {
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
}
})
}
// TestProcessorConcurrentAccess tests concurrent access to processor
func TestProcessorConcurrentAccess(t *testing.T) {
processor := NewProcessor()
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
done := make(chan bool, 2)
// Goroutine 1: Register handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
handler := &UnitTestHandler{name: "concurrent-handler"}
processor.RegisterHandler("CONCURRENT", handler, false)
processor.UnregisterHandler("CONCURRENT")
}
}()
// Goroutine 2: Get handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
processor.GetHandler("CONCURRENT")
}
}()
// Wait for both goroutines to complete
<-done
<-done
})
}
// TestProcessorInterfaceCompliance tests interface compliance
func TestProcessorInterfaceCompliance(t *testing.T) {
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*Processor)(nil)
})
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*VoidProcessor)(nil)
})
}
// UnitTestHandler is a test implementation of NotificationHandler
type UnitTestHandler struct {
name string
lastNotification []interface{}
errorToReturn error
callCount int
}
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
h.callCount++
h.lastNotification = notification
return h.errorToReturn
}
// Helper methods for UnitTestHandler
func (h *UnitTestHandler) GetCallCount() int {
return h.callCount
}
func (h *UnitTestHandler) GetLastNotification() []interface{} {
return h.lastNotification
}
func (h *UnitTestHandler) SetErrorToReturn(err error) {
h.errorToReturn = err
}
func (h *UnitTestHandler) Reset() {
h.callCount = 0
h.lastNotification = nil
h.errorToReturn = nil
}

View File

@@ -4,24 +4,6 @@ import (
"github.com/redis/go-redis/v9/push"
)
// Push notification constants for cluster operations
const (
// MOVING indicates a slot is being moved to a different node
PushNotificationMoving = "MOVING"
// MIGRATING indicates a slot is being migrated from this node
PushNotificationMigrating = "MIGRATING"
// MIGRATED indicates a slot has been migrated to this node
PushNotificationMigrated = "MIGRATED"
// FAILING_OVER indicates a failover is starting
PushNotificationFailingOver = "FAILING_OVER"
// FAILED_OVER indicates a failover has completed
PushNotificationFailedOver = "FAILED_OVER"
)
// NewPushNotificationProcessor creates a new push notification processor
// This processor maintains a registry of handlers and processes push notifications
// It is used for RESP3 connections where push notifications are available

211
redis.go
View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
@@ -204,19 +205,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
//------------------------------------------------------------------------------
type baseClient struct {
opt *Options
connPool pool.Pooler
opt *Options
optLock sync.RWMutex
connPool pool.Pooler
pubSubPool *pool.PubSubPool
hooksMixin
onClose func() error // hook called when client is closed
// Push notification processing
pushProcessor push.NotificationProcessor
// Hitless upgrade manager
hitlessManager *hitless.HitlessManager
hitlessManagerLock sync.RWMutex
}
func (c *baseClient) clone() *baseClient {
clone := *c
return &clone
c.hitlessManagerLock.RLock()
hitlessManager := c.hitlessManager
c.hitlessManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
hitlessManager: hitlessManager,
}
return clone
}
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
@@ -234,21 +251,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.NewConn(ctx)
if err != nil {
return nil, err
}
err = c.initConn(ctx, cn)
if err != nil {
_ = c.connPool.CloseConn(cn)
return nil, err
}
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
@@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
if cn.Inited {
if cn.IsInited() {
return cn, nil
}
@@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
if !cn.Inited.CompareAndSwap(false, true) {
return nil
}
var err error
cn.Inited = true
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin)
@@ -430,6 +430,50 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("failed to initialize connection options: %w", err)
}
// Enable maintenance notifications if hitless upgrades are configured
c.optLock.RLock()
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
protocol := c.opt.Protocol
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
c.optLock.RUnlock()
var hitlessHandshakeErr error
if hitlessEnabled && protocol == 3 {
hitlessHandshakeErr = conn.ClientMaintNotifications(
ctx,
true,
endpointType.String(),
).Err()
if hitlessHandshakeErr != nil {
if !isRedisError(hitlessHandshakeErr) {
// if not redis error, fail the connection
return hitlessHandshakeErr
}
c.optLock.Lock()
// handshake failed - check and modify config atomically
switch c.opt.HitlessUpgradeConfig.Mode {
case hitless.MaintNotificationsEnabled:
// enabled mode, fail the connection
c.optLock.Unlock()
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
default: // will handle auto and any other
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
c.optLock.Unlock()
// auto mode, disable hitless upgrades and continue
if err := c.disableHitlessUpgrades(); err != nil {
// Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
}
}
} else {
// handshake was executed successfully
// to make sure that the handshake will be executed on other connections as well if it was successfully
// executed on this connection, we will force the handshake to be executed on all connections
c.optLock.Lock()
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
c.optLock.Unlock()
}
}
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
libName := ""
libVer := Version()
@@ -446,6 +490,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
}
}
cn.SetUsable(true)
cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections
cn.SetInitConnFunc(c.createInitConnFunc())
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
@@ -593,19 +643,76 @@ func (c *baseClient) context(ctx context.Context) context.Context {
return context.Background()
}
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
return func(ctx context.Context, cn *pool.Conn) error {
return c.initConn(ctx, cn)
}
}
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
// This function is called during client initialization.
// will register push notification handlers for all hitless upgrade events.
// will start background workers for handoff processing in the pool hook.
func (c *baseClient) enableHitlessUpgrades() error {
// Create client adapter
clientAdapterInstance := newClientAdapter(c)
// Create hitless manager directly
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
if err != nil {
return err
}
// Set the manager reference and initialize pool hook
c.hitlessManagerLock.Lock()
c.hitlessManager = manager
c.hitlessManagerLock.Unlock()
// Initialize pool hook (safe to call without lock since manager is now set)
manager.InitPoolHook(c.dialHook)
return nil
}
func (c *baseClient) disableHitlessUpgrades() error {
c.hitlessManagerLock.Lock()
defer c.hitlessManagerLock.Unlock()
// Close the hitless manager
if c.hitlessManager != nil {
// Closing the manager will also shutdown the pool hook
// and remove it from the pool
c.hitlessManager.Close()
c.hitlessManager = nil
}
return nil
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
// Close hitless manager first
if err := c.disableHitlessUpgrades(); err != nil {
firstErr = err
}
if c.onClose != nil {
if err := c.onClose(); err != nil {
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
if c.connPool != nil {
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.pubSubPool != nil {
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
@@ -810,11 +917,24 @@ func NewClient(opt *Options) *Client {
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
// Update options with the initialized push processor for connection pool
// Update options with the initialized push processor
opt.PushNotificationProcessor = c.pushProcessor
// Create connection pools
c.connPool = newConnPool(opt, c.dialHook)
c.pubSubPool = newPubSubPool(opt, c.dialHook)
// Initialize hitless upgrades first if enabled and protocol is RESP3
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
err := c.enableHitlessUpgrades()
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
// panic so we fail fast without breaking existing clients api
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
}
}
}
return &c
}
@@ -851,6 +971,14 @@ func (c *Client) Options() *Options {
return c.opt
}
// GetHitlessManager returns the hitless manager instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
c.hitlessManagerLock.RLock()
defer c.hitlessManagerLock.RUnlock()
return c.hitlessManager
}
// initializePushProcessor initializes the push notification processor for any client type.
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
func initializePushProcessor(opt *Options) push.NotificationProcessor {
@@ -887,6 +1015,7 @@ type PoolStats pool.Stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
stats := c.connPool.Stats()
stats.PubSubStats = *(c.pubSubPool.Stats())
return (*PoolStats)(stats)
}
@@ -921,11 +1050,27 @@ func (c *Client) TxPipeline() Pipeliner {
func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
closeConn: c.connPool.CloseConn,
pushProcessor: c.pushProcessor,
}
pubsub.init()
@@ -1113,6 +1258,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica
return push.NotificationHandlerContext{
Client: c,
ConnPool: c.connPool,
Conn: cn,
Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access
}
}

View File

@@ -12,7 +12,6 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)

View File

@@ -16,8 +16,8 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/push"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/push"
)
//------------------------------------------------------------------------------
@@ -139,6 +139,14 @@ type FailoverOptions struct {
FailingTimeoutSeconds int
UnstableResp3 bool
// Hitless is not supported for FailoverClients at the moment
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// upgrade notifications gracefully and manage connection/pool state transitions
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are disabled.
//HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *FailoverOptions) clientOptions() *Options {
@@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt.Dialer = masterReplicaDialer(failover)
opt.init()
var connPool *pool.ConnPool
rdb := &Client{
baseClient: &baseClient{
opt: opt,
@@ -469,15 +475,18 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
// Use void processor by default for RESP2 connections
rdb.pushProcessor = initializePushProcessor(opt)
connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool
rdb.connPool = newConnPool(opt, rdb.dialHook)
rdb.pubSubPool = newPubSubPool(opt, rdb.dialHook)
rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
}
}
failover.mu.Unlock()
@@ -544,6 +553,7 @@ func NewSentinelClient(opt *Options) *SentinelClient {
process: c.baseClient.process,
})
c.connPool = newConnPool(opt, c.dialHook)
c.pubSubPool = newPubSubPool(opt, c.dialHook)
return c
}
@@ -570,13 +580,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}

2
tx.go
View File

@@ -24,7 +24,7 @@ type Tx struct {
func (c *Client) newTx() *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
opt: c.opt.clone(), // Clone options to avoid sharing HitlessUpgradeConfig
connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
pushProcessor: c.pushProcessor, // Copy push processor from parent client

View File

@@ -122,6 +122,9 @@ type UniversalOptions struct {
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
IsClusterMode bool
// HitlessUpgradeConfig provides configuration for hitless upgrades.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// Cluster returns cluster options created from the universal options.
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
// Note: HitlessUpgradeConfig not supported for FailoverOptions
}
}
@@ -284,10 +289,11 @@ func (o *UniversalOptions) Simple() *Options {
TLSConfig: o.TLSConfig,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}