1
0
mirror of https://github.com/redis/go-redis.git synced 2025-09-10 07:11:50 +03:00

[CAE-1072] Hitless Upgrades (#3447)

* feat(hitless): Introduce handlers for hitless upgrades

This commit includes all the work on hitless upgrades with the addition
of:

- Pubsub Pool
- Examples
- Refactor of push
- Refactor of pool (using atomics for most things)
- Introducing of hooks in pool


---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nedyalko Dyakov
2025-09-03 14:49:16 +03:00
committed by GitHub
parent 36f9f58c67
commit cb3af0800e
56 changed files with 8062 additions and 286 deletions

3
.gitignore vendored
View File

@@ -9,3 +9,6 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
# Hitless upgrade documentation (temporary)
hitless/docs/

111
adapters.go Normal file
View File

@@ -0,0 +1,111 @@
package redis
import (
"context"
"errors"
"net"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/push"
)
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
var ErrInvalidCommand = errors.New("invalid command type")
// ErrInvalidPool is returned when the pool type is not supported.
var ErrInvalidPool = errors.New("invalid pool type")
// newClientAdapter creates a new client adapter for regular Redis clients.
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
return &clientAdapter{client: client}
}
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
type clientAdapter struct {
client *baseClient
}
// GetOptions returns the client options.
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
return &optionsAdapter{options: ca.client.opt}
}
// GetPushProcessor returns the client's push notification processor.
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
}
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
type optionsAdapter struct {
options *Options
}
// GetReadTimeout returns the read timeout.
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
return oa.options.ReadTimeout
}
// GetWriteTimeout returns the write timeout.
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
return oa.options.WriteTimeout
}
// GetNetwork returns the network type.
func (oa *optionsAdapter) GetNetwork() string {
return oa.options.Network
}
// GetAddr returns the connection address.
func (oa *optionsAdapter) GetAddr() string {
return oa.options.Addr
}
// IsTLSEnabled returns true if TLS is enabled.
func (oa *optionsAdapter) IsTLSEnabled() bool {
return oa.options.TLSConfig != nil
}
// GetProtocol returns the protocol version.
func (oa *optionsAdapter) GetProtocol() int {
return oa.options.Protocol
}
// GetPoolSize returns the connection pool size.
func (oa *optionsAdapter) GetPoolSize() int {
return oa.options.PoolSize
}
// NewDialer returns a new dialer function for the connection.
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
baseDialer := oa.options.NewDialer()
return func(ctx context.Context) (net.Conn, error) {
// Extract network and address from the options
network := oa.options.Network
addr := oa.options.Addr
return baseDialer(ctx, network, addr)
}
}
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
type pushProcessorAdapter struct {
processor push.NotificationProcessor
}
// RegisterHandler registers a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
if pushHandler, ok := handler.(push.NotificationHandler); ok {
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
}
return errors.New("handler must implement push.NotificationHandler")
}
// UnregisterHandler removes a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
return ppa.processor.UnregisterHandler(pushNotificationName)
}
// GetHandler returns the handler for a specific push notification name.
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
return ppa.processor.GetHandler(pushNotificationName)
}

View File

@@ -0,0 +1,353 @@
package redis
import (
"context"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
func TestEventDrivenHandoffIntegration(t *testing.T) {
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor with event-driven handoff support
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create a test pool with hooks
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(5),
PoolTimeout: time.Second,
})
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
defer testPool.Close()
// Set the pool reference in the processor for connection removal on handoff failure
processor.SetPool(testPool)
ctx := context.Background()
// Get a connection and mark it for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
// Set initialization function with a small delay to ensure handoff is pending
initConnCalled := false
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
initConnCalled = true
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark connection for handoff
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
if err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Return connection to pool - this should queue handoff
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start processing
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Try to get the same connection - should be skipped due to pending handoff
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get second connection: %v", err)
}
// Should get a different connection (the pending one should be skipped)
if conn == conn2 {
t.Error("Should have gotten a different connection while handoff is pending")
}
// Return the second connection
testPool.Put(ctx, conn2)
// Wait for handoff to complete
time.Sleep(200 * time.Millisecond)
// Verify handoff completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map")
}
if !initConnCalled {
t.Error("InitConn should have been called during handoff")
}
// Now the original connection should be available again
conn3, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get third connection: %v", err)
}
// Could be the original connection (now handed off) or a new one
testPool.Put(ctx, conn3)
})
t.Run("ConcurrentHandoffs", func(t *testing.T) {
// Create a base dialer that simulates slow handoffs
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(50 * time.Millisecond) // Simulate network delay
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(10),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
var wg sync.WaitGroup
// Start multiple concurrent handoffs
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Get connection
conn, err := testPool.Get(ctx)
if err != nil {
t.Errorf("Failed to get conn[%d]: %v", id, err)
return
}
// Set initialization function
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark for handoff
conn.MarkForHandoff("new-endpoint:6379", int64(id))
// Return to pool (starts async handoff)
testPool.Put(ctx, conn)
}(i)
}
wg.Wait()
// Wait for all handoffs to complete
time.Sleep(300 * time.Millisecond)
// Verify pool is still functional
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
}
testPool.Put(ctx, conn)
})
t.Run("HandoffFailureRecovery", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
}
processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(3),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Get connection and mark for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
// Return to pool (starts async handoff that will fail)
testPool.Put(ctx, conn)
// Wait for handoff to fail
time.Sleep(200 * time.Millisecond)
// Connection should be removed from pending map after failed handoff
if processor.IsHandoffPending(conn) {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Pool should still be functional
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional: %v", err)
}
// In event-driven approach, the original connection remains in pool
// even after failed handoff (it's still a valid connection)
// We might get the same connection or a different one
testPool.Put(ctx, conn2)
})
t.Run("GracefulShutdown", func(t *testing.T) {
// Create a slow base dialer
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(100 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(2),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Start a handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function with delay to ensure handoff is pending
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
return nil
})
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start and begin processing
// The handoff should be pending because the slowDialer takes 100ms
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued and is being processed
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Give the handoff a moment to start processing
time.Sleep(50 * time.Millisecond)
// Shutdown processor gracefully
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = processor.Shutdown(shutdownCtx)
if err != nil {
t.Errorf("Graceful shutdown should succeed: %v", err)
}
// Handoff should have completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map after shutdown")
}
})
}
func init() {
logging.Disable()
}

View File

@@ -193,6 +193,7 @@ type Cmdable interface {
ClientID(ctx context.Context) *IntCmd
ClientUnblock(ctx context.Context, id int64) *IntCmd
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
ConfigResetStat(ctx context.Context) *StatusCmd
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
@@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}
if enabled {
if endpointType == "" {
endpointType = "none"
}
args = append(args, "on", "moving-endpoint-type", endpointType)
} else {
args = append(args, "off")
}
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// ------------------------------------------------------------------------------------------------
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {

12
example/pubsub/go.mod Normal file
View File

@@ -0,0 +1,12 @@
module github.com/redis/go-redis/example/pubsub
go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.11.0
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

6
example/pubsub/go.sum Normal file
View File

@@ -0,0 +1,6 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

171
example/pubsub/main.go Normal file
View File

@@ -0,0 +1,171 @@
package main
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/logging"
)
var ctx = context.Background()
var cntErrors atomic.Int64
var cntSuccess atomic.Int64
var startTime = time.Now()
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
// It was used to find regressions in pool management in hitless mode.
// Please don't use it as a reference for how to use pubsub.
func main() {
startTime = time.Now()
wg := &sync.WaitGroup{}
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{
Mode: hitless.MaintNotificationsEnabled,
},
})
_ = rdb.FlushDB(ctx).Err()
hitlessManager := rdb.GetHitlessManager()
if hitlessManager == nil {
panic("hitless manager is nil")
}
loggingHook := hitless.NewLoggingHook(logging.LogLevelDebug)
hitlessManager.AddNotificationHook(loggingHook)
go func() {
for {
time.Sleep(2 * time.Second)
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
}
}()
err := rdb.Ping(ctx).Err()
if err != nil {
panic(err)
}
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
panic(err)
}
fmt.Println("published", rdb.Get(ctx, "published").Val())
fmt.Println("received", rdb.Get(ctx, "received").Val())
subCtx, cancelSubCtx := context.WithCancel(ctx)
pubCtx, cancelPublishers := context.WithCancel(ctx)
for i := 0; i < 10; i++ {
wg.Add(1)
go subscribe(subCtx, rdb, "test", i, wg)
}
time.Sleep(time.Second)
cancelSubCtx()
time.Sleep(time.Second)
subCtx, cancelSubCtx = context.WithCancel(ctx)
for i := 0; i < 10; i++ {
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
fmt.Println("incr error:", err)
cntErrors.Add(1)
}
wg.Add(1)
go floodThePool(pubCtx, rdb, wg)
}
for i := 0; i < 500; i++ {
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
fmt.Println("incr error:", err)
cntErrors.Add(1)
}
wg.Add(1)
go subscribe(subCtx, rdb, "test2", i, wg)
}
time.Sleep(120 * time.Second)
fmt.Println("canceling publishers")
cancelPublishers()
time.Sleep(10 * time.Second)
fmt.Println("canceling subscribers")
cancelSubCtx()
wg.Wait()
published, err := rdb.Get(ctx, "published").Result()
received, err := rdb.Get(ctx, "received").Result()
publishers, err := rdb.Get(ctx, "publishers").Result()
subscribers, err := rdb.Get(ctx, "subscribers").Result()
fmt.Printf("publishers: %s\n", publishers)
fmt.Printf("published: %s\n", published)
fmt.Printf("subscribers: %s\n", subscribers)
fmt.Printf("received: %s\n", received)
publishedInt, err := rdb.Get(ctx, "published").Int()
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
time.Sleep(2 * time.Second)
fmt.Println("errors:", cntErrors.Load())
fmt.Println("success:", cntSuccess.Load())
fmt.Println("time:", time.Since(startTime))
}
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
err := rdb.Publish(ctx, "test2", "hello").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Println("publish error:", err)
cntErrors.Add(1)
}
}
err = rdb.Incr(ctx, "published").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Println("incr error:", err)
cntErrors.Add(1)
}
}
time.Sleep(10 * time.Nanosecond)
}
}
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
defer wg.Done()
rec := rdb.Subscribe(ctx, topic)
recChan := rec.Channel()
for {
select {
case <-ctx.Done():
rec.Close()
return
default:
select {
case <-ctx.Done():
rec.Close()
return
case msg := <-recChan:
err := rdb.Incr(ctx, "received").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Printf("%s\n", err.Error())
cntErrors.Add(1)
}
}
_ = msg // Use the message to avoid unused variable warning
}
}
}
}

View File

@@ -57,6 +57,8 @@ func Example_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[ping]>
}
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// pipeline finished processing: [[ping] [ping]]
}
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[watch foo]>
// starting processing: <[ping]>
// finished processing: <[ping]>

98
hitless/README.md Normal file
View File

@@ -0,0 +1,98 @@
# Hitless Upgrades
Seamless Redis connection handoffs during cluster changes without dropping connections.
## Quick Start
```go
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
HitlessUpgrades: &hitless.Config{
Mode: hitless.MaintNotificationsEnabled,
},
})
```
## Modes
- **`MaintNotificationsDisabled`** - Hitless upgrades disabled
- **`MaintNotificationsEnabled`** - Forcefully enabled (fails if server doesn't support)
- **`MaintNotificationsAuto`** - Auto-detect server support (default)
## Configuration
```go
&hitless.Config{
Mode: hitless.MaintNotificationsAuto,
EndpointType: hitless.EndpointTypeAuto,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxHandoffRetries: 3,
MaxWorkers: 0, // Auto-calculated
HandoffQueueSize: 0, // Auto-calculated
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
LogLevel: logging.LogLevelError,
}
```
### Endpoint Types
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
- **`EndpointTypeInternalIP`** - Internal IP address
- **`EndpointTypeInternalFQDN`** - Internal FQDN
- **`EndpointTypeExternalIP`** - External IP address
- **`EndpointTypeExternalFQDN`** - External FQDN
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
### Auto-Scaling
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
**Examples:**
- Pool 100: 33 workers, 660 queue (capped at 500)
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
## How It Works
1. Redis sends push notifications about cluster changes
2. Client creates new connections to updated endpoints
3. Active operations transfer to new connections
4. Old connections close gracefully
## Supported Notifications
- `MOVING` - Slot moving to new node
- `MIGRATING` - Slot in migration state
- `MIGRATED` - Migration completed
- `FAILING_OVER` - Node failing over
- `FAILED_OVER` - Failover completed
## Hooks (Optional)
Monitor and customize hitless operations:
```go
type NotificationHook interface {
PreHook(ctx, notificationCtx, notificationType, notification) ([]interface{}, bool)
PostHook(ctx, notificationCtx, notificationType, notification, result)
}
// Add custom hook
manager.AddNotificationHook(&MyHook{})
```
### Metrics Hook Example
```go
// Create metrics hook
metricsHook := hitless.NewMetricsHook()
manager.AddNotificationHook(metricsHook)
// Access collected metrics
metrics := metricsHook.GetMetrics()
fmt.Printf("Notification counts: %v\n", metrics["notification_counts"])
fmt.Printf("Processing times: %v\n", metrics["processing_times"])
fmt.Printf("Error counts: %v\n", metrics["error_counts"])
```

360
hitless/circuit_breaker.go Normal file
View File

@@ -0,0 +1,360 @@
package hitless
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
)
// CircuitBreakerState represents the state of a circuit breaker
type CircuitBreakerState int32
const (
// CircuitBreakerClosed - normal operation, requests allowed
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen - failing fast, requests rejected
CircuitBreakerOpen
// CircuitBreakerHalfOpen - testing if service recovered
CircuitBreakerHalfOpen
)
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
type CircuitBreaker struct {
// Configuration
failureThreshold int // Number of failures before opening
resetTimeout time.Duration // How long to stay open before testing
maxRequests int // Max requests allowed in half-open state
// State tracking (atomic for lock-free access)
state atomic.Int32 // CircuitBreakerState
failures atomic.Int64 // Current failure count
successes atomic.Int64 // Success count in half-open state
requests atomic.Int64 // Request count in half-open state
lastFailureTime atomic.Int64 // Unix timestamp of last failure
lastSuccessTime atomic.Int64 // Unix timestamp of last success
// Endpoint identification
endpoint string
config *Config
}
// newCircuitBreaker creates a new circuit breaker for an endpoint
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
// Use configuration values with sensible defaults
failureThreshold := 5
resetTimeout := 60 * time.Second
maxRequests := 3
if config != nil {
failureThreshold = config.CircuitBreakerFailureThreshold
resetTimeout = config.CircuitBreakerResetTimeout
maxRequests = config.CircuitBreakerMaxRequests
}
return &CircuitBreaker{
failureThreshold: failureThreshold,
resetTimeout: resetTimeout,
maxRequests: maxRequests,
endpoint: endpoint,
config: config,
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
}
}
// IsOpen returns true if the circuit breaker is open (rejecting requests)
func (cb *CircuitBreaker) IsOpen() bool {
state := CircuitBreakerState(cb.state.Load())
return state == CircuitBreakerOpen
}
// shouldAttemptReset checks if enough time has passed to attempt reset
func (cb *CircuitBreaker) shouldAttemptReset() bool {
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
return time.Since(lastFailure) >= cb.resetTimeout
}
// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
// Single atomic state load for consistency
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerOpen:
if cb.shouldAttemptReset() {
// Attempt transition to half-open
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.requests.Store(0)
cb.successes.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker for %s transitioning to half-open", cb.endpoint)
}
// Fall through to half-open logic
} else {
return ErrCircuitBreakerOpen
}
} else {
return ErrCircuitBreakerOpen
}
fallthrough
case CircuitBreakerHalfOpen:
requests := cb.requests.Add(1)
if requests > int64(cb.maxRequests) {
cb.requests.Add(-1) // Revert the increment
return ErrCircuitBreakerOpen
}
}
// Execute the function with consistent state
err := fn()
if err != nil {
cb.recordFailure()
return err
}
cb.recordSuccess()
return nil
}
// recordFailure records a failure and potentially opens the circuit
func (cb *CircuitBreaker) recordFailure() {
cb.lastFailureTime.Store(time.Now().Unix())
failures := cb.failures.Add(1)
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
if failures >= int64(cb.failureThreshold) {
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker opened for endpoint %s after %d failures",
cb.endpoint, failures)
}
}
}
case CircuitBreakerHalfOpen:
// Any failure in half-open state immediately opens the circuit
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker reopened for endpoint %s due to failure in half-open state",
cb.endpoint)
}
}
}
}
// recordSuccess records a success and potentially closes the circuit
func (cb *CircuitBreaker) recordSuccess() {
cb.lastSuccessTime.Store(time.Now().Unix())
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
// Reset failure count on success in closed state
cb.failures.Store(0)
case CircuitBreakerHalfOpen:
successes := cb.successes.Add(1)
// If we've had enough successful requests, close the circuit
if successes >= int64(cb.maxRequests) {
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.failures.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker closed for endpoint %s after %d successful requests",
cb.endpoint, successes)
}
}
}
}
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(cb.state.Load())
}
// GetStats returns current statistics for monitoring
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
return CircuitBreakerStats{
Endpoint: cb.endpoint,
State: cb.GetState(),
Failures: cb.failures.Load(),
Successes: cb.successes.Load(),
Requests: cb.requests.Load(),
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
}
}
// CircuitBreakerStats provides statistics about a circuit breaker
type CircuitBreakerStats struct {
Endpoint string
State CircuitBreakerState
Failures int64
Successes int64
Requests int64
LastFailureTime time.Time
LastSuccessTime time.Time
}
// CircuitBreakerEntry wraps a circuit breaker with access tracking
type CircuitBreakerEntry struct {
breaker *CircuitBreaker
lastAccess atomic.Int64 // Unix timestamp
created time.Time
}
// CircuitBreakerManager manages circuit breakers for multiple endpoints
type CircuitBreakerManager struct {
breakers sync.Map // map[string]*CircuitBreakerEntry
config *Config
cleanupStop chan struct{}
cleanupMu sync.Mutex
lastCleanup atomic.Int64 // Unix timestamp
}
// newCircuitBreakerManager creates a new circuit breaker manager
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
cbm := &CircuitBreakerManager{
config: config,
cleanupStop: make(chan struct{}),
}
cbm.lastCleanup.Store(time.Now().Unix())
// Start background cleanup goroutine
go cbm.cleanupLoop()
return cbm
}
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
now := time.Now().Unix()
if entry, ok := cbm.breakers.Load(endpoint); ok {
cbEntry := entry.(*CircuitBreakerEntry)
cbEntry.lastAccess.Store(now)
return cbEntry.breaker
}
// Create new circuit breaker with metadata
newBreaker := newCircuitBreaker(endpoint, cbm.config)
newEntry := &CircuitBreakerEntry{
breaker: newBreaker,
created: time.Now(),
}
newEntry.lastAccess.Store(now)
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
return actual.(*CircuitBreakerEntry).breaker
}
// GetAllStats returns statistics for all circuit breakers
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
var stats []CircuitBreakerStats
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
stats = append(stats, entry.breaker.GetStats())
return true
})
return stats
}
// cleanupLoop runs background cleanup of unused circuit breakers
func (cbm *CircuitBreakerManager) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-ticker.C:
cbm.cleanup()
case <-cbm.cleanupStop:
return
}
}
}
// cleanup removes circuit breakers that haven't been accessed recently
func (cbm *CircuitBreakerManager) cleanup() {
// Prevent concurrent cleanups
if !cbm.cleanupMu.TryLock() {
return
}
defer cbm.cleanupMu.Unlock()
now := time.Now()
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
var toDelete []string
count := 0
cbm.breakers.Range(func(key, value interface{}) bool {
endpoint := key.(string)
entry := value.(*CircuitBreakerEntry)
count++
// Remove if not accessed recently
if entry.lastAccess.Load() < cutoff {
toDelete = append(toDelete, endpoint)
}
return true
})
// Delete expired entries
for _, endpoint := range toDelete {
cbm.breakers.Delete(endpoint)
}
// Log cleanup results
if len(toDelete) > 0 && cbm.config != nil && cbm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker cleanup removed %d/%d entries", len(toDelete), count)
}
cbm.lastCleanup.Store(now.Unix())
}
// Shutdown stops the cleanup goroutine
func (cbm *CircuitBreakerManager) Shutdown() {
close(cbm.cleanupStop)
}
// Reset resets all circuit breakers (useful for testing)
func (cbm *CircuitBreakerManager) Reset() {
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
breaker := entry.breaker
breaker.state.Store(int32(CircuitBreakerClosed))
breaker.failures.Store(0)
breaker.successes.Store(0)
breaker.requests.Store(0)
breaker.lastFailureTime.Store(0)
breaker.lastSuccessTime.Store(0)
return true
})
}

View File

@@ -0,0 +1,356 @@
package hitless
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9/logging"
)
func TestCircuitBreaker(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError, // Reduce noise in tests
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
}
t.Run("InitialState", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
if cb.IsOpen() {
t.Error("Circuit breaker should start in closed state")
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("SuccessfulExecution", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("FailureThreshold", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Fail 4 times (below threshold of 5)
for i := 0; i < 4; i++ {
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 5th failure should open the circuit
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("OpenCircuitFailsFast", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Now it should fail fast
err := cb.Execute(func() error {
t.Error("Function should not be called when circuit is open")
return nil
})
if err != ErrCircuitBreakerOpen {
t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err)
}
})
t.Run("HalfOpenTransition", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Wait for reset timeout
time.Sleep(150 * time.Millisecond)
// Next call should transition to half-open
executed := false
err := cb.Execute(func() error {
executed = true
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !executed {
t.Error("Function should have been executed in half-open state")
}
})
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// Execute successful requests in half-open state
for i := 0; i < 3; i++ {
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error on attempt %d, got %v", i+1, err)
}
}
// Circuit should now be closed
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
testConfig := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// First request in half-open state fails
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
// Circuit should be open again
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("Stats", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Execute some operations
cb.Execute(func() error { return testError }) // Failure
cb.Execute(func() error { return testError }) // Failure
stats := cb.GetStats()
if stats.Endpoint != "test-endpoint:6379" {
t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint)
}
if stats.Failures != 2 {
t.Errorf("Expected 2 failures, got %d", stats.Failures)
}
if stats.State != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State)
}
// Test that success resets failure count
cb.Execute(func() error { return nil }) // Success
stats = cb.GetStats()
if stats.Failures != 0 {
t.Errorf("Expected 0 failures after success, got %d", stats.Failures)
}
})
}
func TestCircuitBreakerManager(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
}
t.Run("GetCircuitBreaker", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1
if cb1 == cb2 {
t.Error("Different endpoints should have different circuit breakers")
}
if cb1 != cb3 {
t.Error("Same endpoint should return the same circuit breaker")
}
})
t.Run("GetAllStats", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
// Create circuit breakers for different endpoints
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
// Execute some operations
cb1.Execute(func() error { return nil })
cb2.Execute(func() error { return errors.New("test error") })
stats := manager.GetAllStats()
if len(stats) != 2 {
t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats))
}
// Check that we have stats for both endpoints
endpoints := make(map[string]bool)
for _, stat := range stats {
endpoints[stat.Endpoint] = true
}
if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] {
t.Error("Missing stats for expected endpoints")
}
})
t.Run("Reset", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
testError := errors.New("test error")
cb := manager.GetCircuitBreaker("test-endpoint:6379")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Reset all circuit breakers
manager.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Error("Circuit should be closed after reset")
}
if cb.failures.Load() != 0 {
t.Error("Failure count should be reset to 0")
}
})
t.Run("ConfigurableParameters", func(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
CircuitBreakerFailureThreshold: 10,
CircuitBreakerResetTimeout: 30 * time.Second,
CircuitBreakerMaxRequests: 5,
}
cb := newCircuitBreaker("test-endpoint:6379", config)
// Test that configuration values are used
if cb.failureThreshold != 10 {
t.Errorf("Expected failureThreshold=10, got %d", cb.failureThreshold)
}
if cb.resetTimeout != 30*time.Second {
t.Errorf("Expected resetTimeout=30s, got %v", cb.resetTimeout)
}
if cb.maxRequests != 5 {
t.Errorf("Expected maxRequests=5, got %d", cb.maxRequests)
}
// Test that circuit opens after configured threshold
testError := errors.New("test error")
for i := 0; i < 9; i++ {
err := cb.Execute(func() error { return testError })
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 10th failure should open the circuit
err := cb.Execute(func() error { return testError })
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
}

472
hitless/config.go Normal file
View File

@@ -0,0 +1,472 @@
package hitless
import (
"context"
"net"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/logging"
)
// MaintNotificationsMode represents the maintenance notifications mode
type MaintNotificationsMode string
// Constants for maintenance push notifications modes
const (
MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error
MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m MaintNotificationsMode) IsValid() bool {
switch m {
case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto:
return true
default:
return false
}
}
// String returns the string representation of the mode
func (m MaintNotificationsMode) String() string {
return string(m)
}
// EndpointType represents the type of endpoint to request in MOVING notifications
type EndpointType string
// Constants for endpoint types
const (
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
)
// IsValid returns true if the endpoint type is valid
func (e EndpointType) IsValid() bool {
switch e {
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
return true
default:
return false
}
}
// String returns the string representation of the endpoint type
func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for hitless upgrades.
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto
// Default: MaintNotificationsAuto
Mode MaintNotificationsMode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
// Default: EndpointTypeAuto
EndpointType EndpointType
// RelaxedTimeout is the concrete timeout value to use during
// MIGRATING/FAILING_OVER states to accommodate increased latency.
// This applies to both read and write timeouts.
// Default: 10 seconds
RelaxedTimeout time.Duration
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
// If handoff takes longer than this, the old connection will be forcibly closed.
// Default: 15 seconds (matches server-side eviction timeout)
HandoffTimeout time.Duration
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
// Workers are created on-demand and automatically cleaned up when idle.
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
// If explicitly set, enforces minimum of PoolSize/2
//
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
MaxWorkers int
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
// If the queue is full, new handoff requests will be rejected.
// Scales with both worker count and pool size for better burst handling.
//
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
HandoffQueueSize int
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
// after a handoff completes. This provides additional resilience during cluster transitions.
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// LogLevel controls the verbosity of hitless upgrade logging.
// LogLevelError (0) = errors only, LogLevelWarn (1) = warnings, LogLevelInfo (2) = info, LogLevelDebug (3) = debug
// Default: logging.LogLevelError(0)
LogLevel logging.LogLevel
// Circuit breaker configuration for endpoint failure handling
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
// Default: 5
CircuitBreakerFailureThreshold int
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
// Default: 60 seconds
CircuitBreakerResetTimeout time.Duration
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
// Default: 3
CircuitBreakerMaxRequests int
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
// After this many retries, the connection will be removed from the pool.
// Default: 3
MaxHandoffRetries int
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != MaintNotificationsDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
LogLevel: logging.LogLevelError,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
// Connection Handoff Configuration
MaxHandoffRetries: 3,
}
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.RelaxedTimeout <= 0 {
return ErrInvalidRelaxedTimeout
}
if c.HandoffTimeout <= 0 {
return ErrInvalidHandoffTimeout
}
// Validate worker configuration
// Allow 0 for auto-calculation, but negative values are invalid
if c.MaxWorkers < 0 {
return ErrInvalidHandoffWorkers
}
// HandoffQueueSize validation - allow 0 for auto-calculation
if c.HandoffQueueSize < 0 {
return ErrInvalidHandoffQueueSize
}
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
if !c.LogLevel.IsValid() {
return ErrInvalidLogLevel
}
// Circuit breaker validation
if c.CircuitBreakerFailureThreshold < 1 {
return ErrInvalidCircuitBreakerFailureThreshold
}
if c.CircuitBreakerResetTimeout < 0 {
return ErrInvalidCircuitBreakerResetTimeout
}
if c.CircuitBreakerMaxRequests < 1 {
return ErrInvalidCircuitBreakerMaxRequests
}
// Validate Mode (maintenance notifications mode)
if !c.Mode.IsValid() {
return ErrInvalidMaintNotifications
}
// Validate EndpointType
if !c.EndpointType.IsValid() {
return ErrInvalidEndpointType
}
// Validate configuration fields
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
return ErrInvalidHandoffRetries
}
return nil
}
// ApplyDefaults applies default values to any zero-value fields in the configuration.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaults() *Config {
return c.ApplyDefaultsWithPoolSize(0)
}
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
// using the provided pool size to calculate worker defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
}
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
// using the provided pool size and max active connections to calculate worker and queue defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
if c == nil {
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
}
defaults := DefaultConfig()
result := &Config{}
// Apply defaults for enum fields (empty/zero means not set)
result.Mode = defaults.Mode
if c.Mode != "" {
result.Mode = c.Mode
}
result.EndpointType = defaults.EndpointType
if c.EndpointType != "" {
result.EndpointType = c.EndpointType
}
// Apply defaults for duration fields (zero means not set)
result.RelaxedTimeout = defaults.RelaxedTimeout
if c.RelaxedTimeout > 0 {
result.RelaxedTimeout = c.RelaxedTimeout
}
result.HandoffTimeout = defaults.HandoffTimeout
if c.HandoffTimeout > 0 {
result.HandoffTimeout = c.HandoffTimeout
}
// Copy worker configuration
result.MaxWorkers = c.MaxWorkers
// Apply worker defaults based on pool size
result.applyWorkerDefaults(poolSize)
// Apply queue size defaults with new scaling approach
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
workerBasedSize := result.MaxWorkers * 20
poolBasedSize := poolSize
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
if c.HandoffQueueSize > 0 {
// When explicitly set: enforce minimum of 200
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
}
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
var queueCap int
if maxActiveConns > 0 {
queueCap = maxActiveConns + 1
// Ensure queue cap is at least 2 for very small maxActiveConns
if queueCap < 2 {
queueCap = 2
}
} else {
queueCap = poolSize * 5
}
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
// Ensure minimum queue size of 2 (fallback for very small pools)
if result.HandoffQueueSize < 2 {
result.HandoffQueueSize = 2
}
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
if c.PostHandoffRelaxedDuration > 0 {
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
// LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set
// We'll use the provided value as-is, since 0 is valid
result.LogLevel = c.LogLevel
// Apply defaults for configuration fields
result.MaxHandoffRetries = defaults.MaxHandoffRetries
if c.MaxHandoffRetries > 0 {
result.MaxHandoffRetries = c.MaxHandoffRetries
}
// Circuit breaker configuration
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
if c.CircuitBreakerFailureThreshold > 0 {
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
}
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
if c.CircuitBreakerResetTimeout > 0 {
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
}
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
if c.CircuitBreakerMaxRequests > 0 {
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
}
if result.LogLevel.DebugOrAbove() {
internal.Logger.Printf(context.Background(), "hitless: debug logging enabled")
internal.Logger.Printf(context.Background(), "hitless: config: %+v", result)
}
return result
}
// Clone creates a deep copy of the configuration.
func (c *Config) Clone() *Config {
if c == nil {
return DefaultConfig()
}
return &Config{
Mode: c.Mode,
EndpointType: c.EndpointType,
RelaxedTimeout: c.RelaxedTimeout,
HandoffTimeout: c.HandoffTimeout,
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
LogLevel: c.LogLevel,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
// Configuration fields
MaxHandoffRetries: c.MaxHandoffRetries,
}
}
// applyWorkerDefaults calculates and applies worker defaults based on pool size
func (c *Config) applyWorkerDefaults(poolSize int) {
// Calculate defaults based on pool size
if poolSize <= 0 {
poolSize = 10 * runtime.GOMAXPROCS(0)
}
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
originalMaxWorkers := c.MaxWorkers
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
if originalMaxWorkers != 0 {
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
}
// Ensure minimum of 1 worker (fallback for very small pools)
if c.MaxWorkers < 1 {
c.MaxWorkers = 1
}
}
// DetectEndpointType automatically detects the appropriate endpoint type
// based on the connection address and TLS configuration.
//
// For IP addresses:
// - If TLS is enabled: requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// For hostnames:
// - If TLS is enabled: always requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// Internal vs External detection:
// - For IPs: uses private IP range detection
// - For hostnames: uses heuristics based on common internal naming patterns
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
// Extract host from "host:port" format
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // Assume no port
}
// Check if the host is an IP address or hostname
ip := net.ParseIP(host)
isIPAddress := ip != nil
var endpointType EndpointType
if isIPAddress {
// Address is an IP - determine if it's private or public
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
if tlsEnabled {
// TLS with IP addresses - still prefer FQDN for certificate validation
if isPrivate {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
} else {
// No TLS - can use IP addresses directly
if isPrivate {
endpointType = EndpointTypeInternalIP
} else {
endpointType = EndpointTypeExternalIP
}
}
} else {
// Address is a hostname
isInternalHostname := isInternalHostname(host)
if isInternalHostname {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
}
return endpointType
}
// isInternalHostname determines if a hostname appears to be internal/private.
// This is a heuristic based on common naming patterns.
func isInternalHostname(hostname string) bool {
// Convert to lowercase for comparison
hostname = strings.ToLower(hostname)
// Common internal hostname patterns
internalPatterns := []string{
"localhost",
".local",
".internal",
".corp",
".lan",
".intranet",
".private",
}
// Check for exact match or suffix match
for _, pattern := range internalPatterns {
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
return true
}
}
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
// If hostname doesn't contain dots, it's likely internal
if !strings.Contains(hostname, ".") {
return true
}
// Default to external for fully qualified domain names
return false
}

490
hitless/config_test.go Normal file
View File

@@ -0,0 +1,490 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/logging"
)
func TestConfig(t *testing.T) {
t.Run("DefaultConfig", func(t *testing.T) {
config := DefaultConfig()
// MaxWorkers should be 0 in default config (auto-calculated)
if config.MaxWorkers != 0 {
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
}
// HandoffQueueSize should be 0 in default config (auto-calculated)
if config.HandoffQueueSize != 0 {
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
}
if config.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
}
// Test configuration fields have proper defaults
if config.MaxHandoffRetries != 3 {
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
}
// Circuit breaker defaults
if config.CircuitBreakerFailureThreshold != 5 {
t.Errorf("Expected CircuitBreakerFailureThreshold=5, got %d", config.CircuitBreakerFailureThreshold)
}
if config.CircuitBreakerResetTimeout != 60*time.Second {
t.Errorf("Expected CircuitBreakerResetTimeout=60s, got %v", config.CircuitBreakerResetTimeout)
}
if config.CircuitBreakerMaxRequests != 3 {
t.Errorf("Expected CircuitBreakerMaxRequests=3, got %d", config.CircuitBreakerMaxRequests)
}
if config.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
}
if config.PostHandoffRelaxedDuration != 0 {
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
}
// Test that defaults are applied correctly
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
}
})
t.Run("ConfigValidation", func(t *testing.T) {
// Valid config with applied defaults
config := DefaultConfig().ApplyDefaults()
if err := config.Validate(); err != nil {
t.Errorf("Default config with applied defaults should be valid: %v", err)
}
// Invalid worker configuration (negative MaxWorkers)
config = &Config{
RelaxedTimeout: 30 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: -1, // This should be invalid
HandoffQueueSize: 100,
PostHandoffRelaxedDuration: 10 * time.Second,
LogLevel: 1,
MaxHandoffRetries: 3, // Add required field
}
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
}
// Invalid HandoffQueueSize
config = DefaultConfig().ApplyDefaults()
config.HandoffQueueSize = -1
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
}
// Invalid PostHandoffRelaxedDuration
config = DefaultConfig().ApplyDefaults()
config.PostHandoffRelaxedDuration = -1 * time.Second
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
}
})
t.Run("ConfigClone", func(t *testing.T) {
original := DefaultConfig()
original.MaxWorkers = 20
original.HandoffQueueSize = 200
cloned := original.Clone()
if cloned.MaxWorkers != 20 {
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
}
if cloned.HandoffQueueSize != 200 {
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
}
// Modify original to ensure clone is independent
original.MaxWorkers = 2
if cloned.MaxWorkers != 20 {
t.Error("Clone should be independent of original")
}
})
}
func TestApplyDefaults(t *testing.T) {
t.Run("NilConfig", func(t *testing.T) {
var config *Config
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// With nil config, should get default config with auto-calculated workers
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated with hybrid scaling
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
})
t.Run("PartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 60, // Set this field explicitly (> poolSize/2 = 50)
// Leave other fields as zero values
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should keep the explicitly set values when > poolSize/2
if result.MaxWorkers != 60 {
t.Errorf("Expected MaxWorkers to be 60 (explicitly set), got %d", result.MaxWorkers)
}
// Should apply default for unset fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
// Test explicit queue size capping by 5x pool size
configWithLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than 5x pool size
}
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
expectedCap := 20 * 5 // 5x pool size = 100
if resultCapped.HandoffQueueSize != expectedCap {
t.Errorf("Expected HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedCap, resultCapped.HandoffQueueSize)
}
// Test explicit queue size minimum enforcement
configWithSmallQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 10, // Below minimum of 200
}
resultMinimum := configWithSmallQueue.ApplyDefaultsWithPoolSize(100) // Large pool size
if resultMinimum.HandoffQueueSize != 200 {
t.Errorf("Expected HandoffQueueSize to be enforced minimum (200), got %d", resultMinimum.HandoffQueueSize)
}
// Test that large explicit values are capped by 5x pool size
configWithVeryLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than 5x pool size
}
resultVeryLarge := configWithVeryLargeQueue.ApplyDefaultsWithPoolSize(100) // Pool size 100
expectedVeryLargeCap := 100 * 5 // 5x pool size = 500
if resultVeryLarge.HandoffQueueSize != expectedVeryLargeCap {
t.Errorf("Expected very large HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedVeryLargeCap, resultVeryLarge.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
if result.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
}
})
t.Run("ZeroValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 0, // Zero value should get auto-calculated defaults
HandoffQueueSize: 0, // Zero value should get default
RelaxedTimeout: 0, // Zero value should get default
LogLevel: 0, // Zero is valid for LogLevel (errors only)
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Zero values should get auto-calculated defaults
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated with hybrid scaling
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
// LogLevel 0 should be preserved (it's a valid value)
if result.LogLevel != 0 {
t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel)
}
})
}
func TestProcessorWithConfig(t *testing.T) {
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 5,
HandoffQueueSize: 50,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 5 * time.Second,
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// The processor should be created successfully with custom config
if processor == nil {
t.Error("Processor should be created with custom config")
}
})
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 7, // Only set worker field
// Other fields will get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Should work with partial config (defaults applied)
if processor == nil {
t.Error("Processor should be created with partial config")
}
})
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Should use default config when nil is passed
if processor == nil {
t.Error("Processor should be created with nil config (using defaults)")
}
})
}
func TestIntegrationWithApplyDefaults(t *testing.T) {
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
// Create a partial config with only some fields set
partialConfig := &Config{
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
LogLevel: logging.LogLevelInfo, // Custom value
// Other fields left as zero values - should get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor - should apply defaults to missing fields
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
defer processor.Shutdown(context.Background())
// Processor should be created successfully
if processor == nil {
t.Error("Processor should be created with partial config")
}
// Test that the ApplyDefaults method worked correctly by creating the same config
// and applying defaults manually
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should preserve custom values (when >= poolSize/2)
if expectedConfig.MaxWorkers != 50 { // max(poolSize/2, 15) = max(50, 15) = 50
t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers)
}
if expectedConfig.LogLevel != 2 {
t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel)
}
// Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := expectedConfig.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if expectedConfig.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, expectedConfig.HandoffQueueSize)
}
// Test that queue size is always capped by 5x pool size
if expectedConfig.HandoffQueueSize > poolSize*5 {
t.Errorf("HandoffQueueSize (%d) should never exceed 5x pool size (%d)",
expectedConfig.HandoffQueueSize, poolSize*2)
}
if expectedConfig.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
}
if expectedConfig.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
}
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
}
})
}
func TestEnhancedConfigValidation(t *testing.T) {
t.Run("ValidateFields", func(t *testing.T) {
config := DefaultConfig()
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
// Should pass validation with default values
if err := config.Validate(); err != nil {
t.Errorf("Default config should be valid, got error: %v", err)
}
// Test invalid MaxHandoffRetries
config.MaxHandoffRetries = 0
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 0")
}
config.MaxHandoffRetries = 11
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 11")
}
config.MaxHandoffRetries = 3 // Reset to valid value
// Test circuit breaker validation
config.CircuitBreakerFailureThreshold = 0
if err := config.Validate(); err != ErrInvalidCircuitBreakerFailureThreshold {
t.Errorf("Expected ErrInvalidCircuitBreakerFailureThreshold, got %v", err)
}
config.CircuitBreakerFailureThreshold = 5 // Reset to valid value
config.CircuitBreakerResetTimeout = -1 * time.Second
if err := config.Validate(); err != ErrInvalidCircuitBreakerResetTimeout {
t.Errorf("Expected ErrInvalidCircuitBreakerResetTimeout, got %v", err)
}
config.CircuitBreakerResetTimeout = 60 * time.Second // Reset to valid value
config.CircuitBreakerMaxRequests = 0
if err := config.Validate(); err != ErrInvalidCircuitBreakerMaxRequests {
t.Errorf("Expected ErrInvalidCircuitBreakerMaxRequests, got %v", err)
}
config.CircuitBreakerMaxRequests = 3 // Reset to valid value
// Should pass validation again
if err := config.Validate(); err != nil {
t.Errorf("Config should be valid after reset, got error: %v", err)
}
})
}
func TestConfigClone(t *testing.T) {
original := DefaultConfig()
original.MaxHandoffRetries = 7
original.HandoffTimeout = 8 * time.Second
cloned := original.Clone()
// Test that values are copied
if cloned.MaxHandoffRetries != 7 {
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
}
if cloned.HandoffTimeout != 8*time.Second {
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
}
// Test that modifying clone doesn't affect original
cloned.MaxHandoffRetries = 10
if original.MaxHandoffRetries != 7 {
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
}
}
func TestMaxWorkersLogic(t *testing.T) {
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
testCases := []struct {
poolSize int
expectedWorkers int
description string
}{
{6, 3, "Small pool: min(6/2, max(10, 6/3)) = min(3, max(10, 2)) = min(3, 10) = 3"},
{15, 7, "Medium pool: min(15/2, max(10, 15/3)) = min(7, max(10, 5)) = min(7, 10) = 7"},
{30, 10, "Large pool: min(30/2, max(10, 30/3)) = min(15, max(10, 10)) = min(15, 10) = 10"},
{60, 20, "Very large pool: min(60/2, max(10, 60/3)) = min(30, max(10, 20)) = min(30, 20) = 20"},
{120, 40, "Huge pool: min(120/2, max(10, 120/3)) = min(60, max(10, 40)) = min(60, 40) = 40"},
}
for _, tc := range testCases {
config := &Config{} // MaxWorkers = 0 (not set)
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
testCases := []struct {
setValue int
expectedWorkers int
description string
}{
{1, 50, "Set 1: max(poolSize/2, 1) = max(50, 1) = 50 (enforced minimum)"},
{5, 50, "Set 5: max(poolSize/2, 5) = max(50, 5) = 50 (enforced minimum)"},
{8, 50, "Set 8: max(poolSize/2, 8) = max(50, 8) = 50 (enforced minimum)"},
{10, 50, "Set 10: max(poolSize/2, 10) = max(50, 10) = 50 (enforced minimum)"},
{15, 50, "Set 15: max(poolSize/2, 15) = max(50, 15) = 50 (enforced minimum)"},
{60, 60, "Set 60: max(poolSize/2, 60) = max(50, 60) = 60 (respects user choice)"},
}
for _, tc := range testCases {
config := &Config{
MaxWorkers: tc.setValue, // Explicitly set
}
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
}

105
hitless/errors.go Normal file
View File

@@ -0,0 +1,105 @@
package hitless
import (
"errors"
"fmt"
"time"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0")
ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0")
ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0")
ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0")
ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0")
ErrInvalidLogLevel = errors.New("hitless: log level must be LogLevelError (0), LogLevelWarn (1), LogLevelInfo (2), or LogLevelDebug (3)")
ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type")
ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')")
ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached")
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10")
)
// Integration errors
var (
ErrInvalidClient = errors.New("hitless: invalid client type")
)
// Handoff errors
var (
ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration")
)
// Notification errors
var (
ErrInvalidNotification = errors.New("hitless: invalid notification format")
)
// connection handoff errors
var (
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoff = errors.New("hitless: connection marked for handoff")
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff")
)
// general errors
var (
ErrShutdown = errors.New("hitless: shutdown")
)
// circuit breaker errors
var (
ErrCircuitBreakerOpen = errors.New("hitless: circuit breaker is open, failing fast")
)
// CircuitBreakerError provides detailed context for circuit breaker failures
type CircuitBreakerError struct {
Endpoint string
State string
Failures int64
LastFailure time.Time
NextAttempt time.Time
Message string
}
func (e *CircuitBreakerError) Error() string {
if e.NextAttempt.IsZero() {
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v): %s",
e.State, e.Endpoint, e.Failures, e.LastFailure, e.Message)
}
return fmt.Sprintf("hitless: circuit breaker %s for %s (failures: %d, last: %v, next attempt: %v): %s",
e.State, e.Endpoint, e.Failures, e.LastFailure, e.NextAttempt, e.Message)
}
// HandoffError provides detailed context for connection handoff failures
type HandoffError struct {
ConnectionID uint64
SourceEndpoint string
TargetEndpoint string
Attempt int
MaxAttempts int
Duration time.Duration
FinalError error
Message string
}
func (e *HandoffError) Error() string {
return fmt.Sprintf("hitless: handoff failed for conn[%d] %s→%s (attempt %d/%d, duration: %v): %s",
e.ConnectionID, e.SourceEndpoint, e.TargetEndpoint,
e.Attempt, e.MaxAttempts, e.Duration, e.Message)
}
func (e *HandoffError) Unwrap() error {
return e.FinalError
}
// circuit breaker configuration errors
var (
ErrInvalidCircuitBreakerFailureThreshold = errors.New("hitless: circuit breaker failure threshold must be >= 1")
ErrInvalidCircuitBreakerResetTimeout = errors.New("hitless: circuit breaker reset timeout must be >= 0")
ErrInvalidCircuitBreakerMaxRequests = errors.New("hitless: circuit breaker max requests must be >= 1")
)

100
hitless/example_hooks.go Normal file
View File

@@ -0,0 +1,100 @@
package hitless
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
startTimeKey contextKey = "notif_hitless_start_time"
)
// MetricsHook collects metrics about notification processing.
type MetricsHook struct {
NotificationCounts map[string]int64
ProcessingTimes map[string]time.Duration
ErrorCounts map[string]int64
HandoffCounts int64 // Total handoffs initiated
HandoffSuccesses int64 // Successful handoffs
HandoffFailures int64 // Failed handoffs
}
// NewMetricsHook creates a new metrics collection hook.
func NewMetricsHook() *MetricsHook {
return &MetricsHook{
NotificationCounts: make(map[string]int64),
ProcessingTimes: make(map[string]time.Duration),
ErrorCounts: make(map[string]int64),
}
}
// PreHook records the start time for processing metrics.
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
mh.NotificationCounts[notificationType]++
// Log connection information if available
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on conn[%d]", notificationType, conn.GetID())
}
// Store start time in context for duration calculation
startTime := time.Now()
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
return notification, true
}
// PostHook records processing completion and any errors.
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
// Calculate processing duration
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
duration := time.Since(startTime)
mh.ProcessingTimes[notificationType] = duration
}
// Record errors
if result != nil {
mh.ErrorCounts[notificationType]++
// Log error details with connection information
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on conn[%d]: %v", notificationType, conn.GetID(), result)
}
}
}
// GetMetrics returns a summary of collected metrics.
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"notification_counts": mh.NotificationCounts,
"processing_times": mh.ProcessingTimes,
"error_counts": mh.ErrorCounts,
}
}
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
// Get circuit breaker statistics
stats := poolHook.GetCircuitBreakerStats()
for _, stat := range stats {
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
fmt.Printf(" State: %s\n", stat.State)
fmt.Printf(" Failures: %d\n", stat.Failures)
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
// Alert if circuit breaker is open
if stat.State.String() == "open" {
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
}
}
}

455
hitless/handoff_worker.go Normal file
View File

@@ -0,0 +1,455 @@
package hitless
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// handoffWorkerManager manages background workers and queue for connection handoffs
type handoffWorkerManager struct {
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers atomic.Int32
workerTimeout time.Duration // How long workers wait for work before exiting
workersScaling atomic.Bool
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the hitless upgrade
config *Config
// Pool hook reference for handoff processing
poolHook *PoolHook
// Circuit breaker manager for endpoint failure handling
circuitBreakerManager *CircuitBreakerManager
}
// newHandoffWorkerManager creates a new handoff worker manager
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
return &handoffWorkerManager{
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
circuitBreakerManager: newCircuitBreakerManager(config),
}
}
// getCurrentWorkers returns the current number of active workers (for testing)
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
return int(hwm.activeWorkers.Load())
}
// getPendingMap returns the pending map for testing purposes
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
return &hwm.pending
}
// getMaxWorkers returns the max workers for testing purposes
func (hwm *handoffWorkerManager) getMaxWorkers() int {
return hwm.maxWorkers
}
// getHandoffQueue returns the handoff queue for testing purposes
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
return hwm.handoffQueue
}
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
return hwm.circuitBreakerManager.GetAllStats()
}
// resetCircuitBreakers resets all circuit breakers (useful for testing)
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
hwm.circuitBreakerManager.Reset()
}
// isHandoffPending returns true if the given connection has a pending handoff
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
_, pending := hwm.pending.Load(conn.GetID())
return pending
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
select {
case <-hwm.shutdown:
return
default:
if hwm.workersScaling.CompareAndSwap(false, true) {
defer hwm.workersScaling.Store(false)
// Check if we need a new worker
currentWorkers := hwm.activeWorkers.Load()
workersWas := currentWorkers
for currentWorkers < int32(hwm.maxWorkers) {
hwm.workerWg.Add(1)
go hwm.onDemandWorker()
currentWorkers++
}
// workersWas is always <= currentWorkers
// currentWorkers will be maxWorkers, but if we have a worker that was closed
// while we were creating new workers, just add the difference between
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
hwm.activeWorkers.Add(currentWorkers - workersWas)
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (hwm *handoffWorkerManager) onDemandWorker() {
defer func() {
// Decrement active worker count when exiting
hwm.activeWorkers.Add(-1)
hwm.workerWg.Done()
}()
// Create reusable timer to prevent timer leaks
timer := time.NewTimer(hwm.workerTimeout)
defer timer.Stop()
for {
// Reset timer for next iteration
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(hwm.workerTimeout)
select {
case <-hwm.shutdown:
return
case <-timer.C:
// Worker has been idle for too long, exit to save resources
if hwm.config != nil && hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: worker exiting due to inactivity timeout (%v)", hwm.workerTimeout)
}
return
case request := <-hwm.handoffQueue:
// Check for shutdown before processing
select {
case <-hwm.shutdown:
// Clean up the request before exiting
hwm.pending.Delete(request.ConnID)
return
default:
// Process the request
hwm.processHandoffRequest(request)
}
}
}
}
// processHandoffRequest processes a single handoff request
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer hwm.pending.Delete(request.Conn.GetID())
internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID())
// Create a context with handoff timeout from config
handoffTimeout := 15 * time.Second // Default timeout
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
handoffTimeout = hwm.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-hwm.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
minRetryBackoff := 500 * time.Millisecond
if err != nil {
if shouldRetry {
now := time.Now()
deadline, ok := shutdownCtx.Deadline()
thirdOfTimeout := handoffTimeout / 3
if !ok || deadline.Before(now) {
// wait half the timeout before retrying if no deadline or deadline has passed
deadline = now.Add(thirdOfTimeout)
}
afterTime := deadline.Sub(now)
if afterTime < minRetryBackoff {
afterTime = minRetryBackoff
}
internal.Logger.Printf(context.Background(), "Handoff failed for conn[%d] WILL RETRY After %v: %v", request.ConnID, afterTime, err)
time.AfterFunc(afterTime, func() {
if err := hwm.queueHandoff(request.Conn); err != nil {
internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err)
hwm.closeConnFromRequest(context.Background(), request, err)
}
})
return
} else {
go hwm.closeConnFromRequest(ctx, request, err)
}
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if hwm.poolHook.hitlessManager != nil {
hwm.poolHook.hitlessManager.UntrackOperationWithConnID(seqID, connID)
}
}
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Create handoff request
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: conn.GetHandoffEndpoint(),
SeqID: conn.GetMovingSeqID(),
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-hwm.shutdown:
return ErrShutdown
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
case <-time.After(100 * time.Millisecond): // give workers a chance to process
// Queue is full - log and attempt scaling
queueLen := len(hwm.handoffQueue)
queueCap := cap(hwm.handoffQueue)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(context.Background(),
"hitless: handoff queue is full (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration",
queueLen, queueCap)
}
}
}
}
// Ensure we have workers available to handle the load
hwm.ensureWorkerAvailable()
return ErrHandoffQueueFull
}
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
hwm.shutdownOnce.Do(func() {
close(hwm.shutdown)
// workers will exit when they finish their current request
// Shutdown circuit breaker manager cleanup goroutine
if hwm.circuitBreakerManager != nil {
hwm.circuitBreakerManager.Shutdown()
}
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
hwm.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// performConnectionHandoff performs the actual connection handoff
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
// Clear handoff state after successful handoff
connID := conn.GetID()
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
return false, ErrConnectionInvalidHandoffState
}
// Use circuit breaker to protect against failing endpoints
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
// Check if circuit breaker is open before attempting handoff
if circuitBreaker.IsOpen() {
internal.Logger.Printf(ctx, "hitless: conn[%d] handoff to %s failed fast due to circuit breaker", connID, newEndpoint)
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
}
// Perform the handoff
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
// Update circuit breaker based on result
if err != nil {
// Only track dial/network errors in circuit breaker, not initialization errors
if shouldRetry {
circuitBreaker.recordFailure()
}
return shouldRetry, err
}
// Success - record in circuit breaker
circuitBreaker.recordSuccess()
return false, nil
}
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", connID, retries, newEndpoint, conn.RemoteAddr().String())
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
if retries > maxRetries {
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
maxRetries, connID, newEndpoint)
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
}
// Create endpoint-specific dialer
endpointDialer := hwm.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", connID, newEndpoint, err)
// hitless: will retry
// Maybe a network error - retry after a delay
return true, err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
// Setting this here (before initing the connection) ensures that the connection is going
// to use the relaxed timeout for the first operation (auth/ACL select)
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := hwm.config.RelaxedTimeout
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if hwm.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: conn[%d] applied post-handoff relaxed timeout (%v) until %v",
connID, relaxedTimeout, deadline.Format("15:04:05.000"))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// hitless: won't retry
// Initialization failed - remove the connection
return false, err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", connID, newEndpoint)
return false, nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
}
}
// closeConnFromRequest closes the connection and logs the reason
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
pooler := request.Pool
conn := request.Conn
if pooler != nil {
pooler.Remove(ctx, conn, err)
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: removed conn[%d] from pool due: %v",
conn.GetID(), err)
}
} else {
conn.Close()
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: no pool provided for conn[%d], cannot remove due to: %v",
conn.GetID(), err)
}
}
}

318
hitless/hitless_manager.go Normal file
View File

@@ -0,0 +1,318 @@
package hitless
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// Push notification type constants for hitless upgrades
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
NotificationMigrated = "MIGRATED"
NotificationFailingOver = "FAILING_OVER"
NotificationFailedOver = "FAILED_OVER"
)
// hitlessNotificationTypes contains all notification types that hitless upgrades handles
var hitlessNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
// NotificationHook is called before and after notification processing
// PreHook can modify the notification and return false to skip processing
// PostHook is called after successful processing
type NotificationHook interface {
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
}
// MovingOperationKey provides a unique key for tracking MOVING operations
// that combines sequence ID with connection identifier to handle duplicate
// sequence IDs across multiple connections to the same node.
type MovingOperationKey struct {
SeqID int64 // Sequence ID from MOVING notification
ConnID uint64 // Unique connection identifier
}
// String returns a string representation of the key for debugging
func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state.
type HitlessManager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
pool pool.Pooler
// MOVING operation tracking - using sync.Map for better concurrent performance
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
// Atomic state tracking - no locks needed for state queries
activeOperationCount atomic.Int64 // Number of active operations
closed atomic.Bool // Manager closed state
// Notification hooks for extensibility
hooks []NotificationHook
hooksMu sync.RWMutex // Protects hooks slice
poolHooksRef *PoolHook
}
// MovingOperation tracks an active MOVING operation.
type MovingOperation struct {
SeqID int64
NewEndpoint string
StartTime time.Time
Deadline time.Time
}
// NewHitlessManager creates a new simplified hitless manager.
func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &HitlessManager{
client: client,
pool: pool,
options: client.GetOptions(),
config: config.Clone(),
hooks: make([]NotificationHook, 0),
}
// Set up push notification handling
if err := hm.setupPushNotifications(); err != nil {
return nil, err
}
return hm, nil
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *HitlessManager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm}
// Register handlers for all hitless upgrade notifications with the client's processor
for _, notificationType := range hitlessNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return fmt.Errorf("failed to register handler for %s: %w", notificationType, err)
}
}
return nil
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Create MOVING operation record
movingOp := &MovingOperation{
SeqID: seqID,
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Deadline: deadline,
}
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Duplicate MOVING operation ignored: %s", connID, seqID, key.String())
}
return nil
}
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Tracking MOVING operation: %s", connID, seqID, key.String())
}
// Increment active operation count atomically
hm.activeOperationCount.Add(1)
return nil
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Untracking MOVING operation: %s", connID, seqID, key.String())
}
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
} else {
if hm.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), "hitless: conn[%d] seqID[%d] Operation not found for untracking: %s", connID, seqID, key.String())
}
}
}
// GetActiveMovingOperations returns active operations with composite keys.
// WARNING: This method creates a new map and copies all operations on every call.
// Use sparingly, especially in hot paths or high-frequency logging.
func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
hm.activeMovingOps.Range(func(key, value interface{}) bool {
k := key.(MovingOperationKey)
op := value.(*MovingOperation)
// Create a copy to avoid sharing references
result[k] = &MovingOperation{
SeqID: op.SeqID,
NewEndpoint: op.NewEndpoint,
StartTime: op.StartTime,
Deadline: op.Deadline,
}
return true // Continue iteration
})
return result
}
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *HitlessManager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the hitless manager.
func (hm *HitlessManager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
// Shutdown the pool hook if it exists
if hm.poolHooksRef != nil {
// Use a timeout to prevent hanging indefinitely
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := hm.poolHooksRef.Shutdown(shutdownCtx)
if err != nil {
// was not able to close pool hook, keep closed state false
hm.closed.Store(false)
return err
}
// Remove the pool hook from the pool
if hm.pool != nil {
hm.pool.RemovePoolHook(hm.poolHooksRef)
}
}
// Clear all active operations
hm.activeMovingOps.Range(func(key, value interface{}) bool {
hm.activeMovingOps.Delete(key)
return true
})
// Reset counter
hm.activeOperationCount.Store(0)
return nil
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *HitlessManager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
return StateIdle
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
currentNotification := notification
for _, hook := range hm.hooks {
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
if !shouldContinue {
return modifiedNotification, false
}
currentNotification = modifiedNotification
}
return currentNotification, true
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
for _, hook := range hm.hooks {
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
}
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
// Get pool size from client options for better worker defaults
poolSize := 0
if hm.options != nil {
poolSize = hm.options.GetPoolSize()
}
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
hm.poolHooksRef.SetPool(hm.pool)
return hm.poolHooksRef
}
func (hm *HitlessManager) AddNotificationHook(notificationHook NotificationHook) {
hm.hooksMu.Lock()
defer hm.hooksMu.Unlock()
hm.hooks = append(hm.hooks, notificationHook)
}

View File

@@ -0,0 +1,260 @@
package hitless
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
)
// MockClient implements interfaces.ClientInterface for testing
type MockClient struct {
options interfaces.OptionsInterface
}
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
return mc.options
}
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
return &MockPushProcessor{}
}
// MockPushProcessor implements interfaces.NotificationProcessor for testing
type MockPushProcessor struct{}
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
return nil
}
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
return nil
}
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
return nil
}
// MockOptions implements interfaces.OptionsInterface for testing
type MockOptions struct{}
func (mo *MockOptions) GetReadTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetWriteTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetAddr() string {
return "localhost:6379"
}
func (mo *MockOptions) IsTLSEnabled() bool {
return false
}
func (mo *MockOptions) GetProtocol() int {
return 3 // RESP3
}
func (mo *MockOptions) GetPoolSize() int {
return 10
}
func (mo *MockOptions) GetNetwork() string {
return "tcp"
}
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
return nil, nil
}
}
func TestHitlessManagerRefactoring(t *testing.T) {
t.Run("AtomicStateTracking", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
// Test initial state
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress initially")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
// Add an operation
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Test state after adding operation
if !manager.IsHandoffInProgress() {
t.Error("Expected handoff in progress after adding operation")
}
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateMoving {
t.Errorf("Expected StateMoving, got %v", manager.GetState())
}
// Remove the operation
manager.UntrackOperationWithConnID(12345, 1)
// Test state after removing operation
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress after removing operation")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
})
t.Run("SyncMapPerformance", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Test concurrent operations
const numOps = 100
for i := 0; i < numOps; i++ {
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
if err != nil {
t.Fatalf("Failed to track operation %d: %v", i, err)
}
}
if manager.GetActiveOperationCount() != numOps {
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
}
// Test GetActiveMovingOperations
operations := manager.GetActiveMovingOperations()
if len(operations) != numOps {
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
}
// Remove all operations
for i := 0; i < numOps; i++ {
manager.UntrackOperationWithConnID(int64(i), uint64(i))
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
}
})
t.Run("DuplicateOperationHandling", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewHitlessManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create hitless manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Add operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Try to add duplicate operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Duplicate operation should not return error: %v", err)
}
// Should still have only 1 operation
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
}
})
t.Run("NotificationTypeConstants", func(t *testing.T) {
// Test that constants are properly defined
expectedTypes := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
if len(hitlessNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes))
}
// Test that all expected types are present
typeMap := make(map[string]bool)
for _, t := range hitlessNotificationTypes {
typeMap[t] = true
}
for _, expected := range expectedTypes {
if !typeMap[expected] {
t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected)
}
}
// Test that hitlessNotificationTypes contains all expected constants
expectedConstants := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
for _, expected := range expectedConstants {
found := false
for _, actual := range hitlessNotificationTypes {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected)
}
}
})
}

47
hitless/hooks.go Normal file
View File

@@ -0,0 +1,47 @@
package hitless
import (
"context"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/push"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel logging.LogLevel
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel.InfoOrAbove() { // Info level
// Log the notification type and content
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
internal.Logger.Printf(ctx, "hitless: conn[%d] processing %s notification: %v", connID, notificationType, notification)
}
return notification, true // Continue processing with unmodified notification
}
// PostHook logs the result after processing.
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
if result != nil && lh.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processing failed: %v - %v", connID, notificationType, result, notification)
} else if lh.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processed successfully", connID, notificationType)
}
}
// NewLoggingHook creates a new logging hook with the specified log level.
// Log levels: LogLevelError=errors, LogLevelWarn=warnings, LogLevelInfo=info, LogLevelDebug=debug
func NewLoggingHook(logLevel logging.LogLevel) *LoggingHook {
return &LoggingHook{LogLevel: logLevel}
}

179
hitless/pool_hook.go Normal file
View File

@@ -0,0 +1,179 @@
package hitless
import (
"context"
"net"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// HitlessManagerInterface defines the interface for completing handoff operations
type HitlessManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
// HandoffRequest represents a request to handoff a connection to a new endpoint
type HandoffRequest struct {
Conn *pool.Conn
ConnID uint64 // Unique connection identifier
Endpoint string
SeqID int64
Pool pool.Pooler // Pool to remove connection from on failure
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with hitless upgrade support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
baseDialer func(context.Context, string, string) (net.Conn, error)
// Network type (e.g., "tcp", "unix")
network string
// Worker manager for background handoff processing
workerManager *handoffWorkerManager
// Configuration for the hitless upgrade
config *Config
// Hitless manager for operation completion tracking
hitlessManager HitlessManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook {
// Apply defaults if config is nil or has zero values
if config == nil {
config = config.ApplyDefaultsWithPoolSize(poolSize)
}
ph := &PoolHook{
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
config: config,
// Hitless manager for operation completion tracking
hitlessManager: hitlessManager,
}
// Create worker manager
ph.workerManager = newHandoffWorkerManager(config, ph)
return ph
}
// SetPool sets the pool interface for removing connections on handoff failure
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
ph.pool = pooler
}
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return ph.workerManager.getCurrentWorkers()
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
return ph.workerManager.isHandoffPending(conn)
}
// GetPendingMap returns the pending map for testing purposes
func (ph *PoolHook) GetPendingMap() *sync.Map {
return ph.workerManager.getPendingMap()
}
// GetMaxWorkers returns the max workers for testing purposes
func (ph *PoolHook) GetMaxWorkers() int {
return ph.workerManager.getMaxWorkers()
}
// GetHandoffQueue returns the handoff queue for testing purposes
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
return ph.workerManager.getHandoffQueue()
}
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
return ph.workerManager.getCircuitBreakerStats()
}
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
func (ph *PoolHook) ResetCircuitBreakers() {
ph.workerManager.resetCircuitBreakers()
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
// in a handoff state at the moment.
// Check if connection is usable (not in a handoff state)
// Should not happen since the pool will not return a connection that is not usable.
if !conn.IsUsable() {
return ErrConnectionMarkedForHandoff
}
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
if conn.ShouldHandoff() {
return ErrConnectionMarkedForHandoff
}
return nil
}
// OnPut is called when a connection is returned to the pool
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
// first check if we should handoff for faster rejection
if !conn.ShouldHandoff() {
// Default behavior (no handoff): pool the connection
return true, false, nil
}
// check pending handoff to not queue the same connection twice
if ph.workerManager.isHandoffPending(conn) {
// Default behavior (pending handoff): pool the connection
return true, false, nil
}
if err := ph.workerManager.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err)
// Don't pool, remove connection, no error to caller
return false, true, nil
}
// Check if handoff was already processed by a worker before we can mark it as queued
if !conn.ShouldHandoff() {
// Handoff was already processed - this is normal and the connection should be pooled
return true, false, nil
}
if err := conn.MarkQueuedForHandoff(); err != nil {
// If marking fails, check if handoff was processed in the meantime
if !conn.ShouldHandoff() {
// Handoff was processed - this is normal, pool the connection
return true, false, nil
}
// Other error - remove the connection
return false, true, nil
}
return true, false, nil
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
return ph.workerManager.shutdownWorkers(ctx)
}

964
hitless/pool_hook_test.go Normal file
View File

@@ -0,0 +1,964 @@
package hitless
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
shouldFailInit bool
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// createMockPoolConnection creates a mock pool connection for testing
func createMockPoolConnection() *pool.Conn {
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
conn.SetUsable(true) // Make connection usable for testing
return conn
}
// mockPool implements pool.Pooler for testing
type mockPool struct {
removedConnections map[uint64]bool
mu sync.Mutex
}
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
return nil
}
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
// Not implemented for testing
}
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Use pool.Conn directly - no adapter needed
mp.removedConnections[conn.GetID()] = true
}
// WasRemoved safely checks if a connection was removed from the pool
func (mp *mockPool) WasRemoved(connID uint64) bool {
mp.mu.Lock()
defer mp.mu.Unlock()
return mp.removedConnections[connID]
}
func (mp *mockPool) Len() int {
return 0
}
func (mp *mockPool) IdleLen() int {
return 0
}
func (mp *mockPool) Stats() *pool.Stats {
return &pool.Stats{}
}
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) Close() error {
return nil
}
// TestConnectionHook tests the Redis connection processor functionality
func TestConnectionHook(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Verify connection is marked for handoff
if !conn.ShouldHandoff() {
t.Fatal("Connection should be marked for handoff")
}
// Set a mock initialization function with synchronization
initConnCalled := make(chan bool, 1)
proceedWithInit := make(chan bool, 1)
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
select {
case initConnCalled <- true:
default:
}
// Wait for test to proceed
<-proceedWithInit
return nil
}
conn.SetInitConnFunc(initConnFunc)
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection immediately (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled immediately with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for initialization to be called (indicates handoff started)
select {
case <-initConnCalled:
// Good, initialization was called
case <-time.After(1 * time.Second):
t.Fatal("Timeout waiting for initialization function to be called")
}
// Connection should be in pending map while initialization is blocked
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending handoffs map")
}
// Allow initialization to proceed
proceedWithInit <- true
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(2 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify handoff completed (removed from pending map)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map after handoff")
}
// Verify connection is usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after successful handoff")
}
// Verify handoff state is cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after completion")
}
})
t.Run("HandoffNotNeeded", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// Don't mark for handoff
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error when handoff not needed: %v", err)
}
// Should pool the connection normally
if !shouldPool {
t.Error("Connection should be pooled when no handoff needed")
}
if shouldRemove {
t.Error("Connection should not be removed when no handoff needed")
}
})
t.Run("EmptyEndpoint", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error with empty endpoint: %v", err)
}
// Should pool the connection (empty endpoint clears state)
if !shouldPool {
t.Error("Connection should be pooled after clearing empty endpoint")
}
if shouldRemove {
t.Error("Connection should not be removed after clearing empty endpoint")
}
// State should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
}
})
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("dial failed")
}
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 2, // Reduced retries for faster test
HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not return error to caller: %v", err)
}
// Should pool the connection initially (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled initially with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for handoff to complete and fail with proper timeout and polling
timeout := time.After(3 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
// wait for handoff to start
time.Sleep(50 * time.Millisecond)
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for failed handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
handoffCompleted = true
}
}
}
// Connection should be removed from pending map after failed handoff
if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up)
// Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete
time.Sleep(800 * time.Millisecond)
// After max retries are reached, the connection should be removed from pool
// and handoff state should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after max retries reached")
}
t.Logf("EventDrivenHandoffDialerError test completed successfully")
})
t.Run("BufferedDataRESP2", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// For this test, we'll just verify the logic works for connections without buffered data
// The actual buffered data detection is handled by the pool's connection health check
// which is outside the scope of the Redis connection processor
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection normally (no buffered data in mock)
if !shouldPool {
t.Error("Connection should be pooled when no buffered data")
}
if shouldRemove {
t.Error("Connection should not be removed when no buffered data")
}
})
t.Run("OnGet", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should not error for normal connection: %v", err)
}
})
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
config := &Config{
Mode: MaintNotificationsAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Simulate a pending handoff by marking for handoff and queuing
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Clean up
processor.GetPendingMap().Delete(conn)
})
t.Run("EventDrivenStateManagement", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Test initial state - no pending handoffs
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("New connection should not have pending handoffs")
}
// Test adding to pending map
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending map")
}
// Test OnGet with pending handoff
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
}
// Test removing from pending map and clearing handoff state
processor.GetPendingMap().Delete(conn)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map")
}
// Clear handoff state to simulate completed handoff
conn.ClearHandoffState()
conn.SetUsable(true) // Make connection usable again
// Test OnGet without pending handoff
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("Should not return error for non-pending connection: %v", err)
}
})
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
// Create processor with small queue to test optimization features
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 2,
MaxHandoffRetries: 3, // Small queue to trigger optimizations
LogLevel: 3, // Debug level to see optimization logs
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Add small delay to simulate network latency
time.Sleep(10 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create multiple connections that need handoff to fill the queue
connections := make([]*pool.Conn, 5)
for i := 0; i < 5; i++ {
connections[i] = createMockPoolConnection()
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
}
// Set a mock initialization function
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
}
ctx := context.Background()
successCount := 0
// Process connections - should trigger scaling and timeout logic
for _, conn := range connections {
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Logf("OnPut returned error (expected with timeout): %v", err)
}
if shouldPool && !shouldRemove {
successCount++
}
}
// With timeout and scaling, most handoffs should eventually succeed
if successCount == 0 {
t.Error("Should have queued some handoffs with timeout and scaling")
}
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
// Give time for workers to process and scaling to occur
time.Sleep(100 * time.Millisecond)
})
t.Run("WorkerScalingBehavior", func(t *testing.T) {
// Create processor with small queue to test scaling behavior
config := &Config{
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
HandoffQueueSize: 1,
MaxHandoffRetries: 3, // Very small queue to force scaling
LogLevel: 2, // Info level to see scaling logs
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Verify initial worker count (should be 0 with on-demand workers)
if processor.GetCurrentWorkers() != 0 {
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
}
if processor.GetMaxWorkers() != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers())
}
// The on-demand worker behavior creates workers only when needed
// This test just verifies the basic configuration is correct
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
processor.GetMaxWorkers(), processor.GetCurrentWorkers())
})
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
// Create processor with fast post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Allow retries for successful handoff
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
RelaxedTimeout: 5 * time.Second,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a connection and trigger handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Handoff should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify relaxed timeout is set with deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should have relaxed timeout after handoff")
}
// Test that timeout is still active before deadline
// We'll use HasRelaxedTimeout which internally checks the deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should still have active relaxed timeout before deadline")
}
// Wait for deadline to pass
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
// Test that timeout is automatically restored after deadline
// HasRelaxedTimeout should return false after deadline passes
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have active relaxed timeout after deadline")
}
// Additional verification: calling HasRelaxedTimeout again should still return false
// and should have cleared the internal timeout values
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have relaxed timeout after deadline (second check)")
}
t.Logf("Passive timeout restoration test completed successfully")
})
t.Run("UsableFlagBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a new connection without setting it usable
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
// Initially, connection should not be usable (not initialized)
if conn.IsUsable() {
t.Error("New connection should not be usable before initialization")
}
// Simulate initialization by setting usable to true
conn.SetUsable(true)
if !conn.IsUsable() {
t.Error("Connection should be usable after initialization")
}
// OnGet should succeed for usable connection
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed for usable connection: %v", err)
}
// Mark connection for handoff
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Connection should still be usable until queued, but marked for handoff
if !conn.IsUsable() {
t.Error("Connection should still be usable after being marked for handoff (until queued)")
}
if !conn.ShouldHandoff() {
t.Error("Connection should be marked for handoff")
}
// OnGet should fail for connection marked for handoff
err = processor.OnGet(ctx, conn, false)
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete
time.Sleep(50 * time.Millisecond)
// After handoff completion, connection should be usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after handoff completion")
}
// OnGet should succeed again
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed after handoff completion: %v", err)
}
t.Logf("Usable flag behavior test completed successfully")
})
t.Run("StaticQueueBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 50,
MaxHandoffRetries: 3, // Explicit static queue size
LogLevel: 2,
}
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
defer processor.Shutdown(context.Background())
// Verify queue capacity matches configured size
queueCapacity := cap(processor.GetHandoffQueue())
if queueCapacity != 50 {
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
}
// Test that queue size is static regardless of pool size
// (No dynamic resizing should occur)
ctx := context.Background()
// Fill part of the queue
for i := 0; i < 10; i++ {
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Failed to queue handoff %d: %v", i, err)
}
if !shouldPool || shouldRemove {
t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
i, shouldPool, shouldRemove)
}
}
// Verify queue capacity remains static (the main purpose of this test)
finalCapacity := cap(processor.GetHandoffQueue())
if finalCapacity != 50 {
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
}
// Note: We don't check queue size here because workers process items quickly
// The important thing is that the capacity remains static regardless of pool size
})
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
// Create a failing dialer that will cause handoff initialization to fail
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Return a connection that will fail during initialization
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
}
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
LogLevel: 2,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a mock pool that tracks removals
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
processor.SetPool(mockPool)
ctx := context.Background()
// Create a connection and mark it for handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a failing initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("initialization failed")
})
// Process the connection - handoff should fail and connection should be removed
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after failed handoff attempt")
}
// Wait for handoff to be attempted and fail
time.Sleep(100 * time.Millisecond)
// Verify that the connection was removed from the pool
if !mockPool.WasRemoved(conn.GetID()) {
t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID())
}
t.Logf("Connection removal on handoff failure test completed successfully")
})
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
// Create config with short post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Allow retries for successful handoff
RelaxedTimeout: 5 * time.Second,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Fatalf("OnPut failed: %v", err)
}
if !shouldPool {
t.Error("Connection should be pooled after successful handoff")
}
if shouldRemove {
t.Error("Connection should not be removed after successful handoff")
}
// Wait for the handoff to complete (it happens asynchronously)
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify that relaxed timeout was applied to the new connection
if !conn.HasRelaxedTimeout() {
t.Error("New connection should have relaxed timeout applied after handoff")
}
// Wait for the post-handoff duration to expire
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
// Verify that relaxed timeout was automatically cleared
if conn.HasRelaxedTimeout() {
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
}
})
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
conn := createMockPoolConnection()
// First mark should succeed
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("First MarkForHandoff should succeed: %v", err)
}
// Second mark should fail
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
t.Fatal("Second MarkForHandoff should return error")
} else if err.Error() != "connection is already marked for handoff" {
t.Fatalf("Expected specific error message, got: %v", err)
}
// Verify original handoff data is preserved
if !conn.ShouldHandoff() {
t.Fatal("Connection should still be marked for handoff")
}
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
}
if conn.GetMovingSeqID() != 1 {
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
}
})
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
// Test that HandoffTimeout from config is actually used
customTimeout := 2 * time.Second
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
HandoffTimeout: customTimeout, // Custom timeout
MaxHandoffRetries: 1, // Single retry to speed up test
LogLevel: 2,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a connection that will test the timeout
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a dialer that will check the context timeout
var timeoutVerified int32 // Use atomic for thread safety
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
// Check that the context has the expected timeout
deadline, ok := ctx.Deadline()
if !ok {
t.Error("Context should have a deadline")
return errors.New("no deadline")
}
// The deadline should be approximately customTimeout from now
expectedDeadline := time.Now().Add(customTimeout)
timeDiff := deadline.Sub(expectedDeadline)
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
expectedDeadline, deadline, timeDiff)
} else {
atomic.StoreInt32(&timeoutVerified, 1)
}
return nil // Successful handoff
})
// Trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
if err != nil {
t.Errorf("OnPut should not return error: %v", err)
}
// Connection should be queued for handoff
if !shouldPool || shouldRemove {
t.Errorf("Connection should be pooled for handoff processing")
}
// Wait for handoff to complete
time.Sleep(500 * time.Millisecond)
if atomic.LoadInt32(&timeoutVerified) == 0 {
t.Error("HandoffTimeout was not properly applied to context")
}
t.Logf("HandoffTimeout configuration test completed successfully")
})
}

View File

@@ -0,0 +1,276 @@
package hitless
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *HitlessManager
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
internal.Logger.Printf(ctx, "hitless: invalid notification format: %v", notification)
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid notification type format: %v", notification[0])
return ErrInvalidNotification
}
// Process pre-hooks - they can modify the notification or skip processing
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
if !shouldContinue {
return nil // Hooks decided to skip processing
}
var err error
switch notificationType {
case NotificationMoving:
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
case NotificationMigrating:
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
case NotificationMigrated:
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
case NotificationFailingOver:
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
case NotificationFailedOver:
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
default:
// Ignore other notification types (e.g., pub/sub messages)
err = nil
}
// Process post-hooks with the result
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
return err
}
// handleMoving processes MOVING notifications.
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
internal.Logger.Printf(ctx, "hitless: invalid MOVING notification: %v", notification)
return ErrInvalidNotification
}
seqID, ok := notification[1].(int64)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid seqID in MOVING notification: %v", notification[1])
return ErrInvalidNotification
}
// Extract timeS
timeS, ok := notification[2].(int64)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid timeS in MOVING notification: %v", notification[2])
return ErrInvalidNotification
}
newEndpoint := ""
if len(notification) > 3 {
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid newEndpoint in MOVING notification: %v", notification[3])
return ErrInvalidNotification
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MOVING notification")
return ErrInvalidNotification
}
// Type assert to get the underlying pool connection
var poolConn *pool.Conn
if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx)
return ErrInvalidNotification
}
// If the connection is closed or not pooled, we can ignore the notification
// this connection won't be remembered by the pool and will be garbage collected
// Keep pubsub connections around since they are not pooled but are long-lived
// and should be allowed to handoff (the pubsub instance will reconnect and change
// the underlying *pool.Conn)
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
return nil
}
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
if snh.manager.config.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] scheduling handoff to current endpoint in %v seconds",
poolConn.GetID(), timeS/2)
}
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
// delay the handoff for timeS/2 seconds to the same endpoint
// do this in a goroutine to avoid blocking the notification handler
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
// and there should be no possibility of a race condition or double handoff.
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
if poolConn == nil || poolConn.IsClosed() {
return
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine - use background context since original may be cancelled
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
}
})
return nil
}
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
}
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err)
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in hitless manager for monitoring/debugging
if snh.manager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return fmt.Errorf("hitless: manager not initialized")
}
return nil
}
// handleMigrating processes MIGRATING notifications.
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid MIGRATING notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification")
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for MIGRATING notification",
conn.GetID(),
snh.manager.config.RelaxedTimeout)
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleMigrated processes MIGRATED notifications.
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid MIGRATED notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification")
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for MIGRATED notification", connID)
}
conn.ClearRelaxedTimeout()
return nil
}
// handleFailingOver processes FAILING_OVER notifications.
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid FAILING_OVER notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification")
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] applying relaxed timeout (%v) for FAILING_OVER notification", connID, snh.manager.config.RelaxedTimeout)
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleFailedOver processes FAILED_OVER notifications.
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, "hitless: invalid FAILED_OVER notification: %v", notification)
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification")
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification")
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if snh.manager.config.LogLevel.InfoOrAbove() { // Debug level
connID := conn.GetID()
internal.Logger.Printf(ctx, "hitless: conn[%d] clearing relaxed timeout for FAILED_OVER notification", connID)
}
conn.ClearRelaxedTimeout()
return nil
}

24
hitless/state.go Normal file
View File

@@ -0,0 +1,24 @@
package hitless
// State represents the current state of a hitless upgrade operation.
type State int
const (
// StateIdle indicates no upgrade is in progress
StateIdle State = iota
// StateHandoff indicates a connection handoff is in progress
StateMoving
)
// String returns a string representation of the state.
func (s State) String() string {
switch s {
case StateIdle:
return "idle"
case StateMoving:
return "moving"
default:
return "unknown"
}
}

View File

@@ -0,0 +1,54 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the hitless upgrade package to avoid circular dependencies.
package interfaces
import (
"context"
"net"
"time"
)
// NotificationProcessor is (most probably) a push.NotificationProcessor
// forward declaration to avoid circular imports
type NotificationProcessor interface {
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
UnregisterHandler(pushNotificationName string) error
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for hitless upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface
// GetPushProcessor returns the client's push notification processor.
GetPushProcessor() NotificationProcessor
}
// OptionsInterface defines the interface for client options.
// Uses an adapter pattern to avoid circular dependencies.
type OptionsInterface interface {
// GetReadTimeout returns the read timeout.
GetReadTimeout() time.Duration
// GetWriteTimeout returns the write timeout.
GetWriteTimeout() time.Duration
// GetNetwork returns the network type.
GetNetwork() string
// GetAddr returns the connection address.
GetAddr() string
// IsTLSEnabled returns true if TLS is enabled.
IsTLSEnabled() bool
// GetProtocol returns the protocol version.
GetProtocol() int
// GetPoolSize returns the connection pool size.
GetPoolSize() int
// NewDialer returns a new dialer function for the connection.
NewDialer() func(context.Context) (net.Conn, error)
}

View File

@@ -14,26 +14,20 @@ type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
type DefaultLogger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
func NewDefaultLogger() Logging {
return &DefaultLogger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
}
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
// VoidLogger is a logger that does nothing.
// Used to disable logging and thus speed up the library.
type VoidLogger struct{}
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
// do nothing
}
var _ Logging = (*VoidLogger)(nil)
var Logger Logging = NewDefaultLogger()

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
if err != nil {
b.Fatal(err)
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
}
})
})

View File

@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should use default buffer sizes when not specified", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
})
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should handle zero buffer sizes by using defaults", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
// without setting ReadBufferSize and WriteBufferSize
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
// ReadBufferSize and WriteBufferSize are not set (will be 0)
})

View File

@@ -3,7 +3,10 @@ package pool
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
@@ -12,17 +15,65 @@ import (
var noDeadline = time.Time{}
// Global atomic counter for connection IDs
var connIDCounter uint64
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
type atomicNetConn struct {
conn net.Conn
}
// generateConnID generates a fast unique identifier for a connection with zero allocations
func generateConnID() uint64 {
return atomic.AddUint64(&connIDCounter, 1)
}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
usedAt int64 // atomic
// Lock-free netConn access using atomic.Value
// Contains *atomicNetConn wrapper, accessed atomically for better performance
netConnAtomic atomic.Value // stores *atomicNetConn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
// Lightweight mutex to protect reader operations during handoff
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
pooled bool
pubsub bool
closed atomic.Bool
createdAt time.Time
expiresAt time.Time
// Hitless upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
// Counter to track multiple relaxed timeout setters if we have nested calls
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
// if counter reaches 0, we clear the relaxed timeouts
relaxedCounter atomic.Int32
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking across handoffs
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
shouldHandoffAtomic atomic.Bool // Whether connection should be handed off
movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// newEndpointAtomic needs special handling as it's a string
newEndpointAtomic atomic.Value // stores string
onClose func() error
}
@@ -33,8 +84,8 @@ func NewConn(netConn net.Conn) *Conn {
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
id: generateConnID(), // Generate unique ID for this connection
}
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
@@ -50,6 +101,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
}
// Store netConn atomically for lock-free access using wrapper
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic handoff state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.shouldHandoffAtomic.Store(false) // false initially
cn.movingSeqIDAtomic.Store(0) // 0 initially
cn.handoffRetriesAtomic.Store(0) // 0 initially
cn.newEndpointAtomic.Store("") // empty string initially
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
@@ -64,23 +125,366 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
if v := cn.netConnAtomic.Load(); v != nil {
if wrapper, ok := v.(*atomicNetConn); ok {
return wrapper.conn
}
}
return nil
}
// setNetConn stores the network connection atomically (lock-free).
// This is used for the fast path of connection replacement.
func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
}
// shouldHandoff returns true if connection needs handoff (lock-free).
func (cn *Conn) shouldHandoff() bool {
return cn.shouldHandoffAtomic.Load()
}
// setShouldHandoff sets the handoff flag atomically (lock-free).
func (cn *Conn) setShouldHandoff(should bool) {
cn.shouldHandoffAtomic.Store(should)
}
// getMovingSeqID returns the sequence ID atomically (lock-free).
func (cn *Conn) getMovingSeqID() int64 {
return cn.movingSeqIDAtomic.Load()
}
// setMovingSeqID sets the sequence ID atomically (lock-free).
func (cn *Conn) setMovingSeqID(seqID int64) {
cn.movingSeqIDAtomic.Store(seqID)
}
// getNewEndpoint returns the new endpoint atomically (lock-free).
func (cn *Conn) getNewEndpoint() string {
if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil {
return endpoint.(string)
}
return ""
}
// setNewEndpoint sets the new endpoint atomically (lock-free).
func (cn *Conn) setNewEndpoint(endpoint string) {
cn.newEndpointAtomic.Store(endpoint)
}
// setHandoffRetries sets the retry count atomically (lock-free).
func (cn *Conn) setHandoffRetries(retries int) {
cn.handoffRetriesAtomic.Store(uint32(retries))
}
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
func (cn *Conn) IsPooled() bool {
return cn.pooled
}
// IsPubSub returns true if the connection is used for PubSub.
func (cn *Conn) IsPubSub() bool {
return cn.pubsub
}
func (cn *Conn) IsInited() bool {
return cn.Inited.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
}
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
// Uses atomic operations for lock-free access.
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
// Use atomic load to get current value for CAS to avoid stale value race
current := cn.relaxedCounter.Load()
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
cn.clearRelaxedTimeout()
}
}
}
func (cn *Conn) clearRelaxedTimeout() {
cn.relaxedReadTimeoutNs.Store(0)
cn.relaxedWriteTimeoutNs.Store(0)
cn.relaxedDeadlineNs.Store(0)
cn.relaxedCounter.Store(0)
}
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
// This checks both the timeout values and the deadline (if set).
// Uses atomic operations for lock-free access.
func (cn *Conn) HasRelaxedTimeout() bool {
// Fast path: no relaxed timeouts are set
if cn.relaxedCounter.Load() <= 0 {
return false
}
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// If no relaxed timeouts are set, return false
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
return false
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, relaxed timeouts are active
if deadlineNs == 0 {
return true
}
// If deadline is set, check if it's still in the future
return time.Now().UnixNano() < deadlineNs
}
// getEffectiveReadTimeout returns the timeout to use for read operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
// Fast path: no relaxed timeout set
if readTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(readTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
// getEffectiveWriteTimeout returns the timeout to use for write operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// Fast path: no relaxed timeout set
if writeTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(writeTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
cn.relaxedCounter.Add(-1)
if cn.relaxedCounter.Load() <= 0 {
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
cn.initConnFunc = fn
}
// ExecuteInitConn runs the stored connection initialization function if available.
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
if cn.initConnFunc != nil {
return cn.initConnFunc(ctx, cn)
}
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
// Store the new connection atomically first (lock-free)
cn.setNetConn(netConn)
// Protect reader reset operations to avoid data races
// Use write lock since we're modifying the reader state
cn.readerMu.Lock()
cn.rd.Reset(netConn)
cn.readerMu.Unlock()
cn.bw.Reset(netConn)
}
// GetNetConn safely returns the current network connection using atomic load (lock-free).
// This method is used by the pool for health checks and provides better performance.
func (cn *Conn) GetNetConn() net.Conn {
return cn.getNetConn()
}
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
// New connection is not initialized yet
cn.Inited.Store(false)
// Replace the underlying connection
cn.SetNetConn(netConn)
return cn.ExecuteInitConn(ctx)
}
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
// Returns an error if the connection is already marked for handoff.
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
// Use single atomic CAS operation for state transition
if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) {
return errors.New("connection is already marked for handoff")
}
cn.setNewEndpoint(newEndpoint)
cn.setMovingSeqID(seqID)
return nil
}
func (cn *Conn) MarkQueuedForHandoff() error {
// Use single atomic CAS operation for state transition
if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) {
return errors.New("connection was not marked for handoff")
}
cn.setUsable(false)
return nil
}
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
func (cn *Conn) ShouldHandoff() bool {
return cn.shouldHandoff()
}
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
func (cn *Conn) GetHandoffEndpoint() string {
return cn.getNewEndpoint()
}
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
func (cn *Conn) GetMovingSeqID() int64 {
return cn.getMovingSeqID()
}
// GetID returns the unique identifier for this connection.
func (cn *Conn) GetID() uint64 {
return cn.id
}
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
func (cn *Conn) ClearHandoffState() {
// clear handoff state
cn.setShouldHandoff(false)
cn.setNewEndpoint("")
cn.setMovingSeqID(0)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return cn.incrementHandoffRetries(n)
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
return cn.rd.Buffered() > 0
}
// PeekReplyTypeSafe safely peeks at the reply type.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
if cn.rd.Buffered() <= 0 {
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
}
return cn.rd.PeekReplyType()
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Write(b)
}
return 0, net.ErrClosed
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.RemoteAddr()
}
return nil
}
@@ -89,7 +493,16 @@ func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
// Get the connection directly from atomic storage
netConn := cn.getNetConn()
if netConn == nil {
return fmt.Errorf("redis: connection not available")
}
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
}
@@ -100,13 +513,26 @@ func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
// Always set write deadline, even if getNetConn() returns nil
// This prevents write operations from hanging indefinitely
if netConn := cn.getNetConn(); netConn != nil {
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
} else {
// If getNetConn() returns nil, we still need to respect the timeout
// Return an error to prevent indefinite blocking
return fmt.Errorf("redis: connection not available for write operation")
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
if netConn := cn.getNetConn(); netConn != nil {
cn.bw.Reset(netConn)
}
}
if err := fn(cn.wr); err != nil {
@@ -116,19 +542,33 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}
func (cn *Conn) IsClosed() bool {
return cn.closed.Load()
}
func (cn *Conn) Close() error {
cn.closed.Store(true)
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Close()
}
return nil
}
// MaybeHasData tries to peek at the next byte in the socket without consuming it
// This is used to check if there are push notifications available
// Important: This will work on Linux, but not on Windows
func (cn *Conn) MaybeHasData() bool {
return maybeHasData(cn.netConn)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return maybeHasData(netConn)
}
return false
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {

View File

@@ -0,0 +1,92 @@
package pool
import (
"net"
"sync"
"testing"
"time"
)
// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout
func TestConcurrentRelaxedTimeoutClearing(t *testing.T) {
// Create a dummy connection for testing
netConn := &net.TCPConn{}
cn := NewConn(netConn)
defer cn.Close()
// Set relaxed timeout multiple times to increase counter
cn.SetRelaxedTimeout(time.Second, time.Second)
cn.SetRelaxedTimeout(time.Second, time.Second)
cn.SetRelaxedTimeout(time.Second, time.Second)
// Verify counter is 3
if count := cn.relaxedCounter.Load(); count != 3 {
t.Errorf("Expected relaxed counter to be 3, got %d", count)
}
// Clear timeouts concurrently to test race condition fix
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cn.ClearRelaxedTimeout()
}()
}
wg.Wait()
// Verify counter is 0 and timeouts are cleared
if count := cn.relaxedCounter.Load(); count != 0 {
t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count)
}
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout)
}
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout)
}
}
// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario
func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) {
netConn := &net.TCPConn{}
cn := NewConn(netConn)
defer cn.Close()
// Set relaxed timeout once
cn.SetRelaxedTimeout(time.Second, time.Second)
// Verify counter is 1
if count := cn.relaxedCounter.Load(); count != 1 {
t.Errorf("Expected relaxed counter to be 1, got %d", count)
}
// Test concurrent clearing with race condition scenario
var wg sync.WaitGroup
// Multiple goroutines try to clear simultaneously
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cn.ClearRelaxedTimeout()
}()
}
wg.Wait()
// Verify final state is consistent
if count := cn.relaxedCounter.Load(); count != 0 {
t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count)
}
// Verify timeouts are actually cleared
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout)
}
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout)
}
if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 {
t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline)
}
}

View File

@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
}
func (cn *Conn) NetConn() net.Conn {
return cn.netConn
return cn.getNetConn()
}
func (p *ConnPool) CheckMinIdleConns() {

114
internal/pool/hooks.go Normal file
View File

@@ -0,0 +1,114 @@
package pool
import (
"context"
"sync"
)
// PoolHook defines the interface for connection lifecycle hooks.
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
}
// PoolHookManager manages multiple pool hooks.
type PoolHookManager struct {
hooks []PoolHook
hooksMu sync.RWMutex
}
// NewPoolHookManager creates a new pool hook manager.
func NewPoolHookManager() *PoolHookManager {
return &PoolHookManager{
hooks: make([]PoolHook, 0),
}
}
// AddHook adds a pool hook to the manager.
// Hooks are called in the order they were added.
func (phm *PoolHookManager) AddHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
phm.hooks = append(phm.hooks, hook)
}
// RemoveHook removes a pool hook from the manager.
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
for i, h := range phm.hooks {
if h == hook {
// Remove hook by swapping with last element and truncating
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
phm.hooks = phm.hooks[:len(phm.hooks)-1]
break
}
}
}
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
return err
}
}
return nil
}
// ProcessOnPut calls all OnPut hooks in order.
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
shouldPool = true // Default to pooling the connection
for _, hook := range phm.hooks {
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
if hookErr != nil {
return false, true, hookErr
}
// If any hook says to remove or not pool, respect that decision
if hookShouldRemove {
return false, true, nil
}
if !hookShouldPool {
shouldPool = false
}
}
return shouldPool, false, nil
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
return len(phm.hooks)
}
// GetHooks returns a copy of all registered hooks.
func (phm *PoolHookManager) GetHooks() []PoolHook {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
hooks := make([]PoolHook, len(phm.hooks))
copy(hooks, phm.hooks)
return hooks
}

213
internal/pool/hooks_test.go Normal file
View File

@@ -0,0 +1,213 @@
package pool
import (
"context"
"errors"
"net"
"testing"
"time"
)
// TestHook for testing hook functionality
type TestHook struct {
OnGetCalled int
OnPutCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
}
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
th.OnGetCalled++
return th.GetError
}
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
th.OnPutCalled++
return th.ShouldPool, th.ShouldRemove, th.PutError
}
func TestPoolHookManager(t *testing.T) {
manager := NewPoolHookManager()
// Test initial state
if manager.GetHookCount() != 0 {
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
}
// Add hooks
hook1 := &TestHook{ShouldPool: true}
hook2 := &TestHook{ShouldPool: true}
manager.AddHook(hook1)
manager.AddHook(hook2)
if manager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
}
// Test ProcessOnGet
ctx := context.Background()
conn := &Conn{} // Mock connection
err := manager.ProcessOnGet(ctx, conn, false)
if err != nil {
t.Errorf("ProcessOnGet should not error: %v", err)
}
if hook1.OnGetCalled != 1 {
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
}
if hook2.OnGetCalled != 1 {
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
}
// Test ProcessOnPut
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if !shouldPool {
t.Error("Expected shouldPool to be true")
}
if shouldRemove {
t.Error("Expected shouldRemove to be false")
}
if hook1.OnPutCalled != 1 {
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
}
if hook2.OnPutCalled != 1 {
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
}
// Remove a hook
manager.RemoveHook(hook1)
if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
}
}
func TestHookErrorHandling(t *testing.T) {
manager := NewPoolHookManager()
// Hook that returns error on Get
errorHook := &TestHook{
GetError: errors.New("test error"),
ShouldPool: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(errorHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
// Test that error stops processing
err := manager.ProcessOnGet(ctx, conn, false)
if err == nil {
t.Error("Expected error from ProcessOnGet")
}
if errorHook.OnGetCalled != 1 {
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
}
// normalHook should not be called due to error
if normalHook.OnGetCalled != 0 {
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
}
}
func TestHookShouldRemove(t *testing.T) {
manager := NewPoolHookManager()
// Hook that says to remove connection
removeHook := &TestHook{
ShouldPool: false,
ShouldRemove: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(removeHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if shouldPool {
t.Error("Expected shouldPool to be false")
}
if !shouldRemove {
t.Error("Expected shouldRemove to be true")
}
if removeHook.OnPutCalled != 1 {
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
}
// normalHook should not be called due to early return
if normalHook.OnPutCalled != 0 {
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
}
}
func TestPoolWithHooks(t *testing.T) {
// Create a pool with hooks
hookManager := NewPoolHookManager()
testHook := &TestHook{ShouldPool: true}
hookManager.AddHook(testHook)
opt := &Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // Mock connection
},
PoolSize: 1,
DialTimeout: time.Second,
}
pool := NewConnPool(opt)
defer pool.Close()
// Add hook to pool after creation
pool.AddPoolHook(testHook)
// Verify hooks are initialized
if pool.hookManager == nil {
t.Error("Expected hookManager to be initialized")
}
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
}
// Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true}
pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
}
// Test removing hook from pool
pool.RemovePoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
)
var (
@@ -22,6 +23,23 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
// popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during hitless upgrades or network issues).
// Value of 50 provides sufficient resilience without excessive overhead.
// This is capped by the idle connection count, so we won't loop excessively.
popAttempts = 50
// getAttempts is the maximum number of attempts to get a connection that passes
// hook validation (e.g., hitless upgrade hooks). This protects against race conditions
// where hooks might temporarily reject connections during cluster transitions.
// Value of 3 balances resilience with performance - most hook rejections resolve quickly.
getAttempts = 3
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
maxTime = minTime.Add(1<<63 - 1)
noExpiration = maxTime
)
var timers = sync.Pool{
@@ -38,11 +56,14 @@ type Stats struct {
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
WaitCount uint32 // number of times a connection was waited
Unusable uint32 // number of times a connection was found to be unusable
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
PubSubStats PubSubStats
}
type Pooler interface {
@@ -57,29 +78,35 @@ type Pooler interface {
IdleLen() int
Stats() *Stats
AddPoolHook(hook PoolHook)
RemovePoolHook(hook PoolHook)
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
PoolFIFO bool
PoolSize int
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxActiveConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
// Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without)
Protocol int
Dialer func(context.Context) (net.Conn, error)
ReadBufferSize int
WriteBufferSize int
PoolFIFO bool
PoolSize int32
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int32
MaxIdleConns int32
MaxActiveConns int32
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
PushNotificationsEnabled bool
// DialerRetries is the maximum number of retry attempts when dialing fails.
// Default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
// Default: 100ms
DialerRetryTimeout time.Duration
}
type lastDialErrorWrap struct {
@@ -95,16 +122,21 @@ type ConnPool struct {
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
conns map[uint64]*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
poolSize atomic.Int32
idleConnsLen atomic.Int32
idleCheckInProgress atomic.Bool
stats Stats
waitDurationNs atomic.Int64
_closed uint32 // atomic
// Pool hooks manager for flexible connection processing
hookManagerMu sync.RWMutex
hookManager *PoolHookManager
}
var _ Pooler = (*ConnPool)(nil)
@@ -114,34 +146,69 @@ func NewConnPool(opt *Options) *ConnPool {
cfg: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize),
}
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
// Only create MinIdleConns if explicitly requested (> 0)
// This avoids creating connections during pool initialization for tests
if opt.MinIdleConns > 0 {
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
}
return p
}
// initializeHooks sets up the pool hooks system.
func (p *ConnPool) initializeHooks() {
p.hookManager = NewPoolHookManager()
}
// AddPoolHook adds a pool hook to the pool.
func (p *ConnPool) AddPoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager == nil {
p.initializeHooks()
}
p.hookManager.AddHook(hook)
}
// RemovePoolHook removes a pool hook from the pool.
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager != nil {
p.hookManager.RemoveHook(hook)
}
}
func (p *ConnPool) checkMinIdleConns() {
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
return
}
defer p.idleCheckInProgress.Store(false)
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
// Only create idle connections if we haven't reached the total pool size limit
// MinIdleConns should be a subset of PoolSize, not additional connections
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
select {
case p.queue <- struct{}{}:
p.poolSize++
p.idleConnsLen++
p.poolSize.Add(1)
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
@@ -150,12 +217,9 @@ func (p *ConnPool) checkMinIdleConns() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
default:
@@ -172,6 +236,9 @@ func (p *ConnPool) addIdleConn() error {
if err != nil {
return err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
@@ -182,11 +249,15 @@ func (p *ConnPool) addIdleConn() error {
return ErrClosed
}
p.conns = append(p.conns, cn)
p.conns[cn.GetID()] = cn
p.idleConns = append(p.idleConns, cn)
return nil
}
// NewConn creates a new connection and returns it to the user.
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
//
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades.
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
@@ -196,33 +267,44 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, ErrClosed
}
p.connsMu.Lock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
return nil, ErrPoolExhausted
}
p.connsMu.Unlock()
cn, err := p.dialConn(ctx, pooled)
dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout)
defer cancel()
cn, err := p.dialConn(dialCtx, pooled)
if err != nil {
return nil, err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
_ = cn.Close()
return nil, ErrPoolExhausted
}
p.conns = append(p.conns, cn)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.closed() {
_ = cn.Close()
return nil, ErrClosed
}
// Check if pool was closed while we were waiting for the lock
if p.conns == nil {
p.conns = make(map[uint64]*Conn)
}
p.conns[cn.GetID()] = cn
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
currentPoolSize := p.poolSize.Load()
if currentPoolSize >= p.cfg.PoolSize {
cn.pooled = false
} else {
p.poolSize++
p.poolSize.Add(1)
}
}
@@ -238,18 +320,57 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}
netConn, err := p.cfg.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, err
// Retry dialing with backoff
// the context timeout is already handled by the context passed in
// so we may never reach the max retries, higher values don't hurt
maxRetries := p.cfg.DialerRetries
if maxRetries <= 0 {
maxRetries = 5 // Default value
}
backoffDuration := p.cfg.DialerRetryTimeout
if backoffDuration <= 0 {
backoffDuration = 100 * time.Millisecond // Default value
}
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
return cn, nil
var lastErr error
shouldLoop := true
// when the timeout is reached, we should stop retrying
// but keep the lastErr to return to the caller
// instead of a generic context deadline exceeded error
for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ {
netConn, err := p.cfg.Dialer(ctx)
if err != nil {
lastErr = err
// Add backoff delay for retry attempts
// (not for the first attempt, do at least one)
select {
case <-ctx.Done():
shouldLoop = false
case <-time.After(backoffDuration):
// Continue with retry
}
continue
}
// Success - create connection
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
if p.cfg.ConnMaxLifetime > 0 {
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
} else {
cn.expiresAt = noExpiration
}
return cn, nil
}
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr)
// All retries failed - handle error tracking
p.setLastDialError(lastErr)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, lastErr
}
func (p *ConnPool) tryDial() {
@@ -289,6 +410,14 @@ func (p *ConnPool) getLastDialError() error {
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return p.getConn(ctx)
}
// getConn returns a connection from the pool.
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
var cn *Conn
var err error
if p.closed() {
return nil, ErrClosed
}
@@ -297,9 +426,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
now := time.Now()
attempts := 0
for {
if attempts >= getAttempts {
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
break
}
attempts++
p.connsMu.Lock()
cn, err := p.popIdle()
cn, err = p.popIdle()
p.connsMu.Unlock()
if err != nil {
@@ -311,11 +448,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
break
}
if !p.isHealthyConn(cn) {
if !p.isHealthyConn(cn, now) {
_ = p.CloseConn(cn)
continue
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
// Failed to process connection, discard it
_ = p.CloseConn(cn)
continue
}
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
@@ -328,6 +479,19 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
// Failed to process connection, discard it
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
_ = p.CloseConn(newcn)
return nil, err
}
}
return newcn, nil
}
@@ -356,7 +520,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
}
return ctx.Err()
case p.queue <- struct{}{}:
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
atomic.AddUint32(&p.stats.WaitCount, 1)
if !timer.Stop() {
<-timer.C
@@ -376,68 +540,130 @@ func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
defer p.checkMinIdleConns()
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
} else {
idx := n - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
attempts := 0
maxAttempts := util.Min(popAttempts, n)
for attempts < maxAttempts {
if len(p.idleConns) == 0 {
return nil, nil
}
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:len(p.idleConns)-1]
} else {
idx := len(p.idleConns) - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
attempts++
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
}
// Connection is not usable, put it back in the pool
if p.cfg.PoolFIFO {
// FIFO: put at end (will be picked up last since we pop from front)
p.idleConns = append(p.idleConns, cn)
} else {
// LIFO: put at beginning (will be picked up last since we pop from end)
p.idleConns = append([]*Conn{cn}, p.idleConns...)
}
cn = nil
}
p.idleConnsLen--
p.checkMinIdleConns()
// If we exhausted all attempts without finding a usable connection, return nil
if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() {
internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts)
return nil, nil
}
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
// Process connection using the hooks system
shouldPool := true
shouldRemove := false
if cn.rd.Buffered() > 0 {
// Check if this might be push notification data
if p.cfg.Protocol == 3 {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block and check if it's a push notification
if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush {
shouldRemove = true
}
} else {
// not a push notification since protocol 2 doesn't support them
shouldRemove = true
}
var err error
if shouldRemove {
// For non-RESP3 or data that is not a push notification, buffered data is unexpected
internal.Logger.Printf(ctx, "Conn has unread data, closing it")
p.Remove(ctx, cn, BadConnError{})
if cn.HasBufferedData() {
// Peek at the reply type to check if it's a push notification
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
// Not a push notification or error peeking, remove connection
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
p.Remove(ctx, cn, err)
}
// It's a push notification, allow pooling (client will handle it)
}
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
if err != nil {
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
p.Remove(ctx, cn, err)
return
}
}
// If hooks say to remove the connection, do so
if shouldRemove {
p.Remove(ctx, cn, errors.New("hook requested removal"))
return
}
// If processor says not to pool the connection, remove it
if !shouldPool {
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
p.Remove(ctx, cn, errors.New("connection not pooled"))
return
}
var shouldCloseConn bool
p.connsMu.Lock()
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue
if !cn.IsUsable() {
if p.cfg.PoolFIFO {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
} else {
p.connsMu.Lock()
p.idleConns = append([]*Conn{cn}, p.idleConns...)
p.connsMu.Unlock()
}
} else {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
}
p.idleConnsLen.Add(1)
} else {
p.removeConn(cn)
p.removeConnWithLock(cn)
shouldCloseConn = true
}
p.connsMu.Unlock()
p.freeTurn()
if shouldCloseConn {
@@ -447,8 +673,13 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
// Check if we need to create new idle connections to maintain MinIdleConns
p.checkMinIdleConns()
}
func (p *ConnPool) CloseConn(cn *Conn) error {
@@ -463,17 +694,23 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
cid := cn.GetID()
delete(p.conns, cid)
atomic.AddUint32(&p.stats.StaleConns, 1)
// Decrement pool size counter when removing a connection
if cn.pooled {
p.poolSize.Add(-1)
// this can be idle conn
for idx, ic := range p.idleConns {
if ic.GetID() == cid {
internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid)
p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...)
p.idleConnsLen.Add(-1)
break
}
break
}
}
atomic.AddUint32(&p.stats.StaleConns, 1)
}
func (p *ConnPool) closeConn(cn *Conn) error {
@@ -491,9 +728,9 @@ func (p *ConnPool) Len() int {
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
n := p.idleConnsLen.Load()
p.connsMu.Unlock()
return n
return int(n)
}
func (p *ConnPool) Stats() *Stats {
@@ -502,6 +739,7 @@ func (p *ConnPool) Stats() *Stats {
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
Unusable: atomic.LoadUint32(&p.stats.Unusable),
WaitDurationNs: p.waitDurationNs.Load(),
TotalConns: uint32(p.Len()),
@@ -542,30 +780,33 @@ func (p *ConnPool) Close() error {
}
}
p.conns = nil
p.poolSize = 0
p.poolSize.Store(0)
p.idleConns = nil
p.idleConnsLen = 0
p.idleConnsLen.Store(0)
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
// slight optimization, check expiresAt first.
if cn.expiresAt.Before(now) {
return false
}
// Check if connection has exceeded idle timeout
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false
}
// Check connection health, but be aware of push notifications
if err := connCheck(cn.netConn); err != nil {
cn.SetUsedAt(now)
// Check basic connection health
// Use GetNetConn() to safely access netConn and avoid data races
if err := connCheck(cn.getNetConn()); err != nil {
// If there's unexpected data, it might be push notifications (RESP3)
// However, push notification processing is now handled by the client
// before WithReader to ensure proper context is available to handlers
if err == errUnexpectedRead && p.cfg.Protocol == 3 {
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
@@ -579,7 +820,5 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
return false
}
}
cn.SetUsedAt(now)
return true
}

View File

@@ -1,6 +1,8 @@
package pool
import "context"
import (
"context"
)
type SingleConnPool struct {
pool Pooler
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -2,15 +2,17 @@ package pool_test
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
var _ = Describe("ConnPool", func() {
@@ -20,7 +22,7 @@ var _ = Describe("ConnPool", func() {
BeforeEach(func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -45,11 +47,11 @@ var _ = Describe("ConnPool", func() {
<-closedChan
return &net.TCPConn{}, nil
},
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
MinIdleConns: minIdleConns,
MinIdleConns: int32(minIdleConns),
})
wg.Wait()
Expect(connPool.Close()).NotTo(HaveOccurred())
@@ -105,7 +107,7 @@ var _ = Describe("ConnPool", func() {
// ok
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
// Check that Get is unblocked.
select {
@@ -130,8 +132,8 @@ var _ = Describe("MinIdleConns", func() {
newConnPool := func() *pool.ConnPool {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: poolSize,
MinIdleConns: minIdleConns,
PoolSize: int32(poolSize),
MinIdleConns: int32(minIdleConns),
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: -1,
@@ -168,7 +170,7 @@ var _ = Describe("MinIdleConns", func() {
Context("after Remove", func() {
BeforeEach(func() {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
})
It("has idle connections", func() {
@@ -245,7 +247,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Remove(ctx, cns[i], nil)
connPool.Remove(ctx, cns[i], errors.New("test"))
mu.RUnlock()
})
@@ -309,7 +311,7 @@ var _ = Describe("race", func() {
It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Minute,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -328,7 +330,7 @@ var _ = Describe("race", func() {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
}
}
})
@@ -339,15 +341,15 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1000,
MinIdleConns: 50,
PoolSize: int32(1000),
MinIdleConns: int32(50),
PoolTimeout: 3 * time.Second,
DialTimeout: 1 * time.Second,
}
p := pool.NewConnPool(opt)
var wg sync.WaitGroup
for i := 0; i < opt.PoolSize; i++ {
for i := int32(0); i < opt.PoolSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
@@ -366,8 +368,8 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
panic("test panic")
},
PoolSize: 100,
MinIdleConns: 30,
PoolSize: int32(100),
MinIdleConns: int32(30),
}
p := pool.NewConnPool(opt)
@@ -377,14 +379,14 @@ var _ = Describe("race", func() {
state := p.Stats()
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
}, "3s", "50ms").Should(BeTrue())
})
})
It("wait", func() {
opt := &pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 3 * time.Second,
}
p := pool.NewConnPool(opt)
@@ -415,7 +417,7 @@ var _ = Describe("race", func() {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: testPoolTimeout,
}
p := pool.NewConnPool(opt)
@@ -435,3 +437,73 @@ var _ = Describe("race", func() {
Expect(stats.Timeouts).To(Equal(uint32(1)))
})
})
// TestDialerRetryConfiguration tests the new DialerRetries and DialerRetryTimeout options
func TestDialerRetryConfiguration(t *testing.T) {
ctx := context.Background()
t.Run("CustomDialerRetries", func(t *testing.T) {
var attempts int64
failingDialer := func(ctx context.Context) (net.Conn, error) {
atomic.AddInt64(&attempts, 1)
return nil, errors.New("dial failed")
}
connPool := pool.NewConnPool(&pool.Options{
Dialer: failingDialer,
PoolSize: 1,
PoolTimeout: time.Second,
DialTimeout: time.Second,
DialerRetries: 3, // Custom retry count
DialerRetryTimeout: 10 * time.Millisecond, // Fast retries for testing
})
defer connPool.Close()
_, err := connPool.Get(ctx)
if err == nil {
t.Error("Expected error from failing dialer")
}
// Should have attempted at least 3 times (DialerRetries = 3)
// There might be additional attempts due to pool logic
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts < 3 {
t.Errorf("Expected at least 3 dial attempts, got %d", finalAttempts)
}
if finalAttempts > 6 {
t.Errorf("Expected around 3 dial attempts, got %d (too many)", finalAttempts)
}
})
t.Run("DefaultDialerRetries", func(t *testing.T) {
var attempts int64
failingDialer := func(ctx context.Context) (net.Conn, error) {
atomic.AddInt64(&attempts, 1)
return nil, errors.New("dial failed")
}
connPool := pool.NewConnPool(&pool.Options{
Dialer: failingDialer,
PoolSize: 1,
PoolTimeout: time.Second,
DialTimeout: time.Second,
// DialerRetries and DialerRetryTimeout not set - should use defaults
})
defer connPool.Close()
_, err := connPool.Get(ctx)
if err == nil {
t.Error("Expected error from failing dialer")
}
// Should have attempted 5 times (default DialerRetries = 5)
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts != 5 {
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
}
})
}
func init() {
logging.Disable()
}

78
internal/pool/pubsub.go Normal file
View File

@@ -0,0 +1,78 @@
package pool
import (
"context"
"net"
"sync"
"sync/atomic"
)
type PubSubStats struct {
Created uint32
Untracked uint32
Active uint32
}
// PubSubPool manages a pool of PubSub connections.
type PubSubPool struct {
opt *Options
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Map to track active PubSub connections
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
closed atomic.Bool
stats PubSubStats
}
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,
netDialer: netDialer,
}
}
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
if p.closed.Load() {
return nil, ErrClosed
}
netConn, err := p.netDialer(ctx, network, addr)
if err != nil {
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
cn.pubsub = true
atomic.AddUint32(&p.stats.Created, 1)
return cn, nil
}
func (p *PubSubPool) TrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, 1)
p.activeConns.Store(cn.GetID(), cn)
}
func (p *PubSubPool) UntrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, ^uint32(0))
atomic.AddUint32(&p.stats.Untracked, 1)
p.activeConns.Delete(cn.GetID())
}
func (p *PubSubPool) Close() error {
p.closed.Store(true)
p.activeConns.Range(func(key, value interface{}) bool {
cn := value.(*Conn)
_ = cn.Close()
return true
})
return nil
}
func (p *PubSubPool) Stats() *PubSubStats {
// load stats atomically
return &PubSubStats{
Created: atomic.LoadUint32(&p.stats.Created),
Untracked: atomic.LoadUint32(&p.stats.Untracked),
Active: atomic.LoadUint32(&p.stats.Active),
}
}

3
internal/redis.go Normal file
View File

@@ -0,0 +1,3 @@
package internal
const RedisNull = "null"

View File

@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
}
return f
}
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
func SafeIntToInt32(value int, fieldName string) (int32, error) {
if value > math.MaxInt32 {
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
}
if value < math.MinInt32 {
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
}
return int32(value), nil
}

17
internal/util/math.go Normal file
View File

@@ -0,0 +1,17 @@
package util
// Max returns the maximum of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// Min returns the minimum of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}

121
logging/logging.go Normal file
View File

@@ -0,0 +1,121 @@
// Package logging provides logging level constants and utilities for the go-redis library.
// This package centralizes logging configuration to ensure consistency across all components.
package logging
import (
"context"
"fmt"
"strings"
"github.com/redis/go-redis/v9/internal"
)
// LogLevel represents the logging level
type LogLevel int
// Log level constants for the entire go-redis library
const (
LogLevelError LogLevel = iota // 0 - errors only
LogLevelWarn // 1 - warnings and errors
LogLevelInfo // 2 - info, warnings, and errors
LogLevelDebug // 3 - debug, info, warnings, and errors
)
// String returns the string representation of the log level
func (l LogLevel) String() string {
switch l {
case LogLevelError:
return "ERROR"
case LogLevelWarn:
return "WARN"
case LogLevelInfo:
return "INFO"
case LogLevelDebug:
return "DEBUG"
default:
return "UNKNOWN"
}
}
// IsValid returns true if the log level is valid
func (l LogLevel) IsValid() bool {
return l >= LogLevelError && l <= LogLevelDebug
}
func (l LogLevel) WarnOrAbove() bool {
return l >= LogLevelWarn
}
func (l LogLevel) InfoOrAbove() bool {
return l >= LogLevelInfo
}
func (l LogLevel) DebugOrAbove() bool {
return l >= LogLevelDebug
}
// VoidLogger is a logger that does nothing.
// Used to disable logging and thus speed up the library.
type VoidLogger struct{}
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
// do nothing
}
// Disable disables logging by setting the internal logger to a void logger.
// This can be used to speed up the library if logging is not needed.
// It will override any custom logger that was set before and set the VoidLogger.
func Disable() {
internal.Logger = &VoidLogger{}
}
// Enable enables logging by setting the internal logger to the default logger.
// This is the default behavior.
// You can use redis.SetLogger to set a custom logger.
//
// NOTE: This function is not thread-safe.
// It will override any custom logger that was set before and set the DefaultLogger.
func Enable() {
internal.Logger = internal.NewDefaultLogger()
}
// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings.
// This can be used to filter out messages containing sensitive information.
func NewBlacklistLogger(substr []string) internal.Logging {
l := internal.NewDefaultLogger()
return &filterLogger{logger: l, substr: substr, blacklist: true}
}
// NewWhitelistLogger returns a new logger that only logs messages containing any of the substrings.
// This can be used to only log messages related to specific commands or patterns.
func NewWhitelistLogger(substr []string) internal.Logging {
l := internal.NewDefaultLogger()
return &filterLogger{logger: l, substr: substr, blacklist: false}
}
type filterLogger struct {
logger internal.Logging
blacklist bool
substr []string
}
func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface{}) {
msg := fmt.Sprintf(format, v...)
found := false
for _, substr := range l.substr {
if strings.Contains(msg, substr) {
found = true
if l.blacklist {
return
}
}
}
// whitelist, only log if one of the substrings is present
if !l.blacklist && !found {
return
}
if l.logger != nil {
l.logger.Printf(ctx, format, v...)
return
}
}

59
logging/logging_test.go Normal file
View File

@@ -0,0 +1,59 @@
package logging
import "testing"
func TestLogLevel_String(t *testing.T) {
tests := []struct {
level LogLevel
expected string
}{
{LogLevelError, "ERROR"},
{LogLevelWarn, "WARN"},
{LogLevelInfo, "INFO"},
{LogLevelDebug, "DEBUG"},
{LogLevel(99), "UNKNOWN"},
}
for _, test := range tests {
if got := test.level.String(); got != test.expected {
t.Errorf("LogLevel(%d).String() = %q, want %q", test.level, got, test.expected)
}
}
}
func TestLogLevel_IsValid(t *testing.T) {
tests := []struct {
level LogLevel
expected bool
}{
{LogLevelError, true},
{LogLevelWarn, true},
{LogLevelInfo, true},
{LogLevelDebug, true},
{LogLevel(-1), false},
{LogLevel(4), false},
{LogLevel(99), false},
}
for _, test := range tests {
if got := test.level.IsValid(); got != test.expected {
t.Errorf("LogLevel(%d).IsValid() = %v, want %v", test.level, got, test.expected)
}
}
}
func TestLogLevelConstants(t *testing.T) {
// Test that constants have expected values
if LogLevelError != 0 {
t.Errorf("LogLevelError = %d, want 0", LogLevelError)
}
if LogLevelWarn != 1 {
t.Errorf("LogLevelWarn = %d, want 1", LogLevelWarn)
}
if LogLevelInfo != 2 {
t.Errorf("LogLevelInfo = %d, want 2", LogLevelInfo)
}
if LogLevelDebug != 3 {
t.Errorf("LogLevelDebug = %d, want 3", LogLevelDebug)
}
}

View File

@@ -13,6 +13,7 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
)
const (
@@ -102,6 +103,7 @@ var _ = BeforeSuite(func() {
fmt.Printf("RCEDocker: %v\n", RCEDocker)
fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion)
fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE"))
logging.Disable()
if RedisVersion < 7.0 || RedisVersion > 9 {
panic("incorrect or not supported redis version")

View File

@@ -14,9 +14,11 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/push"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
@@ -107,9 +109,19 @@ type Options struct {
// DialTimeout for establishing new connections.
//
// default: 5 seconds
// default: 10 seconds
DialTimeout time.Duration
// DialerRetries is the maximum number of retry attempts when dialing fails.
//
// default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
//
// default: 100 milliseconds
DialerRetryTimeout time.Duration
// ReadTimeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
//
@@ -153,6 +165,7 @@ type Options struct {
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
// default: false
PoolFIFO bool
// PoolSize is the base number of socket connections.
@@ -244,8 +257,19 @@ type Options struct {
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// HitlessUpgradeConfig provides configuration options for hitless upgrades.
// This is an alias to hitless.Config for convenience.
type HitlessUpgradeConfig = hitless.Config
func (opt *Options) init() {
if opt.Addr == "" {
opt.Addr = "localhost:6379"
@@ -261,7 +285,13 @@ func (opt *Options) init() {
opt.Protocol = 3
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
opt.DialTimeout = 10 * time.Second
}
if opt.DialerRetries == 0 {
opt.DialerRetries = 5
}
if opt.DialerRetryTimeout == 0 {
opt.DialerRetryTimeout = 100 * time.Millisecond
}
if opt.Dialer == nil {
opt.Dialer = NewDialer(opt)
@@ -320,13 +350,36 @@ func (opt *Options) init() {
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
// auto-detect endpoint type if not specified
endpointType := opt.HitlessUpgradeConfig.EndpointType
if endpointType == "" || endpointType == hitless.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.HitlessUpgradeConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone HitlessUpgradeConfig to avoid sharing between clients
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
clone.HitlessUpgradeConfig = &configClone
}
return &clone
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
return NewDialer(opt)
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
@@ -612,23 +665,84 @@ func getUserPassword(u *url.URL) (string, string) {
func newConnPool(
opt *Options,
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) *pool.ConnPool {
) (*pool.ConnPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
// Pass protocol version for push notification optimization
Protocol: opt.Protocol,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
})
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PushNotificationsEnabled: opt.Protocol == 3,
}), nil
}
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) (*pool.PubSubPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewPubSubPool(&pool.Options{
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
PushNotificationsEnabled: opt.Protocol == 3,
}, dialer), nil
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/push"
)
const (
@@ -38,6 +39,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -125,10 +127,22 @@ type ClusterOptions struct {
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
UnstableResp3 bool
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *ClusterOptions) init() {
@@ -319,6 +333,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone HitlessUpgradeConfig to avoid sharing between cluster node clients
var hitlessConfig *HitlessUpgradeConfig
if opt.HitlessUpgradeConfig != nil {
configClone := *opt.HitlessUpgradeConfig
hitlessConfig = &configClone
}
return &Options{
ClientName: opt.ClientName,
Dialer: opt.Dialer,
@@ -360,8 +381,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
// much use for ClusterSlots config). This means we cannot execute the
// READONLY command against that node -- setting readOnly to false in such
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
HitlessUpgradeConfig: hitlessConfig,
PushNotificationProcessor: opt.PushNotificationProcessor,
}
}
@@ -1830,12 +1853,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// hitless won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@@ -1850,18 +1873,25 @@ func (c *ClusterClient) pubSub() *PubSub {
if err != nil {
return nil, err
}
cn, err := node.Client.newConn(context.TODO())
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
if err != nil {
node = nil
return nil, err
}
// will return nil if already initialized
err = node.Client.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
node = nil
return nil, err
}
node.Client.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
err := node.Client.connPool.CloseConn(cn)
// Untrack connection from PubSubPool
node.Client.pubSubPool.UntrackConn(cn)
err := cn.Close()
node = nil
return err
},

375
pool_pubsub_bench_test.go Normal file
View File

@@ -0,0 +1,375 @@
// Pool and PubSub Benchmark Suite
//
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
// It's designed to be run against different branches to compare performance.
//
// Usage Examples:
// # Run all benchmarks
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
//
// # Run only pool benchmarks
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
//
// # Run only PubSub benchmarks
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
//
// # Compare between branches
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
// benchcmp branch1.txt branch2.txt
//
// # Run with memory profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
//
// # Run with CPU profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
package redis_test
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
)
// dummyDialer creates a mock connection for benchmarking
func dummyDialer(ctx context.Context) (net.Conn, error) {
return &dummyConn{}, nil
}
// dummyConn implements net.Conn for benchmarking
type dummyConn struct{}
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Close() error { return nil }
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
func (c *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
}
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
// =============================================================================
// POOL BENCHMARKS
// =============================================================================
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(poolSize),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0), // Start with no idle connections
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
ctx := context.Background()
configs := []struct {
poolSize int
minIdleConns int
}{
{8, 2},
{16, 4},
{32, 8},
{64, 16},
}
for _, config := range configs {
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(config.poolSize),
MinIdleConns: int32(config.minIdleConns),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
ctx := context.Background()
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(32),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0),
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
// Test with different levels of concurrency
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.SetParallelism(concurrency)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// =============================================================================
// PUBSUB BENCHMARKS
// =============================================================================
// benchmarkClient creates a Redis client for benchmarking with mock dialer
func benchmarkClient(poolSize int) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Mock address
DialTimeout: time.Second,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
PoolSize: poolSize,
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
})
}
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
func BenchmarkPubSubCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
func BenchmarkPubSubPatternCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.PSubscribe(ctx, "test-*")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
concurrencyLevels := []int{1, 2, 4, 8, 16}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for i := 0; i < b.N; i++ {
wg.Add(1)
semaphore <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-semaphore }()
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}()
}
wg.Wait()
})
}
}
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
func BenchmarkPubSubMultipleChannels(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
channelCounts := []int{1, 5, 10, 25, 50, 100}
for _, channelCount := range channelCounts {
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
// Prepare channel names
channels := make([]string, channelCount)
for i := 0; i < channelCount; i++ {
channels[i] = fmt.Sprintf("channel-%d", i)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, channels...)
pubsub.Close()
}
})
}
}
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
func BenchmarkPubSubReuse(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Benchmark just the creation and closing of PubSub connections
// This simulates reuse patterns without requiring actual Redis operations
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
pubsub.Close()
}
}
// =============================================================================
// COMBINED BENCHMARKS
// =============================================================================
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// Mix of pool stats collection and PubSub creation
if pb.Next() {
// Pool stats operation
stats := client.PoolStats()
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
}
if pb.Next() {
// PubSub operation
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
}
})
}
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
func BenchmarkPoolStatsCollection(b *testing.B) {
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
stats := client.PoolStats()
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
}
}
// BenchmarkPoolHighContention tests pool performance under high contention
func BenchmarkPoolHighContention(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// High contention Get/Put operations
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}

View File

@@ -22,7 +22,7 @@ import (
type PubSub struct {
opt *Options
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
@@ -42,6 +42,9 @@ type PubSub struct {
// Push notification processor for handling generic push notifications
pushProcessor push.NotificationProcessor
// Cleanup callback for hitless upgrade tracking
onClose func()
}
func (c *PubSub) init() {
@@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
return c.cn, nil
}
if c.opt.Addr == "" {
// TODO(hitless):
// this is probably cluster client
// c.newConn will ignore the addr argument
// will be changed when we have hitless upgrades for cluster clients
c.opt.Addr = internal.RedisNull
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(ctx, channels)
cn, err := c.newConn(ctx, c.opt.Addr, channels)
if err != nil {
return nil, err
}
@@ -157,12 +168,31 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
if c.cn != cn {
return
}
if !cn.IsUsable() || cn.ShouldHandoff() {
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
}
if isBadConn(err, allowTimeout, c.opt.Addr) {
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(ctx context.Context, reason error) {
if c.cn != nil && c.cn.ShouldHandoff() {
newEndpoint := c.cn.GetHandoffEndpoint()
// If new endpoint is NULL, use the original address
if newEndpoint == internal.RedisNull {
newEndpoint = c.opt.Addr
}
if newEndpoint != "" {
// Update the address in the options
oldAddr := c.cn.RemoteAddr().String()
c.opt.Addr = newEndpoint
internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr)
}
}
_ = c.closeTheCn(reason)
_, _ = c.conn(ctx, nil)
}
@@ -171,9 +201,6 @@ func (c *PubSub) closeTheCn(reason error) error {
if c.cn == nil {
return nil
}
if !c.closed {
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
}
err := c.closeConn(c.cn)
c.cn = nil
return err
@@ -189,6 +216,11 @@ func (c *PubSub) Close() error {
c.closed = true
close(c.exit)
// Call cleanup callback if set
if c.onClose != nil {
c.onClose()
}
return c.closeTheCn(pool.ErrClosed)
}
@@ -444,11 +476,10 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err)
}
return c.cmd.readReply(rd)
})
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
if err != nil {
@@ -461,6 +492,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
// Receive returns a message as a Subscription, Message, Pong, or an error.
// See PubSub example for details. This is a low-level API and in most cases
// Channel should be used instead.
// This method blocks until a message is received or an error occurs.
// It may return early with an error if the context is canceled, the connection fails,
// or other internal errors occur.
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
@@ -543,7 +580,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
}
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
if c.pushProcessor == nil {
// Only process push notifications for RESP3 connections with a processor
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}

View File

@@ -113,6 +113,9 @@ var _ = Describe("PubSub", func() {
pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2")
defer pubsub.Close()
// sleep a bit to make sure redis knows about the subscriptions
time.Sleep(10 * time.Millisecond)
channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result()
Expect(err).NotTo(HaveOccurred())
Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"}))

View File

@@ -1,8 +1,6 @@
package push
import (
"github.com/redis/go-redis/v9/internal/pool"
)
// No imports needed for this file
// NotificationHandlerContext provides context information about where a push notification was received.
// This struct allows handlers to make informed decisions based on the source of the notification
@@ -35,7 +33,11 @@ type NotificationHandlerContext struct {
PubSub interface{}
// Conn is the specific connection on which the notification was received.
Conn *pool.Conn
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.Conn
Conn interface{}
// IsBlocking indicates if the notification was received on a blocking connection.
IsBlocking bool

315
push/processor_unit_test.go Normal file
View File

@@ -0,0 +1,315 @@
package push
import (
"context"
"testing"
)
// TestProcessorCreation tests processor creation and initialization
func TestProcessorCreation(t *testing.T) {
t.Run("NewProcessor", func(t *testing.T) {
processor := NewProcessor()
if processor == nil {
t.Fatal("NewProcessor should not return nil")
}
if processor.registry == nil {
t.Error("Processor should have a registry")
}
})
t.Run("NewVoidProcessor", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
if voidProcessor == nil {
t.Fatal("NewVoidProcessor should not return nil")
}
})
}
// TestProcessorHandlerManagement tests handler registration and retrieval
func TestProcessorHandlerManagement(t *testing.T) {
processor := NewProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("RegisterHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", handler, false)
if err != nil {
t.Errorf("RegisterHandler should not error: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != handler {
t.Error("GetHandler should return the registered handler")
}
})
t.Run("RegisterProtectedHandler", func(t *testing.T) {
protectedHandler := &UnitTestHandler{name: "protected-handler"}
err := processor.RegisterHandler("PROTECTED", protectedHandler, true)
if err != nil {
t.Errorf("RegisterHandler should not error for protected handler: %v", err)
}
// Verify handler is registered
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler != protectedHandler {
t.Error("GetHandler should return the protected handler")
}
})
t.Run("GetNonExistentHandler", func(t *testing.T) {
handler := processor.GetHandler("NONEXISTENT")
if handler != nil {
t.Error("GetHandler should return nil for non-existent handler")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := processor.UnregisterHandler("TEST")
if err != nil {
t.Errorf("UnregisterHandler should not error: %v", err)
}
// Verify handler is removed
retrievedHandler := processor.GetHandler("TEST")
if retrievedHandler != nil {
t.Error("GetHandler should return nil after unregistering")
}
})
t.Run("UnregisterProtectedHandler", func(t *testing.T) {
err := processor.UnregisterHandler("PROTECTED")
if err == nil {
t.Error("UnregisterHandler should error for protected handler")
}
// Verify handler is still there
retrievedHandler := processor.GetHandler("PROTECTED")
if retrievedHandler == nil {
t.Error("Protected handler should not be removed")
}
})
}
// TestVoidProcessorBehavior tests void processor behavior
func TestVoidProcessorBehavior(t *testing.T) {
voidProcessor := NewVoidProcessor()
handler := &UnitTestHandler{name: "test-handler"}
t.Run("GetHandler", func(t *testing.T) {
retrievedHandler := voidProcessor.GetHandler("ANY")
if retrievedHandler != nil {
t.Error("VoidProcessor GetHandler should always return nil")
}
})
t.Run("RegisterHandler", func(t *testing.T) {
err := voidProcessor.RegisterHandler("TEST", handler, false)
if err == nil {
t.Error("VoidProcessor RegisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
t.Run("UnregisterHandler", func(t *testing.T) {
err := voidProcessor.UnregisterHandler("TEST")
if err == nil {
t.Error("VoidProcessor UnregisterHandler should return error")
}
// Check error type
if !IsVoidProcessorError(err) {
t.Error("Error should be a VoidProcessorError")
}
})
}
// TestProcessPendingNotificationsNilReader tests handling of nil reader
func TestProcessPendingNotificationsNilReader(t *testing.T) {
t.Run("ProcessorWithNilReader", func(t *testing.T) {
processor := NewProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
t.Run("VoidProcessorWithNilReader", func(t *testing.T) {
voidProcessor := NewVoidProcessor()
ctx := context.Background()
handlerCtx := NotificationHandlerContext{}
err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil)
if err != nil {
t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err)
}
})
}
// TestWillHandleNotificationInClient tests the notification filtering logic
func TestWillHandleNotificationInClient(t *testing.T) {
testCases := []struct {
name string
notificationType string
shouldHandle bool
}{
// Pub/Sub notifications (should be handled in client)
{"message", "message", true},
{"pmessage", "pmessage", true},
{"subscribe", "subscribe", true},
{"unsubscribe", "unsubscribe", true},
{"psubscribe", "psubscribe", true},
{"punsubscribe", "punsubscribe", true},
{"smessage", "smessage", true},
{"ssubscribe", "ssubscribe", true},
{"sunsubscribe", "sunsubscribe", true},
// Push notifications (should be handled by processor)
{"MOVING", "MOVING", false},
{"MIGRATING", "MIGRATING", false},
{"MIGRATED", "MIGRATED", false},
{"FAILING_OVER", "FAILING_OVER", false},
{"FAILED_OVER", "FAILED_OVER", false},
{"custom", "custom", false},
{"unknown", "unknown", false},
{"empty", "", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := willHandleNotificationInClient(tc.notificationType)
if result != tc.shouldHandle {
t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle)
}
})
}
}
// TestProcessorErrorHandlingUnit tests error handling scenarios
func TestProcessorErrorHandlingUnit(t *testing.T) {
processor := NewProcessor()
t.Run("RegisterNilHandler", func(t *testing.T) {
err := processor.RegisterHandler("TEST", nil, false)
if err == nil {
t.Error("RegisterHandler should error with nil handler")
}
// Check error type
if !IsHandlerNilError(err) {
t.Error("Error should be a HandlerNilError")
}
})
t.Run("RegisterDuplicateHandler", func(t *testing.T) {
handler1 := &UnitTestHandler{name: "handler1"}
handler2 := &UnitTestHandler{name: "handler2"}
// Register first handler
err := processor.RegisterHandler("DUPLICATE", handler1, false)
if err != nil {
t.Errorf("First RegisterHandler should not error: %v", err)
}
// Try to register second handler with same name
err = processor.RegisterHandler("DUPLICATE", handler2, false)
if err == nil {
t.Error("RegisterHandler should error when registering duplicate handler")
}
// Verify original handler is still there
retrievedHandler := processor.GetHandler("DUPLICATE")
if retrievedHandler != handler1 {
t.Error("Original handler should remain after failed duplicate registration")
}
})
t.Run("UnregisterNonExistentHandler", func(t *testing.T) {
err := processor.UnregisterHandler("NONEXISTENT")
if err != nil {
t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err)
}
})
}
// TestProcessorConcurrentAccess tests concurrent access to processor
func TestProcessorConcurrentAccess(t *testing.T) {
processor := NewProcessor()
t.Run("ConcurrentRegisterAndGet", func(t *testing.T) {
done := make(chan bool, 2)
// Goroutine 1: Register handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
handler := &UnitTestHandler{name: "concurrent-handler"}
processor.RegisterHandler("CONCURRENT", handler, false)
processor.UnregisterHandler("CONCURRENT")
}
}()
// Goroutine 2: Get handlers
go func() {
defer func() { done <- true }()
for i := 0; i < 100; i++ {
processor.GetHandler("CONCURRENT")
}
}()
// Wait for both goroutines to complete
<-done
<-done
})
}
// TestProcessorInterfaceCompliance tests interface compliance
func TestProcessorInterfaceCompliance(t *testing.T) {
t.Run("ProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*Processor)(nil)
})
t.Run("VoidProcessorImplementsInterface", func(t *testing.T) {
var _ NotificationProcessor = (*VoidProcessor)(nil)
})
}
// UnitTestHandler is a test implementation of NotificationHandler
type UnitTestHandler struct {
name string
lastNotification []interface{}
errorToReturn error
callCount int
}
func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error {
h.callCount++
h.lastNotification = notification
return h.errorToReturn
}
// Helper methods for UnitTestHandler
func (h *UnitTestHandler) GetCallCount() int {
return h.callCount
}
func (h *UnitTestHandler) GetLastNotification() []interface{} {
return h.lastNotification
}
func (h *UnitTestHandler) SetErrorToReturn(err error) {
h.errorToReturn = err
}
func (h *UnitTestHandler) Reset() {
h.callCount = 0
h.lastNotification = nil
h.errorToReturn = nil
}

View File

@@ -4,24 +4,6 @@ import (
"github.com/redis/go-redis/v9/push"
)
// Push notification constants for cluster operations
const (
// MOVING indicates a slot is being moved to a different node
PushNotificationMoving = "MOVING"
// MIGRATING indicates a slot is being migrated from this node
PushNotificationMigrating = "MIGRATING"
// MIGRATED indicates a slot has been migrated to this node
PushNotificationMigrated = "MIGRATED"
// FAILING_OVER indicates a failover is starting
PushNotificationFailingOver = "FAILING_OVER"
// FAILED_OVER indicates a failover has completed
PushNotificationFailedOver = "FAILED_OVER"
)
// NewPushNotificationProcessor creates a new push notification processor
// This processor maintains a registry of handlers and processes push notifications
// It is used for RESP3 connections where push notifications are available

234
redis.go
View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/hitless"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
@@ -204,19 +205,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
//------------------------------------------------------------------------------
type baseClient struct {
opt *Options
connPool pool.Pooler
opt *Options
optLock sync.RWMutex
connPool pool.Pooler
pubSubPool *pool.PubSubPool
hooksMixin
onClose func() error // hook called when client is closed
// Push notification processing
pushProcessor push.NotificationProcessor
// Hitless upgrade manager
hitlessManager *hitless.HitlessManager
hitlessManagerLock sync.RWMutex
}
func (c *baseClient) clone() *baseClient {
clone := *c
return &clone
c.hitlessManagerLock.RLock()
hitlessManager := c.hitlessManager
c.hitlessManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
hitlessManager: hitlessManager,
}
return clone
}
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
@@ -234,21 +251,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.NewConn(ctx)
if err != nil {
return nil, err
}
err = c.initConn(ctx, cn)
if err != nil {
_ = c.connPool.CloseConn(cn)
return nil, err
}
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
@@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
if cn.Inited {
if cn.IsInited() {
return cn, nil
}
@@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
if !cn.Inited.CompareAndSwap(false, true) {
return nil
}
var err error
cn.Inited = true
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin)
@@ -430,6 +430,51 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("failed to initialize connection options: %w", err)
}
// Enable maintenance notifications if hitless upgrades are configured
c.optLock.RLock()
hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled
protocol := c.opt.Protocol
endpointType := c.opt.HitlessUpgradeConfig.EndpointType
c.optLock.RUnlock()
var hitlessHandshakeErr error
if hitlessEnabled && protocol == 3 {
hitlessHandshakeErr = conn.ClientMaintNotifications(
ctx,
true,
endpointType.String(),
).Err()
if hitlessHandshakeErr != nil {
if !isRedisError(hitlessHandshakeErr) {
// if not redis error, fail the connection
return hitlessHandshakeErr
}
c.optLock.Lock()
// handshake failed - check and modify config atomically
switch c.opt.HitlessUpgradeConfig.Mode {
case hitless.MaintNotificationsEnabled:
// enabled mode, fail the connection
c.optLock.Unlock()
return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr)
default: // will handle auto and any other
internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake error: %v", hitlessHandshakeErr)
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled
c.optLock.Unlock()
// auto mode, disable hitless upgrades and continue
if err := c.disableHitlessUpgrades(); err != nil {
// Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err)
}
}
} else {
// handshake was executed successfully
// to make sure that the handshake will be executed on other connections as well if it was successfully
// executed on this connection, we will force the handshake to be executed on all connections
c.optLock.Lock()
c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled
c.optLock.Unlock()
}
}
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
libName := ""
libVer := Version()
@@ -446,6 +491,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
}
}
cn.SetUsable(true)
cn.Inited.Store(true)
// Set the connection initialization function for potential reconnections
cn.SetInitConnFunc(c.createInitConnFunc())
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
@@ -512,6 +563,8 @@ func (c *baseClient) assertUnstableCommand(cmd Cmder) bool {
if c.opt.UnstableResp3 {
return true
} else {
// TODO: find the best way to remove the panic and return error here
// The client should not panic when executing a command, only when initializing.
panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.")
}
default:
@@ -593,19 +646,76 @@ func (c *baseClient) context(ctx context.Context) context.Context {
return context.Background()
}
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
return func(ctx context.Context, cn *pool.Conn) error {
return c.initConn(ctx, cn)
}
}
// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook.
// This function is called during client initialization.
// will register push notification handlers for all hitless upgrade events.
// will start background workers for handoff processing in the pool hook.
func (c *baseClient) enableHitlessUpgrades() error {
// Create client adapter
clientAdapterInstance := newClientAdapter(c)
// Create hitless manager directly
manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig)
if err != nil {
return err
}
// Set the manager reference and initialize pool hook
c.hitlessManagerLock.Lock()
c.hitlessManager = manager
c.hitlessManagerLock.Unlock()
// Initialize pool hook (safe to call without lock since manager is now set)
manager.InitPoolHook(c.dialHook)
return nil
}
func (c *baseClient) disableHitlessUpgrades() error {
c.hitlessManagerLock.Lock()
defer c.hitlessManagerLock.Unlock()
// Close the hitless manager
if c.hitlessManager != nil {
// Closing the manager will also shutdown the pool hook
// and remove it from the pool
c.hitlessManager.Close()
c.hitlessManager = nil
}
return nil
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
// Close hitless manager first
if err := c.disableHitlessUpgrades(); err != nil {
firstErr = err
}
if c.onClose != nil {
if err := c.onClose(); err != nil {
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
if c.connPool != nil {
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.pubSubPool != nil {
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
@@ -796,6 +906,8 @@ func NewClient(opt *Options) *Client {
if opt == nil {
panic("redis: NewClient nil options")
}
// clone to not share options with the caller
opt = opt.clone()
opt.init()
// Push notifications are always enabled for RESP3 (cannot be disabled)
@@ -810,11 +922,40 @@ func NewClient(opt *Options) *Client {
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
// set opt push processor for child clients
c.opt.PushNotificationProcessor = c.pushProcessor
// Update options with the initialized push processor for connection pool
opt.PushNotificationProcessor = c.pushProcessor
// Create connection pools
var err error
c.connPool, err = newConnPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
c.connPool = newConnPool(opt, c.dialHook)
// Initialize hitless upgrades first if enabled and protocol is RESP3
if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 {
err := c.enableHitlessUpgrades()
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled {
/*
Design decision: panic here to fail fast if hitless upgrades cannot be enabled when explicitly requested.
We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect
an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced
immediately, rather than allowing the client to continue in a partially initialized or inconsistent state.
Clients relying on hitless upgrades should be aware that initialization errors will cause a panic, and should
handle this accordingly (e.g., via recover or by validating configuration before calling NewClient).
This approach is only used when HitlessUpgradeConfig.Mode is MaintNotificationsEnabled, indicating that hitless
upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic.
*/
panic(fmt.Errorf("failed to enable hitless upgrades: %w", err))
}
}
}
return &c
}
@@ -851,6 +992,14 @@ func (c *Client) Options() *Options {
return c.opt
}
// GetHitlessManager returns the hitless manager instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessManager() *hitless.HitlessManager {
c.hitlessManagerLock.RLock()
defer c.hitlessManagerLock.RUnlock()
return c.hitlessManager
}
// initializePushProcessor initializes the push notification processor for any client type.
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
func initializePushProcessor(opt *Options) push.NotificationProcessor {
@@ -887,6 +1036,7 @@ type PoolStats pool.Stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
stats := c.connPool.Stats()
stats.PubSubStats = *(c.pubSubPool.Stats())
return (*PoolStats)(stats)
}
@@ -921,11 +1071,27 @@ func (c *Client) TxPipeline() Pipeliner {
func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
closeConn: c.connPool.CloseConn,
pushProcessor: c.pushProcessor,
}
pubsub.init()
@@ -1113,6 +1279,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica
return push.NotificationHandlerContext{
Client: c,
ConnPool: c.connPool,
Conn: cn,
Conn: cn, // Wrap in adapter for easier interface access
}
}

View File

@@ -12,7 +12,6 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)

View File

@@ -16,8 +16,8 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/push"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/push"
)
//------------------------------------------------------------------------------
@@ -139,6 +139,14 @@ type FailoverOptions struct {
FailingTimeoutSeconds int
UnstableResp3 bool
// Hitless is not supported for FailoverClients at the moment
// HitlessUpgradeConfig provides custom configuration for hitless upgrades.
// When HitlessUpgradeConfig.Mode is not "disabled", the client will handle
// upgrade notifications gracefully and manage connection/pool state transitions
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, hitless upgrades are disabled.
//HitlessUpgradeConfig *HitlessUpgradeConfig
}
func (opt *FailoverOptions) clientOptions() *Options {
@@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt.Dialer = masterReplicaDialer(failover)
opt.init()
var connPool *pool.ConnPool
rdb := &Client{
baseClient: &baseClient{
opt: opt,
@@ -469,15 +475,25 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
// Use void processor by default for RESP2 connections
rdb.pushProcessor = initializePushProcessor(opt)
connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool
var err error
rdb.connPool, err = newConnPool(opt, rdb.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
}
}
failover.mu.Unlock()
@@ -543,7 +559,15 @@ func NewSentinelClient(opt *Options) *SentinelClient {
dial: c.baseClient.dial,
process: c.baseClient.process,
})
c.connPool = newConnPool(opt, c.dialHook)
var err error
c.connPool, err = newConnPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
return c
}
@@ -570,13 +594,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}

2
tx.go
View File

@@ -24,7 +24,7 @@ type Tx struct {
func (c *Client) newTx() *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client
connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
pushProcessor: c.pushProcessor, // Copy push processor from parent client

View File

@@ -122,6 +122,9 @@ type UniversalOptions struct {
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
IsClusterMode bool
// HitlessUpgradeConfig provides configuration for hitless upgrades.
HitlessUpgradeConfig *HitlessUpgradeConfig
}
// Cluster returns cluster options created from the universal options.
@@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}
@@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
// Note: HitlessUpgradeConfig not supported for FailoverOptions
}
}
@@ -284,10 +289,11 @@ func (o *UniversalOptions) Simple() *Options {
TLSConfig: o.TLSConfig,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
HitlessUpgradeConfig: o.HitlessUpgradeConfig,
}
}