1
0
mirror of https://github.com/redis/go-redis.git synced 2025-12-02 06:22:31 +03:00

Merge branch 'master' into implement-tls-url-parameters-pr2076

This commit is contained in:
Nedyalko Dyakov
2025-10-21 12:00:21 +03:00
committed by GitHub
118 changed files with 18598 additions and 656 deletions

View File

@@ -36,6 +36,8 @@ categories:
change-template: '- $TITLE (#$NUMBER)'
exclude-labels:
- 'skip-changelog'
exclude-contributors:
- 'dependabot'
template: |
# Changes

View File

@@ -27,7 +27,7 @@ jobs:
steps:
- name: Set up ${{ matrix.go-version }}
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}

View File

@@ -39,7 +39,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v3
uses: github/codeql-action/autobuild@v4
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4

View File

@@ -31,7 +31,7 @@ jobs:
steps:
- name: Set up ${{ matrix.go-version }}
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}

View File

@@ -8,7 +8,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v5
- name: Check Spelling
uses: rojopolis/spellcheck-github-actions@0.51.0
uses: rojopolis/spellcheck-github-actions@0.52.0
with:
config_path: .github/spellcheck-settings.yml
task_name: Markdown

View File

@@ -29,7 +29,7 @@ jobs:
path: redis-ee
- name: Set up ${{ matrix.go-version }}
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}

4
.gitignore vendored
View File

@@ -9,3 +9,7 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
*.test
# maintenanceNotifications upgrade documentation (temporary)
maintenanceNotifications/docs/

View File

@@ -1,5 +1,124 @@
# Release Notes
# 9.15.0-beta.3 (2025-09-26)
## Highlights
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
# Changes
- chore: Update hash_commands.go ([#3523](https://github.com/redis/go-redis/pull/3523))
## 🚀 New Features
- feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
## 🐛 Bug Fixes
- fix: pipeline repeatedly sets the error ([#3525](https://github.com/redis/go-redis/pull/3525))
## 🧰 Maintenance
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
- feat(e2e-testing): maintnotifications e2e and refactor ([#3526](https://github.com/redis/go-redis/pull/3526))
- feat(tag.sh): Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@cxljs](https://github.com/cxljs), [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), and [@omid-h70](https://github.com/omid-h70)
# 9.15.0-beta.1 (2025-09-10)
## Highlights
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
### Hitless Upgrades
Hitless upgrades is a major new feature that allows for zero-downtime upgrades in Redis clusters.
You can find more information in the [Hitless Upgrades documentation](https://github.com/redis/go-redis/tree/master/hitless).
# Changes
## 🚀 New Features
- [CAE-1088] & [CAE-1072] feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@ofekshenawa](https://github.com/ofekshenawa)
# 9.14.0 (2025-09-10)
## Highlights
- Added batch process method to the pipeline ([#3510](https://github.com/redis/go-redis/pull/3510))
# Changes
## 🚀 New Features
- Added batch process method to the pipeline ([#3510](https://github.com/redis/go-redis/pull/3510))
## 🐛 Bug Fixes
- fix: SetErr on Cmd if the command cannot be queued correctly in multi/exec ([#3509](https://github.com/redis/go-redis/pull/3509))
## 🧰 Maintenance
- Updates release drafter config to exclude dependabot ([#3511](https://github.com/redis/go-redis/pull/3511))
- chore(deps): bump actions/setup-go from 5 to 6 ([#3504](https://github.com/redis/go-redis/pull/3504))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@elena-kolevska](https://github.com/elena-kolevksa), [@htemelski-redis](https://github.com/htemelski-redis) and [@ndyakov](https://github.com/ndyakov)
# 9.13.0 (2025-09-03)
## Highlights
- Pipeliner expose queued commands ([#3496](https://github.com/redis/go-redis/pull/3496))
- Ensure that JSON.GET returns Nil response ([#3470](https://github.com/redis/go-redis/pull/3470))
- Fixes on Read and Write buffer sizes and UniversalOptions
## Changes
- Pipeliner expose queued commands ([#3496](https://github.com/redis/go-redis/pull/3496))
- fix(test): fix a timing issue in pubsub test ([#3498](https://github.com/redis/go-redis/pull/3498))
- Allow users to enable read-write splitting in failover mode. ([#3482](https://github.com/redis/go-redis/pull/3482))
- Set the read/write buffer size of the sentinel client to 4KiB ([#3476](https://github.com/redis/go-redis/pull/3476))
## 🚀 New Features
- fix(otel): register wait metrics ([#3499](https://github.com/redis/go-redis/pull/3499))
- Support subscriptions against cluster slave nodes ([#3480](https://github.com/redis/go-redis/pull/3480))
- Add wait metrics to otel ([#3493](https://github.com/redis/go-redis/pull/3493))
- Clean failing timeout implementation ([#3472](https://github.com/redis/go-redis/pull/3472))
## 🐛 Bug Fixes
- Do not assume that all non-IP hosts are loopbacks ([#3085](https://github.com/redis/go-redis/pull/3085))
- Ensure that JSON.GET returns Nil response ([#3470](https://github.com/redis/go-redis/pull/3470))
## 🧰 Maintenance
- fix(otel): register wait metrics ([#3499](https://github.com/redis/go-redis/pull/3499))
- fix(make test): Add default env in makefile ([#3491](https://github.com/redis/go-redis/pull/3491))
- Update the introduction to running tests in README.md ([#3495](https://github.com/redis/go-redis/pull/3495))
- test: Add comprehensive edge case tests for IncrByFloat command ([#3477](https://github.com/redis/go-redis/pull/3477))
- Set the default read/write buffer size of Redis connection to 32KiB ([#3483](https://github.com/redis/go-redis/pull/3483))
- Bumps test image to 8.2.1-pre ([#3478](https://github.com/redis/go-redis/pull/3478))
- fix UniversalOptions miss ReadBufferSize and WriteBufferSize options ([#3485](https://github.com/redis/go-redis/pull/3485))
- chore(deps): bump actions/checkout from 4 to 5 ([#3484](https://github.com/redis/go-redis/pull/3484))
- Removes dry run for stale issues policy ([#3471](https://github.com/redis/go-redis/pull/3471))
- Update otel metrics URL ([#3474](https://github.com/redis/go-redis/pull/3474))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@LINKIWI](https://github.com/LINKIWI), [@cxljs](https://github.com/cxljs), [@cybersmeashish](https://github.com/cybersmeashish), [@elena-kolevska](https://github.com/elena-kolevska), [@htemelski-redis](https://github.com/htemelski-redis), [@mwhooker](https://github.com/mwhooker), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@suever](https://github.com/suever)
# 9.12.1 (2025-08-11)
## 🚀 Highlights
In the last version (9.12.0) the client introduced bigger write and read buffer sized. The default value we set was 512KiB.

111
adapters.go Normal file
View File

@@ -0,0 +1,111 @@
package redis
import (
"context"
"errors"
"net"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/push"
)
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
var ErrInvalidCommand = errors.New("invalid command type")
// ErrInvalidPool is returned when the pool type is not supported.
var ErrInvalidPool = errors.New("invalid pool type")
// newClientAdapter creates a new client adapter for regular Redis clients.
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
return &clientAdapter{client: client}
}
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
type clientAdapter struct {
client *baseClient
}
// GetOptions returns the client options.
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
return &optionsAdapter{options: ca.client.opt}
}
// GetPushProcessor returns the client's push notification processor.
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
}
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
type optionsAdapter struct {
options *Options
}
// GetReadTimeout returns the read timeout.
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
return oa.options.ReadTimeout
}
// GetWriteTimeout returns the write timeout.
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
return oa.options.WriteTimeout
}
// GetNetwork returns the network type.
func (oa *optionsAdapter) GetNetwork() string {
return oa.options.Network
}
// GetAddr returns the connection address.
func (oa *optionsAdapter) GetAddr() string {
return oa.options.Addr
}
// IsTLSEnabled returns true if TLS is enabled.
func (oa *optionsAdapter) IsTLSEnabled() bool {
return oa.options.TLSConfig != nil
}
// GetProtocol returns the protocol version.
func (oa *optionsAdapter) GetProtocol() int {
return oa.options.Protocol
}
// GetPoolSize returns the connection pool size.
func (oa *optionsAdapter) GetPoolSize() int {
return oa.options.PoolSize
}
// NewDialer returns a new dialer function for the connection.
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
baseDialer := oa.options.NewDialer()
return func(ctx context.Context) (net.Conn, error) {
// Extract network and address from the options
network := oa.options.Network
addr := oa.options.Addr
return baseDialer(ctx, network, addr)
}
}
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
type pushProcessorAdapter struct {
processor push.NotificationProcessor
}
// RegisterHandler registers a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
if pushHandler, ok := handler.(push.NotificationHandler); ok {
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
}
return errors.New("handler must implement push.NotificationHandler")
}
// UnregisterHandler removes a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
return ppa.processor.UnregisterHandler(pushNotificationName)
}
// GetHandler returns the handler for a specific push notification name.
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
return ppa.processor.GetHandler(pushNotificationName)
}

View File

@@ -0,0 +1,353 @@
package redis
import (
"context"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow
func TestEventDrivenHandoffIntegration(t *testing.T) {
t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor with event-driven handoff support
processor := maintnotifications.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create a test pool with hooks
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(5),
PoolTimeout: time.Second,
})
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
defer testPool.Close()
// Set the pool reference in the processor for connection removal on handoff failure
processor.SetPool(testPool)
ctx := context.Background()
// Get a connection and mark it for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
// Set initialization function with a small delay to ensure handoff is pending
initConnCalled := false
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
initConnCalled = true
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark connection for handoff
err = conn.MarkForHandoff("new-endpoint:6379", 12345)
if err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Return connection to pool - this should queue handoff
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start processing
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Try to get the same connection - should be skipped due to pending handoff
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get second connection: %v", err)
}
// Should get a different connection (the pending one should be skipped)
if conn == conn2 {
t.Error("Should have gotten a different connection while handoff is pending")
}
// Return the second connection
testPool.Put(ctx, conn2)
// Wait for handoff to complete
time.Sleep(200 * time.Millisecond)
// Verify handoff completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map")
}
if !initConnCalled {
t.Error("InitConn should have been called during handoff")
}
// Now the original connection should be available again
conn3, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get third connection: %v", err)
}
// Could be the original connection (now handed off) or a new one
testPool.Put(ctx, conn3)
})
t.Run("ConcurrentHandoffs", func(t *testing.T) {
// Create a base dialer that simulates slow handoffs
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(50 * time.Millisecond) // Simulate network delay
return &mockNetConn{addr: addr}, nil
}
processor := maintnotifications.NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(10),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
var wg sync.WaitGroup
// Start multiple concurrent handoffs
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Get connection
conn, err := testPool.Get(ctx)
if err != nil {
t.Errorf("Failed to get conn[%d]: %v", id, err)
return
}
// Set initialization function
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
return nil
}
conn.SetInitConnFunc(initConnFunc)
// Mark for handoff
conn.MarkForHandoff("new-endpoint:6379", int64(id))
// Return to pool (starts async handoff)
testPool.Put(ctx, conn)
}(i)
}
wg.Wait()
// Wait for all handoffs to complete
time.Sleep(300 * time.Millisecond)
// Verify pool is still functional
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err)
}
testPool.Put(ctx, conn)
})
t.Run("HandoffFailureRecovery", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}}
}
processor := maintnotifications.NewPoolHook(failingDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(3),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Get connection and mark for handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
conn.MarkForHandoff("unreachable-endpoint:6379", 12345)
// Return to pool (starts async handoff that will fail)
testPool.Put(ctx, conn)
// Wait for handoff to fail
time.Sleep(200 * time.Millisecond)
// Connection should be removed from pending map after failed handoff
if processor.IsHandoffPending(conn) {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Pool should still be functional
conn2, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Pool should still be functional: %v", err)
}
// In event-driven approach, the original connection remains in pool
// even after failed handoff (it's still a valid connection)
// We might get the same connection or a different one
testPool.Put(ctx, conn2)
})
t.Run("GracefulShutdown", func(t *testing.T) {
// Create a slow base dialer
slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
time.Sleep(100 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := maintnotifications.NewPoolHook(slowDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Create hooks manager and add processor as hook
hookManager := pool.NewPoolHookManager()
hookManager.AddHook(processor)
testPool := pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(2),
PoolTimeout: time.Second,
})
defer testPool.Close()
// Add the hook to the pool after creation
testPool.AddPoolHook(processor)
// Set the pool reference in the processor
processor.SetPool(testPool)
ctx := context.Background()
// Start a handoff
conn, err := testPool.Get(ctx)
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function with delay to ensure handoff is pending
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending
return nil
})
testPool.Put(ctx, conn)
// Give the on-demand worker a moment to start and begin processing
// The handoff should be pending because the slowDialer takes 100ms
time.Sleep(10 * time.Millisecond)
// Verify handoff was queued and is being processed
if !processor.IsHandoffPending(conn) {
t.Error("Handoff should be queued in pending map")
}
// Give the handoff a moment to start processing
time.Sleep(50 * time.Millisecond)
// Shutdown processor gracefully
// Use a longer timeout to account for slow dialer (100ms) plus processing overhead
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = processor.Shutdown(shutdownCtx)
if err != nil {
t.Errorf("Graceful shutdown should succeed: %v", err)
}
// Handoff should have completed (removed from pending map)
if processor.IsHandoffPending(conn) {
t.Error("Handoff should have completed and been removed from pending map after shutdown")
}
})
}
func init() {
logging.Disable()
}

View File

@@ -1,316 +0,0 @@
package redis
import (
"context"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/proto"
)
var ctx = context.TODO()
type ClientStub struct {
Cmdable
resp []byte
}
var initHello = []byte("%1\r\n+proto\r\n:3\r\n")
func NewClientStub(resp []byte) *ClientStub {
stub := &ClientStub{
resp: resp,
}
stub.Cmdable = NewClient(&Options{
PoolSize: 128,
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return stub.stubConn(initHello), nil
},
DisableIdentity: true,
})
return stub
}
func NewClusterClientStub(resp []byte) *ClientStub {
stub := &ClientStub{
resp: resp,
}
client := NewClusterClient(&ClusterOptions{
PoolSize: 128,
Addrs: []string{":6379"},
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return stub.stubConn(initHello), nil
},
DisableIdentity: true,
ClusterSlots: func(_ context.Context) ([]ClusterSlot, error) {
return []ClusterSlot{
{
Start: 0,
End: 16383,
Nodes: []ClusterNode{{Addr: "127.0.0.1:6379"}},
},
}, nil
},
})
stub.Cmdable = client
return stub
}
func (c *ClientStub) stubConn(init []byte) *ConnStub {
return &ConnStub{
init: init,
resp: c.resp,
}
}
type ConnStub struct {
init []byte
resp []byte
pos int
}
func (c *ConnStub) Read(b []byte) (n int, err error) {
// Return conn.init()
if len(c.init) > 0 {
n = copy(b, c.init)
c.init = c.init[n:]
return n, nil
}
if len(c.resp) == 0 {
return 0, io.EOF
}
if c.pos >= len(c.resp) {
c.pos = 0
}
n = copy(b, c.resp[c.pos:])
c.pos += n
return n, nil
}
func (c *ConnStub) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *ConnStub) Close() error { return nil }
func (c *ConnStub) LocalAddr() net.Addr { return nil }
func (c *ConnStub) RemoteAddr() net.Addr { return nil }
func (c *ConnStub) SetDeadline(_ time.Time) error { return nil }
func (c *ConnStub) SetReadDeadline(_ time.Time) error { return nil }
func (c *ConnStub) SetWriteDeadline(_ time.Time) error { return nil }
type ClientStubFunc func([]byte) *ClientStub
func BenchmarkDecode(b *testing.B) {
type Benchmark struct {
name string
stub ClientStubFunc
}
benchmarks := []Benchmark{
{"server", NewClientStub},
{"cluster", NewClusterClientStub},
}
for _, bench := range benchmarks {
b.Run(fmt.Sprintf("RespError-%s", bench.name), func(b *testing.B) {
respError(b, bench.stub)
})
b.Run(fmt.Sprintf("RespStatus-%s", bench.name), func(b *testing.B) {
respStatus(b, bench.stub)
})
b.Run(fmt.Sprintf("RespInt-%s", bench.name), func(b *testing.B) {
respInt(b, bench.stub)
})
b.Run(fmt.Sprintf("RespString-%s", bench.name), func(b *testing.B) {
respString(b, bench.stub)
})
b.Run(fmt.Sprintf("RespArray-%s", bench.name), func(b *testing.B) {
respArray(b, bench.stub)
})
b.Run(fmt.Sprintf("RespPipeline-%s", bench.name), func(b *testing.B) {
respPipeline(b, bench.stub)
})
b.Run(fmt.Sprintf("RespTxPipeline-%s", bench.name), func(b *testing.B) {
respTxPipeline(b, bench.stub)
})
// goroutine
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 5)
})
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 20)
})
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 50)
})
b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
dynamicGoroutine(b, bench.stub, 100)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 5)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 20)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 50)
})
b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
staticGoroutine(b, bench.stub, 100)
})
}
}
func respError(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("-ERR test error\r\n"))
respErr := proto.RedisError("ERR test error")
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := rdb.Get(ctx, "key").Err(); err != respErr {
b.Fatalf("response error, got %q, want %q", err, respErr)
}
}
}
func respStatus(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("+OK\r\n"))
var val string
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.Set(ctx, "key", "value", 0).Val(); val != "OK" {
b.Fatalf("response error, got %q, want OK", val)
}
}
}
func respInt(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte(":10\r\n"))
var val int64
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.Incr(ctx, "key").Val(); val != 10 {
b.Fatalf("response error, got %q, want 10", val)
}
}
}
func respString(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("$5\r\nhello\r\n"))
var val string
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.Get(ctx, "key").Val(); val != "hello" {
b.Fatalf("response error, got %q, want hello", val)
}
}
}
func respArray(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("*3\r\n$5\r\nhello\r\n:10\r\n+OK\r\n"))
var val []interface{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if val = rdb.MGet(ctx, "key").Val(); len(val) != 3 {
b.Fatalf("response error, got len(%d), want len(3)", len(val))
}
}
}
func respPipeline(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("+OK\r\n$5\r\nhello\r\n:1\r\n"))
var pipe Pipeliner
b.ResetTimer()
for i := 0; i < b.N; i++ {
pipe = rdb.Pipeline()
set := pipe.Set(ctx, "key", "value", 0)
get := pipe.Get(ctx, "key")
del := pipe.Del(ctx, "key")
_, err := pipe.Exec(ctx)
if err != nil {
b.Fatalf("response error, got %q, want nil", err)
}
if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
b.Fatal("response error")
}
}
}
func respTxPipeline(b *testing.B, stub ClientStubFunc) {
rdb := stub([]byte("+OK\r\n+QUEUED\r\n+QUEUED\r\n+QUEUED\r\n*3\r\n+OK\r\n$5\r\nhello\r\n:1\r\n"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
var set *StatusCmd
var get *StringCmd
var del *IntCmd
_, err := rdb.TxPipelined(ctx, func(pipe Pipeliner) error {
set = pipe.Set(ctx, "key", "value", 0)
get = pipe.Get(ctx, "key")
del = pipe.Del(ctx, "key")
return nil
})
if err != nil {
b.Fatalf("response error, got %q, want nil", err)
}
if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
b.Fatal("response error")
}
}
}
func dynamicGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
rdb := stub([]byte("$5\r\nhello\r\n"))
c := make(chan struct{}, concurrency)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c <- struct{}{}
go func() {
if val := rdb.Get(ctx, "key").Val(); val != "hello" {
panic(fmt.Sprintf("response error, got %q, want hello", val))
}
<-c
}()
}
// Here no longer wait for all goroutines to complete, it will not affect the test results.
close(c)
}
func staticGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
rdb := stub([]byte("$5\r\nhello\r\n"))
c := make(chan struct{}, concurrency)
b.ResetTimer()
for i := 0; i < concurrency; i++ {
go func() {
for {
_, ok := <-c
if !ok {
return
}
if val := rdb.Get(ctx, "key").Val(); val != "hello" {
panic(fmt.Sprintf("response error, got %q, want hello", val))
}
}
}()
}
for i := 0; i < b.N; i++ {
c <- struct{}{}
}
close(c)
}

View File

@@ -193,6 +193,7 @@ type Cmdable interface {
ClientID(ctx context.Context) *IntCmd
ClientUnblock(ctx context.Context, id int64) *IntCmd
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
ConfigResetStat(ctx context.Context) *StatusCmd
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
@@ -519,6 +520,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}
if enabled {
if endpointType == "" {
endpointType = "none"
}
args = append(args, "on", "moving-endpoint-type", endpointType)
} else {
args = append(args, "off")
}
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// ------------------------------------------------------------------------------------------------
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {

View File

@@ -3019,7 +3019,8 @@ var _ = Describe("Commands", func() {
res, err = client.HPTTL(ctx, "myhash", "key1", "key2", "key200").Result()
Expect(err).NotTo(HaveOccurred())
Expect(res[0]).To(BeNumerically("~", 10*time.Second.Milliseconds(), 1))
// overhead of the push notification check is about 1-2ms for 100 commands
Expect(res[0]).To(BeNumerically("~", 10*time.Second.Milliseconds(), 2))
})
It("should HGETDEL", Label("hash", "HGETDEL"), func() {

View File

@@ -5,7 +5,7 @@ go 1.18
replace github.com/redis/go-redis/v9 => ../..
require (
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
go.uber.org/zap v1.24.0
)

View File

@@ -4,7 +4,7 @@ go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.12.1
require github.com/redis/go-redis/v9 v9.16.0-beta.1
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect

View File

@@ -6,7 +6,7 @@ replace github.com/redis/go-redis/v9 => ../..
require (
github.com/davecgh/go-spew v1.1.1
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
)
require (

View File

@@ -4,7 +4,7 @@ go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.12.1
require github.com/redis/go-redis/v9 v9.16.0-beta.1
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect

View File

@@ -11,8 +11,8 @@ replace github.com/redis/go-redis/extra/redisotel/v9 => ../../extra/redisotel
replace github.com/redis/go-redis/extra/rediscmd/v9 => ../../extra/rediscmd
require (
github.com/redis/go-redis/extra/redisotel/v9 v9.12.1
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/extra/redisotel/v9 v9.16.0-beta.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
github.com/uptrace/uptrace-go v1.21.0
go.opentelemetry.io/otel v1.22.0
)
@@ -25,7 +25,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect
github.com/redis/go-redis/extra/rediscmd/v9 v9.12.1 // indirect
github.com/redis/go-redis/extra/rediscmd/v9 v9.16.0-beta.1 // indirect
go.opentelemetry.io/contrib/instrumentation/runtime v0.46.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v0.44.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 // indirect

12
example/pubsub/go.mod Normal file
View File

@@ -0,0 +1,12 @@
module github.com/redis/go-redis/example/pubsub
go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.11.0
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

6
example/pubsub/go.sum Normal file
View File

@@ -0,0 +1,6 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

175
example/pubsub/main.go Normal file
View File

@@ -0,0 +1,175 @@
package main
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
var ctx = context.Background()
var cntErrors atomic.Int64
var cntSuccess atomic.Int64
var startTime = time.Now()
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
// It was used to find regressions in pool management in maintnotifications mode.
// Please don't use it as a reference for how to use pubsub.
func main() {
startTime = time.Now()
wg := &sync.WaitGroup{}
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
EndpointType: maintnotifications.EndpointTypeExternalIP,
HandoffTimeout: 10 * time.Second,
RelaxedTimeout: 10 * time.Second,
PostHandoffRelaxedDuration: 10 * time.Second,
},
})
_ = rdb.FlushDB(ctx).Err()
maintnotificationsManager := rdb.GetMaintNotificationsManager()
if maintnotificationsManager == nil {
panic("maintnotifications manager is nil")
}
loggingHook := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
maintnotificationsManager.AddNotificationHook(loggingHook)
go func() {
for {
time.Sleep(2 * time.Second)
fmt.Printf("pool stats: %+v\n", rdb.PoolStats())
}
}()
err := rdb.Ping(ctx).Err()
if err != nil {
panic(err)
}
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
panic(err)
}
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
panic(err)
}
fmt.Println("published", rdb.Get(ctx, "published").Val())
fmt.Println("received", rdb.Get(ctx, "received").Val())
subCtx, cancelSubCtx := context.WithCancel(ctx)
pubCtx, cancelPublishers := context.WithCancel(ctx)
for i := 0; i < 10; i++ {
wg.Add(1)
go subscribe(subCtx, rdb, "test", i, wg)
}
time.Sleep(time.Second)
cancelSubCtx()
time.Sleep(time.Second)
subCtx, cancelSubCtx = context.WithCancel(ctx)
for i := 0; i < 10; i++ {
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
fmt.Println("incr error:", err)
cntErrors.Add(1)
}
wg.Add(1)
go floodThePool(pubCtx, rdb, wg)
}
for i := 0; i < 500; i++ {
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
fmt.Println("incr error:", err)
cntErrors.Add(1)
}
wg.Add(1)
go subscribe(subCtx, rdb, "test2", i, wg)
}
time.Sleep(120 * time.Second)
fmt.Println("canceling publishers")
cancelPublishers()
time.Sleep(10 * time.Second)
fmt.Println("canceling subscribers")
cancelSubCtx()
wg.Wait()
published, err := rdb.Get(ctx, "published").Result()
received, err := rdb.Get(ctx, "received").Result()
publishers, err := rdb.Get(ctx, "publishers").Result()
subscribers, err := rdb.Get(ctx, "subscribers").Result()
fmt.Printf("publishers: %s\n", publishers)
fmt.Printf("published: %s\n", published)
fmt.Printf("subscribers: %s\n", subscribers)
fmt.Printf("received: %s\n", received)
publishedInt, err := rdb.Get(ctx, "published").Int()
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
time.Sleep(2 * time.Second)
fmt.Println("errors:", cntErrors.Load())
fmt.Println("success:", cntSuccess.Load())
fmt.Println("time:", time.Since(startTime))
}
func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
err := rdb.Publish(ctx, "test2", "hello").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Println("publish error:", err)
cntErrors.Add(1)
}
}
err = rdb.Incr(ctx, "published").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Println("incr error:", err)
cntErrors.Add(1)
}
}
time.Sleep(10 * time.Nanosecond)
}
}
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
defer wg.Done()
rec := rdb.Subscribe(ctx, topic)
recChan := rec.Channel()
for {
select {
case <-ctx.Done():
rec.Close()
return
default:
select {
case <-ctx.Done():
rec.Close()
return
case msg := <-recChan:
err := rdb.Incr(ctx, "received").Err()
if err != nil {
if err.Error() != "context canceled" {
log.Printf("%s\n", err.Error())
cntErrors.Add(1)
}
}
_ = msg // Use the message to avoid unused variable warning
}
}
}
}

View File

@@ -4,7 +4,7 @@ go 1.18
replace github.com/redis/go-redis/v9 => ../..
require github.com/redis/go-redis/v9 v9.12.1
require github.com/redis/go-redis/v9 v9.16.0-beta.1
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect

View File

@@ -6,7 +6,7 @@ replace github.com/redis/go-redis/v9 => ../..
require (
github.com/davecgh/go-spew v1.1.1
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
)
require (

View File

@@ -57,6 +57,8 @@ func Example_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[ping]>
}
@@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// pipeline finished processing: [[ping] [ping]]
}
@@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() {
// finished dialing tcp :6379
// starting processing: <[hello 3]>
// finished processing: <[hello 3]>
// starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]>
// finished processing: <[watch foo]>
// starting processing: <[ping]>
// finished processing: <[ping]>

View File

@@ -7,8 +7,8 @@ replace github.com/redis/go-redis/v9 => ../..
replace github.com/redis/go-redis/extra/rediscmd/v9 => ../rediscmd
require (
github.com/redis/go-redis/extra/rediscmd/v9 v9.12.1
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/extra/rediscmd/v9 v9.16.0-beta.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
go.opencensus.io v0.24.0
)
@@ -19,6 +19,7 @@ require (
)
retract (
v9.7.2 // This version was accidentally released.
v9.5.3 // This version was accidentally released.
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
)

View File

@@ -7,7 +7,7 @@ replace github.com/redis/go-redis/v9 => ../..
require (
github.com/bsm/ginkgo/v2 v2.12.0
github.com/bsm/gomega v1.27.10
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
)
require (
@@ -16,6 +16,7 @@ require (
)
retract (
v9.7.2 // This version was accidentally released.
v9.5.3 // This version was accidentally released.
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
)

View File

@@ -1,6 +1,9 @@
package redisotel
import (
"strings"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -21,6 +24,7 @@ type config struct {
dbStmtEnabled bool
callerEnabled bool
filter func(cmd redis.Cmder) bool
// Metrics options.
@@ -124,6 +128,37 @@ func WithCallerEnabled(on bool) TracingOption {
})
}
// WithCommandFilter allows filtering of commands when tracing to omit commands that may have sensitive details like
// passwords.
func WithCommandFilter(filter func(cmd redis.Cmder) bool) TracingOption {
return tracingOption(func(conf *config) {
conf.filter = filter
})
}
func BasicCommandFilter(cmd redis.Cmder) bool {
if strings.ToLower(cmd.Name()) == "auth" {
return true
}
if strings.ToLower(cmd.Name()) == "hello" {
if len(cmd.Args()) < 3 {
return false
}
arg, exists := cmd.Args()[2].(string)
if !exists {
return false
}
if strings.ToLower(arg) == "auth" {
return true
}
}
return false
}
//------------------------------------------------------------------------------
type MetricsOption interface {

View File

@@ -7,8 +7,8 @@ replace github.com/redis/go-redis/v9 => ../..
replace github.com/redis/go-redis/extra/rediscmd/v9 => ../rediscmd
require (
github.com/redis/go-redis/extra/rediscmd/v9 v9.12.1
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/extra/rediscmd/v9 v9.16.0-beta.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
go.opentelemetry.io/otel v1.22.0
go.opentelemetry.io/otel/metric v1.22.0
go.opentelemetry.io/otel/sdk v1.22.0
@@ -24,6 +24,7 @@ require (
)
retract (
v9.7.2 // This version was accidentally released.
v9.5.3 // This version was accidentally released.
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
)

View File

@@ -220,6 +220,8 @@ func reportPoolStats(rdb *redis.Client, conf *config) (metric.Registration, erro
idleMin,
connsMax,
usage,
waits,
waitsDuration,
timeouts,
hits,
misses,

View File

@@ -102,6 +102,12 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook {
func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
// Check if the command should be filtered out
if th.conf.filter != nil && th.conf.filter(cmd) {
// If so, just call the next hook
return hook(ctx, cmd)
}
attrs := make([]attribute.KeyValue, 0, 8)
if th.conf.callerEnabled {
fn, file, line := funcFileLine("github.com/redis/go-redis")

View File

@@ -95,6 +95,138 @@ func TestWithoutCaller(t *testing.T) {
}
}
func TestWithCommandFilter(t *testing.T) {
t.Run("filter out ping command", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(func(cmd redis.Cmder) bool {
return cmd.Name() == "ping"
}),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "ping")
defer span.End()
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "ping" {
t.Fatalf("ping command should not be traced")
}
return nil
})
err := processHook(ctx, cmd)
if err != nil {
t.Fatal(err)
}
})
t.Run("do not filter ping command", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(func(cmd redis.Cmder) bool {
return false // never filter
}),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "ping")
defer span.End()
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "ping" {
t.Fatalf("ping command should be traced")
}
return nil
})
err := processHook(ctx, cmd)
if err != nil {
t.Fatal(err)
}
})
t.Run("auth command filtered with basic command filter", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(BasicCommandFilter),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "auth", "test-password")
defer span.End()
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "auth" {
t.Fatalf("auth command should not be traced by default")
}
return nil
})
err := processHook(ctx, cmd)
if err != nil {
t.Fatal(err)
}
})
t.Run("hello command filtered with basic command filter when sensitive", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(BasicCommandFilter),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "hello", 3, "AUTH", "test-user", "test-password")
defer span.End()
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "redis-test" || innerSpan.Name() == "hello" {
t.Fatalf("auth command should not be traced by default")
}
return nil
})
err := processHook(ctx, cmd)
if err != nil {
t.Fatal(err)
}
})
t.Run("hello command not filtered with basic command filter when not sensitive", func(t *testing.T) {
provider := sdktrace.NewTracerProvider()
hook := newTracingHook(
"",
WithTracerProvider(provider),
WithCommandFilter(BasicCommandFilter),
)
ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test")
cmd := redis.NewCmd(ctx, "hello", 3)
defer span.End()
processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error {
innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan)
if innerSpan.Name() != "hello" {
t.Fatalf("hello command should be traced")
}
return nil
})
err := processHook(ctx, cmd)
if err != nil {
t.Fatal(err)
}
})
}
func TestTracingHook_DialHook(t *testing.T) {
imsb := tracetest.NewInMemoryExporter()
provider := sdktrace.NewTracerProvider(sdktrace.WithSyncer(imsb))

View File

@@ -6,7 +6,7 @@ replace github.com/redis/go-redis/v9 => ../..
require (
github.com/prometheus/client_golang v1.14.0
github.com/redis/go-redis/v9 v9.12.1
github.com/redis/go-redis/v9 v9.16.0-beta.1
)
require (
@@ -23,6 +23,7 @@ require (
)
retract (
v9.7.2 // This version was accidentally released.
v9.5.3 // This version was accidentally released.
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.
)

2
go.mod
View File

@@ -10,6 +10,8 @@ require (
)
retract (
v9.15.1 // This version is used to retract v9.15.0
v9.15.0 // This version was accidentally released. It is identical to 9.15.0-beta.2
v9.7.2 // This version was accidentally released. Please use version 9.7.3 instead.
v9.5.4 // This version was accidentally released. Please use version 9.6.0 instead.
v9.5.3 // This version was accidentally released. Please use version 9.6.0 instead.

View File

@@ -116,16 +116,16 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice
// HSet accepts values in following formats:
//
// - HSet("myhash", "key1", "value1", "key2", "value2")
// - HSet(ctx, "myhash", "key1", "value1", "key2", "value2")
//
// - HSet("myhash", []string{"key1", "value1", "key2", "value2"})
// - HSet(ctx, "myhash", []string{"key1", "value1", "key2", "value2"})
//
// - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
// - HSet(ctx, "myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
//
// Playing struct With "redis" tag.
// type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` }
//
// - HSet("myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
// - HSet(ctx, "myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
//
// For struct, can be a structure pointer type, we only parse the field whose tag is redis.
// if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it,

245
hset_benchmark_test.go Normal file
View File

@@ -0,0 +1,245 @@
package redis_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// HSET Benchmark Tests
//
// This file contains benchmark tests for Redis HSET operations with different scales:
// 1, 10, 100, 1000, 10000, 100000 operations
//
// Prerequisites:
// - Redis server running on localhost:6379
// - No authentication required
//
// Usage:
// go test -bench=BenchmarkHSET -v ./hset_benchmark_test.go
// go test -bench=BenchmarkHSETPipelined -v ./hset_benchmark_test.go
// go test -bench=. -v ./hset_benchmark_test.go # Run all benchmarks
//
// Example output:
// BenchmarkHSET/HSET_1_operations-8 5000 250000 ns/op 1000000.00 ops/sec
// BenchmarkHSET/HSET_100_operations-8 100 10000000 ns/op 100000.00 ops/sec
//
// The benchmarks test three different approaches:
// 1. Individual HSET commands (BenchmarkHSET)
// 2. Pipelined HSET commands (BenchmarkHSETPipelined)
// BenchmarkHSET benchmarks HSET operations with different scales
func BenchmarkHSET(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 0,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_%d_operations", scale), func(b *testing.B) {
benchmarkHSETOperations(b, rdb, ctx, scale)
})
}
}
// benchmarkHSETOperations performs the actual HSET benchmark for a given scale
func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
hashKey := fmt.Sprintf("benchmark_hash_%d", operations)
b.ResetTimer()
b.StartTimer()
totalTimes := []time.Duration{}
for i := 0; i < b.N; i++ {
b.StopTimer()
// Clean up the hash before each iteration
rdb.Del(ctx, hashKey)
b.StartTimer()
startTime := time.Now()
// Perform the specified number of HSET operations
for j := 0; j < operations; j++ {
field := fmt.Sprintf("field_%d", j)
value := fmt.Sprintf("value_%d", j)
err := rdb.HSet(ctx, hashKey, field, value).Err()
if err != nil {
b.Fatalf("HSET operation failed: %v", err)
}
}
totalTimes = append(totalTimes, time.Now().Sub(startTime))
}
// Stop the timer to calculate metrics
b.StopTimer()
// Report operations per second
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
b.ReportMetric(opsPerSec, "ops/sec")
// Report average time per operation
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms")
}
// BenchmarkHSETPipelined benchmarks HSET operations using pipelining for better performance
func BenchmarkHSETPipelined(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 0,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_Pipelined_%d_operations", scale), func(b *testing.B) {
benchmarkHSETPipelined(b, rdb, ctx, scale)
})
}
}
// benchmarkHSETPipelined performs HSET benchmark using pipelining
func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) {
hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations)
b.ResetTimer()
b.StartTimer()
totalTimes := []time.Duration{}
for i := 0; i < b.N; i++ {
b.StopTimer()
// Clean up the hash before each iteration
rdb.Del(ctx, hashKey)
b.StartTimer()
startTime := time.Now()
// Use pipelining for better performance
pipe := rdb.Pipeline()
// Add all HSET operations to the pipeline
for j := 0; j < operations; j++ {
field := fmt.Sprintf("field_%d", j)
value := fmt.Sprintf("value_%d", j)
pipe.HSet(ctx, hashKey, field, value)
}
// Execute all operations at once
_, err := pipe.Exec(ctx)
if err != nil {
b.Fatalf("Pipeline execution failed: %v", err)
}
totalTimes = append(totalTimes, time.Now().Sub(startTime))
}
b.StopTimer()
// Report operations per second
opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds()
b.ReportMetric(opsPerSec, "ops/sec")
// Report average time per operation
avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N)
b.ReportMetric(float64(avgTimePerOp), "ns/op")
// report average time in milliseconds from totalTimes
avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes))
b.ReportMetric(float64(avgTimePerOpMs), "ms")
}
// add same tests but with RESP2
func BenchmarkHSET_RESP2(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password docs
DB: 0, // use default DB
Protocol: 2,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_RESP2_%d_operations", scale), func(b *testing.B) {
benchmarkHSETOperations(b, rdb, ctx, scale)
})
}
}
func BenchmarkHSETPipelined_RESP2(b *testing.B) {
ctx := context.Background()
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password docs
DB: 0, // use default DB
Protocol: 2,
})
defer rdb.Close()
// Test connection
if err := rdb.Ping(ctx).Err(); err != nil {
b.Skipf("Redis server not available: %v", err)
}
// Clean up before and after tests
defer func() {
rdb.FlushDB(ctx)
}()
scales := []int{1, 10, 100, 1000, 10000, 100000}
for _, scale := range scales {
b.Run(fmt.Sprintf("HSET_Pipelined_RESP2_%d_operations", scale), func(b *testing.B) {
benchmarkHSETPipelined(b, rdb, ctx, scale)
})
}
}

View File

@@ -0,0 +1,54 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the maintnotifications upgrade package to avoid circular dependencies.
package interfaces
import (
"context"
"net"
"time"
)
// NotificationProcessor is (most probably) a push.NotificationProcessor
// forward declaration to avoid circular imports
type NotificationProcessor interface {
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
UnregisterHandler(pushNotificationName string) error
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for maintnotifications upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface
// GetPushProcessor returns the client's push notification processor.
GetPushProcessor() NotificationProcessor
}
// OptionsInterface defines the interface for client options.
// Uses an adapter pattern to avoid circular dependencies.
type OptionsInterface interface {
// GetReadTimeout returns the read timeout.
GetReadTimeout() time.Duration
// GetWriteTimeout returns the write timeout.
GetWriteTimeout() time.Duration
// GetNetwork returns the network type.
GetNetwork() string
// GetAddr returns the connection address.
GetAddr() string
// IsTLSEnabled returns true if TLS is enabled.
IsTLSEnabled() bool
// GetProtocol returns the protocol version.
GetProtocol() int
// GetPoolSize returns the connection pool size.
GetPoolSize() int
// NewDialer returns a new dialer function for the connection.
NewDialer() func(context.Context) (net.Conn, error)
}

View File

@@ -7,20 +7,73 @@ import (
"os"
)
// TODO (ned): Revisit logging
// Add more standardized approach with log levels and configurability
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
type DefaultLogger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
func NewDefaultLogger() Logging {
return &DefaultLogger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
}
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
var Logger Logging = NewDefaultLogger()
var LogLevel LogLevelT = LogLevelError
// LogLevelT represents the logging level
type LogLevelT int
// Log level constants for the entire go-redis library
const (
LogLevelError LogLevelT = iota // 0 - errors only
LogLevelWarn // 1 - warnings and errors
LogLevelInfo // 2 - info, warnings, and errors
LogLevelDebug // 3 - debug, info, warnings, and errors
)
// String returns the string representation of the log level
func (l LogLevelT) String() string {
switch l {
case LogLevelError:
return "ERROR"
case LogLevelWarn:
return "WARN"
case LogLevelInfo:
return "INFO"
case LogLevelDebug:
return "DEBUG"
default:
return "UNKNOWN"
}
}
// IsValid returns true if the log level is valid
func (l LogLevelT) IsValid() bool {
return l >= LogLevelError && l <= LogLevelDebug
}
func (l LogLevelT) WarnOrAbove() bool {
return l >= LogLevelWarn
}
func (l LogLevelT) InfoOrAbove() bool {
return l >= LogLevelInfo
}
func (l LogLevelT) DebugOrAbove() bool {
return l >= LogLevelDebug
}

View File

@@ -0,0 +1,625 @@
package logs
import (
"encoding/json"
"fmt"
"regexp"
"github.com/redis/go-redis/v9/internal"
)
// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug
func appendJSONIfDebug(message string, data map[string]interface{}) string {
if internal.LogLevel.DebugOrAbove() {
jsonData, _ := json.Marshal(data)
return fmt.Sprintf("%s %s", message, string(jsonData))
}
return message
}
const (
// ========================================
// CIRCUIT_BREAKER.GO - Circuit breaker management
// ========================================
CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open"
CircuitBreakerOpenedMessage = "circuit breaker opened"
CircuitBreakerReopenedMessage = "circuit breaker reopened"
CircuitBreakerClosedMessage = "circuit breaker closed"
CircuitBreakerCleanupMessage = "circuit breaker cleanup"
CircuitBreakerOpenMessage = "circuit breaker is open, failing fast"
// ========================================
// CONFIG.GO - Configuration and debug
// ========================================
DebugLoggingEnabledMessage = "debug logging enabled"
ConfigDebugMessage = "config debug"
// ========================================
// ERRORS.GO - Error message constants
// ========================================
InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0"
InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0"
InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0"
InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0"
InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0"
InvalidEndpointTypeErrorMessage = "invalid endpoint type"
InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')"
InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10"
InvalidClientErrorMessage = "invalid client type"
InvalidNotificationErrorMessage = "invalid notification format"
MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached"
HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration"
InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1"
InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0"
InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1"
ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff"
ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff"
ShutdownErrorMessage = "shutdown"
CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast"
// ========================================
// EXAMPLE_HOOKS.GO - Example metrics hooks
// ========================================
MetricsHookProcessingNotificationMessage = "metrics hook processing"
MetricsHookRecordedErrorMessage = "metrics hook recorded error"
// ========================================
// HANDOFF_WORKER.GO - Connection handoff processing
// ========================================
HandoffStartedMessage = "handoff started"
HandoffFailedMessage = "handoff failed"
ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries"
ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff"
HandoffRetryAttemptMessage = "Performing handoff"
CannotQueueHandoffForRetryMessage = "can't queue handoff for retry"
HandoffQueueFullMessage = "handoff queue is full"
FailedToDialNewEndpointMessage = "failed to dial new endpoint"
ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff"
HandoffSuccessMessage = "handoff succeeded"
RemovingConnectionFromPoolMessage = "removing connection from pool"
NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it"
WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown"
WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request"
WorkerPanicRecoveredMessage = "worker panic recovered"
WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout"
ReachedMaxHandoffRetriesMessage = "reached max handoff retries"
// ========================================
// MANAGER.GO - Moving operation tracking and handler registration
// ========================================
DuplicateMovingOperationMessage = "duplicate MOVING operation ignored"
TrackingMovingOperationMessage = "tracking MOVING operation"
UntrackingMovingOperationMessage = "untracking MOVING operation"
OperationNotTrackedMessage = "operation not tracked"
FailedToRegisterHandlerMessage = "failed to register handler"
// ========================================
// HOOKS.GO - Notification processing hooks
// ========================================
ProcessingNotificationMessage = "processing notification started"
ProcessingNotificationFailedMessage = "proccessing notification failed"
ProcessingNotificationSucceededMessage = "processing notification succeeded"
// ========================================
// POOL_HOOK.GO - Pool connection management
// ========================================
FailedToQueueHandoffMessage = "failed to queue handoff"
MarkedForHandoffMessage = "connection marked for handoff"
// ========================================
// PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing
// ========================================
InvalidNotificationFormatMessage = "invalid notification format"
InvalidNotificationTypeFormatMessage = "invalid notification type format"
InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification"
InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification"
InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification"
NoConnectionInHandlerContextMessage = "no connection in handler context"
InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context"
SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint"
RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification"
UnrelaxedTimeoutMessage = "clearing relaxed timeout"
ManagerNotInitializedMessage = "manager not initialized"
FailedToMarkForHandoffMessage = "failed to mark connection for handoff"
// ========================================
// used in pool/conn
// ========================================
UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline"
)
func HandoffStarted(connID uint64, newEndpoint string) string {
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
})
}
func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string {
message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
"attempt": attempt,
"maxAttempts": maxAttempts,
"error": err.Error(),
})
}
func HandoffSucceeded(connID uint64, newEndpoint string) string {
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
})
}
// Timeout-related log functions
func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string {
message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
"timeout": fmt.Sprintf("%v", timeout),
})
}
func UnrelaxedTimeout(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
func UnrelaxedTimeoutAfterDeadline(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
// Handoff queue and marking functions
func HandoffQueueFull(queueLen, queueCap int) string {
message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap)
return appendJSONIfDebug(message, map[string]interface{}{
"queueLen": queueLen,
"queueCap": queueCap,
})
}
func FailedToQueueHandoff(connID uint64, err error) string {
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"error": err.Error(),
})
}
func FailedToMarkForHandoff(connID uint64, err error) string {
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"error": err.Error(),
})
}
func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string {
message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"error": err.Error(),
})
}
func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string {
message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"maxRetries": maxRetries,
})
}
// Notification processing functions
func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string {
message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
"notificationType": notificationType,
"notification": fmt.Sprintf("%v", notification),
})
}
func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string {
message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
"error": err.Error(),
"notification": fmt.Sprintf("%v", notification),
})
}
func ProcessingNotificationSucceeded(connID uint64, notificationType string) string {
message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
})
}
// Moving operation tracking functions
func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"seqID": seqID,
})
}
func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"seqID": seqID,
})
}
func UntrackingMovingOperation(connID uint64, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
})
}
func OperationNotTracked(connID uint64, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
})
}
// Connection pool functions
func RemovingConnectionFromPool(connID uint64, reason error) string {
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"reason": reason.Error(),
})
}
func NoPoolProvidedCannotRemove(connID uint64, reason error) string {
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"reason": reason.Error(),
})
}
// Circuit breaker functions
func CircuitBreakerOpen(connID uint64, endpoint string) string {
message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
})
}
// Additional handoff functions for specific cases
func ConnectionNotMarkedForHandoff(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
func ConnectionNotMarkedForHandoffError(connID uint64) string {
return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage)
}
func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string {
message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"retries": retries,
"newEndpoint": newEndpoint,
"oldEndpoint": oldEndpoint,
})
}
func CannotQueueHandoffForRetry(err error) string {
message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"error": err.Error(),
})
}
// Validation and error functions
func InvalidNotificationFormat(notification interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"notification": fmt.Sprintf("%v", notification),
})
}
func InvalidNotificationTypeFormat(notificationType interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": fmt.Sprintf("%v", notificationType),
})
}
// InvalidNotification creates a log message for invalid notifications of any type
func InvalidNotification(notificationType string, notification interface{}) string {
message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"notification": fmt.Sprintf("%v", notification),
})
}
func InvalidSeqIDInMovingNotification(seqID interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"seqID": fmt.Sprintf("%v", seqID),
})
}
func InvalidTimeSInMovingNotification(timeS interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS)
return appendJSONIfDebug(message, map[string]interface{}{
"timeS": fmt.Sprintf("%v", timeS),
})
}
func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"newEndpoint": fmt.Sprintf("%v", newEndpoint),
})
}
func NoConnectionInHandlerContext(notificationType string) string {
message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
})
}
func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string {
message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connType": fmt.Sprintf("%T", conn),
})
}
func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string {
message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seconds": seconds,
})
}
func ManagerNotInitialized() string {
return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{})
}
func FailedToRegisterHandler(notificationType string, err error) string {
message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"error": err.Error(),
})
}
func ShutdownError() string {
return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{})
}
// Configuration validation error functions
func InvalidRelaxedTimeoutError() string {
return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidHandoffTimeoutError() string {
return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidHandoffWorkersError() string {
return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{})
}
func InvalidHandoffQueueSizeError() string {
return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{})
}
func InvalidPostHandoffRelaxedDurationError() string {
return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{})
}
func InvalidEndpointTypeError() string {
return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{})
}
func InvalidMaintNotificationsError() string {
return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{})
}
func InvalidHandoffRetriesError() string {
return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{})
}
func InvalidClientError() string {
return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{})
}
func InvalidNotificationError() string {
return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{})
}
func MaxHandoffRetriesReachedError() string {
return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{})
}
func HandoffQueueFullError() string {
return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerFailureThresholdError() string {
return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerResetTimeoutError() string {
return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerMaxRequestsError() string {
return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{})
}
// Configuration and debug functions
func DebugLoggingEnabled() string {
return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{})
}
func ConfigDebug(config interface{}) string {
message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config)
return appendJSONIfDebug(message, map[string]interface{}{
"config": fmt.Sprintf("%+v", config),
})
}
// Handoff worker functions
func WorkerExitingDueToShutdown() string {
return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{})
}
func WorkerExitingDueToShutdownWhileProcessing() string {
return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{})
}
func WorkerPanicRecovered(panicValue interface{}) string {
message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue)
return appendJSONIfDebug(message, map[string]interface{}{
"panic": fmt.Sprintf("%v", panicValue),
})
}
func WorkerExitingDueToInactivityTimeout(timeout interface{}) string {
message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout)
return appendJSONIfDebug(message, map[string]interface{}{
"timeout": fmt.Sprintf("%v", timeout),
})
}
func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string {
message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"timeout": fmt.Sprintf("%v", timeout),
"until": until,
})
}
// Example hooks functions
func MetricsHookProcessingNotification(notificationType string, connID uint64) string {
message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connID": connID,
})
}
func MetricsHookRecordedError(notificationType string, connID uint64, err error) string {
message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connID": connID,
"error": err.Error(),
})
}
// Pool hook functions
func MarkedForHandoff(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
// Circuit breaker additional functions
func CircuitBreakerTransitioningToHalfOpen(endpoint string) string {
message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
})
}
func CircuitBreakerOpened(endpoint string, failures int64) string {
message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
"failures": failures,
})
}
func CircuitBreakerReopened(endpoint string) string {
message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
})
}
func CircuitBreakerClosed(endpoint string, successes int64) string {
message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
"successes": successes,
})
}
func CircuitBreakerCleanup(removed int, total int) string {
message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total)
return appendJSONIfDebug(message, map[string]interface{}{
"removed": removed,
"total": total,
})
}
// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages
// Returns a map containing the parsed key-value pairs from the structured data section
// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}"
// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"}
func ExtractDataFromLogMessage(logMessage string) map[string]interface{} {
result := make(map[string]interface{})
// Find the JSON data section at the end of the message
re := regexp.MustCompile(`(\{.*\})$`)
matches := re.FindStringSubmatch(logMessage)
if len(matches) < 2 {
return result
}
jsonStr := matches[1]
if jsonStr == "" {
return result
}
// Parse the JSON directly
var jsonResult map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil {
return jsonResult
}
// If JSON parsing fails, return empty map
return result
}

View File

@@ -2,6 +2,7 @@ package pool_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
@@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
@@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {
if err != nil {
b.Fatal(err)
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("Bench test remove"))
}
})
})

View File

@@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should use default buffer sizes when not specified", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
})
@@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
@@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() {
It("should handle zero buffer sizes by using defaults", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
@@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() {
// without setting ReadBufferSize and WriteBufferSize
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 1000,
// ReadBufferSize and WriteBufferSize are not set (will be 0)
})

View File

@@ -3,26 +3,88 @@ package pool
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/proto"
)
var noDeadline = time.Time{}
// Global atomic counter for connection IDs
var connIDCounter uint64
// HandoffState represents the atomic state for connection handoffs
// This struct is stored atomically to prevent race conditions between
// checking handoff status and reading handoff parameters
type HandoffState struct {
ShouldHandoff bool // Whether connection should be handed off
Endpoint string // New endpoint for handoff
SeqID int64 // Sequence ID from MOVING notification
}
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
type atomicNetConn struct {
conn net.Conn
}
// generateConnID generates a fast unique identifier for a connection with zero allocations
func generateConnID() uint64 {
return atomic.AddUint64(&connIDCounter, 1)
}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
usedAt int64 // atomic
// Lock-free netConn access using atomic.Value
// Contains *atomicNetConn wrapper, accessed atomically for better performance
netConnAtomic atomic.Value // stores *atomicNetConn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
// Lightweight mutex to protect reader operations during handoff
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
Inited atomic.Bool
pooled bool
pubsub bool
closed atomic.Bool
createdAt time.Time
expiresAt time.Time
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
// Counter to track multiple relaxed timeout setters if we have nested calls
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
// if counter reaches 0, we clear the relaxed timeouts
relaxedCounter atomic.Int32
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
// Connection identifier for unique tracking
id uint64 // Unique numeric identifier for this connection
// Handoff state - using atomic operations for lock-free access
usableAtomic atomic.Bool // Connection usability state
handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts
// Atomic handoff state to prevent race conditions
// Stores *HandoffState to ensure atomic updates of all handoff-related fields
handoffStateAtomic atomic.Value // stores *HandoffState
onClose func() error
}
@@ -33,8 +95,8 @@ func NewConn(netConn net.Conn) *Conn {
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
id: generateConnID(), // Generate unique ID for this connection
}
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
@@ -50,6 +112,21 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
}
// Store netConn atomically for lock-free access using wrapper
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
// Initialize atomic state
cn.usableAtomic.Store(false) // false initially, set to true after initialization
cn.handoffRetriesAtomic.Store(0) // 0 initially
// Initialize handoff state atomically
initialHandoffState := &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
cn.handoffStateAtomic.Store(initialHandoffState)
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
@@ -64,23 +141,439 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
if v := cn.netConnAtomic.Load(); v != nil {
if wrapper, ok := v.(*atomicNetConn); ok {
return wrapper.conn
}
}
return nil
}
// setNetConn stores the network connection atomically (lock-free).
// This is used for the fast path of connection replacement.
func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Lock-free helper methods for handoff state management
// isUsable returns true if the connection is safe to use (lock-free).
func (cn *Conn) isUsable() bool {
return cn.usableAtomic.Load()
}
// setUsable sets the usable flag atomically (lock-free).
func (cn *Conn) setUsable(usable bool) {
cn.usableAtomic.Store(usable)
}
// getHandoffState returns the current handoff state atomically (lock-free).
func (cn *Conn) getHandoffState() *HandoffState {
state := cn.handoffStateAtomic.Load()
if state == nil {
// Return default state if not initialized
return &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
}
return state.(*HandoffState)
}
// setHandoffState sets the handoff state atomically (lock-free).
func (cn *Conn) setHandoffState(state *HandoffState) {
cn.handoffStateAtomic.Store(state)
}
// shouldHandoff returns true if connection needs handoff (lock-free).
func (cn *Conn) shouldHandoff() bool {
return cn.getHandoffState().ShouldHandoff
}
// getMovingSeqID returns the sequence ID atomically (lock-free).
func (cn *Conn) getMovingSeqID() int64 {
return cn.getHandoffState().SeqID
}
// getNewEndpoint returns the new endpoint atomically (lock-free).
func (cn *Conn) getNewEndpoint() string {
return cn.getHandoffState().Endpoint
}
// setHandoffRetries sets the retry count atomically (lock-free).
func (cn *Conn) setHandoffRetries(retries int) {
cn.handoffRetriesAtomic.Store(uint32(retries))
}
// incrementHandoffRetries atomically increments and returns the new retry count (lock-free).
func (cn *Conn) incrementHandoffRetries(delta int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
func (cn *Conn) IsUsable() bool {
return cn.isUsable()
}
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
func (cn *Conn) IsPooled() bool {
return cn.pooled
}
// IsPubSub returns true if the connection is used for PubSub.
func (cn *Conn) IsPubSub() bool {
return cn.pubsub
}
func (cn *Conn) IsInited() bool {
return cn.Inited.Load()
}
// SetUsable sets the usable flag for the connection (lock-free).
func (cn *Conn) SetUsable(usable bool) {
cn.setUsable(usable)
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
}
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
// Uses atomic operations for lock-free access.
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
deadlineNs := cn.relaxedDeadlineNs.Load()
if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) {
// Use atomic load to get current value for CAS to avoid stale value race
current := cn.relaxedCounter.Load()
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
cn.clearRelaxedTimeout()
}
}
}
func (cn *Conn) clearRelaxedTimeout() {
cn.relaxedReadTimeoutNs.Store(0)
cn.relaxedWriteTimeoutNs.Store(0)
cn.relaxedDeadlineNs.Store(0)
cn.relaxedCounter.Store(0)
}
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
// This checks both the timeout values and the deadline (if set).
// Uses atomic operations for lock-free access.
func (cn *Conn) HasRelaxedTimeout() bool {
// Fast path: no relaxed timeouts are set
if cn.relaxedCounter.Load() <= 0 {
return false
}
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// If no relaxed timeouts are set, return false
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
return false
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, relaxed timeouts are active
if deadlineNs == 0 {
return true
}
// If deadline is set, check if it's still in the future
return time.Now().UnixNano() < deadlineNs
}
// getEffectiveReadTimeout returns the timeout to use for read operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
// Fast path: no relaxed timeout set
if readTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(readTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
// getEffectiveWriteTimeout returns the timeout to use for write operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// Fast path: no relaxed timeout set
if writeTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(writeTimeoutNs)
}
nowNs := time.Now().UnixNano()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
cn.initConnFunc = fn
}
// ExecuteInitConn runs the stored connection initialization function if available.
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
if cn.initConnFunc != nil {
return cn.initConnFunc(ctx, cn)
}
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
// Store the new connection atomically first (lock-free)
cn.setNetConn(netConn)
// Protect reader reset operations to avoid data races
// Use write lock since we're modifying the reader state
cn.readerMu.Lock()
cn.rd.Reset(netConn)
cn.readerMu.Unlock()
cn.bw.Reset(netConn)
}
// GetNetConn safely returns the current network connection using atomic load (lock-free).
// This method is used by the pool for health checks and provides better performance.
func (cn *Conn) GetNetConn() net.Conn {
return cn.getNetConn()
}
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
// New connection is not initialized yet
cn.Inited.Store(false)
// Replace the underlying connection
cn.SetNetConn(netConn)
return cn.ExecuteInitConn(ctx)
}
// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free).
// Returns an error if the connection is already marked for handoff.
// This method uses atomic compare-and-swap to ensure all handoff state is updated atomically.
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
const maxRetries = 50
const baseDelay = time.Microsecond
for attempt := 0; attempt < maxRetries; attempt++ {
currentState := cn.getHandoffState()
// Check if already marked for handoff
if currentState.ShouldHandoff {
return errors.New("connection is already marked for handoff")
}
// Create new state with handoff enabled
newState := &HandoffState{
ShouldHandoff: true,
Endpoint: newEndpoint,
SeqID: seqID,
}
// Atomic compare-and-swap to update entire state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
return nil
}
// If CAS failed, add exponential backoff to reduce contention
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
}
return fmt.Errorf("failed to mark connection for handoff after %d attempts due to high contention", maxRetries)
}
func (cn *Conn) MarkQueuedForHandoff() error {
const maxRetries = 50
const baseDelay = time.Microsecond
for attempt := 0; attempt < maxRetries; attempt++ {
currentState := cn.getHandoffState()
// Check if marked for handoff
if !currentState.ShouldHandoff {
return errors.New("connection was not marked for handoff")
}
// Create new state with handoff disabled (queued)
newState := &HandoffState{
ShouldHandoff: false,
Endpoint: currentState.Endpoint, // Preserve endpoint for handoff processing
SeqID: currentState.SeqID, // Preserve seqID for handoff processing
}
// Atomic compare-and-swap to update state
if cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
cn.setUsable(false)
return nil
}
// If CAS failed, add exponential backoff to reduce contention
// the delay will be 1, 2, 4... up to 512 microseconds
if attempt < maxRetries-1 {
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
time.Sleep(delay)
}
}
return fmt.Errorf("failed to mark connection as queued for handoff after %d attempts due to high contention", maxRetries)
}
// ShouldHandoff returns true if the connection needs to be handed off (lock-free).
func (cn *Conn) ShouldHandoff() bool {
return cn.shouldHandoff()
}
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
func (cn *Conn) GetHandoffEndpoint() string {
return cn.getNewEndpoint()
}
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
func (cn *Conn) GetMovingSeqID() int64 {
return cn.getMovingSeqID()
}
// GetHandoffInfo returns all handoff information atomically (lock-free).
// This method prevents race conditions by returning all handoff state in a single atomic operation.
// Returns (shouldHandoff, endpoint, seqID).
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
state := cn.getHandoffState()
return state.ShouldHandoff, state.Endpoint, state.SeqID
}
// GetID returns the unique identifier for this connection.
func (cn *Conn) GetID() uint64 {
return cn.id
}
// ClearHandoffState clears the handoff state after successful handoff (lock-free).
func (cn *Conn) ClearHandoffState() {
// Create clean state
cleanState := &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
// Atomically set clean state
cn.setHandoffState(cleanState)
cn.setHandoffRetries(0)
cn.setUsable(true) // Connection is safe to use again after handoff completes
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return cn.incrementHandoffRetries(n)
}
// GetHandoffRetries returns the current handoff retry count (lock-free).
func (cn *Conn) HandoffRetries() int {
return int(cn.handoffRetriesAtomic.Load())
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
return cn.rd.Buffered() > 0
}
// PeekReplyTypeSafe safely peeks at the reply type.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
if cn.rd.Buffered() <= 0 {
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
}
return cn.rd.PeekReplyType()
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Write(b)
}
return 0, net.ErrClosed
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.RemoteAddr()
}
return nil
}
@@ -89,7 +582,16 @@ func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
// Get the connection directly from atomic storage
netConn := cn.getNetConn()
if netConn == nil {
return fmt.Errorf("redis: connection not available")
}
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
}
@@ -100,13 +602,26 @@ func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
// Always set write deadline, even if getNetConn() returns nil
// This prevents write operations from hanging indefinitely
if netConn := cn.getNetConn(); netConn != nil {
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
} else {
// If getNetConn() returns nil, we still need to respect the timeout
// Return an error to prevent indefinite blocking
return fmt.Errorf("redis: conn[%d] not available for write operation", cn.GetID())
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
if netConn := cn.getNetConn(); netConn != nil {
cn.bw.Reset(netConn)
}
}
if err := fn(cn.wr); err != nil {
@@ -116,12 +631,33 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}
func (cn *Conn) IsClosed() bool {
return cn.closed.Load()
}
func (cn *Conn) Close() error {
cn.closed.Store(true)
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Close()
}
return nil
}
// MaybeHasData tries to peek at the next byte in the socket without consuming it
// This is used to check if there are push notifications available
// Important: This will work on Linux, but not on Windows
func (cn *Conn) MaybeHasData() bool {
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return maybeHasData(netConn)
}
return false
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {

View File

@@ -12,6 +12,9 @@ import (
var errUnexpectedRead = errors.New("unexpected read from socket")
// connCheck checks if the connection is still alive and if there is data in the socket
// it will try to peek at the next byte without consuming it since we may want to work with it
// later on (e.g. push notifications)
func connCheck(conn net.Conn) error {
// Reset previous timeout.
_ = conn.SetDeadline(time.Time{})
@@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error {
if err := rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
// Use MSG_PEEK to peek at data without consuming it
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
switch {
case n == 0 && err == nil:
sysErr = io.EOF
@@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error {
return sysErr
}
// maybeHasData checks if there is data in the socket without consuming it
func maybeHasData(conn net.Conn) bool {
return connCheck(conn) == errUnexpectedRead
}

View File

@@ -2,8 +2,19 @@
package pool
import "net"
import (
"errors"
"net"
)
func connCheck(conn net.Conn) error {
// errUnexpectedRead is placeholder error variable for non-unix build constraints
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(_ net.Conn) error {
return nil
}
// since we can't check for data on the socket, we just assume there is some
func maybeHasData(_ net.Conn) bool {
return true
}

View File

@@ -0,0 +1,92 @@
package pool
import (
"net"
"sync"
"testing"
"time"
)
// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout
func TestConcurrentRelaxedTimeoutClearing(t *testing.T) {
// Create a dummy connection for testing
netConn := &net.TCPConn{}
cn := NewConn(netConn)
defer cn.Close()
// Set relaxed timeout multiple times to increase counter
cn.SetRelaxedTimeout(time.Second, time.Second)
cn.SetRelaxedTimeout(time.Second, time.Second)
cn.SetRelaxedTimeout(time.Second, time.Second)
// Verify counter is 3
if count := cn.relaxedCounter.Load(); count != 3 {
t.Errorf("Expected relaxed counter to be 3, got %d", count)
}
// Clear timeouts concurrently to test race condition fix
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cn.ClearRelaxedTimeout()
}()
}
wg.Wait()
// Verify counter is 0 and timeouts are cleared
if count := cn.relaxedCounter.Load(); count != 0 {
t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count)
}
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout)
}
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout)
}
}
// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario
func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) {
netConn := &net.TCPConn{}
cn := NewConn(netConn)
defer cn.Close()
// Set relaxed timeout once
cn.SetRelaxedTimeout(time.Second, time.Second)
// Verify counter is 1
if count := cn.relaxedCounter.Load(); count != 1 {
t.Errorf("Expected relaxed counter to be 1, got %d", count)
}
// Test concurrent clearing with race condition scenario
var wg sync.WaitGroup
// Multiple goroutines try to clear simultaneously
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cn.ClearRelaxedTimeout()
}()
}
wg.Wait()
// Verify final state is consistent
if count := cn.relaxedCounter.Load(); count != 0 {
t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count)
}
// Verify timeouts are actually cleared
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout)
}
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout)
}
if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 {
t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline)
}
}

View File

@@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) {
}
func (cn *Conn) NetConn() net.Conn {
return cn.netConn
return cn.getNetConn()
}
func (p *ConnPool) CheckMinIdleConns() {

114
internal/pool/hooks.go Normal file
View File

@@ -0,0 +1,114 @@
package pool
import (
"context"
"sync"
)
// PoolHook defines the interface for connection lifecycle hooks.
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) error
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
}
// PoolHookManager manages multiple pool hooks.
type PoolHookManager struct {
hooks []PoolHook
hooksMu sync.RWMutex
}
// NewPoolHookManager creates a new pool hook manager.
func NewPoolHookManager() *PoolHookManager {
return &PoolHookManager{
hooks: make([]PoolHook, 0),
}
}
// AddHook adds a pool hook to the manager.
// Hooks are called in the order they were added.
func (phm *PoolHookManager) AddHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
phm.hooks = append(phm.hooks, hook)
}
// RemoveHook removes a pool hook from the manager.
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
for i, h := range phm.hooks {
if h == hook {
// Remove hook by swapping with last element and truncating
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
phm.hooks = phm.hooks[:len(phm.hooks)-1]
break
}
}
}
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
for _, hook := range phm.hooks {
if err := hook.OnGet(ctx, conn, isNewConn); err != nil {
return err
}
}
return nil
}
// ProcessOnPut calls all OnPut hooks in order.
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
shouldPool = true // Default to pooling the connection
for _, hook := range phm.hooks {
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
if hookErr != nil {
return false, true, hookErr
}
// If any hook says to remove or not pool, respect that decision
if hookShouldRemove {
return false, true, nil
}
if !hookShouldPool {
shouldPool = false
}
}
return shouldPool, false, nil
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
return len(phm.hooks)
}
// GetHooks returns a copy of all registered hooks.
func (phm *PoolHookManager) GetHooks() []PoolHook {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
hooks := make([]PoolHook, len(phm.hooks))
copy(hooks, phm.hooks)
return hooks
}

213
internal/pool/hooks_test.go Normal file
View File

@@ -0,0 +1,213 @@
package pool
import (
"context"
"errors"
"net"
"testing"
"time"
)
// TestHook for testing hook functionality
type TestHook struct {
OnGetCalled int
OnPutCalled int
GetError error
PutError error
ShouldPool bool
ShouldRemove bool
}
func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error {
th.OnGetCalled++
return th.GetError
}
func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
th.OnPutCalled++
return th.ShouldPool, th.ShouldRemove, th.PutError
}
func TestPoolHookManager(t *testing.T) {
manager := NewPoolHookManager()
// Test initial state
if manager.GetHookCount() != 0 {
t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount())
}
// Add hooks
hook1 := &TestHook{ShouldPool: true}
hook2 := &TestHook{ShouldPool: true}
manager.AddHook(hook1)
manager.AddHook(hook2)
if manager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount())
}
// Test ProcessOnGet
ctx := context.Background()
conn := &Conn{} // Mock connection
err := manager.ProcessOnGet(ctx, conn, false)
if err != nil {
t.Errorf("ProcessOnGet should not error: %v", err)
}
if hook1.OnGetCalled != 1 {
t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled)
}
if hook2.OnGetCalled != 1 {
t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled)
}
// Test ProcessOnPut
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if !shouldPool {
t.Error("Expected shouldPool to be true")
}
if shouldRemove {
t.Error("Expected shouldRemove to be false")
}
if hook1.OnPutCalled != 1 {
t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled)
}
if hook2.OnPutCalled != 1 {
t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled)
}
// Remove a hook
manager.RemoveHook(hook1)
if manager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount())
}
}
func TestHookErrorHandling(t *testing.T) {
manager := NewPoolHookManager()
// Hook that returns error on Get
errorHook := &TestHook{
GetError: errors.New("test error"),
ShouldPool: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(errorHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
// Test that error stops processing
err := manager.ProcessOnGet(ctx, conn, false)
if err == nil {
t.Error("Expected error from ProcessOnGet")
}
if errorHook.OnGetCalled != 1 {
t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled)
}
// normalHook should not be called due to error
if normalHook.OnGetCalled != 0 {
t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled)
}
}
func TestHookShouldRemove(t *testing.T) {
manager := NewPoolHookManager()
// Hook that says to remove connection
removeHook := &TestHook{
ShouldPool: false,
ShouldRemove: true,
}
normalHook := &TestHook{ShouldPool: true}
manager.AddHook(removeHook)
manager.AddHook(normalHook)
ctx := context.Background()
conn := &Conn{}
shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn)
if err != nil {
t.Errorf("ProcessOnPut should not error: %v", err)
}
if shouldPool {
t.Error("Expected shouldPool to be false")
}
if !shouldRemove {
t.Error("Expected shouldRemove to be true")
}
if removeHook.OnPutCalled != 1 {
t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled)
}
// normalHook should not be called due to early return
if normalHook.OnPutCalled != 0 {
t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled)
}
}
func TestPoolWithHooks(t *testing.T) {
// Create a pool with hooks
hookManager := NewPoolHookManager()
testHook := &TestHook{ShouldPool: true}
hookManager.AddHook(testHook)
opt := &Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // Mock connection
},
PoolSize: 1,
DialTimeout: time.Second,
}
pool := NewConnPool(opt)
defer pool.Close()
// Add hook to pool after creation
pool.AddPoolHook(testHook)
// Verify hooks are initialized
if pool.hookManager == nil {
t.Error("Expected hookManager to be initialized")
}
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount())
}
// Test adding hook to pool
additionalHook := &TestHook{ShouldPool: true}
pool.AddPoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 2 {
t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount())
}
// Test removing hook from pool
pool.RemovePoolHook(additionalHook)
if pool.hookManager.GetHookCount() != 1 {
t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount())
}
}

View File

@@ -9,6 +9,8 @@ import (
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
)
var (
@@ -21,6 +23,23 @@ var (
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
// popAttempts is the maximum number of attempts to find a usable connection
// when popping from the idle connection pool. This handles cases where connections
// are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues).
// Value of 50 provides sufficient resilience without excessive overhead.
// This is capped by the idle connection count, so we won't loop excessively.
popAttempts = 50
// getAttempts is the maximum number of attempts to get a connection that passes
// hook validation (e.g., maintenanceNotifications upgrade hooks). This protects against race conditions
// where hooks might temporarily reject connections during cluster transitions.
// Value of 3 balances resilience with performance - most hook rejections resolve quickly.
getAttempts = 3
minTime = time.Unix(-2208988800, 0) // Jan 1, 1900
maxTime = minTime.Add(1<<63 - 1)
noExpiration = maxTime
)
var timers = sync.Pool{
@@ -37,11 +56,14 @@ type Stats struct {
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
WaitCount uint32 // number of times a connection was waited
Unusable uint32 // number of times a connection was found to be unusable
WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
PubSubStats PubSubStats
}
type Pooler interface {
@@ -56,24 +78,35 @@ type Pooler interface {
IdleLen() int
Stats() *Stats
AddPoolHook(hook PoolHook)
RemovePoolHook(hook PoolHook)
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
PoolFIFO bool
PoolSize int
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxActiveConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
Dialer func(context.Context) (net.Conn, error)
ReadBufferSize int
WriteBufferSize int
PoolFIFO bool
PoolSize int32
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int32
MaxIdleConns int32
MaxActiveConns int32
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
PushNotificationsEnabled bool
// DialerRetries is the maximum number of retry attempts when dialing fails.
// Default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
// Default: 100ms
DialerRetryTimeout time.Duration
}
type lastDialErrorWrap struct {
@@ -89,16 +122,21 @@ type ConnPool struct {
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
conns map[uint64]*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
poolSize atomic.Int32
idleConnsLen atomic.Int32
idleCheckInProgress atomic.Bool
stats Stats
waitDurationNs atomic.Int64
_closed uint32 // atomic
// Pool hooks manager for flexible connection processing
hookManagerMu sync.RWMutex
hookManager *PoolHookManager
}
var _ Pooler = (*ConnPool)(nil)
@@ -108,34 +146,69 @@ func NewConnPool(opt *Options) *ConnPool {
cfg: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize),
}
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
// Only create MinIdleConns if explicitly requested (> 0)
// This avoids creating connections during pool initialization for tests
if opt.MinIdleConns > 0 {
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
}
return p
}
// initializeHooks sets up the pool hooks system.
func (p *ConnPool) initializeHooks() {
p.hookManager = NewPoolHookManager()
}
// AddPoolHook adds a pool hook to the pool.
func (p *ConnPool) AddPoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager == nil {
p.initializeHooks()
}
p.hookManager.AddHook(hook)
}
// RemovePoolHook removes a pool hook from the pool.
func (p *ConnPool) RemovePoolHook(hook PoolHook) {
p.hookManagerMu.Lock()
defer p.hookManagerMu.Unlock()
if p.hookManager != nil {
p.hookManager.RemoveHook(hook)
}
}
func (p *ConnPool) checkMinIdleConns() {
if !p.idleCheckInProgress.CompareAndSwap(false, true) {
return
}
defer p.idleCheckInProgress.Store(false)
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
// Only create idle connections if we haven't reached the total pool size limit
// MinIdleConns should be a subset of PoolSize, not additional connections
for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns {
select {
case p.queue <- struct{}{}:
p.poolSize++
p.idleConnsLen++
p.poolSize.Add(1)
p.idleConnsLen.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
p.freeTurn()
internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
@@ -144,12 +217,9 @@ func (p *ConnPool) checkMinIdleConns() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
p.poolSize.Add(-1)
p.idleConnsLen.Add(-1)
}
p.freeTurn()
}()
default:
@@ -166,6 +236,9 @@ func (p *ConnPool) addIdleConn() error {
if err != nil {
return err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
@@ -176,11 +249,15 @@ func (p *ConnPool) addIdleConn() error {
return ErrClosed
}
p.conns = append(p.conns, cn)
p.conns[cn.GetID()] = cn
p.idleConns = append(p.idleConns, cn)
return nil
}
// NewConn creates a new connection and returns it to the user.
// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size.
//
// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support maintnotifications upgrades.
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
@@ -190,33 +267,44 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, ErrClosed
}
p.connsMu.Lock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) {
return nil, ErrPoolExhausted
}
p.connsMu.Unlock()
cn, err := p.dialConn(ctx, pooled)
dialCtx, cancel := context.WithTimeout(ctx, p.cfg.DialTimeout)
defer cancel()
cn, err := p.dialConn(dialCtx, pooled)
if err != nil {
return nil, err
}
// Mark connection as usable after successful creation
// This is essential for normal pool operations
cn.SetUsable(true)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) {
_ = cn.Close()
return nil, ErrPoolExhausted
}
p.conns = append(p.conns, cn)
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.closed() {
_ = cn.Close()
return nil, ErrClosed
}
// Check if pool was closed while we were waiting for the lock
if p.conns == nil {
p.conns = make(map[uint64]*Conn)
}
p.conns[cn.GetID()] = cn
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
currentPoolSize := p.poolSize.Load()
if currentPoolSize >= p.cfg.PoolSize {
cn.pooled = false
} else {
p.poolSize++
p.poolSize.Add(1)
}
}
@@ -232,18 +320,57 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}
netConn, err := p.cfg.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, err
// Retry dialing with backoff
// the context timeout is already handled by the context passed in
// so we may never reach the max retries, higher values don't hurt
maxRetries := p.cfg.DialerRetries
if maxRetries <= 0 {
maxRetries = 5 // Default value
}
backoffDuration := p.cfg.DialerRetryTimeout
if backoffDuration <= 0 {
backoffDuration = 100 * time.Millisecond // Default value
}
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
return cn, nil
var lastErr error
shouldLoop := true
// when the timeout is reached, we should stop retrying
// but keep the lastErr to return to the caller
// instead of a generic context deadline exceeded error
for attempt := 0; (attempt < maxRetries) && shouldLoop; attempt++ {
netConn, err := p.cfg.Dialer(ctx)
if err != nil {
lastErr = err
// Add backoff delay for retry attempts
// (not for the first attempt, do at least one)
select {
case <-ctx.Done():
shouldLoop = false
case <-time.After(backoffDuration):
// Continue with retry
}
continue
}
// Success - create connection
cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
cn.pooled = pooled
if p.cfg.ConnMaxLifetime > 0 {
cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime)
} else {
cn.expiresAt = noExpiration
}
return cn, nil
}
internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr)
// All retries failed - handle error tracking
p.setLastDialError(lastErr)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, lastErr
}
func (p *ConnPool) tryDial() {
@@ -283,6 +410,14 @@ func (p *ConnPool) getLastDialError() error {
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return p.getConn(ctx)
}
// getConn returns a connection from the pool.
func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
var cn *Conn
var err error
if p.closed() {
return nil, ErrClosed
}
@@ -291,9 +426,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
now := time.Now()
attempts := 0
for {
if attempts >= getAttempts {
internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts)
break
}
attempts++
p.connsMu.Lock()
cn, err := p.popIdle()
cn, err = p.popIdle()
p.connsMu.Unlock()
if err != nil {
@@ -305,11 +448,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
break
}
if !p.isHealthyConn(cn) {
if !p.isHealthyConn(cn, now) {
_ = p.CloseConn(cn)
continue
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil {
internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err)
// Failed to process connection, discard it
_ = p.CloseConn(cn)
continue
}
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
@@ -322,6 +479,19 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
return nil, err
}
// Process connection using the hooks system
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil {
// Failed to process connection, discard it
internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err)
_ = p.CloseConn(newcn)
return nil, err
}
}
return newcn, nil
}
@@ -350,7 +520,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error {
}
return ctx.Err()
case p.queue <- struct{}{}:
p.waitDurationNs.Add(time.Since(start).Nanoseconds())
p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano())
atomic.AddUint32(&p.stats.WaitCount, 1)
if !timer.Stop() {
<-timer.C
@@ -370,52 +540,130 @@ func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
defer p.checkMinIdleConns()
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
} else {
idx := n - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
attempts := 0
maxAttempts := util.Min(popAttempts, n)
for attempts < maxAttempts {
if len(p.idleConns) == 0 {
return nil, nil
}
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:len(p.idleConns)-1]
} else {
idx := len(p.idleConns) - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
attempts++
if cn.IsUsable() {
p.idleConnsLen.Add(-1)
break
}
// Connection is not usable, put it back in the pool
if p.cfg.PoolFIFO {
// FIFO: put at end (will be picked up last since we pop from front)
p.idleConns = append(p.idleConns, cn)
} else {
// LIFO: put at beginning (will be picked up last since we pop from end)
p.idleConns = append([]*Conn{cn}, p.idleConns...)
}
cn = nil
}
p.idleConnsLen--
p.checkMinIdleConns()
// If we exhausted all attempts without finding a usable connection, return nil
if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() {
internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts)
return nil, nil
}
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf(ctx, "Conn has unread data")
p.Remove(ctx, cn, BadConnError{})
// Process connection using the hooks system
shouldPool := true
shouldRemove := false
var err error
if cn.HasBufferedData() {
// Peek at the reply type to check if it's a push notification
if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush {
// Not a push notification or error peeking, remove connection
internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it")
p.Remove(ctx, cn, err)
}
// It's a push notification, allow pooling (client will handle it)
}
p.hookManagerMu.RLock()
hookManager := p.hookManager
p.hookManagerMu.RUnlock()
if hookManager != nil {
shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn)
if err != nil {
internal.Logger.Printf(ctx, "Connection hook error: %v", err)
p.Remove(ctx, cn, err)
return
}
}
// If hooks say to remove the connection, do so
if shouldRemove {
p.Remove(ctx, cn, errors.New("hook requested removal"))
return
}
// If processor says not to pool the connection, remove it
if !shouldPool {
p.Remove(ctx, cn, errors.New("hook requested no pooling"))
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
p.Remove(ctx, cn, errors.New("connection not pooled"))
return
}
var shouldCloseConn bool
p.connsMu.Lock()
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns {
// unusable conns are expected to become usable at some point (background process is reconnecting them)
// put them at the opposite end of the queue
if !cn.IsUsable() {
if p.cfg.PoolFIFO {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
} else {
p.connsMu.Lock()
p.idleConns = append([]*Conn{cn}, p.idleConns...)
p.connsMu.Unlock()
}
} else {
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.connsMu.Unlock()
}
p.idleConnsLen.Add(1)
} else {
p.removeConn(cn)
p.removeConnWithLock(cn)
shouldCloseConn = true
}
p.connsMu.Unlock()
p.freeTurn()
if shouldCloseConn {
@@ -425,8 +673,13 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
// Check if we need to create new idle connections to maintain MinIdleConns
p.checkMinIdleConns()
}
func (p *ConnPool) CloseConn(cn *Conn) error {
@@ -441,17 +694,23 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) {
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
cid := cn.GetID()
delete(p.conns, cid)
atomic.AddUint32(&p.stats.StaleConns, 1)
// Decrement pool size counter when removing a connection
if cn.pooled {
p.poolSize.Add(-1)
// this can be idle conn
for idx, ic := range p.idleConns {
if ic.GetID() == cid {
internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid)
p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...)
p.idleConnsLen.Add(-1)
break
}
break
}
}
atomic.AddUint32(&p.stats.StaleConns, 1)
}
func (p *ConnPool) closeConn(cn *Conn) error {
@@ -469,9 +728,9 @@ func (p *ConnPool) Len() int {
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
n := p.idleConnsLen.Load()
p.connsMu.Unlock()
return n
return int(n)
}
func (p *ConnPool) Stats() *Stats {
@@ -480,6 +739,7 @@ func (p *ConnPool) Stats() *Stats {
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
Unusable: atomic.LoadUint32(&p.stats.Unusable),
WaitDurationNs: p.waitDurationNs.Load(),
TotalConns: uint32(p.Len()),
@@ -520,28 +780,45 @@ func (p *ConnPool) Close() error {
}
}
p.conns = nil
p.poolSize = 0
p.poolSize.Store(0)
p.idleConns = nil
p.idleConnsLen = 0
p.idleConnsLen.Store(0)
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool {
// slight optimization, check expiresAt first.
if cn.expiresAt.Before(now) {
return false
}
// Check if connection has exceeded idle timeout
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false
}
if connCheck(cn.netConn) != nil {
return false
}
cn.SetUsedAt(now)
// Check basic connection health
// Use GetNetConn() to safely access netConn and avoid data races
if err := connCheck(cn.getNetConn()); err != nil {
// If there's unexpected data, it might be push notifications (RESP3)
// However, push notification processing is now handled by the client
// before WithReader to ensure proper context is available to handlers
if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead {
// we know that there is something in the buffer, so peek at the next reply type without
// the potential to block
if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush {
// For RESP3 connections with push notifications, we allow some buffered data
// The client will process these notifications before using the connection
internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID())
return true // Connection is healthy, client will handle notifications
}
return false // Unexpected data, not push notifications, connection is unhealthy
} else {
return false
}
}
return true
}

View File

@@ -1,6 +1,8 @@
package pool
import "context"
import (
"context"
)
type SingleConnPool struct {
pool Pooler
@@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int {
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(hook PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int {
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}

View File

@@ -2,15 +2,17 @@ package pool_test
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/logging"
)
var _ = Describe("ConnPool", func() {
@@ -20,7 +22,7 @@ var _ = Describe("ConnPool", func() {
BeforeEach(func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -45,11 +47,11 @@ var _ = Describe("ConnPool", func() {
<-closedChan
return &net.TCPConn{}, nil
},
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Hour,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
MinIdleConns: minIdleConns,
MinIdleConns: int32(minIdleConns),
})
wg.Wait()
Expect(connPool.Close()).NotTo(HaveOccurred())
@@ -105,7 +107,7 @@ var _ = Describe("ConnPool", func() {
// ok
}
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
// Check that Get is unblocked.
select {
@@ -130,8 +132,8 @@ var _ = Describe("MinIdleConns", func() {
newConnPool := func() *pool.ConnPool {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: poolSize,
MinIdleConns: minIdleConns,
PoolSize: int32(poolSize),
MinIdleConns: int32(minIdleConns),
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: -1,
@@ -168,7 +170,7 @@ var _ = Describe("MinIdleConns", func() {
Context("after Remove", func() {
BeforeEach(func() {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
})
It("has idle connections", func() {
@@ -245,7 +247,7 @@ var _ = Describe("MinIdleConns", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Remove(ctx, cns[i], nil)
connPool.Remove(ctx, cns[i], errors.New("test"))
mu.RUnlock()
})
@@ -309,7 +311,7 @@ var _ = Describe("race", func() {
It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolSize: int32(10),
PoolTimeout: time.Minute,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Millisecond,
@@ -328,7 +330,7 @@ var _ = Describe("race", func() {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(ctx, cn, nil)
connPool.Remove(ctx, cn, errors.New("test"))
}
}
})
@@ -339,15 +341,15 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1000,
MinIdleConns: 50,
PoolSize: int32(1000),
MinIdleConns: int32(50),
PoolTimeout: 3 * time.Second,
DialTimeout: 1 * time.Second,
}
p := pool.NewConnPool(opt)
var wg sync.WaitGroup
for i := 0; i < opt.PoolSize; i++ {
for i := int32(0); i < opt.PoolSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
@@ -366,8 +368,8 @@ var _ = Describe("race", func() {
Dialer: func(ctx context.Context) (net.Conn, error) {
panic("test panic")
},
PoolSize: 100,
MinIdleConns: 30,
PoolSize: int32(100),
MinIdleConns: int32(30),
}
p := pool.NewConnPool(opt)
@@ -377,14 +379,14 @@ var _ = Describe("race", func() {
state := p.Stats()
return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0
}, "3s", "50ms").Should(BeTrue())
})
})
It("wait", func() {
opt := &pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: 3 * time.Second,
}
p := pool.NewConnPool(opt)
@@ -415,7 +417,7 @@ var _ = Describe("race", func() {
return &net.TCPConn{}, nil
},
PoolSize: 1,
PoolSize: int32(1),
PoolTimeout: testPoolTimeout,
}
p := pool.NewConnPool(opt)
@@ -435,3 +437,73 @@ var _ = Describe("race", func() {
Expect(stats.Timeouts).To(Equal(uint32(1)))
})
})
// TestDialerRetryConfiguration tests the new DialerRetries and DialerRetryTimeout options
func TestDialerRetryConfiguration(t *testing.T) {
ctx := context.Background()
t.Run("CustomDialerRetries", func(t *testing.T) {
var attempts int64
failingDialer := func(ctx context.Context) (net.Conn, error) {
atomic.AddInt64(&attempts, 1)
return nil, errors.New("dial failed")
}
connPool := pool.NewConnPool(&pool.Options{
Dialer: failingDialer,
PoolSize: 1,
PoolTimeout: time.Second,
DialTimeout: time.Second,
DialerRetries: 3, // Custom retry count
DialerRetryTimeout: 10 * time.Millisecond, // Fast retries for testing
})
defer connPool.Close()
_, err := connPool.Get(ctx)
if err == nil {
t.Error("Expected error from failing dialer")
}
// Should have attempted at least 3 times (DialerRetries = 3)
// There might be additional attempts due to pool logic
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts < 3 {
t.Errorf("Expected at least 3 dial attempts, got %d", finalAttempts)
}
if finalAttempts > 6 {
t.Errorf("Expected around 3 dial attempts, got %d (too many)", finalAttempts)
}
})
t.Run("DefaultDialerRetries", func(t *testing.T) {
var attempts int64
failingDialer := func(ctx context.Context) (net.Conn, error) {
atomic.AddInt64(&attempts, 1)
return nil, errors.New("dial failed")
}
connPool := pool.NewConnPool(&pool.Options{
Dialer: failingDialer,
PoolSize: 1,
PoolTimeout: time.Second,
DialTimeout: time.Second,
// DialerRetries and DialerRetryTimeout not set - should use defaults
})
defer connPool.Close()
_, err := connPool.Get(ctx)
if err == nil {
t.Error("Expected error from failing dialer")
}
// Should have attempted 5 times (default DialerRetries = 5)
finalAttempts := atomic.LoadInt64(&attempts)
if finalAttempts != 5 {
t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts)
}
})
}
func init() {
logging.Disable()
}

78
internal/pool/pubsub.go Normal file
View File

@@ -0,0 +1,78 @@
package pool
import (
"context"
"net"
"sync"
"sync/atomic"
)
type PubSubStats struct {
Created uint32
Untracked uint32
Active uint32
}
// PubSubPool manages a pool of PubSub connections.
type PubSubPool struct {
opt *Options
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Map to track active PubSub connections
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
closed atomic.Bool
stats PubSubStats
}
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,
netDialer: netDialer,
}
}
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
if p.closed.Load() {
return nil, ErrClosed
}
netConn, err := p.netDialer(ctx, network, addr)
if err != nil {
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
cn.pubsub = true
atomic.AddUint32(&p.stats.Created, 1)
return cn, nil
}
func (p *PubSubPool) TrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, 1)
p.activeConns.Store(cn.GetID(), cn)
}
func (p *PubSubPool) UntrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, ^uint32(0))
atomic.AddUint32(&p.stats.Untracked, 1)
p.activeConns.Delete(cn.GetID())
}
func (p *PubSubPool) Close() error {
p.closed.Store(true)
p.activeConns.Range(func(key, value interface{}) bool {
cn := value.(*Conn)
_ = cn.Close()
return true
})
return nil
}
func (p *PubSubPool) Stats() *PubSubStats {
// load stats atomically
return &PubSubStats{
Created: atomic.LoadUint32(&p.stats.Created),
Untracked: atomic.LoadUint32(&p.stats.Untracked),
Active: atomic.LoadUint32(&p.stats.Active),
}
}

View File

@@ -0,0 +1,614 @@
package proto
import (
"bytes"
"fmt"
"math/rand"
"strings"
"testing"
)
// TestPeekPushNotificationName tests the updated PeekPushNotificationName method
func TestPeekPushNotificationName(t *testing.T) {
t.Run("ValidPushNotifications", func(t *testing.T) {
testCases := []struct {
name string
notification string
expected string
}{
{"MOVING", "MOVING", "MOVING"},
{"MIGRATING", "MIGRATING", "MIGRATING"},
{"MIGRATED", "MIGRATED", "MIGRATED"},
{"FAILING_OVER", "FAILING_OVER", "FAILING_OVER"},
{"FAILED_OVER", "FAILED_OVER", "FAILED_OVER"},
{"message", "message", "message"},
{"pmessage", "pmessage", "pmessage"},
{"subscribe", "subscribe", "subscribe"},
{"unsubscribe", "unsubscribe", "unsubscribe"},
{"psubscribe", "psubscribe", "psubscribe"},
{"punsubscribe", "punsubscribe", "punsubscribe"},
{"smessage", "smessage", "smessage"},
{"ssubscribe", "ssubscribe", "ssubscribe"},
{"sunsubscribe", "sunsubscribe", "sunsubscribe"},
{"custom", "custom", "custom"},
{"short", "a", "a"},
{"empty", "", ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := createValidPushNotification(tc.notification, "data")
reader := NewReader(buf)
// Prime the buffer by peeking first
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for valid notification: %v", err)
}
if name != tc.expected {
t.Errorf("Expected notification name '%s', got '%s'", tc.expected, name)
}
})
}
})
t.Run("NotificationWithMultipleArguments", func(t *testing.T) {
// Create push notification with multiple arguments
buf := createPushNotificationWithArgs("MOVING", "slot", "123", "from", "node1", "to", "node2")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "MOVING" {
t.Errorf("Expected 'MOVING', got '%s'", name)
}
})
t.Run("SingleElementNotification", func(t *testing.T) {
// Create push notification with single element
buf := createSingleElementPushNotification("TEST")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
})
t.Run("ErrorDetection", func(t *testing.T) {
t.Run("NotPushNotification", func(t *testing.T) {
// Test with regular array instead of push notification
buf := &bytes.Buffer{}
buf.WriteString("*2\r\n$6\r\nMOVING\r\n$4\r\ndata\r\n")
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for non-push notification")
}
// The error might be "no data available" or "can't parse push notification"
if !strings.Contains(err.Error(), "can't peek push notification name") {
t.Errorf("Error should mention push notification parsing, got: %v", err)
}
})
t.Run("InsufficientData", func(t *testing.T) {
// Test with buffer smaller than peek size - this might panic due to bounds checking
buf := &bytes.Buffer{}
buf.WriteString(">")
reader := NewReader(buf)
func() {
defer func() {
if r := recover(); r != nil {
t.Logf("PeekPushNotificationName panicked as expected for insufficient data: %v", r)
}
}()
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for insufficient data")
}
}()
})
t.Run("EmptyBuffer", func(t *testing.T) {
buf := &bytes.Buffer{}
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for empty buffer")
}
})
t.Run("DifferentRESPTypes", func(t *testing.T) {
// Test with different RESP types that should be rejected
respTypes := []byte{'+', '-', ':', '$', '*', '%', '~', '|', '('}
for _, respType := range respTypes {
t.Run(fmt.Sprintf("Type_%c", respType), func(t *testing.T) {
buf := &bytes.Buffer{}
buf.WriteByte(respType)
buf.WriteString("test data that fills the buffer completely")
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Errorf("PeekPushNotificationName should error for RESP type '%c'", respType)
}
// The error might be "no data available" or "can't parse push notification"
if !strings.Contains(err.Error(), "can't peek push notification name") {
t.Errorf("Error should mention push notification parsing, got: %v", err)
}
})
}
})
})
t.Run("EdgeCases", func(t *testing.T) {
t.Run("ZeroLengthArray", func(t *testing.T) {
// Create push notification with zero elements: >0\r\n
buf := &bytes.Buffer{}
buf.WriteString(">0\r\npadding_data_to_fill_buffer_completely")
reader := NewReader(buf)
_, err := reader.PeekPushNotificationName()
if err == nil {
t.Error("PeekPushNotificationName should error for zero-length array")
}
})
t.Run("EmptyNotificationName", func(t *testing.T) {
// Create push notification with empty name: >1\r\n$0\r\n\r\n
buf := createValidPushNotification("", "data")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for empty name: %v", err)
}
if name != "" {
t.Errorf("Expected empty notification name, got '%s'", name)
}
})
t.Run("CorruptedData", func(t *testing.T) {
corruptedCases := []struct {
name string
data string
}{
{"CorruptedLength", ">abc\r\n$6\r\nMOVING\r\n"},
{"MissingCRLF", ">2$6\r\nMOVING\r\n$4\r\ndata\r\n"},
{"InvalidStringLength", ">2\r\n$abc\r\nMOVING\r\n$4\r\ndata\r\n"},
{"NegativeStringLength", ">2\r\n$-1\r\n$4\r\ndata\r\n"},
{"IncompleteString", ">1\r\n$6\r\nMOV"},
}
for _, tc := range corruptedCases {
t.Run(tc.name, func(t *testing.T) {
buf := &bytes.Buffer{}
buf.WriteString(tc.data)
reader := NewReader(buf)
// Some corrupted data might not error but return unexpected results
// This is acceptable behavior for malformed input
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data)
} else {
t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data)
}
})
}
})
})
t.Run("BoundaryConditions", func(t *testing.T) {
t.Run("ExactlyPeekSize", func(t *testing.T) {
// Create buffer that is exactly 36 bytes (the peek window size)
buf := &bytes.Buffer{}
// ">1\r\n$4\r\nTEST\r\n" = 14 bytes, need 22 more
buf.WriteString(">1\r\n$4\r\nTEST\r\n1234567890123456789012")
if buf.Len() != 36 {
t.Errorf("Expected buffer length 36, got %d", buf.Len())
}
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should work for exact peek size: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
})
t.Run("LessThanPeekSize", func(t *testing.T) {
// Create buffer smaller than 36 bytes but with complete notification
buf := createValidPushNotification("TEST", "")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should work for complete notification: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
})
t.Run("LongNotificationName", func(t *testing.T) {
// Test with notification name that might exceed peek window
longName := strings.Repeat("A", 20) // 20 character name (safe size)
buf := createValidPushNotification(longName, "data")
reader := NewReader(buf)
// Prime the buffer
_, _ = reader.rd.Peek(1)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should work for long name: %v", err)
}
if name != longName {
t.Errorf("Expected '%s', got '%s'", longName, name)
}
})
})
}
// Helper functions to create test data
// createValidPushNotification creates a valid RESP3 push notification
func createValidPushNotification(notificationName, data string) *bytes.Buffer {
buf := &bytes.Buffer{}
simpleOrString := rand.Intn(2) == 0
if data == "" {
// Single element notification
buf.WriteString(">1\r\n")
if simpleOrString {
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
} else {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
}
} else {
// Two element notification
buf.WriteString(">2\r\n")
if simpleOrString {
buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName))
buf.WriteString(fmt.Sprintf("+%s\r\n", data))
} else {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
}
}
return buf
}
// createReaderWithPrimedBuffer creates a reader and primes the buffer
func createReaderWithPrimedBuffer(buf *bytes.Buffer) *Reader {
reader := NewReader(buf)
// Prime the buffer by peeking first
_, _ = reader.rd.Peek(1)
return reader
}
// createPushNotificationWithArgs creates a push notification with multiple arguments
func createPushNotificationWithArgs(notificationName string, args ...string) *bytes.Buffer {
buf := &bytes.Buffer{}
totalElements := 1 + len(args)
buf.WriteString(fmt.Sprintf(">%d\r\n", totalElements))
// Write notification name
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
// Write arguments
for _, arg := range args {
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(arg), arg))
}
return buf
}
// createSingleElementPushNotification creates a push notification with single element
func createSingleElementPushNotification(notificationName string) *bytes.Buffer {
buf := &bytes.Buffer{}
buf.WriteString(">1\r\n")
buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName))
return buf
}
// BenchmarkPeekPushNotificationName benchmarks the method performance
func BenchmarkPeekPushNotificationName(b *testing.B) {
testCases := []struct {
name string
notification string
}{
{"Short", "TEST"},
{"Medium", "MOVING_NOTIFICATION"},
{"Long", "VERY_LONG_NOTIFICATION_NAME_FOR_TESTING"},
}
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
buf := createValidPushNotification(tc.notification, "data")
data := buf.Bytes()
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader := NewReader(bytes.NewReader(data))
_, err := reader.PeekPushNotificationName()
if err != nil {
b.Errorf("PeekPushNotificationName should not error: %v", err)
}
}
})
}
}
// TestPeekPushNotificationNameSpecialCases tests special cases and realistic scenarios
func TestPeekPushNotificationNameSpecialCases(t *testing.T) {
t.Run("RealisticNotifications", func(t *testing.T) {
// Test realistic Redis push notifications
realisticCases := []struct {
name string
notification []string
expected string
}{
{"MovingSlot", []string{"MOVING", "slot", "123", "from", "127.0.0.1:7000", "to", "127.0.0.1:7001"}, "MOVING"},
{"MigratingSlot", []string{"MIGRATING", "slot", "456", "from", "127.0.0.1:7001", "to", "127.0.0.1:7002"}, "MIGRATING"},
{"MigratedSlot", []string{"MIGRATED", "slot", "789", "from", "127.0.0.1:7002", "to", "127.0.0.1:7000"}, "MIGRATED"},
{"FailingOver", []string{"FAILING_OVER", "node", "127.0.0.1:7000"}, "FAILING_OVER"},
{"FailedOver", []string{"FAILED_OVER", "node", "127.0.0.1:7000"}, "FAILED_OVER"},
{"PubSubMessage", []string{"message", "mychannel", "hello world"}, "message"},
{"PubSubPMessage", []string{"pmessage", "pattern*", "mychannel", "hello world"}, "pmessage"},
{"Subscribe", []string{"subscribe", "mychannel", "1"}, "subscribe"},
{"Unsubscribe", []string{"unsubscribe", "mychannel", "0"}, "unsubscribe"},
}
for _, tc := range realisticCases {
t.Run(tc.name, func(t *testing.T) {
buf := createPushNotificationWithArgs(tc.notification[0], tc.notification[1:]...)
reader := createReaderWithPrimedBuffer(buf)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for %s: %v", tc.name, err)
}
if name != tc.expected {
t.Errorf("Expected '%s', got '%s'", tc.expected, name)
}
})
}
})
t.Run("SpecialCharactersInName", func(t *testing.T) {
specialCases := []struct {
name string
notification string
}{
{"WithUnderscore", "test_notification"},
{"WithDash", "test-notification"},
{"WithNumbers", "test123"},
{"WithDots", "test.notification"},
{"WithColon", "test:notification"},
{"WithSlash", "test/notification"},
{"MixedCase", "TestNotification"},
{"AllCaps", "TESTNOTIFICATION"},
{"AllLower", "testnotification"},
{"Unicode", "tëst"},
}
for _, tc := range specialCases {
t.Run(tc.name, func(t *testing.T) {
buf := createValidPushNotification(tc.notification, "data")
reader := createReaderWithPrimedBuffer(buf)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for '%s': %v", tc.notification, err)
}
if name != tc.notification {
t.Errorf("Expected '%s', got '%s'", tc.notification, name)
}
})
}
})
t.Run("IdempotentPeek", func(t *testing.T) {
// Test that multiple peeks return the same result
buf := createValidPushNotification("MOVING", "data")
reader := createReaderWithPrimedBuffer(buf)
// First peek
name1, err1 := reader.PeekPushNotificationName()
if err1 != nil {
t.Errorf("First PeekPushNotificationName should not error: %v", err1)
}
// Second peek should return the same result
name2, err2 := reader.PeekPushNotificationName()
if err2 != nil {
t.Errorf("Second PeekPushNotificationName should not error: %v", err2)
}
if name1 != name2 {
t.Errorf("Peek should be idempotent: first='%s', second='%s'", name1, name2)
}
if name1 != "MOVING" {
t.Errorf("Expected 'MOVING', got '%s'", name1)
}
})
}
// TestPeekPushNotificationNamePerformance tests performance characteristics
func TestPeekPushNotificationNamePerformance(t *testing.T) {
t.Run("RepeatedCalls", func(t *testing.T) {
// Test that repeated calls work correctly
buf := createValidPushNotification("TEST", "data")
reader := createReaderWithPrimedBuffer(buf)
// Call multiple times
for i := 0; i < 10; i++ {
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error on call %d: %v", i, err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST' on call %d, got '%s'", i, name)
}
}
})
t.Run("LargeNotifications", func(t *testing.T) {
// Test with large notification data
largeData := strings.Repeat("x", 1000)
buf := createValidPushNotification("LARGE", largeData)
reader := createReaderWithPrimedBuffer(buf)
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error for large notification: %v", err)
}
if name != "LARGE" {
t.Errorf("Expected 'LARGE', got '%s'", name)
}
})
}
// TestPeekPushNotificationNameBehavior documents the method's behavior
func TestPeekPushNotificationNameBehavior(t *testing.T) {
t.Run("MethodBehavior", func(t *testing.T) {
// Test that the method works as intended:
// 1. Peek at the buffer without consuming it
// 2. Detect push notifications (RESP type '>')
// 3. Extract the notification name from the first element
// 4. Return the name for filtering decisions
buf := createValidPushNotification("MOVING", "slot_data")
reader := createReaderWithPrimedBuffer(buf)
// Peek should not consume the buffer
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "MOVING" {
t.Errorf("Expected 'MOVING', got '%s'", name)
}
// Buffer should still be available for normal reading
replyType, err := reader.PeekReplyType()
if err != nil {
t.Errorf("PeekReplyType should work after PeekPushNotificationName: %v", err)
}
if replyType != RespPush {
t.Errorf("Expected RespPush, got %v", replyType)
}
})
t.Run("BufferNotConsumed", func(t *testing.T) {
// Verify that peeking doesn't consume the buffer
buf := createValidPushNotification("TEST", "data")
originalData := buf.Bytes()
reader := createReaderWithPrimedBuffer(buf)
// Peek the notification name
name, err := reader.PeekPushNotificationName()
if err != nil {
t.Errorf("PeekPushNotificationName should not error: %v", err)
}
if name != "TEST" {
t.Errorf("Expected 'TEST', got '%s'", name)
}
// Read the actual notification
reply, err := reader.ReadReply()
if err != nil {
t.Errorf("ReadReply should work after peek: %v", err)
}
// Verify we got the complete notification
if replySlice, ok := reply.([]interface{}); ok {
if len(replySlice) != 2 {
t.Errorf("Expected 2 elements, got %d", len(replySlice))
}
if replySlice[0] != "TEST" {
t.Errorf("Expected 'TEST', got %v", replySlice[0])
}
} else {
t.Errorf("Expected slice reply, got %T", reply)
}
// Verify buffer was properly consumed
if buf.Len() != 0 {
t.Errorf("Buffer should be empty after reading, but has %d bytes: %q", buf.Len(), buf.Bytes())
}
t.Logf("Original buffer size: %d bytes", len(originalData))
t.Logf("Successfully peeked and then read complete notification")
})
t.Run("ImplementationSuccess", func(t *testing.T) {
// Document that the implementation is now working correctly
t.Log("PeekPushNotificationName implementation status:")
t.Log("1. ✅ Correctly parses RESP3 push notifications")
t.Log("2. ✅ Extracts notification names properly")
t.Log("3. ✅ Handles buffer peeking without consumption")
t.Log("4. ✅ Works with various notification types")
t.Log("5. ✅ Supports empty notification names")
t.Log("")
t.Log("RESP3 format parsing:")
t.Log(">2\\r\\n$6\\r\\nMOVING\\r\\n$4\\r\\ndata\\r\\n")
t.Log("✅ Correctly identifies push notification marker (>)")
t.Log("✅ Skips array length (2)")
t.Log("✅ Parses string marker ($) and length (6)")
t.Log("✅ Extracts notification name (MOVING)")
t.Log("✅ Returns name without consuming buffer")
t.Log("")
t.Log("Note: Buffer must be primed with a peek operation first")
})
}

View File

@@ -99,6 +99,92 @@ func (r *Reader) PeekReplyType() (byte, error) {
return b[0], nil
}
func (r *Reader) PeekPushNotificationName() (string, error) {
// "prime" the buffer by peeking at the next byte
c, err := r.Peek(1)
if err != nil {
return "", err
}
if c[0] != RespPush {
return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification")
}
// peek 36 bytes at most, should be enough to read the push notification name
toPeek := 36
buffered := r.Buffered()
if buffered == 0 {
return "", fmt.Errorf("redis: can't peek push notification name, no data available")
}
if buffered < toPeek {
toPeek = buffered
}
buf, err := r.rd.Peek(toPeek)
if err != nil {
return "", err
}
if buf[0] != RespPush {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
if len(buf) < 3 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// remove push notification type
buf = buf[1:]
// remove first line - e.g. >2\r\n
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
}
}
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// next line should be $<length><string>\r\n or +<length><string>\r\n
// should have the type of the push notification name and it's length
if buf[0] != RespString && buf[0] != RespStatus {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
typeOfName := buf[0]
// remove the type of the push notification name
buf = buf[1:]
if typeOfName == RespString {
// remove the length of the string
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
}
}
}
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
// keep only the notification name
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[:i]
break
}
}
return util.BytesToString(buf), nil
}
// ReadLine Return a valid reply, it will check the protocol or redis error,
// and discard the attribute type.
func (r *Reader) ReadLine() ([]byte, error) {

3
internal/redis.go Normal file
View File

@@ -0,0 +1,3 @@
package internal
const RedisNull = "<nil>"

View File

@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
}
return f
}
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
func SafeIntToInt32(value int, fieldName string) (int32, error) {
if value > math.MaxInt32 {
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
}
if value < math.MinInt32 {
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
}
return int32(value), nil
}

17
internal/util/math.go Normal file
View File

@@ -0,0 +1,17 @@
package util
// Max returns the maximum of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// Min returns the minimum of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -16,6 +16,8 @@ import (
. "github.com/bsm/gomega"
)
var ctx = context.TODO()
var _ = Describe("newClusterState", func() {
var state *clusterState

91
logging/logging.go Normal file
View File

@@ -0,0 +1,91 @@
// Package logging provides logging level constants and utilities for the go-redis library.
// This package centralizes logging configuration to ensure consistency across all components.
package logging
import (
"context"
"fmt"
"strings"
"github.com/redis/go-redis/v9/internal"
)
type LogLevelT = internal.LogLevelT
const (
LogLevelError = internal.LogLevelError
LogLevelWarn = internal.LogLevelWarn
LogLevelInfo = internal.LogLevelInfo
LogLevelDebug = internal.LogLevelDebug
)
// VoidLogger is a logger that does nothing.
// Used to disable logging and thus speed up the library.
type VoidLogger struct{}
func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) {
// do nothing
}
// Disable disables logging by setting the internal logger to a void logger.
// This can be used to speed up the library if logging is not needed.
// It will override any custom logger that was set before and set the VoidLogger.
func Disable() {
internal.Logger = &VoidLogger{}
}
// Enable enables logging by setting the internal logger to the default logger.
// This is the default behavior.
// You can use redis.SetLogger to set a custom logger.
//
// NOTE: This function is not thread-safe.
// It will override any custom logger that was set before and set the DefaultLogger.
func Enable() {
internal.Logger = internal.NewDefaultLogger()
}
// SetLogLevel sets the log level for the library.
func SetLogLevel(logLevel LogLevelT) {
internal.LogLevel = logLevel
}
// NewBlacklistLogger returns a new logger that filters out messages containing any of the substrings.
// This can be used to filter out messages containing sensitive information.
func NewBlacklistLogger(substr []string) internal.Logging {
l := internal.NewDefaultLogger()
return &filterLogger{logger: l, substr: substr, blacklist: true}
}
// NewWhitelistLogger returns a new logger that only logs messages containing any of the substrings.
// This can be used to only log messages related to specific commands or patterns.
func NewWhitelistLogger(substr []string) internal.Logging {
l := internal.NewDefaultLogger()
return &filterLogger{logger: l, substr: substr, blacklist: false}
}
type filterLogger struct {
logger internal.Logging
blacklist bool
substr []string
}
func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface{}) {
msg := fmt.Sprintf(format, v...)
found := false
for _, substr := range l.substr {
if strings.Contains(msg, substr) {
found = true
if l.blacklist {
return
}
}
}
// whitelist, only log if one of the substrings is present
if !l.blacklist && !found {
return
}
if l.logger != nil {
l.logger.Printf(ctx, format, v...)
return
}
}

59
logging/logging_test.go Normal file
View File

@@ -0,0 +1,59 @@
package logging
import "testing"
func TestLogLevel_String(t *testing.T) {
tests := []struct {
level LogLevelT
expected string
}{
{LogLevelError, "ERROR"},
{LogLevelWarn, "WARN"},
{LogLevelInfo, "INFO"},
{LogLevelDebug, "DEBUG"},
{LogLevelT(99), "UNKNOWN"},
}
for _, test := range tests {
if got := test.level.String(); got != test.expected {
t.Errorf("LogLevel(%d).String() = %q, want %q", test.level, got, test.expected)
}
}
}
func TestLogLevel_IsValid(t *testing.T) {
tests := []struct {
level LogLevelT
expected bool
}{
{LogLevelError, true},
{LogLevelWarn, true},
{LogLevelInfo, true},
{LogLevelDebug, true},
{LogLevelT(-1), false},
{LogLevelT(4), false},
{LogLevelT(99), false},
}
for _, test := range tests {
if got := test.level.IsValid(); got != test.expected {
t.Errorf("LogLevel(%d).IsValid() = %v, want %v", test.level, got, test.expected)
}
}
}
func TestLogLevelConstants(t *testing.T) {
// Test that constants have expected values
if LogLevelError != 0 {
t.Errorf("LogLevelError = %d, want 0", LogLevelError)
}
if LogLevelWarn != 1 {
t.Errorf("LogLevelWarn = %d, want 1", LogLevelWarn)
}
if LogLevelInfo != 2 {
t.Errorf("LogLevelInfo = %d, want 2", LogLevelInfo)
}
if LogLevelDebug != 3 {
t.Errorf("LogLevelDebug = %d, want 3", LogLevelDebug)
}
}

View File

@@ -13,6 +13,7 @@ import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
)
const (
@@ -102,6 +103,7 @@ var _ = BeforeSuite(func() {
fmt.Printf("RCEDocker: %v\n", RCEDocker)
fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion)
fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE"))
logging.Disable()
if RedisVersion < 7.0 || RedisVersion > 9 {
panic("incorrect or not supported redis version")

View File

@@ -0,0 +1,100 @@
# Maintenance Notifications
Seamless Redis connection handoffs during cluster maintenance operations without dropping connections.
## ⚠️ **Important Note**
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
## Quick Start
```go
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
},
})
```
## Modes
- **`ModeDisabled`** - Maintenance notifications disabled
- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support)
- **`ModeAuto`** - Auto-detect server support (default)
## Configuration
```go
&maintnotifications.Config{
Mode: maintnotifications.ModeAuto,
EndpointType: maintnotifications.EndpointTypeAuto,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxHandoffRetries: 3,
MaxWorkers: 0, // Auto-calculated
HandoffQueueSize: 0, // Auto-calculated
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
}
```
### Endpoint Types
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
- **`EndpointTypeInternalIP`** - Internal IP address
- **`EndpointTypeInternalFQDN`** - Internal FQDN
- **`EndpointTypeExternalIP`** - External IP address
- **`EndpointTypeExternalFQDN`** - External FQDN
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
### Auto-Scaling
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
**Examples:**
- Pool 100: 33 workers, 660 queue (capped at 500)
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
## How It Works
1. Redis sends push notifications about cluster maintenance operations
2. Client creates new connections to updated endpoints
3. Active operations transfer to new connections
4. Old connections close gracefully
## Supported Notifications
- `MOVING` - Slot moving to new node
- `MIGRATING` - Slot in migration state
- `MIGRATED` - Migration completed
- `FAILING_OVER` - Node failing over
- `FAILED_OVER` - Failover completed
## Hooks (Optional)
Monitor and customize maintenance notification operations:
```go
type NotificationHook interface {
PreHook(ctx, notificationCtx, notificationType, notification) ([]interface{}, bool)
PostHook(ctx, notificationCtx, notificationType, notification, result)
}
// Add custom hook
manager.AddNotificationHook(&MyHook{})
```
### Metrics Hook Example
```go
// Create metrics hook
metricsHook := maintnotifications.NewMetricsHook()
manager.AddNotificationHook(metricsHook)
// Access collected metrics
metrics := metricsHook.GetMetrics()
fmt.Printf("Notification counts: %v\n", metrics["notification_counts"])
fmt.Printf("Processing times: %v\n", metrics["processing_times"])
fmt.Printf("Error counts: %v\n", metrics["error_counts"])
```

View File

@@ -0,0 +1,353 @@
package maintnotifications
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// CircuitBreakerState represents the state of a circuit breaker
type CircuitBreakerState int32
const (
// CircuitBreakerClosed - normal operation, requests allowed
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen - failing fast, requests rejected
CircuitBreakerOpen
// CircuitBreakerHalfOpen - testing if service recovered
CircuitBreakerHalfOpen
)
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
type CircuitBreaker struct {
// Configuration
failureThreshold int // Number of failures before opening
resetTimeout time.Duration // How long to stay open before testing
maxRequests int // Max requests allowed in half-open state
// State tracking (atomic for lock-free access)
state atomic.Int32 // CircuitBreakerState
failures atomic.Int64 // Current failure count
successes atomic.Int64 // Success count in half-open state
requests atomic.Int64 // Request count in half-open state
lastFailureTime atomic.Int64 // Unix timestamp of last failure
lastSuccessTime atomic.Int64 // Unix timestamp of last success
// Endpoint identification
endpoint string
config *Config
}
// newCircuitBreaker creates a new circuit breaker for an endpoint
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
// Use configuration values with sensible defaults
failureThreshold := 5
resetTimeout := 60 * time.Second
maxRequests := 3
if config != nil {
failureThreshold = config.CircuitBreakerFailureThreshold
resetTimeout = config.CircuitBreakerResetTimeout
maxRequests = config.CircuitBreakerMaxRequests
}
return &CircuitBreaker{
failureThreshold: failureThreshold,
resetTimeout: resetTimeout,
maxRequests: maxRequests,
endpoint: endpoint,
config: config,
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
}
}
// IsOpen returns true if the circuit breaker is open (rejecting requests)
func (cb *CircuitBreaker) IsOpen() bool {
state := CircuitBreakerState(cb.state.Load())
return state == CircuitBreakerOpen
}
// shouldAttemptReset checks if enough time has passed to attempt reset
func (cb *CircuitBreaker) shouldAttemptReset() bool {
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
return time.Since(lastFailure) >= cb.resetTimeout
}
// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
// Single atomic state load for consistency
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerOpen:
if cb.shouldAttemptReset() {
// Attempt transition to half-open
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.requests.Store(0)
cb.successes.Store(0)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint))
}
// Fall through to half-open logic
} else {
return ErrCircuitBreakerOpen
}
} else {
return ErrCircuitBreakerOpen
}
fallthrough
case CircuitBreakerHalfOpen:
requests := cb.requests.Add(1)
if requests > int64(cb.maxRequests) {
cb.requests.Add(-1) // Revert the increment
return ErrCircuitBreakerOpen
}
}
// Execute the function with consistent state
err := fn()
if err != nil {
cb.recordFailure()
return err
}
cb.recordSuccess()
return nil
}
// recordFailure records a failure and potentially opens the circuit
func (cb *CircuitBreaker) recordFailure() {
cb.lastFailureTime.Store(time.Now().Unix())
failures := cb.failures.Add(1)
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
if failures >= int64(cb.failureThreshold) {
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures))
}
}
}
case CircuitBreakerHalfOpen:
// Any failure in half-open state immediately opens the circuit
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint))
}
}
}
}
// recordSuccess records a success and potentially closes the circuit
func (cb *CircuitBreaker) recordSuccess() {
cb.lastSuccessTime.Store(time.Now().Unix())
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
// Reset failure count on success in closed state
cb.failures.Store(0)
case CircuitBreakerHalfOpen:
successes := cb.successes.Add(1)
// If we've had enough successful requests, close the circuit
if successes >= int64(cb.maxRequests) {
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.failures.Store(0)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes))
}
}
}
}
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(cb.state.Load())
}
// GetStats returns current statistics for monitoring
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
return CircuitBreakerStats{
Endpoint: cb.endpoint,
State: cb.GetState(),
Failures: cb.failures.Load(),
Successes: cb.successes.Load(),
Requests: cb.requests.Load(),
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
}
}
// CircuitBreakerStats provides statistics about a circuit breaker
type CircuitBreakerStats struct {
Endpoint string
State CircuitBreakerState
Failures int64
Successes int64
Requests int64
LastFailureTime time.Time
LastSuccessTime time.Time
}
// CircuitBreakerEntry wraps a circuit breaker with access tracking
type CircuitBreakerEntry struct {
breaker *CircuitBreaker
lastAccess atomic.Int64 // Unix timestamp
created time.Time
}
// CircuitBreakerManager manages circuit breakers for multiple endpoints
type CircuitBreakerManager struct {
breakers sync.Map // map[string]*CircuitBreakerEntry
config *Config
cleanupStop chan struct{}
cleanupMu sync.Mutex
lastCleanup atomic.Int64 // Unix timestamp
}
// newCircuitBreakerManager creates a new circuit breaker manager
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
cbm := &CircuitBreakerManager{
config: config,
cleanupStop: make(chan struct{}),
}
cbm.lastCleanup.Store(time.Now().Unix())
// Start background cleanup goroutine
go cbm.cleanupLoop()
return cbm
}
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
now := time.Now().Unix()
if entry, ok := cbm.breakers.Load(endpoint); ok {
cbEntry := entry.(*CircuitBreakerEntry)
cbEntry.lastAccess.Store(now)
return cbEntry.breaker
}
// Create new circuit breaker with metadata
newBreaker := newCircuitBreaker(endpoint, cbm.config)
newEntry := &CircuitBreakerEntry{
breaker: newBreaker,
created: time.Now(),
}
newEntry.lastAccess.Store(now)
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
return actual.(*CircuitBreakerEntry).breaker
}
// GetAllStats returns statistics for all circuit breakers
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
var stats []CircuitBreakerStats
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
stats = append(stats, entry.breaker.GetStats())
return true
})
return stats
}
// cleanupLoop runs background cleanup of unused circuit breakers
func (cbm *CircuitBreakerManager) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-ticker.C:
cbm.cleanup()
case <-cbm.cleanupStop:
return
}
}
}
// cleanup removes circuit breakers that haven't been accessed recently
func (cbm *CircuitBreakerManager) cleanup() {
// Prevent concurrent cleanups
if !cbm.cleanupMu.TryLock() {
return
}
defer cbm.cleanupMu.Unlock()
now := time.Now()
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
var toDelete []string
count := 0
cbm.breakers.Range(func(key, value interface{}) bool {
endpoint := key.(string)
entry := value.(*CircuitBreakerEntry)
count++
// Remove if not accessed recently
if entry.lastAccess.Load() < cutoff {
toDelete = append(toDelete, endpoint)
}
return true
})
// Delete expired entries
for _, endpoint := range toDelete {
cbm.breakers.Delete(endpoint)
}
// Log cleanup results
if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count))
}
cbm.lastCleanup.Store(now.Unix())
}
// Shutdown stops the cleanup goroutine
func (cbm *CircuitBreakerManager) Shutdown() {
close(cbm.cleanupStop)
}
// Reset resets all circuit breakers (useful for testing)
func (cbm *CircuitBreakerManager) Reset() {
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
breaker := entry.breaker
breaker.state.Store(int32(CircuitBreakerClosed))
breaker.failures.Store(0)
breaker.successes.Store(0)
breaker.requests.Store(0)
breaker.lastFailureTime.Store(0)
breaker.lastSuccessTime.Store(0)
return true
})
}

View File

@@ -0,0 +1,348 @@
package maintnotifications
import (
"errors"
"testing"
"time"
)
func TestCircuitBreaker(t *testing.T) {
config := &Config{
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
}
t.Run("InitialState", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
if cb.IsOpen() {
t.Error("Circuit breaker should start in closed state")
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("SuccessfulExecution", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("FailureThreshold", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Fail 4 times (below threshold of 5)
for i := 0; i < 4; i++ {
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 5th failure should open the circuit
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("OpenCircuitFailsFast", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Now it should fail fast
err := cb.Execute(func() error {
t.Error("Function should not be called when circuit is open")
return nil
})
if err != ErrCircuitBreakerOpen {
t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err)
}
})
t.Run("HalfOpenTransition", func(t *testing.T) {
testConfig := &Config{
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 100 * time.Millisecond, // Short timeout for testing
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Wait for reset timeout
time.Sleep(150 * time.Millisecond)
// Next call should transition to half-open
executed := false
err := cb.Execute(func() error {
executed = true
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !executed {
t.Error("Function should have been executed in half-open state")
}
})
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
testConfig := &Config{
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// Execute successful requests in half-open state
for i := 0; i < 3; i++ {
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error on attempt %d, got %v", i+1, err)
}
}
// Circuit should now be closed
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
testConfig := &Config{
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 50 * time.Millisecond,
CircuitBreakerMaxRequests: 3,
}
cb := newCircuitBreaker("test-endpoint:6379", testConfig)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// First request in half-open state fails
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
// Circuit should be open again
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("Stats", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Execute some operations
cb.Execute(func() error { return testError }) // Failure
cb.Execute(func() error { return testError }) // Failure
stats := cb.GetStats()
if stats.Endpoint != "test-endpoint:6379" {
t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint)
}
if stats.Failures != 2 {
t.Errorf("Expected 2 failures, got %d", stats.Failures)
}
if stats.State != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State)
}
// Test that success resets failure count
cb.Execute(func() error { return nil }) // Success
stats = cb.GetStats()
if stats.Failures != 0 {
t.Errorf("Expected 0 failures after success, got %d", stats.Failures)
}
})
}
func TestCircuitBreakerManager(t *testing.T) {
config := &Config{
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
}
t.Run("GetCircuitBreaker", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1
if cb1 == cb2 {
t.Error("Different endpoints should have different circuit breakers")
}
if cb1 != cb3 {
t.Error("Same endpoint should return the same circuit breaker")
}
})
t.Run("GetAllStats", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
// Create circuit breakers for different endpoints
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
// Execute some operations
cb1.Execute(func() error { return nil })
cb2.Execute(func() error { return errors.New("test error") })
stats := manager.GetAllStats()
if len(stats) != 2 {
t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats))
}
// Check that we have stats for both endpoints
endpoints := make(map[string]bool)
for _, stat := range stats {
endpoints[stat.Endpoint] = true
}
if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] {
t.Error("Missing stats for expected endpoints")
}
})
t.Run("Reset", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
testError := errors.New("test error")
cb := manager.GetCircuitBreaker("test-endpoint:6379")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Reset all circuit breakers
manager.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Error("Circuit should be closed after reset")
}
if cb.failures.Load() != 0 {
t.Error("Failure count should be reset to 0")
}
})
t.Run("ConfigurableParameters", func(t *testing.T) {
config := &Config{
CircuitBreakerFailureThreshold: 10,
CircuitBreakerResetTimeout: 30 * time.Second,
CircuitBreakerMaxRequests: 5,
}
cb := newCircuitBreaker("test-endpoint:6379", config)
// Test that configuration values are used
if cb.failureThreshold != 10 {
t.Errorf("Expected failureThreshold=10, got %d", cb.failureThreshold)
}
if cb.resetTimeout != 30*time.Second {
t.Errorf("Expected resetTimeout=30s, got %v", cb.resetTimeout)
}
if cb.maxRequests != 5 {
t.Errorf("Expected maxRequests=5, got %d", cb.maxRequests)
}
// Test that circuit opens after configured threshold
testError := errors.New("test error")
for i := 0; i < 9; i++ {
err := cb.Execute(func() error { return testError })
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 10th failure should open the circuit
err := cb.Execute(func() error { return testError })
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
}

View File

@@ -0,0 +1,458 @@
package maintnotifications
import (
"context"
"net"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/util"
)
// Mode represents the maintenance notifications mode
type Mode string
// Constants for maintenance push notifications modes
const (
ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error
ModeAuto Mode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m Mode) IsValid() bool {
switch m {
case ModeDisabled, ModeEnabled, ModeAuto:
return true
default:
return false
}
}
// String returns the string representation of the mode
func (m Mode) String() string {
return string(m)
}
// EndpointType represents the type of endpoint to request in MOVING notifications
type EndpointType string
// Constants for endpoint types
const (
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
)
// IsValid returns true if the endpoint type is valid
func (e EndpointType) IsValid() bool {
switch e {
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
return true
default:
return false
}
}
// String returns the string representation of the endpoint type
func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for maintenance notifications
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: ModeDisabled, ModeEnabled, ModeAuto
// Default: ModeAuto
Mode Mode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
// Default: EndpointTypeAuto
EndpointType EndpointType
// RelaxedTimeout is the concrete timeout value to use during
// MIGRATING/FAILING_OVER states to accommodate increased latency.
// This applies to both read and write timeouts.
// Default: 10 seconds
RelaxedTimeout time.Duration
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
// If handoff takes longer than this, the old connection will be forcibly closed.
// Default: 15 seconds (matches server-side eviction timeout)
HandoffTimeout time.Duration
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
// Workers are created on-demand and automatically cleaned up when idle.
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
// If explicitly set, enforces minimum of PoolSize/2
//
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
MaxWorkers int
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
// If the queue is full, new handoff requests will be rejected.
// Scales with both worker count and pool size for better burst handling.
//
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
HandoffQueueSize int
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
// after a handoff completes. This provides additional resilience during cluster transitions.
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// Circuit breaker configuration for endpoint failure handling
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
// Default: 5
CircuitBreakerFailureThreshold int
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
// Default: 60 seconds
CircuitBreakerResetTimeout time.Duration
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
// Default: 3
CircuitBreakerMaxRequests int
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
// After this many retries, the connection will be removed from the pool.
// Default: 3
MaxHandoffRetries int
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != ModeDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: ModeAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
// Circuit breaker configuration
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
// Connection Handoff Configuration
MaxHandoffRetries: 3,
}
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.RelaxedTimeout <= 0 {
return ErrInvalidRelaxedTimeout
}
if c.HandoffTimeout <= 0 {
return ErrInvalidHandoffTimeout
}
// Validate worker configuration
// Allow 0 for auto-calculation, but negative values are invalid
if c.MaxWorkers < 0 {
return ErrInvalidHandoffWorkers
}
// HandoffQueueSize validation - allow 0 for auto-calculation
if c.HandoffQueueSize < 0 {
return ErrInvalidHandoffQueueSize
}
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
// Circuit breaker validation
if c.CircuitBreakerFailureThreshold < 1 {
return ErrInvalidCircuitBreakerFailureThreshold
}
if c.CircuitBreakerResetTimeout < 0 {
return ErrInvalidCircuitBreakerResetTimeout
}
if c.CircuitBreakerMaxRequests < 1 {
return ErrInvalidCircuitBreakerMaxRequests
}
// Validate Mode (maintenance notifications mode)
if !c.Mode.IsValid() {
return ErrInvalidMaintNotifications
}
// Validate EndpointType
if !c.EndpointType.IsValid() {
return ErrInvalidEndpointType
}
// Validate configuration fields
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
return ErrInvalidHandoffRetries
}
return nil
}
// ApplyDefaults applies default values to any zero-value fields in the configuration.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaults() *Config {
return c.ApplyDefaultsWithPoolSize(0)
}
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
// using the provided pool size to calculate worker defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
}
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
// using the provided pool size and max active connections to calculate worker and queue defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
if c == nil {
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
}
defaults := DefaultConfig()
result := &Config{}
// Apply defaults for enum fields (empty/zero means not set)
result.Mode = defaults.Mode
if c.Mode != "" {
result.Mode = c.Mode
}
result.EndpointType = defaults.EndpointType
if c.EndpointType != "" {
result.EndpointType = c.EndpointType
}
// Apply defaults for duration fields (zero means not set)
result.RelaxedTimeout = defaults.RelaxedTimeout
if c.RelaxedTimeout > 0 {
result.RelaxedTimeout = c.RelaxedTimeout
}
result.HandoffTimeout = defaults.HandoffTimeout
if c.HandoffTimeout > 0 {
result.HandoffTimeout = c.HandoffTimeout
}
// Copy worker configuration
result.MaxWorkers = c.MaxWorkers
// Apply worker defaults based on pool size
result.applyWorkerDefaults(poolSize)
// Apply queue size defaults with new scaling approach
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
workerBasedSize := result.MaxWorkers * 20
poolBasedSize := poolSize
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
if c.HandoffQueueSize > 0 {
// When explicitly set: enforce minimum of 200
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
}
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
var queueCap int
if maxActiveConns > 0 {
queueCap = maxActiveConns + 1
// Ensure queue cap is at least 2 for very small maxActiveConns
if queueCap < 2 {
queueCap = 2
}
} else {
queueCap = poolSize * 5
}
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
// Ensure minimum queue size of 2 (fallback for very small pools)
if result.HandoffQueueSize < 2 {
result.HandoffQueueSize = 2
}
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
if c.PostHandoffRelaxedDuration > 0 {
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
// Apply defaults for configuration fields
result.MaxHandoffRetries = defaults.MaxHandoffRetries
if c.MaxHandoffRetries > 0 {
result.MaxHandoffRetries = c.MaxHandoffRetries
}
// Circuit breaker configuration
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
if c.CircuitBreakerFailureThreshold > 0 {
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
}
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
if c.CircuitBreakerResetTimeout > 0 {
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
}
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
if c.CircuitBreakerMaxRequests > 0 {
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
}
if internal.LogLevel.DebugOrAbove() {
internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled())
internal.Logger.Printf(context.Background(), logs.ConfigDebug(result))
}
return result
}
// Clone creates a deep copy of the configuration.
func (c *Config) Clone() *Config {
if c == nil {
return DefaultConfig()
}
return &Config{
Mode: c.Mode,
EndpointType: c.EndpointType,
RelaxedTimeout: c.RelaxedTimeout,
HandoffTimeout: c.HandoffTimeout,
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
// Configuration fields
MaxHandoffRetries: c.MaxHandoffRetries,
}
}
// applyWorkerDefaults calculates and applies worker defaults based on pool size
func (c *Config) applyWorkerDefaults(poolSize int) {
// Calculate defaults based on pool size
if poolSize <= 0 {
poolSize = 10 * runtime.GOMAXPROCS(0)
}
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
originalMaxWorkers := c.MaxWorkers
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
if originalMaxWorkers != 0 {
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
}
// Ensure minimum of 1 worker (fallback for very small pools)
if c.MaxWorkers < 1 {
c.MaxWorkers = 1
}
}
// DetectEndpointType automatically detects the appropriate endpoint type
// based on the connection address and TLS configuration.
//
// For IP addresses:
// - If TLS is enabled: requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// For hostnames:
// - If TLS is enabled: always requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// Internal vs External detection:
// - For IPs: uses private IP range detection
// - For hostnames: uses heuristics based on common internal naming patterns
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
// Extract host from "host:port" format
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // Assume no port
}
// Check if the host is an IP address or hostname
ip := net.ParseIP(host)
isIPAddress := ip != nil
var endpointType EndpointType
if isIPAddress {
// Address is an IP - determine if it's private or public
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
if tlsEnabled {
// TLS with IP addresses - still prefer FQDN for certificate validation
if isPrivate {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
} else {
// No TLS - can use IP addresses directly
if isPrivate {
endpointType = EndpointTypeInternalIP
} else {
endpointType = EndpointTypeExternalIP
}
}
} else {
// Address is a hostname
isInternalHostname := isInternalHostname(host)
if isInternalHostname {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
}
return endpointType
}
// isInternalHostname determines if a hostname appears to be internal/private.
// This is a heuristic based on common naming patterns.
func isInternalHostname(hostname string) bool {
// Convert to lowercase for comparison
hostname = strings.ToLower(hostname)
// Common internal hostname patterns
internalPatterns := []string{
"localhost",
".local",
".internal",
".corp",
".lan",
".intranet",
".private",
}
// Check for exact match or suffix match
for _, pattern := range internalPatterns {
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
return true
}
}
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
// If hostname doesn't contain dots, it's likely internal
if !strings.Contains(hostname, ".") {
return true
}
// Default to external for fully qualified domain names
return false
}

View File

@@ -0,0 +1,481 @@
package maintnotifications
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
func TestConfig(t *testing.T) {
t.Run("DefaultConfig", func(t *testing.T) {
config := DefaultConfig()
// MaxWorkers should be 0 in default config (auto-calculated)
if config.MaxWorkers != 0 {
t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers)
}
// HandoffQueueSize should be 0 in default config (auto-calculated)
if config.HandoffQueueSize != 0 {
t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize)
}
if config.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout)
}
// Test configuration fields have proper defaults
if config.MaxHandoffRetries != 3 {
t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries)
}
// Circuit breaker defaults
if config.CircuitBreakerFailureThreshold != 5 {
t.Errorf("Expected CircuitBreakerFailureThreshold=5, got %d", config.CircuitBreakerFailureThreshold)
}
if config.CircuitBreakerResetTimeout != 60*time.Second {
t.Errorf("Expected CircuitBreakerResetTimeout=60s, got %v", config.CircuitBreakerResetTimeout)
}
if config.CircuitBreakerMaxRequests != 3 {
t.Errorf("Expected CircuitBreakerMaxRequests=3, got %d", config.CircuitBreakerMaxRequests)
}
if config.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout)
}
if config.PostHandoffRelaxedDuration != 0 {
t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration)
}
// Test that defaults are applied correctly
configWithDefaults := config.ApplyDefaultsWithPoolSize(100)
if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration)
}
})
t.Run("ConfigValidation", func(t *testing.T) {
// Valid config with applied defaults
config := DefaultConfig().ApplyDefaults()
if err := config.Validate(); err != nil {
t.Errorf("Default config with applied defaults should be valid: %v", err)
}
// Invalid worker configuration (negative MaxWorkers)
config = &Config{
RelaxedTimeout: 30 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: -1, // This should be invalid
HandoffQueueSize: 100,
PostHandoffRelaxedDuration: 10 * time.Second,
MaxHandoffRetries: 3, // Add required field
}
if err := config.Validate(); err != ErrInvalidHandoffWorkers {
t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err)
}
// Invalid HandoffQueueSize
config = DefaultConfig().ApplyDefaults()
config.HandoffQueueSize = -1
if err := config.Validate(); err != ErrInvalidHandoffQueueSize {
t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err)
}
// Invalid PostHandoffRelaxedDuration
config = DefaultConfig().ApplyDefaults()
config.PostHandoffRelaxedDuration = -1 * time.Second
if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration {
t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err)
}
})
t.Run("ConfigClone", func(t *testing.T) {
original := DefaultConfig()
original.MaxWorkers = 20
original.HandoffQueueSize = 200
cloned := original.Clone()
if cloned.MaxWorkers != 20 {
t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers)
}
if cloned.HandoffQueueSize != 200 {
t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize)
}
// Modify original to ensure clone is independent
original.MaxWorkers = 2
if cloned.MaxWorkers != 20 {
t.Error("Clone should be independent of original")
}
})
}
func TestApplyDefaults(t *testing.T) {
t.Run("NilConfig", func(t *testing.T) {
var config *Config
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// With nil config, should get default config with auto-calculated workers
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated with hybrid scaling
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
})
t.Run("PartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 60, // Set this field explicitly (> poolSize/2 = 50)
// Leave other fields as zero values
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should keep the explicitly set values when > poolSize/2
if result.MaxWorkers != 60 {
t.Errorf("Expected MaxWorkers to be 60 (explicitly set), got %d", result.MaxWorkers)
}
// Should apply default for unset fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
// Test explicit queue size capping by 5x pool size
configWithLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than 5x pool size
}
resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size
expectedCap := 20 * 5 // 5x pool size = 100
if resultCapped.HandoffQueueSize != expectedCap {
t.Errorf("Expected HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedCap, resultCapped.HandoffQueueSize)
}
// Test explicit queue size minimum enforcement
configWithSmallQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 10, // Below minimum of 200
}
resultMinimum := configWithSmallQueue.ApplyDefaultsWithPoolSize(100) // Large pool size
if resultMinimum.HandoffQueueSize != 200 {
t.Errorf("Expected HandoffQueueSize to be enforced minimum (200), got %d", resultMinimum.HandoffQueueSize)
}
// Test that large explicit values are capped by 5x pool size
configWithVeryLargeQueue := &Config{
MaxWorkers: 5,
HandoffQueueSize: 1000, // Much larger than 5x pool size
}
resultVeryLarge := configWithVeryLargeQueue.ApplyDefaultsWithPoolSize(100) // Pool size 100
expectedVeryLargeCap := 100 * 5 // 5x pool size = 500
if resultVeryLarge.HandoffQueueSize != expectedVeryLargeCap {
t.Errorf("Expected very large HandoffQueueSize to be capped by 5x pool size (%d), got %d", expectedVeryLargeCap, resultVeryLarge.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
if result.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout)
}
})
t.Run("ZeroValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 0, // Zero value should get auto-calculated defaults
HandoffQueueSize: 0, // Zero value should get default
RelaxedTimeout: 0, // Zero value should get default
}
result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Zero values should get auto-calculated defaults
if result.MaxWorkers <= 0 {
t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers)
}
// HandoffQueueSize should be auto-calculated with hybrid scaling
workerBasedSize := result.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if result.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, result.HandoffQueueSize)
}
if result.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout)
}
})
}
func TestProcessorWithConfig(t *testing.T) {
t.Run("ProcessorUsesConfigValues", func(t *testing.T) {
config := &Config{
MaxWorkers: 5,
HandoffQueueSize: 50,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 5 * time.Second,
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// The processor should be created successfully with custom config
if processor == nil {
t.Error("Processor should be created with custom config")
}
})
t.Run("ProcessorWithPartialConfig", func(t *testing.T) {
config := &Config{
MaxWorkers: 7, // Only set worker field
// Other fields will get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Should work with partial config (defaults applied)
if processor == nil {
t.Error("Processor should be created with partial config")
}
})
t.Run("ProcessorWithNilConfig", func(t *testing.T) {
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
// Should use default config when nil is passed
if processor == nil {
t.Error("Processor should be created with nil config (using defaults)")
}
})
}
func TestIntegrationWithApplyDefaults(t *testing.T) {
t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) {
// Create a partial config with only some fields set
partialConfig := &Config{
MaxWorkers: 15, // Custom value (>= 10 to test preservation)
// Other fields left as zero values - should get defaults
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
// Create processor - should apply defaults to missing fields
processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil)
defer processor.Shutdown(context.Background())
// Processor should be created successfully
if processor == nil {
t.Error("Processor should be created with partial config")
}
// Test that the ApplyDefaults method worked correctly by creating the same config
// and applying defaults manually
expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing
// Should preserve custom values (when >= poolSize/2)
if expectedConfig.MaxWorkers != 50 { // max(poolSize/2, 15) = max(50, 15) = 50
t.Errorf("Expected MaxWorkers to be 50, got %d", expectedConfig.MaxWorkers)
}
// Should apply defaults for missing fields (auto-calculated queue size with hybrid scaling)
workerBasedSize := expectedConfig.MaxWorkers * 20
poolSize := 100 // Default pool size used in ApplyDefaults
poolBasedSize := poolSize
expectedQueueSize := util.Max(workerBasedSize, poolBasedSize)
expectedQueueSize = util.Min(expectedQueueSize, poolSize*5) // Cap by 5x pool size
if expectedConfig.HandoffQueueSize != expectedQueueSize {
t.Errorf("Expected HandoffQueueSize to be %d (max(20*MaxWorkers=%d, poolSize=%d) capped by 5*poolSize=%d), got %d",
expectedQueueSize, workerBasedSize, poolBasedSize, poolSize*5, expectedConfig.HandoffQueueSize)
}
// Test that queue size is always capped by 5x pool size
if expectedConfig.HandoffQueueSize > poolSize*5 {
t.Errorf("HandoffQueueSize (%d) should never exceed 5x pool size (%d)",
expectedConfig.HandoffQueueSize, poolSize*2)
}
if expectedConfig.RelaxedTimeout != 10*time.Second {
t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout)
}
if expectedConfig.HandoffTimeout != 15*time.Second {
t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout)
}
if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second {
t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration)
}
})
}
func TestEnhancedConfigValidation(t *testing.T) {
t.Run("ValidateFields", func(t *testing.T) {
config := DefaultConfig()
config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100
// Should pass validation with default values
if err := config.Validate(); err != nil {
t.Errorf("Default config should be valid, got error: %v", err)
}
// Test invalid MaxHandoffRetries
config.MaxHandoffRetries = 0
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 0")
}
config.MaxHandoffRetries = 11
if err := config.Validate(); err == nil {
t.Error("Expected validation error for MaxHandoffRetries = 11")
}
config.MaxHandoffRetries = 3 // Reset to valid value
// Test circuit breaker validation
config.CircuitBreakerFailureThreshold = 0
if err := config.Validate(); err != ErrInvalidCircuitBreakerFailureThreshold {
t.Errorf("Expected ErrInvalidCircuitBreakerFailureThreshold, got %v", err)
}
config.CircuitBreakerFailureThreshold = 5 // Reset to valid value
config.CircuitBreakerResetTimeout = -1 * time.Second
if err := config.Validate(); err != ErrInvalidCircuitBreakerResetTimeout {
t.Errorf("Expected ErrInvalidCircuitBreakerResetTimeout, got %v", err)
}
config.CircuitBreakerResetTimeout = 60 * time.Second // Reset to valid value
config.CircuitBreakerMaxRequests = 0
if err := config.Validate(); err != ErrInvalidCircuitBreakerMaxRequests {
t.Errorf("Expected ErrInvalidCircuitBreakerMaxRequests, got %v", err)
}
config.CircuitBreakerMaxRequests = 3 // Reset to valid value
// Should pass validation again
if err := config.Validate(); err != nil {
t.Errorf("Config should be valid after reset, got error: %v", err)
}
})
}
func TestConfigClone(t *testing.T) {
original := DefaultConfig()
original.MaxHandoffRetries = 7
original.HandoffTimeout = 8 * time.Second
cloned := original.Clone()
// Test that values are copied
if cloned.MaxHandoffRetries != 7 {
t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries)
}
if cloned.HandoffTimeout != 8*time.Second {
t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout)
}
// Test that modifying clone doesn't affect original
cloned.MaxHandoffRetries = 10
if original.MaxHandoffRetries != 7 {
t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries)
}
}
func TestMaxWorkersLogic(t *testing.T) {
t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) {
testCases := []struct {
poolSize int
expectedWorkers int
description string
}{
{6, 3, "Small pool: min(6/2, max(10, 6/3)) = min(3, max(10, 2)) = min(3, 10) = 3"},
{15, 7, "Medium pool: min(15/2, max(10, 15/3)) = min(7, max(10, 5)) = min(7, 10) = 7"},
{30, 10, "Large pool: min(30/2, max(10, 30/3)) = min(15, max(10, 10)) = min(15, 10) = 10"},
{60, 20, "Very large pool: min(60/2, max(10, 60/3)) = min(30, max(10, 20)) = min(30, 20) = 20"},
{120, 40, "Huge pool: min(120/2, max(10, 120/3)) = min(60, max(10, 40)) = min(60, 40) = 40"},
}
for _, tc := range testCases {
config := &Config{} // MaxWorkers = 0 (not set)
result := config.ApplyDefaultsWithPoolSize(tc.poolSize)
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)",
tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) {
testCases := []struct {
setValue int
expectedWorkers int
description string
}{
{1, 50, "Set 1: max(poolSize/2, 1) = max(50, 1) = 50 (enforced minimum)"},
{5, 50, "Set 5: max(poolSize/2, 5) = max(50, 5) = 50 (enforced minimum)"},
{8, 50, "Set 8: max(poolSize/2, 8) = max(50, 8) = 50 (enforced minimum)"},
{10, 50, "Set 10: max(poolSize/2, 10) = max(50, 10) = 50 (enforced minimum)"},
{15, 50, "Set 15: max(poolSize/2, 15) = max(50, 15) = 50 (enforced minimum)"},
{60, 60, "Set 60: max(poolSize/2, 60) = max(50, 60) = 60 (respects user choice)"},
}
for _, tc := range testCases {
config := &Config{
MaxWorkers: tc.setValue, // Explicitly set
}
result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values
if result.MaxWorkers != tc.expectedWorkers {
t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)",
tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description)
}
}
})
}

30
maintnotifications/e2e/.gitignore vendored Normal file
View File

@@ -0,0 +1,30 @@
# E2E test artifacts
*.log
*.out
test-results/
coverage/
profiles/
# Test data
test-data/
temp/
*.tmp
# CI artifacts
artifacts/
reports/
# Redis data files (if running local Redis for testing)
dump.rdb
appendonly.aof
redis.conf.local
# Performance test results
*.prof
*.trace
benchmarks/
# Docker compose files for local testing
docker-compose.override.yml
.env.local
infra/

View File

@@ -0,0 +1,363 @@
# Database Management with Fault Injector
This document describes how to use the fault injector's database management endpoints to create and delete Redis databases during E2E testing.
## Overview
The fault injector now supports two new endpoints for database management:
1. **CREATE_DATABASE** - Create a new Redis database with custom configuration
2. **DELETE_DATABASE** - Delete an existing Redis database
These endpoints are useful for E2E tests that need to dynamically create and destroy databases as part of their test scenarios.
## Action Types
### CREATE_DATABASE
Creates a new Redis database with the specified configuration.
**Parameters:**
- `cluster_index` (int): The index of the cluster where the database should be created
- `database_config` (object): The database configuration (see structure below)
**Raises:**
- `CreateDatabaseException`: When database creation fails
### DELETE_DATABASE
Deletes an existing Redis database.
**Parameters:**
- `cluster_index` (int): The index of the cluster containing the database
- `bdb_id` (int): The database ID to delete
**Raises:**
- `DeleteDatabaseException`: When database deletion fails
## Database Configuration Structure
The `database_config` object supports the following fields:
```go
type DatabaseConfig struct {
Name string `json:"name"`
Port int `json:"port"`
MemorySize int64 `json:"memory_size"`
Replication bool `json:"replication"`
EvictionPolicy string `json:"eviction_policy"`
Sharding bool `json:"sharding"`
AutoUpgrade bool `json:"auto_upgrade"`
ShardsCount int `json:"shards_count"`
ModuleList []DatabaseModule `json:"module_list,omitempty"`
OSSCluster bool `json:"oss_cluster"`
OSSClusterAPIPreferredIPType string `json:"oss_cluster_api_preferred_ip_type,omitempty"`
ProxyPolicy string `json:"proxy_policy,omitempty"`
ShardsPlacement string `json:"shards_placement,omitempty"`
ShardKeyRegex []ShardKeyRegexPattern `json:"shard_key_regex,omitempty"`
}
type DatabaseModule struct {
ModuleArgs string `json:"module_args"`
ModuleName string `json:"module_name"`
}
type ShardKeyRegexPattern struct {
Regex string `json:"regex"`
}
```
### Example Configuration
#### Simple Database
```json
{
"name": "simple-db",
"port": 12000,
"memory_size": 268435456,
"replication": false,
"eviction_policy": "noeviction",
"sharding": false,
"auto_upgrade": true,
"shards_count": 1,
"oss_cluster": false
}
```
#### Clustered Database with Modules
```json
{
"name": "ioredis-cluster",
"port": 11112,
"memory_size": 1273741824,
"replication": true,
"eviction_policy": "noeviction",
"sharding": true,
"auto_upgrade": true,
"shards_count": 3,
"module_list": [
{
"module_args": "",
"module_name": "ReJSON"
},
{
"module_args": "",
"module_name": "search"
},
{
"module_args": "",
"module_name": "timeseries"
},
{
"module_args": "",
"module_name": "bf"
}
],
"oss_cluster": true,
"oss_cluster_api_preferred_ip_type": "external",
"proxy_policy": "all-master-shards",
"shards_placement": "sparse",
"shard_key_regex": [
{
"regex": ".*\\{(?<tag>.*)\\}.*"
},
{
"regex": "(?<tag>.*)"
}
]
}
```
## Usage Examples
### Example 1: Create a Simple Database
```go
ctx := context.Background()
faultInjector := NewFaultInjectorClient("http://127.0.0.1:20324")
dbConfig := DatabaseConfig{
Name: "test-db",
Port: 12000,
MemorySize: 268435456, // 256MB
Replication: false,
EvictionPolicy: "noeviction",
Sharding: false,
AutoUpgrade: true,
ShardsCount: 1,
OSSCluster: false,
}
resp, err := faultInjector.CreateDatabase(ctx, 0, dbConfig)
if err != nil {
log.Fatalf("Failed to create database: %v", err)
}
// Wait for creation to complete
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
WithMaxWaitTime(5*time.Minute))
if err != nil {
log.Fatalf("Failed to wait for action: %v", err)
}
if status.Status == StatusSuccess {
log.Println("Database created successfully!")
}
```
### Example 2: Create a Database with Modules
```go
dbConfig := DatabaseConfig{
Name: "modules-db",
Port: 12001,
MemorySize: 536870912, // 512MB
Replication: true,
EvictionPolicy: "noeviction",
Sharding: true,
AutoUpgrade: true,
ShardsCount: 3,
ModuleList: []DatabaseModule{
{ModuleArgs: "", ModuleName: "ReJSON"},
{ModuleArgs: "", ModuleName: "search"},
},
OSSCluster: true,
OSSClusterAPIPreferredIPType: "external",
ProxyPolicy: "all-master-shards",
ShardsPlacement: "sparse",
}
resp, err := faultInjector.CreateDatabase(ctx, 0, dbConfig)
// ... handle response
```
### Example 3: Create Database Using a Map
```go
dbConfigMap := map[string]interface{}{
"name": "map-db",
"port": 12002,
"memory_size": 268435456,
"replication": false,
"eviction_policy": "volatile-lru",
"sharding": false,
"auto_upgrade": true,
"shards_count": 1,
"oss_cluster": false,
}
resp, err := faultInjector.CreateDatabaseFromMap(ctx, 0, dbConfigMap)
// ... handle response
```
### Example 4: Delete a Database
```go
clusterIndex := 0
bdbID := 1
resp, err := faultInjector.DeleteDatabase(ctx, clusterIndex, bdbID)
if err != nil {
log.Fatalf("Failed to delete database: %v", err)
}
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
WithMaxWaitTime(2*time.Minute))
if err != nil {
log.Fatalf("Failed to wait for action: %v", err)
}
if status.Status == StatusSuccess {
log.Println("Database deleted successfully!")
}
```
### Example 5: Complete Lifecycle (Create and Delete)
```go
// Create database
dbConfig := DatabaseConfig{
Name: "temp-db",
Port: 13000,
MemorySize: 268435456,
Replication: false,
EvictionPolicy: "noeviction",
Sharding: false,
AutoUpgrade: true,
ShardsCount: 1,
OSSCluster: false,
}
createResp, err := faultInjector.CreateDatabase(ctx, 0, dbConfig)
if err != nil {
log.Fatalf("Failed to create database: %v", err)
}
createStatus, err := faultInjector.WaitForAction(ctx, createResp.ActionID,
WithMaxWaitTime(5*time.Minute))
if err != nil || createStatus.Status != StatusSuccess {
log.Fatalf("Database creation failed")
}
// Extract bdb_id from output
var bdbID int
if id, ok := createStatus.Output["bdb_id"].(float64); ok {
bdbID = int(id)
}
// Use the database for testing...
time.Sleep(10 * time.Second)
// Delete the database
deleteResp, err := faultInjector.DeleteDatabase(ctx, 0, bdbID)
if err != nil {
log.Fatalf("Failed to delete database: %v", err)
}
deleteStatus, err := faultInjector.WaitForAction(ctx, deleteResp.ActionID,
WithMaxWaitTime(2*time.Minute))
if err != nil || deleteStatus.Status != StatusSuccess {
log.Fatalf("Database deletion failed")
}
log.Println("Database lifecycle completed successfully!")
```
## Available Methods
The `FaultInjectorClient` provides the following methods for database management:
### CreateDatabase
```go
func (c *FaultInjectorClient) CreateDatabase(
ctx context.Context,
clusterIndex int,
databaseConfig DatabaseConfig,
) (*ActionResponse, error)
```
Creates a new database using a structured `DatabaseConfig` object.
### CreateDatabaseFromMap
```go
func (c *FaultInjectorClient) CreateDatabaseFromMap(
ctx context.Context,
clusterIndex int,
databaseConfig map[string]interface{},
) (*ActionResponse, error)
```
Creates a new database using a flexible map configuration. Useful when you need to pass custom or dynamic configurations.
### DeleteDatabase
```go
func (c *FaultInjectorClient) DeleteDatabase(
ctx context.Context,
clusterIndex int,
bdbID int,
) (*ActionResponse, error)
```
Deletes an existing database by its ID.
## Testing
To run the database management E2E tests:
```bash
# Run all database management tests
go test -tags=e2e -v ./maintnotifications/e2e/ -run TestDatabase
# Run specific test
go test -tags=e2e -v ./maintnotifications/e2e/ -run TestDatabaseLifecycle
```
## Notes
- Database creation can take several minutes depending on the configuration
- Always use `WaitForAction` to ensure the operation completes before proceeding
- The `bdb_id` returned in the creation output should be used for deletion
- Deleting a non-existent database will result in a failed action status
- Memory sizes are specified in bytes (e.g., 268435456 = 256MB)
- Port numbers should be unique and not conflict with existing databases
## Common Eviction Policies
- `noeviction` - Return errors when memory limit is reached
- `allkeys-lru` - Evict any key using LRU algorithm
- `volatile-lru` - Evict keys with TTL using LRU algorithm
- `allkeys-random` - Evict random keys
- `volatile-random` - Evict random keys with TTL
- `volatile-ttl` - Evict keys with TTL, shortest TTL first
## Common Proxy Policies
- `all-master-shards` - Route to all master shards
- `all-nodes` - Route to all nodes
- `single-shard` - Route to a single shard

View File

@@ -0,0 +1,156 @@
# E2E Test Scenarios for Push Notifications
This directory contains comprehensive end-to-end test scenarios for Redis push notifications and maintenance notifications functionality. Each scenario tests different aspects of the system under various conditions.
## ⚠️ **Important Note**
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support maintenance notifications functionality.
## Introduction
To run those tests you would need a fault injector service, please review the client and feel free to implement your
fault injector of choice. Those tests are tailored for Redis Enterprise, but can be adapted to other Redis distributions where
a fault injector is available.
Once you have fault injector service up and running, you can execute the tests by running the `run-e2e-tests.sh` script.
there are three environment variables that need to be set before running the tests:
- `REDIS_ENDPOINTS_CONFIG_PATH`: Path to Redis endpoints configuration
- `FAULT_INJECTION_API_URL`: URL of the fault injector server
- `E2E_SCENARIO_TESTS`: Set to `true` to enable scenario tests
## Test Scenarios Overview
### 1. Basic Push Notifications (`scenario_push_notifications_test.go`)
**Original template scenario**
- **Purpose**: Basic functionality test for Redis Enterprise push notifications
- **Features Tested**: FAILING_OVER, FAILED_OVER, MIGRATING, MIGRATED, MOVING notifications
- **Configuration**: Standard enterprise cluster setup
- **Duration**: ~10 minutes
- **Key Validations**:
- All notification types received
- Timeout behavior (relaxed/unrelaxed)
- Handoff success rates
- Connection pool management
### 2. Endpoint Types Scenario (`scenario_endpoint_types_test.go`)
**Different endpoint resolution strategies**
- **Purpose**: Test push notifications with different endpoint types
- **Features Tested**: ExternalIP, InternalIP, InternalFQDN, ExternalFQDN endpoint types
- **Configuration**: Standard setup with varying endpoint types
- **Duration**: ~5 minutes
- **Key Validations**:
- Functionality with each endpoint type
- Proper endpoint resolution
- Notification delivery consistency
- Handoff behavior per endpoint type
### 3. Database Management Scenario (`scenario_database_management_test.go`)
**Dynamic database creation and deletion**
- **Purpose**: Test database lifecycle management via fault injector
- **Features Tested**: CREATE_DATABASE, DELETE_DATABASE endpoints
- **Configuration**: Various database configurations (simple, with modules, clustered)
- **Duration**: ~10 minutes
- **Key Validations**:
- Database creation with different configurations
- Database creation with Redis modules (ReJSON, search, timeseries, bf)
- Database deletion
- Complete lifecycle (create → use → delete)
- Configuration validation
See [DATABASE_MANAGEMENT.md](DATABASE_MANAGEMENT.md) for detailed documentation on database management endpoints.
### 4. Timeout Configurations Scenario (`scenario_timeout_configs_test.go`)
**Various timeout strategies**
- **Purpose**: Test different timeout configurations and their impact
- **Features Tested**: Conservative, Aggressive, HighLatency timeouts
- **Configuration**:
- Conservative: 60s handoff, 20s relaxed, 5s post-handoff
- Aggressive: 5s handoff, 3s relaxed, 1s post-handoff
- HighLatency: 90s handoff, 30s relaxed, 10m post-handoff
- **Duration**: ~10 minutes (3 sub-tests)
- **Key Validations**:
- Timeout behavior matches configuration
- Recovery times appropriate for each strategy
- Error rates correlate with timeout aggressiveness
### 5. TLS Configurations Scenario (`scenario_tls_configs_test.go`)
**Security and encryption testing framework**
- **Purpose**: Test push notifications with different TLS configurations
- **Features Tested**: NoTLS, TLSInsecure, TLSSecure, TLSMinimal, TLSStrict
- **Configuration**: Framework for testing various TLS settings (TLS config handled at connection level)
- **Duration**: ~10 minutes (multiple sub-tests)
- **Key Validations**:
- Functionality with each TLS configuration
- Performance impact of encryption
- Certificate handling (where applicable)
- Security compliance
- **Note**: TLS configuration is handled at the Redis connection config level, not client options level
### 6. Stress Test Scenario (`scenario_stress_test.go`)
**Extreme load and concurrent operations**
- **Purpose**: Test system limits and behavior under extreme stress
- **Features Tested**: Maximum concurrent operations, multiple clients
- **Configuration**:
- 4 clients with 150 pool size each
- 200 max connections per client
- 50 workers, 1000 queue size
- Concurrent failover/migration actions
- **Duration**: ~15 minutes
- **Key Validations**:
- System stability under extreme load
- Error rates within stress limits (<20%)
- Resource utilization and limits
- Concurrent fault injection handling
## Running the Scenarios
### Prerequisites
- Set environment variable: `E2E_SCENARIO_TESTS=true`
- Redis Enterprise cluster available
- Fault injection service available
- Appropriate network access and permissions
- **Note**: Tests use standalone Redis clients only (cluster clients not supported)
### Individual Scenario Execution
```bash
# Run a specific scenario
E2E_SCENARIO_TESTS=true go test -v ./maintnotifications/e2e -run TestEndpointTypesPushNotifications
# Run with timeout
E2E_SCENARIO_TESTS=true go test -v -timeout 30m ./maintnotifications/e2e -run TestStressPushNotifications
```
### All Scenarios Execution
```bash
./scripts/run-e2e-tests.sh
```
## Expected Outcomes
### Success Criteria
- All notifications received and processed correctly
- Error rates within acceptable limits for each scenario
- No notification processing errors
- Proper timeout behavior
- Successful handoffs
- Connection pool management within limits
### Performance Benchmarks
- **Basic**: >1000 operations, <1% errors
- **Stress**: >10000 operations, <20% errors
- **Others**: Functionality over performance
## Troubleshooting
### Common Issues
1. **Enterprise cluster not available**: Most scenarios require Redis Enterprise
2. **Fault injector unavailable**: Some scenarios need fault injection service
3. **Network timeouts**: Increase test timeouts for slow networks
4. **TLS certificate issues**: Some TLS scenarios may fail without proper certs
5. **Resource limits**: Stress scenarios may hit system limits
### Debug Options
- Enable detailed logging in scenarios
- Use `dump = true` to see full log analysis
- Check pool statistics for connection issues
- Monitor client resources during stress tests

View File

@@ -0,0 +1,137 @@
package e2e
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
)
type CommandRunnerStats struct {
Operations int64
Errors int64
TimeoutErrors int64
ErrorsList []error
}
// CommandRunner provides utilities for running commands during tests
type CommandRunner struct {
client redis.UniversalClient
stopCh chan struct{}
operationCount atomic.Int64
errorCount atomic.Int64
timeoutErrors atomic.Int64
errors []error
errorsMutex sync.Mutex
}
// NewCommandRunner creates a new command runner
func NewCommandRunner(client redis.UniversalClient) (*CommandRunner, func()) {
stopCh := make(chan struct{})
return &CommandRunner{
client: client,
stopCh: stopCh,
errors: make([]error, 0),
}, func() {
stopCh <- struct{}{}
}
}
func (cr *CommandRunner) Stop() {
select {
case cr.stopCh <- struct{}{}:
return
case <-time.After(500 * time.Millisecond):
return
}
}
func (cr *CommandRunner) Close() {
close(cr.stopCh)
}
// FireCommandsUntilStop runs commands continuously until stop signal
func (cr *CommandRunner) FireCommandsUntilStop(ctx context.Context) {
fmt.Printf("[CR] Starting command runner...\n")
defer fmt.Printf("[CR] Command runner stopped\n")
// High frequency for timeout testing
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
counter := 0
for {
select {
case <-cr.stopCh:
return
case <-ctx.Done():
return
case <-ticker.C:
poolSize := cr.client.PoolStats().IdleConns
if poolSize == 0 {
poolSize = 1
}
wg := sync.WaitGroup{}
for i := 0; i < int(poolSize); i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
key := fmt.Sprintf("timeout-test-key-%d-%d", counter, i)
value := fmt.Sprintf("timeout-test-value-%d-%d", counter, i)
// Use a short timeout context for individual operations
opCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err := cr.client.Set(opCtx, key, value, time.Minute).Err()
cancel()
cr.operationCount.Add(1)
if err != nil {
if err == redis.ErrClosed || strings.Contains(err.Error(), "client is closed") {
select {
case <-cr.stopCh:
return
default:
}
return
}
fmt.Printf("Error: %v\n", err)
cr.errorCount.Add(1)
// Check if it's a timeout error
if isTimeoutError(err) {
cr.timeoutErrors.Add(1)
}
cr.errorsMutex.Lock()
cr.errors = append(cr.errors, err)
cr.errorsMutex.Unlock()
}
}(i)
}
wg.Wait()
counter++
}
}
}
// GetStats returns operation statistics
func (cr *CommandRunner) GetStats() CommandRunnerStats {
cr.errorsMutex.Lock()
defer cr.errorsMutex.Unlock()
errorList := make([]error, len(cr.errors))
copy(errorList, cr.errors)
stats := CommandRunnerStats{
Operations: cr.operationCount.Load(),
Errors: cr.errorCount.Load(),
TimeoutErrors: cr.timeoutErrors.Load(),
ErrorsList: errorList,
}
return stats
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
// Package e2e provides end-to-end testing scenarios for the maintenance notifications system.
//
// This package contains comprehensive test scenarios that validate the maintenance notifications
// functionality in realistic environments. The tests are designed to work with Redis Enterprise
// clusters and require specific environment configuration.
//
// Environment Variables:
// - E2E_SCENARIO_TESTS: Set to "true" to enable scenario tests
// - REDIS_ENDPOINTS_CONFIG_PATH: Path to endpoints configuration file
// - FAULT_INJECTION_API_URL: URL for fault injection API (optional)
//
// Test Scenarios:
// - Basic Push Notifications: Core functionality testing
// - Endpoint Types: Different endpoint resolution strategies
// - Timeout Configurations: Various timeout strategies
// - TLS Configurations: Different TLS setups
// - Stress Testing: Extreme load and concurrent operations
//
// Note: Maintenance notifications are currently supported only in standalone Redis clients.
// Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
package e2e

View File

@@ -0,0 +1,110 @@
{
"standalone0": {
"password": "foobared",
"tls": false,
"endpoints": [
"redis://localhost:6379"
]
},
"standalone0-tls": {
"username": "default",
"password": "foobared",
"tls": true,
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
"endpoints": [
"rediss://localhost:6390"
]
},
"standalone0-acl": {
"username": "acljedis",
"password": "fizzbuzz",
"tls": false,
"endpoints": [
"redis://localhost:6379"
]
},
"standalone0-acl-tls": {
"username": "acljedis",
"password": "fizzbuzz",
"tls": true,
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
"endpoints": [
"rediss://localhost:6390"
]
},
"cluster0": {
"username": "default",
"password": "foobared",
"tls": false,
"endpoints": [
"redis://localhost:7001",
"redis://localhost:7002",
"redis://localhost:7003",
"redis://localhost:7004",
"redis://localhost:7005",
"redis://localhost:7006"
]
},
"cluster0-tls": {
"username": "default",
"password": "foobared",
"tls": true,
"certificatesLocation": "redis1-2-5-8-sentinel/work/tls",
"endpoints": [
"rediss://localhost:7011",
"rediss://localhost:7012",
"rediss://localhost:7013",
"rediss://localhost:7014",
"rediss://localhost:7015",
"rediss://localhost:7016"
]
},
"sentinel0": {
"username": "default",
"password": "foobared",
"tls": false,
"endpoints": [
"redis://localhost:26379",
"redis://localhost:26380",
"redis://localhost:26381"
]
},
"modules-docker": {
"tls": false,
"endpoints": [
"redis://localhost:6479"
]
},
"enterprise-cluster": {
"bdb_id": 1,
"username": "default",
"password": "enterprise-password",
"tls": true,
"raw_endpoints": [
{
"addr": ["10.0.0.1"],
"addr_type": "ipv4",
"dns_name": "redis-enterprise-cluster.example.com",
"oss_cluster_api_preferred_endpoint_type": "internal",
"oss_cluster_api_preferred_ip_type": "ipv4",
"port": 12000,
"proxy_policy": "single",
"uid": "endpoint-1"
},
{
"addr": ["10.0.0.2"],
"addr_type": "ipv4",
"dns_name": "redis-enterprise-cluster-2.example.com",
"oss_cluster_api_preferred_endpoint_type": "internal",
"oss_cluster_api_preferred_ip_type": "ipv4",
"port": 12000,
"proxy_policy": "single",
"uid": "endpoint-2"
}
],
"endpoints": [
"rediss://redis-enterprise-cluster.example.com:12000",
"rediss://redis-enterprise-cluster-2.example.com:12000"
]
}
}

View File

@@ -0,0 +1,644 @@
package e2e
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
// ActionType represents the type of fault injection action
type ActionType string
const (
// Redis cluster actions
ActionClusterFailover ActionType = "cluster_failover"
ActionClusterReshard ActionType = "cluster_reshard"
ActionClusterAddNode ActionType = "cluster_add_node"
ActionClusterRemoveNode ActionType = "cluster_remove_node"
ActionClusterMigrate ActionType = "cluster_migrate"
// Node-level actions
ActionNodeRestart ActionType = "node_restart"
ActionNodeStop ActionType = "node_stop"
ActionNodeStart ActionType = "node_start"
ActionNodeKill ActionType = "node_kill"
// Network simulation actions
ActionNetworkPartition ActionType = "network_partition"
ActionNetworkLatency ActionType = "network_latency"
ActionNetworkPacketLoss ActionType = "network_packet_loss"
ActionNetworkBandwidth ActionType = "network_bandwidth"
ActionNetworkRestore ActionType = "network_restore"
// Redis configuration actions
ActionConfigChange ActionType = "config_change"
ActionMaintenanceMode ActionType = "maintenance_mode"
ActionSlotMigration ActionType = "slot_migration"
// Sequence and complex actions
ActionSequence ActionType = "sequence_of_actions"
ActionExecuteCommand ActionType = "execute_command"
// Database management actions
ActionDeleteDatabase ActionType = "delete_database"
ActionCreateDatabase ActionType = "create_database"
)
// ActionStatus represents the status of an action
type ActionStatus string
const (
StatusPending ActionStatus = "pending"
StatusRunning ActionStatus = "running"
StatusFinished ActionStatus = "finished"
StatusFailed ActionStatus = "failed"
StatusSuccess ActionStatus = "success"
StatusCancelled ActionStatus = "cancelled"
)
// ActionRequest represents a request to trigger an action
type ActionRequest struct {
Type ActionType `json:"type"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
// ActionResponse represents the response from triggering an action
type ActionResponse struct {
ActionID string `json:"action_id"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
// ActionStatusResponse represents the status of an action
type ActionStatusResponse struct {
ActionID string `json:"action_id"`
Status ActionStatus `json:"status"`
Error interface{} `json:"error,omitempty"`
Output map[string]interface{} `json:"output,omitempty"`
Progress float64 `json:"progress,omitempty"`
StartTime time.Time `json:"start_time,omitempty"`
EndTime time.Time `json:"end_time,omitempty"`
}
// SequenceAction represents an action in a sequence
type SequenceAction struct {
Type ActionType `json:"type"`
Parameters map[string]interface{} `json:"params,omitempty"`
Delay time.Duration `json:"delay,omitempty"`
}
// FaultInjectorClient provides programmatic control over test infrastructure
type FaultInjectorClient struct {
baseURL string
httpClient *http.Client
}
// NewFaultInjectorClient creates a new fault injector client
func NewFaultInjectorClient(baseURL string) *FaultInjectorClient {
return &FaultInjectorClient{
baseURL: strings.TrimSuffix(baseURL, "/"),
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// GetBaseURL returns the base URL of the fault injector server
func (c *FaultInjectorClient) GetBaseURL() string {
return c.baseURL
}
// ListActions lists all available actions
func (c *FaultInjectorClient) ListActions(ctx context.Context) ([]ActionType, error) {
var actions []ActionType
err := c.request(ctx, "GET", "/actions", nil, &actions)
return actions, err
}
// TriggerAction triggers a specific action
func (c *FaultInjectorClient) TriggerAction(ctx context.Context, action ActionRequest) (*ActionResponse, error) {
var response ActionResponse
fmt.Printf("[FI] Triggering action: %+v\n", action)
err := c.request(ctx, "POST", "/action", action, &response)
return &response, err
}
func (c *FaultInjectorClient) TriggerSequence(ctx context.Context, bdbID int, actions []SequenceAction) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionSequence,
Parameters: map[string]interface{}{
"bdb_id": bdbID,
"actions": actions,
},
})
}
// GetActionStatus gets the status of a specific action
func (c *FaultInjectorClient) GetActionStatus(ctx context.Context, actionID string) (*ActionStatusResponse, error) {
var status ActionStatusResponse
err := c.request(ctx, "GET", fmt.Sprintf("/action/%s", actionID), nil, &status)
return &status, err
}
// WaitForAction waits for an action to complete
func (c *FaultInjectorClient) WaitForAction(ctx context.Context, actionID string, options ...WaitOption) (*ActionStatusResponse, error) {
config := &waitConfig{
pollInterval: 1 * time.Second,
maxWaitTime: 60 * time.Second,
}
for _, opt := range options {
opt(config)
}
deadline := time.Now().Add(config.maxWaitTime)
ticker := time.NewTicker(config.pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Until(deadline)):
return nil, fmt.Errorf("timeout waiting for action %s after %v", actionID, config.maxWaitTime)
case <-ticker.C:
status, err := c.GetActionStatus(ctx, actionID)
if err != nil {
return nil, fmt.Errorf("failed to get action status: %w", err)
}
switch status.Status {
case StatusFinished, StatusSuccess, StatusFailed, StatusCancelled:
return status, nil
}
}
}
}
// Cluster Management Actions
// TriggerClusterFailover triggers a cluster failover
func (c *FaultInjectorClient) TriggerClusterFailover(ctx context.Context, nodeID string, force bool) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionClusterFailover,
Parameters: map[string]interface{}{
"node_id": nodeID,
"force": force,
},
})
}
// TriggerClusterReshard triggers cluster resharding
func (c *FaultInjectorClient) TriggerClusterReshard(ctx context.Context, slots []int, sourceNode, targetNode string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionClusterReshard,
Parameters: map[string]interface{}{
"slots": slots,
"source_node": sourceNode,
"target_node": targetNode,
},
})
}
// TriggerSlotMigration triggers migration of specific slots
func (c *FaultInjectorClient) TriggerSlotMigration(ctx context.Context, startSlot, endSlot int, sourceNode, targetNode string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionSlotMigration,
Parameters: map[string]interface{}{
"start_slot": startSlot,
"end_slot": endSlot,
"source_node": sourceNode,
"target_node": targetNode,
},
})
}
// Node Management Actions
// RestartNode restarts a specific Redis node
func (c *FaultInjectorClient) RestartNode(ctx context.Context, nodeID string, graceful bool) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeRestart,
Parameters: map[string]interface{}{
"node_id": nodeID,
"graceful": graceful,
},
})
}
// StopNode stops a specific Redis node
func (c *FaultInjectorClient) StopNode(ctx context.Context, nodeID string, graceful bool) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeStop,
Parameters: map[string]interface{}{
"node_id": nodeID,
"graceful": graceful,
},
})
}
// StartNode starts a specific Redis node
func (c *FaultInjectorClient) StartNode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeStart,
Parameters: map[string]interface{}{
"node_id": nodeID,
},
})
}
// KillNode forcefully kills a Redis node
func (c *FaultInjectorClient) KillNode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNodeKill,
Parameters: map[string]interface{}{
"node_id": nodeID,
},
})
}
// Network Simulation Actions
// SimulateNetworkPartition simulates a network partition
func (c *FaultInjectorClient) SimulateNetworkPartition(ctx context.Context, nodes []string, duration time.Duration) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkPartition,
Parameters: map[string]interface{}{
"nodes": nodes,
"duration": duration.String(),
},
})
}
// SimulateNetworkLatency adds network latency
func (c *FaultInjectorClient) SimulateNetworkLatency(ctx context.Context, nodes []string, latency time.Duration, jitter time.Duration) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkLatency,
Parameters: map[string]interface{}{
"nodes": nodes,
"latency": latency.String(),
"jitter": jitter.String(),
},
})
}
// SimulatePacketLoss simulates packet loss
func (c *FaultInjectorClient) SimulatePacketLoss(ctx context.Context, nodes []string, lossPercent float64) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkPacketLoss,
Parameters: map[string]interface{}{
"nodes": nodes,
"loss_percent": lossPercent,
},
})
}
// LimitBandwidth limits network bandwidth
func (c *FaultInjectorClient) LimitBandwidth(ctx context.Context, nodes []string, bandwidth string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkBandwidth,
Parameters: map[string]interface{}{
"nodes": nodes,
"bandwidth": bandwidth,
},
})
}
// RestoreNetwork restores normal network conditions
func (c *FaultInjectorClient) RestoreNetwork(ctx context.Context, nodes []string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionNetworkRestore,
Parameters: map[string]interface{}{
"nodes": nodes,
},
})
}
// Configuration Actions
// ChangeConfig changes Redis configuration
func (c *FaultInjectorClient) ChangeConfig(ctx context.Context, nodeID string, config map[string]string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionConfigChange,
Parameters: map[string]interface{}{
"node_id": nodeID,
"config": config,
},
})
}
// EnableMaintenanceMode enables maintenance mode
func (c *FaultInjectorClient) EnableMaintenanceMode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionMaintenanceMode,
Parameters: map[string]interface{}{
"node_id": nodeID,
"enabled": true,
},
})
}
// DisableMaintenanceMode disables maintenance mode
func (c *FaultInjectorClient) DisableMaintenanceMode(ctx context.Context, nodeID string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionMaintenanceMode,
Parameters: map[string]interface{}{
"node_id": nodeID,
"enabled": false,
},
})
}
// Database Management Actions
// EnvDatabaseConfig represents the configuration for creating a database
type DatabaseConfig struct {
Name string `json:"name"`
Port int `json:"port"`
MemorySize int64 `json:"memory_size"`
Replication bool `json:"replication"`
EvictionPolicy string `json:"eviction_policy"`
Sharding bool `json:"sharding"`
AutoUpgrade bool `json:"auto_upgrade"`
ShardsCount int `json:"shards_count"`
ModuleList []DatabaseModule `json:"module_list,omitempty"`
OSSCluster bool `json:"oss_cluster"`
OSSClusterAPIPreferredIPType string `json:"oss_cluster_api_preferred_ip_type,omitempty"`
ProxyPolicy string `json:"proxy_policy,omitempty"`
ShardsPlacement string `json:"shards_placement,omitempty"`
ShardKeyRegex []ShardKeyRegexPattern `json:"shard_key_regex,omitempty"`
}
// DatabaseModule represents a Redis module configuration
type DatabaseModule struct {
ModuleArgs string `json:"module_args"`
ModuleName string `json:"module_name"`
}
// ShardKeyRegexPattern represents a shard key regex pattern
type ShardKeyRegexPattern struct {
Regex string `json:"regex"`
}
// DeleteDatabase deletes a database
// Parameters:
// - clusterIndex: The index of the cluster
// - bdbID: The database ID to delete
func (c *FaultInjectorClient) DeleteDatabase(ctx context.Context, clusterIndex int, bdbID int) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionDeleteDatabase,
Parameters: map[string]interface{}{
"cluster_index": clusterIndex,
"bdb_id": bdbID,
},
})
}
// CreateDatabase creates a new database
// Parameters:
// - clusterIndex: The index of the cluster
// - databaseConfig: The database configuration
func (c *FaultInjectorClient) CreateDatabase(ctx context.Context, clusterIndex int, databaseConfig DatabaseConfig) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionCreateDatabase,
Parameters: map[string]interface{}{
"cluster_index": clusterIndex,
"database_config": databaseConfig,
},
})
}
// CreateDatabaseFromMap creates a new database using a map for configuration
// This is useful when you want to pass a raw configuration map
// Parameters:
// - clusterIndex: The index of the cluster
// - databaseConfig: The database configuration as a map
func (c *FaultInjectorClient) CreateDatabaseFromMap(ctx context.Context, clusterIndex int, databaseConfig map[string]interface{}) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionCreateDatabase,
Parameters: map[string]interface{}{
"cluster_index": clusterIndex,
"database_config": databaseConfig,
},
})
}
// Complex Actions
// ExecuteSequence executes a sequence of actions
func (c *FaultInjectorClient) ExecuteSequence(ctx context.Context, actions []SequenceAction) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionSequence,
Parameters: map[string]interface{}{
"actions": actions,
},
})
}
// ExecuteCommand executes a custom command
func (c *FaultInjectorClient) ExecuteCommand(ctx context.Context, nodeID, command string) (*ActionResponse, error) {
return c.TriggerAction(ctx, ActionRequest{
Type: ActionExecuteCommand,
Parameters: map[string]interface{}{
"node_id": nodeID,
"command": command,
},
})
}
// Convenience Methods
// SimulateClusterUpgrade simulates a complete cluster upgrade scenario
func (c *FaultInjectorClient) SimulateClusterUpgrade(ctx context.Context, nodes []string) (*ActionResponse, error) {
actions := make([]SequenceAction, 0, len(nodes)*2)
// Rolling restart of all nodes
for i, nodeID := range nodes {
actions = append(actions, SequenceAction{
Type: ActionNodeRestart,
Parameters: map[string]interface{}{
"node_id": nodeID,
"graceful": true,
},
Delay: time.Duration(i*10) * time.Second, // Stagger restarts
})
}
return c.ExecuteSequence(ctx, actions)
}
// SimulateNetworkIssues simulates various network issues
func (c *FaultInjectorClient) SimulateNetworkIssues(ctx context.Context, nodes []string) (*ActionResponse, error) {
actions := []SequenceAction{
{
Type: ActionNetworkLatency,
Parameters: map[string]interface{}{
"nodes": nodes,
"latency": "100ms",
"jitter": "20ms",
},
},
{
Type: ActionNetworkPacketLoss,
Parameters: map[string]interface{}{
"nodes": nodes,
"loss_percent": 2.0,
},
Delay: 30 * time.Second,
},
{
Type: ActionNetworkRestore,
Parameters: map[string]interface{}{
"nodes": nodes,
},
Delay: 60 * time.Second,
},
}
return c.ExecuteSequence(ctx, actions)
}
// Helper types and functions
type waitConfig struct {
pollInterval time.Duration
maxWaitTime time.Duration
}
type WaitOption func(*waitConfig)
// WithPollInterval sets the polling interval for waiting
func WithPollInterval(interval time.Duration) WaitOption {
return func(c *waitConfig) {
c.pollInterval = interval
}
}
// WithMaxWaitTime sets the maximum wait time
func WithMaxWaitTime(maxWait time.Duration) WaitOption {
return func(c *waitConfig) {
c.maxWaitTime = maxWait
}
}
// Internal HTTP request method
func (c *FaultInjectorClient) request(ctx context.Context, method, path string, body interface{}, result interface{}) error {
url := c.baseURL + path
var reqBody io.Reader
if body != nil {
jsonData, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
reqBody = bytes.NewReader(jsonData)
}
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
if result != nil {
if err := json.Unmarshal(respBody, result); err != nil {
// happens when the API changes and the response structure changes
// sometimes the output of the action status is map, sometimes it is json.
// since we don't have a proper response structure we are going to handle it here
if result, ok := result.(*ActionStatusResponse); ok {
mapResult := map[string]interface{}{}
err = json.Unmarshal(respBody, &mapResult)
if err != nil {
fmt.Println("Failed to unmarshal response:", string(respBody))
panic(err)
}
result.Error = mapResult["error"]
result.Output = map[string]interface{}{"result": mapResult["output"]}
if status, ok := mapResult["status"].(string); ok {
result.Status = ActionStatus(status)
}
if result.Status == StatusSuccess || result.Status == StatusFailed || result.Status == StatusCancelled {
result.EndTime = time.Now()
}
if progress, ok := mapResult["progress"].(float64); ok {
result.Progress = progress
}
if actionID, ok := mapResult["action_id"].(string); ok {
result.ActionID = actionID
}
return nil
}
fmt.Println("Failed to unmarshal response:", string(respBody))
panic(err)
}
}
return nil
}
// Utility functions for common scenarios
// GetClusterNodes returns a list of cluster node IDs
func GetClusterNodes() []string {
// TODO Implement
// This would typically be configured via environment or discovery
return []string{"node-1", "node-2", "node-3", "node-4", "node-5", "node-6"}
}
// GetMasterNodes returns a list of master node IDs
func GetMasterNodes() []string {
// TODO Implement
return []string{"node-1", "node-2", "node-3"}
}
// GetSlaveNodes returns a list of slave node IDs
func GetSlaveNodes() []string {
// TODO Implement
return []string{"node-4", "node-5", "node-6"}
}
// ParseNodeID extracts node ID from various formats
func ParseNodeID(nodeAddr string) string {
// Extract node ID from address like "redis-node-1:7001" -> "node-1"
parts := strings.Split(nodeAddr, ":")
if len(parts) > 0 {
addr := parts[0]
if strings.Contains(addr, "redis-") {
return strings.TrimPrefix(addr, "redis-")
}
return addr
}
return nodeAddr
}
// FormatSlotRange formats a slot range for Redis commands
func FormatSlotRange(start, end int) string {
if start == end {
return strconv.Itoa(start)
}
return fmt.Sprintf("%d-%d", start, end)
}

View File

@@ -0,0 +1,434 @@
package e2e
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// logs is a slice of strings that provides additional functionality
// for filtering and analysis
type logs []string
func (l logs) Contains(searchString string) bool {
for _, log := range l {
if log == searchString {
return true
}
}
return false
}
func (l logs) GetCount() int {
return len(l)
}
func (l logs) GetCountThatContain(searchString string) int {
count := 0
for _, log := range l {
if strings.Contains(log, searchString) {
count++
}
}
return count
}
func (l logs) GetLogsFiltered(filter func(string) bool) []string {
filteredLogs := make([]string, 0, len(l))
for _, log := range l {
if filter(log) {
filteredLogs = append(filteredLogs, log)
}
}
return filteredLogs
}
func (l logs) GetTimedOutLogs() logs {
return l.GetLogsFiltered(isTimeout)
}
func (l logs) GetLogsPerConn(connID uint64) logs {
return l.GetLogsFiltered(func(log string) bool {
return strings.Contains(log, fmt.Sprintf("conn[%d]", connID))
})
}
func (l logs) GetAnalysis() *LogAnalisis {
return NewLogAnalysis(l)
}
// TestLogCollector is a simple logger that captures logs for analysis
// It is thread safe and can be used to capture logs from multiple clients
// It uses type logs to provide additional functionality like filtering
// and analysis
type TestLogCollector struct {
l logs
doPrint bool
matchFuncs []*MatchFunc
matchFuncsMutex sync.Mutex
mu sync.Mutex
}
func (tlc *TestLogCollector) DontPrint() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.doPrint = false
}
func (tlc *TestLogCollector) DoPrint() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.l = make([]string, 0)
tlc.doPrint = true
}
// MatchFunc is a slice of functions that check the logs for a specific condition
// use in WaitForLogMatchFunc
type MatchFunc struct {
completed atomic.Bool
F func(lstring string) bool
matches []string
found chan struct{} // channel to notify when match is found, will be closed
done func()
}
func (tlc *TestLogCollector) Printf(_ context.Context, format string, v ...interface{}) {
tlc.mu.Lock()
defer tlc.mu.Unlock()
lstr := fmt.Sprintf(format, v...)
if len(tlc.matchFuncs) > 0 {
go func(lstr string) {
for _, matchFunc := range tlc.matchFuncs {
if matchFunc.F(lstr) {
matchFunc.matches = append(matchFunc.matches, lstr)
matchFunc.done()
return
}
}
}(lstr)
}
if tlc.doPrint {
fmt.Println(lstr)
}
tlc.l = append(tlc.l, fmt.Sprintf(format, v...))
}
func (tlc *TestLogCollector) WaitForLogContaining(searchString string, timeout time.Duration) bool {
timeoutCh := time.After(timeout)
ticker := time.NewTicker(100 * time.Millisecond)
for {
select {
case <-timeoutCh:
return false
case <-ticker.C:
if tlc.Contains(searchString) {
return true
}
}
}
}
func (tlc *TestLogCollector) MatchOrWaitForLogMatchFunc(mf func(string) bool, timeout time.Duration) (string, bool) {
if logs := tlc.GetLogsFiltered(mf); len(logs) > 0 {
return logs[0], true
}
return tlc.WaitForLogMatchFunc(mf, timeout)
}
func (tlc *TestLogCollector) WaitForLogMatchFunc(mf func(string) bool, timeout time.Duration) (string, bool) {
matchFunc := &MatchFunc{
completed: atomic.Bool{},
F: mf,
found: make(chan struct{}),
matches: make([]string, 0),
}
matchFunc.done = func() {
if !matchFunc.completed.CompareAndSwap(false, true) {
return
}
close(matchFunc.found)
tlc.matchFuncsMutex.Lock()
defer tlc.matchFuncsMutex.Unlock()
for i, mf := range tlc.matchFuncs {
if mf == matchFunc {
tlc.matchFuncs = append(tlc.matchFuncs[:i], tlc.matchFuncs[i+1:]...)
return
}
}
}
tlc.matchFuncsMutex.Lock()
tlc.matchFuncs = append(tlc.matchFuncs, matchFunc)
tlc.matchFuncsMutex.Unlock()
select {
case <-matchFunc.found:
return matchFunc.matches[0], true
case <-time.After(timeout):
return "", false
}
}
func (tlc *TestLogCollector) GetLogs() logs {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l
}
func (tlc *TestLogCollector) DumpLogs() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
fmt.Println("Dumping logs:")
fmt.Println("===================================================")
for _, log := range tlc.l {
fmt.Println(log)
}
}
func (tlc *TestLogCollector) ClearLogs() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.l = make([]string, 0)
}
func (tlc *TestLogCollector) Contains(searchString string) bool {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.Contains(searchString)
}
func (tlc *TestLogCollector) MatchContainsAll(searchStrings []string) []string {
// match a log that contains all
return tlc.GetLogsFiltered(func(log string) bool {
for _, searchString := range searchStrings {
if !strings.Contains(log, searchString) {
return false
}
}
return true
})
}
func (tlc *TestLogCollector) GetLogCount() int {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetCount()
}
func (tlc *TestLogCollector) GetLogCountThatContain(searchString string) int {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetCountThatContain(searchString)
}
func (tlc *TestLogCollector) GetLogsFiltered(filter func(string) bool) logs {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetLogsFiltered(filter)
}
func (tlc *TestLogCollector) GetTimedOutLogs() []string {
return tlc.GetLogsFiltered(isTimeout)
}
func (tlc *TestLogCollector) GetLogsPerConn(connID uint64) logs {
tlc.mu.Lock()
defer tlc.mu.Unlock()
return tlc.l.GetLogsPerConn(connID)
}
func (tlc *TestLogCollector) GetAnalysisForConn(connID uint64) *LogAnalisis {
return NewLogAnalysis(tlc.GetLogsPerConn(connID))
}
func NewTestLogCollector() *TestLogCollector {
return &TestLogCollector{
l: make([]string, 0),
}
}
func (tlc *TestLogCollector) GetAnalysis() *LogAnalisis {
return NewLogAnalysis(tlc.GetLogs())
}
func (tlc *TestLogCollector) Clear() {
tlc.mu.Lock()
defer tlc.mu.Unlock()
tlc.matchFuncs = make([]*MatchFunc, 0)
tlc.l = make([]string, 0)
}
// LogAnalisis provides analysis of logs captured by TestLogCollector
type LogAnalisis struct {
logs []string
TimeoutErrorsCount int64
RelaxedTimeoutCount int64
RelaxedPostHandoffCount int64
UnrelaxedTimeoutCount int64
UnrelaxedAfterMoving int64
ConnectionCount int64
connLogs map[uint64][]string
connIds map[uint64]bool
TotalNotifications int64
MovingCount int64
MigratingCount int64
MigratedCount int64
FailingOverCount int64
FailedOverCount int64
UnexpectedCount int64
TotalHandoffCount int64
FailedHandoffCount int64
SucceededHandoffCount int64
TotalHandoffRetries int64
TotalHandoffToCurrentEndpoint int64
}
func NewLogAnalysis(logs []string) *LogAnalisis {
la := &LogAnalisis{
logs: logs,
connLogs: make(map[uint64][]string),
connIds: make(map[uint64]bool),
}
la.Analyze()
return la
}
func (la *LogAnalisis) Analyze() {
hasMoving := false
for _, log := range la.logs {
if isTimeout(log) {
la.TimeoutErrorsCount++
}
if strings.Contains(log, "MOVING") {
hasMoving = true
}
if strings.Contains(log, logs2.RelaxedTimeoutDueToNotificationMessage) {
la.RelaxedTimeoutCount++
}
if strings.Contains(log, logs2.ApplyingRelaxedTimeoutDueToPostHandoffMessage) {
la.RelaxedTimeoutCount++
la.RelaxedPostHandoffCount++
}
if strings.Contains(log, logs2.UnrelaxedTimeoutMessage) {
la.UnrelaxedTimeoutCount++
}
if strings.Contains(log, logs2.UnrelaxedTimeoutAfterDeadlineMessage) {
if hasMoving {
la.UnrelaxedAfterMoving++
} else {
fmt.Printf("Unrelaxed after deadline but no MOVING: %s\n", log)
}
}
if strings.Contains(log, logs2.ProcessingNotificationMessage) {
la.TotalNotifications++
switch {
case notificationType(log, "MOVING"):
la.MovingCount++
case notificationType(log, "MIGRATING"):
la.MigratingCount++
case notificationType(log, "MIGRATED"):
la.MigratedCount++
case notificationType(log, "FAILING_OVER"):
la.FailingOverCount++
case notificationType(log, "FAILED_OVER"):
la.FailedOverCount++
default:
fmt.Printf("[ERROR] Unexpected notification: %s\n", log)
la.UnexpectedCount++
}
}
if strings.Contains(log, "conn[") {
connID := extractConnID(log)
if _, ok := la.connIds[connID]; !ok {
la.connIds[connID] = true
la.ConnectionCount++
}
la.connLogs[connID] = append(la.connLogs[connID], log)
}
if strings.Contains(log, logs2.SchedulingHandoffToCurrentEndpointMessage) {
la.TotalHandoffToCurrentEndpoint++
}
if strings.Contains(log, logs2.HandoffSuccessMessage) {
la.SucceededHandoffCount++
}
if strings.Contains(log, logs2.HandoffFailedMessage) {
la.FailedHandoffCount++
}
if strings.Contains(log, logs2.HandoffStartedMessage) {
la.TotalHandoffCount++
}
if strings.Contains(log, logs2.HandoffRetryAttemptMessage) {
la.TotalHandoffRetries++
}
}
}
func (la *LogAnalisis) Print(t *testing.T) {
t.Logf("Log Analysis results for %d logs and %d connections:", len(la.logs), len(la.connIds))
t.Logf("Connection Count: %d", la.ConnectionCount)
t.Logf("-------------")
t.Logf("-Timeout Analysis-")
t.Logf("-------------")
t.Logf("Timeout Errors: %d", la.TimeoutErrorsCount)
t.Logf("Relaxed Timeout Count: %d", la.RelaxedTimeoutCount)
t.Logf(" - Relaxed Timeout After Post-Handoff: %d", la.RelaxedPostHandoffCount)
t.Logf("Unrelaxed Timeout Count: %d", la.UnrelaxedTimeoutCount)
t.Logf(" - Unrelaxed Timeout After Moving: %d", la.UnrelaxedAfterMoving)
t.Logf("-------------")
t.Logf("-Handoff Analysis-")
t.Logf("-------------")
t.Logf("Total Handoffs: %d", la.TotalHandoffCount)
t.Logf(" - Succeeded: %d", la.SucceededHandoffCount)
t.Logf(" - Failed: %d", la.FailedHandoffCount)
t.Logf(" - Retries: %d", la.TotalHandoffRetries)
t.Logf(" - Handoffs to current endpoint: %d", la.TotalHandoffToCurrentEndpoint)
t.Logf("-------------")
t.Logf("-Notification Analysis-")
t.Logf("-------------")
t.Logf("Total Notifications: %d", la.TotalNotifications)
t.Logf(" - MOVING: %d", la.MovingCount)
t.Logf(" - MIGRATING: %d", la.MigratingCount)
t.Logf(" - MIGRATED: %d", la.MigratedCount)
t.Logf(" - FAILING_OVER: %d", la.FailingOverCount)
t.Logf(" - FAILED_OVER: %d", la.FailedOverCount)
t.Logf(" - Unexpected: %d", la.UnexpectedCount)
t.Logf("-------------")
t.Logf("Log Analysis completed successfully")
}
func extractConnID(log string) uint64 {
logParts := strings.Split(log, "conn[")
if len(logParts) < 2 {
return 0
}
connIDStr := strings.Split(logParts[1], "]")[0]
connID, err := strconv.ParseUint(connIDStr, 10, 64)
if err != nil {
return 0
}
return connID
}
func notificationType(log string, nt string) bool {
return strings.Contains(log, nt)
}
func connID(log string, connID uint64) bool {
return strings.Contains(log, fmt.Sprintf("conn[%d]", connID))
}
func seqID(log string, seqID int64) bool {
return strings.Contains(log, fmt.Sprintf("seqID[%d]", seqID))
}

View File

@@ -0,0 +1,39 @@
package e2e
import (
"log"
"os"
"testing"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
)
// Global log collector
var logCollector *TestLogCollector
// Global fault injector client
var faultInjector *FaultInjectorClient
func TestMain(m *testing.M) {
var err error
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
log.Println("Skipping scenario tests, E2E_SCENARIO_TESTS is not set")
return
}
faultInjector, err = CreateTestFaultInjector()
if err != nil {
panic("Failed to create fault injector: " + err.Error())
}
// use log collector to capture logs from redis clients
logCollector = NewTestLogCollector()
redis.SetLogger(logCollector)
redis.SetLogLevel(logging.LogLevelDebug)
logCollector.Clear()
defer logCollector.Clear()
log.Println("Running scenario tests...")
status := m.Run()
os.Exit(status)
}

View File

@@ -0,0 +1,435 @@
package e2e
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
// DiagnosticsEvent represents a notification event
// it may be a push notification or an error when processing
// push notifications
type DiagnosticsEvent struct {
// is this pre or post hook
Type string `json:"type"`
ConnID uint64 `json:"connID"`
SeqID int64 `json:"seqID"`
Error error `json:"error"`
Pre bool `json:"pre"`
Timestamp time.Time `json:"timestamp"`
Details map[string]interface{} `json:"details"`
}
// TrackingNotificationsHook is a notification hook that tracks notifications
type TrackingNotificationsHook struct {
// unique connection count
connectionCount atomic.Int64
// timeouts
relaxedTimeoutCount atomic.Int64
unrelaxedTimeoutCount atomic.Int64
notificationProcessingErrors atomic.Int64
// notification types
totalNotifications atomic.Int64
migratingCount atomic.Int64
migratedCount atomic.Int64
failingOverCount atomic.Int64
failedOverCount atomic.Int64
movingCount atomic.Int64
unexpectedNotificationCount atomic.Int64
diagnosticsLog []DiagnosticsEvent
connIds map[uint64]bool
connLogs map[uint64][]DiagnosticsEvent
mutex sync.RWMutex
}
// NewTrackingNotificationsHook creates a new notification hook with counters
func NewTrackingNotificationsHook() *TrackingNotificationsHook {
return &TrackingNotificationsHook{
diagnosticsLog: make([]DiagnosticsEvent, 0),
connIds: make(map[uint64]bool),
connLogs: make(map[uint64][]DiagnosticsEvent),
}
}
// it is not reusable, but just to keep it consistent
// with the log collector
func (tnh *TrackingNotificationsHook) Clear() {
tnh.mutex.Lock()
defer tnh.mutex.Unlock()
tnh.diagnosticsLog = make([]DiagnosticsEvent, 0)
tnh.connIds = make(map[uint64]bool)
tnh.connLogs = make(map[uint64][]DiagnosticsEvent)
tnh.relaxedTimeoutCount.Store(0)
tnh.unrelaxedTimeoutCount.Store(0)
tnh.notificationProcessingErrors.Store(0)
tnh.totalNotifications.Store(0)
tnh.migratingCount.Store(0)
tnh.migratedCount.Store(0)
tnh.failingOverCount.Store(0)
}
// wait for notification in prehook
func (tnh *TrackingNotificationsHook) FindOrWaitForNotification(notificationType string, timeout time.Duration) (notification []interface{}, found bool) {
if notification, found := tnh.FindNotification(notificationType); found {
return notification, true
}
// wait for notification
timeoutCh := time.After(timeout)
ticker := time.NewTicker(100 * time.Millisecond)
for {
select {
case <-timeoutCh:
return nil, false
case <-ticker.C:
if notification, found := tnh.FindNotification(notificationType); found {
return notification, true
}
}
}
}
func (tnh *TrackingNotificationsHook) FindNotification(notificationType string) (notification []interface{}, found bool) {
tnh.mutex.RLock()
defer tnh.mutex.RUnlock()
for _, event := range tnh.diagnosticsLog {
if event.Type == notificationType {
return event.Details["notification"].([]interface{}), true
}
}
return nil, false
}
// PreHook captures timeout-related events before processing
func (tnh *TrackingNotificationsHook) PreHook(_ context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
tnh.increaseNotificationCount(notificationType)
tnh.storeDiagnosticsEvent(notificationType, notification, notificationCtx)
tnh.increaseRelaxedTimeoutCount(notificationType)
return notification, true
}
func (tnh *TrackingNotificationsHook) getConnID(notificationCtx push.NotificationHandlerContext) uint64 {
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
return conn.GetID()
}
return 0
}
func (tnh *TrackingNotificationsHook) getSeqID(notification []interface{}) int64 {
seqID, ok := notification[1].(int64)
if !ok {
return 0
}
return seqID
}
// PostHook captures the result after processing push notification
func (tnh *TrackingNotificationsHook) PostHook(_ context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, err error) {
if err != nil {
event := DiagnosticsEvent{
Type: notificationType + "_ERROR",
ConnID: tnh.getConnID(notificationCtx),
SeqID: tnh.getSeqID(notification),
Error: err,
Timestamp: time.Now(),
Details: map[string]interface{}{
"notification": notification,
"context": "post-hook",
},
}
tnh.notificationProcessingErrors.Add(1)
tnh.mutex.Lock()
tnh.diagnosticsLog = append(tnh.diagnosticsLog, event)
tnh.mutex.Unlock()
}
}
func (tnh *TrackingNotificationsHook) storeDiagnosticsEvent(notificationType string, notification []interface{}, notificationCtx push.NotificationHandlerContext) {
connID := tnh.getConnID(notificationCtx)
event := DiagnosticsEvent{
Type: notificationType,
ConnID: connID,
SeqID: tnh.getSeqID(notification),
Pre: true,
Timestamp: time.Now(),
Details: map[string]interface{}{
"notification": notification,
"context": "pre-hook",
},
}
tnh.mutex.Lock()
if v, ok := tnh.connIds[connID]; !ok || !v {
tnh.connIds[connID] = true
tnh.connectionCount.Add(1)
}
tnh.connLogs[connID] = append(tnh.connLogs[connID], event)
tnh.diagnosticsLog = append(tnh.diagnosticsLog, event)
tnh.mutex.Unlock()
}
// GetRelaxedTimeoutCount returns the count of relaxed timeout events
func (tnh *TrackingNotificationsHook) GetRelaxedTimeoutCount() int64 {
return tnh.relaxedTimeoutCount.Load()
}
// GetUnrelaxedTimeoutCount returns the count of unrelaxed timeout events
func (tnh *TrackingNotificationsHook) GetUnrelaxedTimeoutCount() int64 {
return tnh.unrelaxedTimeoutCount.Load()
}
// GetNotificationProcessingErrors returns the count of timeout errors
func (tnh *TrackingNotificationsHook) GetNotificationProcessingErrors() int64 {
return tnh.notificationProcessingErrors.Load()
}
// GetTotalNotifications returns the total number of notifications processed
func (tnh *TrackingNotificationsHook) GetTotalNotifications() int64 {
return tnh.totalNotifications.Load()
}
// GetConnectionCount returns the current connection count
func (tnh *TrackingNotificationsHook) GetConnectionCount() int64 {
return tnh.connectionCount.Load()
}
// GetMovingCount returns the count of MOVING notifications
func (tnh *TrackingNotificationsHook) GetMovingCount() int64 {
return tnh.movingCount.Load()
}
// GetDiagnosticsLog returns a copy of the diagnostics log
func (tnh *TrackingNotificationsHook) GetDiagnosticsLog() []DiagnosticsEvent {
tnh.mutex.RLock()
defer tnh.mutex.RUnlock()
logCopy := make([]DiagnosticsEvent, len(tnh.diagnosticsLog))
copy(logCopy, tnh.diagnosticsLog)
return logCopy
}
func (tnh *TrackingNotificationsHook) increaseNotificationCount(notificationType string) {
tnh.totalNotifications.Add(1)
switch notificationType {
case "MOVING":
tnh.movingCount.Add(1)
case "MIGRATING":
tnh.migratingCount.Add(1)
case "MIGRATED":
tnh.migratedCount.Add(1)
case "FAILING_OVER":
tnh.failingOverCount.Add(1)
case "FAILED_OVER":
tnh.failedOverCount.Add(1)
default:
tnh.unexpectedNotificationCount.Add(1)
}
}
func (tnh *TrackingNotificationsHook) increaseRelaxedTimeoutCount(notificationType string) {
switch notificationType {
case "MIGRATING", "FAILING_OVER":
tnh.relaxedTimeoutCount.Add(1)
case "MIGRATED", "FAILED_OVER":
tnh.unrelaxedTimeoutCount.Add(1)
}
}
// setupNotificationHook sets up tracking for both regular and cluster clients with notification hooks
func setupNotificationHook(client redis.UniversalClient, hook maintnotifications.NotificationHook) {
if clusterClient, ok := client.(*redis.ClusterClient); ok {
setupClusterClientNotificationHook(clusterClient, hook)
} else if regularClient, ok := client.(*redis.Client); ok {
setupRegularClientNotificationHook(regularClient, hook)
}
}
// setupNotificationHooks sets up tracking for both regular and cluster clients with notification hooks
func setupNotificationHooks(client redis.UniversalClient, hooks ...maintnotifications.NotificationHook) {
for _, hook := range hooks {
setupNotificationHook(client, hook)
}
}
// setupRegularClientNotificationHook sets up notification hook for regular clients
func setupRegularClientNotificationHook(client *redis.Client, hook maintnotifications.NotificationHook) {
maintnotificationsManager := client.GetMaintNotificationsManager()
if maintnotificationsManager != nil {
maintnotificationsManager.AddNotificationHook(hook)
} else {
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking\n")
}
}
// setupClusterClientNotificationHook sets up notification hook for cluster clients
func setupClusterClientNotificationHook(client *redis.ClusterClient, hook maintnotifications.NotificationHook) {
ctx := context.Background()
// Register hook on existing nodes
err := client.ForEachShard(ctx, func(ctx context.Context, nodeClient *redis.Client) error {
maintnotificationsManager := nodeClient.GetMaintNotificationsManager()
if maintnotificationsManager != nil {
maintnotificationsManager.AddNotificationHook(hook)
} else {
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking on node: %s\n", nodeClient.Options().Addr)
}
return nil
})
if err != nil {
fmt.Printf("[TNH] Warning: Failed to register timeout tracking hooks on existing cluster nodes: %v\n", err)
}
// Register hook on new nodes
client.OnNewNode(func(nodeClient *redis.Client) {
maintnotificationsManager := nodeClient.GetMaintNotificationsManager()
if maintnotificationsManager != nil {
maintnotificationsManager.AddNotificationHook(hook)
} else {
fmt.Printf("[TNH] Warning: Maintenance notifications manager not available for tracking on new node: %s\n", nodeClient.Options().Addr)
}
})
}
// filterPushNotificationLogs filters the diagnostics log for push notification events
func filterPushNotificationLogs(diagnosticsLog []DiagnosticsEvent) []DiagnosticsEvent {
var pushNotificationLogs []DiagnosticsEvent
for _, log := range diagnosticsLog {
switch log.Type {
case "MOVING", "MIGRATING", "MIGRATED":
pushNotificationLogs = append(pushNotificationLogs, log)
}
}
return pushNotificationLogs
}
func (tnh *TrackingNotificationsHook) GetAnalysis() *DiagnosticsAnalysis {
return NewDiagnosticsAnalysis(tnh.GetDiagnosticsLog())
}
func (tnh *TrackingNotificationsHook) GetDiagnosticsLogForConn(connID uint64) []DiagnosticsEvent {
tnh.mutex.RLock()
defer tnh.mutex.RUnlock()
var connLogs []DiagnosticsEvent
for _, log := range tnh.diagnosticsLog {
if log.ConnID == connID {
connLogs = append(connLogs, log)
}
}
return connLogs
}
func (tnh *TrackingNotificationsHook) GetAnalysisForConn(connID uint64) *DiagnosticsAnalysis {
return NewDiagnosticsAnalysis(tnh.GetDiagnosticsLogForConn(connID))
}
type DiagnosticsAnalysis struct {
RelaxedTimeoutCount int64
UnrelaxedTimeoutCount int64
NotificationProcessingErrors int64
ConnectionCount int64
MovingCount int64
MigratingCount int64
MigratedCount int64
FailingOverCount int64
FailedOverCount int64
UnexpectedNotificationCount int64
TotalNotifications int64
diagnosticsLog []DiagnosticsEvent
connLogs map[uint64][]DiagnosticsEvent
connIds map[uint64]bool
}
func NewDiagnosticsAnalysis(diagnosticsLog []DiagnosticsEvent) *DiagnosticsAnalysis {
da := &DiagnosticsAnalysis{
diagnosticsLog: diagnosticsLog,
connLogs: make(map[uint64][]DiagnosticsEvent),
connIds: make(map[uint64]bool),
}
da.Analyze()
return da
}
func (da *DiagnosticsAnalysis) Analyze() {
for _, log := range da.diagnosticsLog {
da.TotalNotifications++
switch log.Type {
case "MOVING":
da.MovingCount++
case "MIGRATING":
da.MigratingCount++
case "MIGRATED":
da.MigratedCount++
case "FAILING_OVER":
da.FailingOverCount++
case "FAILED_OVER":
da.FailedOverCount++
default:
da.UnexpectedNotificationCount++
}
if log.Error != nil {
fmt.Printf("[ERROR] Notification processing error: %v\n", log.Error)
fmt.Printf("[ERROR] Notification: %v\n", log.Details["notification"])
fmt.Printf("[ERROR] Context: %v\n", log.Details["context"])
da.NotificationProcessingErrors++
}
if log.Type == "MIGRATING" || log.Type == "FAILING_OVER" {
da.RelaxedTimeoutCount++
} else if log.Type == "MIGRATED" || log.Type == "FAILED_OVER" {
da.UnrelaxedTimeoutCount++
}
if log.ConnID != 0 {
if v, ok := da.connIds[log.ConnID]; !ok || !v {
da.connIds[log.ConnID] = true
da.connLogs[log.ConnID] = make([]DiagnosticsEvent, 0)
da.ConnectionCount++
}
da.connLogs[log.ConnID] = append(da.connLogs[log.ConnID], log)
}
}
}
func (a *DiagnosticsAnalysis) Print(t *testing.T) {
t.Logf("Notification Analysis results for %d events and %d connections:", len(a.diagnosticsLog), len(a.connIds))
t.Logf("-------------")
t.Logf("-Timeout Analysis based on type of notification-")
t.Logf("Note: MIGRATED and FAILED_OVER notifications are not tracked by the hook, so they are not included in the relaxed/unrelaxed count")
t.Logf("Note: The hook only tracks timeouts that occur after the notification is processed, so timeouts that occur during processing are not included")
t.Logf("-------------")
t.Logf(" - Relaxed Timeout Count: %d", a.RelaxedTimeoutCount)
t.Logf(" - Unrelaxed Timeout Count: %d", a.UnrelaxedTimeoutCount)
t.Logf("-------------")
t.Logf("-Notification Analysis-")
t.Logf("-------------")
t.Logf(" - MOVING: %d", a.MovingCount)
t.Logf(" - MIGRATING: %d", a.MigratingCount)
t.Logf(" - MIGRATED: %d", a.MigratedCount)
t.Logf(" - FAILING_OVER: %d", a.FailingOverCount)
t.Logf(" - FAILED_OVER: %d", a.FailedOverCount)
t.Logf(" - Unexpected: %d", a.UnexpectedNotificationCount)
t.Logf("-------------")
t.Logf(" - Total Notifications: %d", a.TotalNotifications)
t.Logf(" - Notification Processing Errors: %d", a.NotificationProcessingErrors)
t.Logf(" - Connection Count: %d", a.ConnectionCount)
t.Logf("-------------")
t.Logf("Diagnostics Analysis completed successfully")
}

View File

@@ -0,0 +1,374 @@
package e2e
import (
"context"
"fmt"
"net"
"os"
"strings"
"testing"
"time"
"github.com/redis/go-redis/v9/internal"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestEndpointTypesPushNotifications tests push notifications with different endpoint types
func TestEndpointTypesPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
defer cancel()
var dump = true
var errorsDetected = false
// Test different endpoint types
endpointTypes := []struct {
name string
endpointType maintnotifications.EndpointType
description string
}{
{
name: "ExternalIP",
endpointType: maintnotifications.EndpointTypeExternalIP,
description: "External IP endpoint type for enterprise clusters",
},
{
name: "ExternalFQDN",
endpointType: maintnotifications.EndpointTypeExternalFQDN,
description: "External FQDN endpoint type for DNS-based routing",
},
{
name: "None",
endpointType: maintnotifications.EndpointTypeNone,
description: "No endpoint type - reconnect with current config",
},
}
defer func() {
logCollector.Clear()
}()
// Test each endpoint type with its own fresh database
for _, endpointTest := range endpointTypes {
t.Run(endpointTest.name, func(t *testing.T) {
// Setup: Create fresh database and client factory for THIS endpoint type test
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
defer cleanup()
t.Logf("[ENDPOINT-TYPES-%s] Created test database with bdb_id: %d", endpointTest.name, bdbID)
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("[ERROR] Failed to create fault injector: %v", err)
}
// Get endpoint config from factory (now connected to new database)
endpointConfig := factory.GetConfig()
defer func() {
if dump {
fmt.Println("Pool stats:")
factory.PrintPoolStats(t)
}
}()
// Clear logs between endpoint type tests
logCollector.Clear()
// reset errors detected flag
errorsDetected = false
// reset dump flag
dump = true
// redefine p and e for each test to get
// proper test name in logs and proper test failures
var p = func(format string, args ...interface{}) {
printLog("ENDPOINT-TYPES", false, format, args...)
}
var e = func(format string, args ...interface{}) {
errorsDetected = true
printLog("ENDPOINT-TYPES", true, format, args...)
}
var ef = func(format string, args ...interface{}) {
printLog("ENDPOINT-TYPES", true, format, args...)
t.FailNow()
}
p("Testing endpoint type: %s - %s", endpointTest.name, endpointTest.description)
minIdleConns := 3
poolSize := 8
maxConnections := 12
// Create Redis client with specific endpoint type
client, err := factory.Create(fmt.Sprintf("endpoint-test-%s", endpointTest.name), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 30 * time.Second,
RelaxedTimeout: 8 * time.Second,
PostHandoffRelaxedDuration: 2 * time.Second,
MaxWorkers: 15,
EndpointType: endpointTest.endpointType, // Test specific endpoint type
},
ClientName: fmt.Sprintf("endpoint-test-%s", endpointTest.name),
})
if err != nil {
ef("Failed to create client for %s: %v", endpointTest.name, err)
}
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
ef("Failed to ping Redis with %s endpoint type: %v", endpointTest.name, err)
}
p("Client connected successfully with %s endpoint type", endpointTest.name)
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
stats := commandsRunner.GetStats()
p("%s endpoint stats: Operations: %d, Errors: %d, Timeout Errors: %d",
endpointTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
}
commandsRunner.Stop()
}()
// Test failover with this endpoint type
p("Testing failover with %s endpoint type on database [bdb_id:%s]...", endpointTest.name, endpointConfig.BdbID)
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger failover action for %s: %v", endpointTest.name, err)
}
// Start command traffic
go func() {
commandsRunner.FireCommandsUntilStop(ctx)
}()
// Wait for failover to complete
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Failover action failed for %s: %v", endpointTest.name, err)
}
p("[FI] Failover action completed for %s: %s %s", endpointTest.name, status.Status, actionOutputIfFailed(status))
// Wait for FAILING_OVER notification
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 3*time.Minute)
if !found {
ef("FAILING_OVER notification was not received for %s endpoint type", endpointTest.name)
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received for %s. %v", endpointTest.name, failingOverData)
// Wait for FAILED_OVER notification
seqIDToObserve := int64(failingOverData["seqID"].(float64))
connIDToObserve := uint64(failingOverData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 3*time.Minute)
if !found {
ef("FAILED_OVER notification was not received for %s endpoint type", endpointTest.name)
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received for %s. %v", endpointTest.name, failedOverData)
// Test migration with this endpoint type
p("Testing migration with %s endpoint type...", endpointTest.name)
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger migrate action for %s: %v", endpointTest.name, err)
}
// Wait for migration to complete
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Migrate action failed for %s: %v", endpointTest.name, err)
}
p("[FI] Migrate action completed for %s: %s %s", endpointTest.name, status.Status, actionOutputIfFailed(status))
// Wait for MIGRATING notification
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 60*time.Second)
if !found {
ef("MIGRATING notification was not received for %s endpoint type", endpointTest.name)
}
migrateData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATING notification received for %s: %v", endpointTest.name, migrateData)
// Wait for MIGRATED notification
seqIDToObserve = int64(migrateData["seqID"].(float64))
connIDToObserve = uint64(migrateData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "MIGRATED") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 3*time.Minute)
if !found {
ef("MIGRATED notification was not received for %s endpoint type", endpointTest.name)
}
migratedData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATED notification received for %s. %v", endpointTest.name, migratedData)
// Complete migration with bind action
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "bind",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger bind action for %s: %v", endpointTest.name, err)
}
// Wait for MOVING notification
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
}, 3*time.Minute)
if !found {
ef("MOVING notification was not received for %s endpoint type", endpointTest.name)
}
movingData := logs2.ExtractDataFromLogMessage(match)
p("MOVING notification received for %s. %v", endpointTest.name, movingData)
notification, ok := movingData["notification"].(string)
if !ok {
e("invalid notification message")
}
notification = notification[:len(notification)-1]
notificationParts := strings.Split(notification, " ")
address := notificationParts[len(notificationParts)-1]
switch endpointTest.endpointType {
case maintnotifications.EndpointTypeExternalFQDN:
address = strings.Split(address, ":")[0]
addressParts := strings.SplitN(address, ".", 2)
if len(addressParts) != 2 {
e("invalid address %s", address)
} else {
address = addressParts[1]
}
var expectedAddress string
hostParts := strings.SplitN(endpointConfig.Host, ".", 2)
if len(hostParts) != 2 {
e("invalid host %s", endpointConfig.Host)
} else {
expectedAddress = hostParts[1]
}
if address != expectedAddress {
e("invalid fqdn, expected: %s, got: %s", expectedAddress, address)
}
case maintnotifications.EndpointTypeExternalIP:
address = strings.Split(address, ":")[0]
ip := net.ParseIP(address)
if ip == nil {
e("invalid message format, expected valid IP, got: %s", address)
}
case maintnotifications.EndpointTypeNone:
if address != internal.RedisNull {
e("invalid endpoint type, expected: %s, got: %s", internal.RedisNull, address)
}
}
// Wait for bind to complete
bindStatus, err := faultInjector.WaitForAction(ctx, bindResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second))
if err != nil {
ef("Bind action failed for %s: %v", endpointTest.name, err)
}
p("Bind action completed for %s: %s %s", endpointTest.name, bindStatus.Status, actionOutputIfFailed(bindStatus))
// Continue traffic for analysis
time.Sleep(30 * time.Second)
commandsRunner.Stop()
// Analyze results for this endpoint type
trackerAnalysis := tracker.GetAnalysis()
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors with %s endpoint type: %d", endpointTest.name, trackerAnalysis.NotificationProcessingErrors)
}
if trackerAnalysis.UnexpectedNotificationCount > 0 {
e("Unexpected notifications with %s endpoint type: %d", endpointTest.name, trackerAnalysis.UnexpectedNotificationCount)
}
// Validate we received all expected notification types
if trackerAnalysis.FailingOverCount == 0 {
e("Expected FAILING_OVER notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.FailedOverCount == 0 {
e("Expected FAILED_OVER notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.MigratingCount == 0 {
e("Expected MIGRATING notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.MigratedCount == 0 {
e("Expected MIGRATED notifications with %s endpoint type, got none", endpointTest.name)
}
if trackerAnalysis.MovingCount == 0 {
e("Expected MOVING notifications with %s endpoint type, got none", endpointTest.name)
}
logAnalysis := logCollector.GetAnalysis()
if logAnalysis.TotalHandoffCount == 0 {
e("Expected at least one handoff with %s endpoint type, got none", endpointTest.name)
}
if logAnalysis.TotalHandoffCount != logAnalysis.SucceededHandoffCount {
e("Expected all handoffs to succeed with %s endpoint type, got %d failed", endpointTest.name, logAnalysis.FailedHandoffCount)
}
if errorsDetected {
logCollector.DumpLogs()
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
ef("[FAIL] Errors detected with %s endpoint type", endpointTest.name)
}
p("Endpoint type %s test completed successfully", endpointTest.name)
logCollector.GetAnalysis().Print(t)
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
})
}
t.Log("All endpoint types tested successfully")
}

View File

@@ -0,0 +1,513 @@
package e2e
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestPushNotifications tests Redis Enterprise push notifications (MOVING, MIGRATING, MIGRATED)
func TestPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
// Setup: Create fresh database and client factory for this test
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
defer cleanup()
t.Logf("[PUSH-NOTIFICATIONS] Created test database with bdb_id: %d", bdbID)
// Wait for database to be fully ready
time.Sleep(10 * time.Second)
var dump = true
var seqIDToObserve int64
var connIDToObserve uint64
var match string
var found bool
var status *ActionStatusResponse
var errorsDetected = false
var p = func(format string, args ...interface{}) {
printLog("PUSH-NOTIFICATIONS", false, format, args...)
}
var e = func(format string, args ...interface{}) {
errorsDetected = true
printLog("PUSH-NOTIFICATIONS", true, format, args...)
}
var ef = func(format string, args ...interface{}) {
printLog("PUSH-NOTIFICATIONS", true, format, args...)
t.FailNow()
}
logCollector.ClearLogs()
defer func() {
logCollector.Clear()
}()
// Get endpoint config from factory (now connected to new database)
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
ef("Failed to create fault injector: %v", err)
}
minIdleConns := 5
poolSize := 10
maxConnections := 15
// Create Redis client with push notifications enabled
client, err := factory.Create("push-notification-client", &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 40 * time.Second, // 30 seconds
RelaxedTimeout: 10 * time.Second, // 10 seconds relaxed timeout
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
},
ClientName: "push-notification-test-client",
})
if err != nil {
ef("Failed to create client: %v", err)
}
defer func() {
factory.DestroyAll()
}()
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
ef("Failed to ping Redis: %v", err)
}
p("Client connected successfully, starting push notification test")
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
p("Command runner stats:")
p("Operations: %d, Errors: %d, Timeout Errors: %d",
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
}
p("Stopping command runner...")
commandsRunner.Stop()
}()
p("Starting FAILING_OVER / FAILED_OVER notifications test...")
// Test: Trigger failover action to generate FAILING_OVER, FAILED_OVER notifications
p("Triggering failover action to generate push notifications...")
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger failover action: %v", err)
}
go func() {
p("Waiting for FAILING_OVER notification")
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 3*time.Minute)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
ef("FAILING_OVER notification was not received within 3 minutes")
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received. %v", failingOverData)
seqIDToObserve = int64(failingOverData["seqID"].(float64))
connIDToObserve = uint64(failingOverData["connID"].(float64))
go func() {
p("Waiting for FAILED_OVER notification on conn %d with seqID %d...", connIDToObserve, seqIDToObserve+1)
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 3*time.Minute)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
ef("FAILED_OVER notification was not received within 3 minutes")
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received. %v", failedOverData)
status, err = faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Failover action failed: %v", err)
}
p("[FI] Failover action completed: %v %s", status.Status, actionOutputIfFailed(status))
p("FAILING_OVER / FAILED_OVER notifications test completed successfully")
// Test: Trigger migrate action to generate MOVING, MIGRATING, MIGRATED notifications
p("Triggering migrate action to generate push notifications...")
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger migrate action: %v", err)
}
go func() {
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 60*time.Second)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Migrate action failed: %v", err)
}
p("[FI] Migrate action completed: %s %s", status.Status, actionOutputIfFailed(status))
ef("MIGRATING notification for migrate action was not received within 60 seconds")
}
migrateData := logs2.ExtractDataFromLogMessage(match)
seqIDToObserve = int64(migrateData["seqID"].(float64))
connIDToObserve = uint64(migrateData["connID"].(float64))
p("MIGRATING notification received: seqID: %d, connID: %d", seqIDToObserve, connIDToObserve)
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Migrate action failed: %v", err)
}
p("[FI] Migrate action completed: %s %s", status.Status, actionOutputIfFailed(status))
go func() {
p("Waiting for MIGRATED notification on conn %d with seqID %d...", connIDToObserve, seqIDToObserve+1)
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "MIGRATED") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 3*time.Minute)
commandsRunner.Stop()
}()
commandsRunner.FireCommandsUntilStop(ctx)
if !found {
ef("MIGRATED notification was not received within 3 minutes")
}
migratedData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATED notification received. %v", migratedData)
p("MIGRATING / MIGRATED notifications test completed successfully")
// Trigger bind action to complete the migration process
p("Triggering bind action to complete migration...")
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "bind",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger bind action: %v", err)
}
// start a second client but don't execute any commands on it
p("Starting a second client to observe notification during moving...")
client2, err := factory.Create("push-notification-client-2", &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 40 * time.Second, // 30 seconds
RelaxedTimeout: 30 * time.Minute, // 30 minutes relaxed timeout for second client
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
},
ClientName: "push-notification-test-client-2",
})
if err != nil {
ef("failed to create client: %v", err)
}
// setup tracking for second client
tracker2 := NewTrackingNotificationsHook()
logger2 := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client2, tracker2, logger2)
commandsRunner2, _ := NewCommandRunner(client2)
p("Second client created")
// Use a channel to communicate errors from the goroutine
errChan := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("goroutine panic: %v", r)
}
}()
p("Waiting for MOVING notification on first client")
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
}, 3*time.Minute)
commandsRunner.Stop()
if !found {
errChan <- fmt.Errorf("MOVING notification was not received within 3 minutes ON A FIRST CLIENT")
return
}
// once moving is received, start a second client commands runner
p("Starting commands on second client")
go commandsRunner2.FireCommandsUntilStop(ctx)
defer func() {
// stop the second runner
commandsRunner2.Stop()
// destroy the second client
factory.Destroy("push-notification-client-2")
}()
p("Waiting for MOVING notification on second client")
matchNotif, fnd := tracker2.FindOrWaitForNotification("MOVING", 3*time.Minute)
if !fnd {
errChan <- fmt.Errorf("MOVING notification was not received within 3 minutes ON A SECOND CLIENT")
return
} else {
p("MOVING notification received on second client %v", matchNotif)
}
// Signal success
errChan <- nil
}()
commandsRunner.FireCommandsUntilStop(ctx)
// wait for moving on first client
// once the commandRunner stops, it means a waiting
// on the logCollector match has completed and we can proceed
if !found {
ef("MOVING notification was not received within 3 minutes")
}
movingData := logs2.ExtractDataFromLogMessage(match)
p("MOVING notification received. %v", movingData)
seqIDToObserve = int64(movingData["seqID"].(float64))
connIDToObserve = uint64(movingData["connID"].(float64))
time.Sleep(3 * time.Second)
// start a third client but don't execute any commands on it
p("Starting a third client to observe notification during moving...")
client3, err := factory.Create("push-notification-client-2", &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 40 * time.Second, // 30 seconds
RelaxedTimeout: 30 * time.Minute, // 30 minutes relaxed timeout for second client
PostHandoffRelaxedDuration: 2 * time.Second, // 2 seconds post-handoff relaxed duration
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP, // Use external IP for enterprise
},
ClientName: "push-notification-test-client-3",
})
if err != nil {
ef("failed to create client: %v", err)
}
// setup tracking for second client
tracker3 := NewTrackingNotificationsHook()
logger3 := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client3, tracker3, logger3)
commandsRunner3, _ := NewCommandRunner(client3)
p("Third client created")
go commandsRunner3.FireCommandsUntilStop(ctx)
// wait for moving on third client
movingNotification, found := tracker3.FindOrWaitForNotification("MOVING", 3*time.Minute)
if !found {
p("[NOTICE] MOVING notification was not received within 3 minutes ON A THIRD CLIENT")
} else {
p("MOVING notification received on third client. %v", movingNotification)
if len(movingNotification) != 4 {
p("[NOTICE] Invalid MOVING notification format: %s", movingNotification)
}
mNotifTimeS, ok := movingNotification[2].(int64)
if !ok {
p("[NOTICE] Invalid timeS in MOVING notification: %s", movingNotification)
}
// expect timeS to be less than 15
if mNotifTimeS < 15 {
p("[NOTICE] Expected timeS < 15, got %d", mNotifTimeS)
}
}
commandsRunner3.Stop()
// Wait for the goroutine to complete and check for errors
if err := <-errChan; err != nil {
ef("Second client goroutine error: %v", err)
}
// Wait for bind action to complete
bindStatus, err := faultInjector.WaitForAction(ctx, bindResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second))
if err != nil {
ef("Bind action failed: %v", err)
}
p("Bind action completed: %s %s", bindStatus.Status, actionOutputIfFailed(bindStatus))
p("MOVING notification test completed successfully")
p("Executing commands and collecting logs for analysis... This will take 30 seconds...")
go commandsRunner.FireCommandsUntilStop(ctx)
time.Sleep(30 * time.Second)
commandsRunner.Stop()
allLogsAnalysis := logCollector.GetAnalysis()
trackerAnalysis := tracker.GetAnalysis()
if allLogsAnalysis.TimeoutErrorsCount > 0 {
e("Unexpected timeout errors: %d", allLogsAnalysis.TimeoutErrorsCount)
}
if trackerAnalysis.UnexpectedNotificationCount > 0 {
e("Unexpected notifications: %d", trackerAnalysis.UnexpectedNotificationCount)
}
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors: %d", trackerAnalysis.NotificationProcessingErrors)
}
if allLogsAnalysis.RelaxedTimeoutCount == 0 {
e("Expected relaxed timeouts, got none")
}
if allLogsAnalysis.UnrelaxedTimeoutCount == 0 {
e("Expected unrelaxed timeouts, got none")
}
if allLogsAnalysis.UnrelaxedAfterMoving == 0 {
e("Expected unrelaxed timeouts after moving, got none")
}
if allLogsAnalysis.RelaxedPostHandoffCount == 0 {
e("Expected relaxed timeouts after post-handoff, got none")
}
// validate number of connections we do not exceed max connections
// we started three clients, so we expect 3x the connections
if allLogsAnalysis.ConnectionCount > int64(maxConnections)*3 {
e("Expected no more than %d connections, got %d", maxConnections*3, allLogsAnalysis.ConnectionCount)
}
if allLogsAnalysis.ConnectionCount < int64(minIdleConns) {
e("Expected at least %d connections, got %d", minIdleConns, allLogsAnalysis.ConnectionCount)
}
// validate logs are present for all connections
for connID := range trackerAnalysis.connIds {
if len(allLogsAnalysis.connLogs[connID]) == 0 {
e("No logs found for connection %d", connID)
}
}
// validate number of notifications in tracker matches number of notifications in logs
// allow for more moving in the logs since we started a second client
if trackerAnalysis.TotalNotifications > allLogsAnalysis.TotalNotifications {
e("Expected %d or more notifications, got %d", trackerAnalysis.TotalNotifications, allLogsAnalysis.TotalNotifications)
}
// and per type
// allow for more moving in the logs since we started a second client
if trackerAnalysis.MovingCount > allLogsAnalysis.MovingCount {
e("Expected %d or more MOVING notifications, got %d", trackerAnalysis.MovingCount, allLogsAnalysis.MovingCount)
}
if trackerAnalysis.MigratingCount != allLogsAnalysis.MigratingCount {
e("Expected %d MIGRATING notifications, got %d", trackerAnalysis.MigratingCount, allLogsAnalysis.MigratingCount)
}
if trackerAnalysis.MigratedCount != allLogsAnalysis.MigratedCount {
e("Expected %d MIGRATED notifications, got %d", trackerAnalysis.MigratedCount, allLogsAnalysis.MigratedCount)
}
if trackerAnalysis.FailingOverCount != allLogsAnalysis.FailingOverCount {
e("Expected %d FAILING_OVER notifications, got %d", trackerAnalysis.FailingOverCount, allLogsAnalysis.FailingOverCount)
}
if trackerAnalysis.FailedOverCount != allLogsAnalysis.FailedOverCount {
e("Expected %d FAILED_OVER notifications, got %d", trackerAnalysis.FailedOverCount, allLogsAnalysis.FailedOverCount)
}
if trackerAnalysis.UnexpectedNotificationCount != allLogsAnalysis.UnexpectedCount {
e("Expected %d unexpected notifications, got %d", trackerAnalysis.UnexpectedNotificationCount, allLogsAnalysis.UnexpectedCount)
}
// unrelaxed (and relaxed) after moving wont be tracked by the hook, so we have to exclude it
if trackerAnalysis.UnrelaxedTimeoutCount != allLogsAnalysis.UnrelaxedTimeoutCount-allLogsAnalysis.UnrelaxedAfterMoving {
e("Expected %d unrelaxed timeouts, got %d", trackerAnalysis.UnrelaxedTimeoutCount, allLogsAnalysis.UnrelaxedTimeoutCount)
}
if trackerAnalysis.RelaxedTimeoutCount != allLogsAnalysis.RelaxedTimeoutCount-allLogsAnalysis.RelaxedPostHandoffCount {
e("Expected %d relaxed timeouts, got %d", trackerAnalysis.RelaxedTimeoutCount, allLogsAnalysis.RelaxedTimeoutCount)
}
// validate all handoffs succeeded
if allLogsAnalysis.FailedHandoffCount > 0 {
e("Expected no failed handoffs, got %d", allLogsAnalysis.FailedHandoffCount)
}
if allLogsAnalysis.SucceededHandoffCount == 0 {
e("Expected at least one successful handoff, got none")
}
if allLogsAnalysis.TotalHandoffCount != allLogsAnalysis.SucceededHandoffCount {
e("Expected total handoffs to match successful handoffs, got %d != %d", allLogsAnalysis.TotalHandoffCount, allLogsAnalysis.SucceededHandoffCount)
}
// no additional retries
if allLogsAnalysis.TotalHandoffRetries != allLogsAnalysis.TotalHandoffCount {
e("Expected no additional handoff retries, got %d", allLogsAnalysis.TotalHandoffRetries-allLogsAnalysis.TotalHandoffCount)
}
if errorsDetected {
logCollector.DumpLogs()
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
ef("[FAIL] Errors detected in push notification test")
}
p("Analysis complete, no errors found")
allLogsAnalysis.Print(t)
trackerAnalysis.Print(t)
p("Command runner stats:")
p("Operations: %d, Errors: %d, Timeout Errors: %d",
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
p("Push notification test completed successfully")
}

View File

@@ -0,0 +1,311 @@
package e2e
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestStressPushNotifications tests push notifications under extreme stress conditions
func TestStressPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("[STRESS][SKIP] Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 35*time.Minute)
defer cancel()
// Setup: Create fresh database and client factory for this test
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
defer cleanup()
t.Logf("[STRESS] Created test database with bdb_id: %d", bdbID)
// Wait for database to be fully ready
time.Sleep(10 * time.Second)
var dump = true
var errorsDetected = false
var p = func(format string, args ...interface{}) {
printLog("STRESS", false, format, args...)
}
var e = func(format string, args ...interface{}) {
errorsDetected = true
printLog("STRESS", true, format, args...)
}
var ef = func(format string, args ...interface{}) {
printLog("STRESS", true, format, args...)
t.FailNow()
}
logCollector.ClearLogs()
defer func() {
logCollector.Clear()
}()
// Get endpoint config from factory (now connected to new database)
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
ef("Failed to create fault injector: %v", err)
}
// Extreme stress configuration
minIdleConns := 50
poolSize := 150
maxConnections := 200
numClients := 4
var clients []redis.UniversalClient
var trackers []*TrackingNotificationsHook
var commandRunners []*CommandRunner
// Create multiple clients for extreme stress
for i := 0; i < numClients; i++ {
client, err := factory.Create(fmt.Sprintf("stress-client-%d", i), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 60 * time.Second, // Longer timeout for stress
RelaxedTimeout: 20 * time.Second, // Longer relaxed timeout
PostHandoffRelaxedDuration: 5 * time.Second, // Longer post-handoff duration
MaxWorkers: 50, // Maximum workers for stress
HandoffQueueSize: 1000, // Large queue for stress
EndpointType: maintnotifications.EndpointTypeExternalIP,
},
ClientName: fmt.Sprintf("stress-test-client-%d", i),
})
if err != nil {
ef("Failed to create stress client %d: %v", i, err)
}
clients = append(clients, client)
// Setup tracking for each client
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelWarn)) // Minimal logging for stress
setupNotificationHooks(client, tracker, logger)
trackers = append(trackers, tracker)
// Create command runner for each client
commandRunner, _ := NewCommandRunner(client)
commandRunners = append(commandRunners, commandRunner)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
for _, runner := range commandRunners {
runner.Stop()
}
factory.DestroyAll()
}()
// Verify initial connectivity for all clients
for i, client := range clients {
err = client.Ping(ctx).Err()
if err != nil {
ef("Failed to ping Redis with stress client %d: %v", i, err)
}
}
p("All %d stress clients connected successfully", numClients)
// Start extreme traffic load on all clients
var trafficWg sync.WaitGroup
for i, runner := range commandRunners {
trafficWg.Add(1)
go func(clientID int, r *CommandRunner) {
defer trafficWg.Done()
p("Starting extreme traffic load on stress client %d", clientID)
r.FireCommandsUntilStop(ctx)
}(i, runner)
}
// Wait for traffic to stabilize
time.Sleep(10 * time.Second)
// Trigger multiple concurrent fault injection actions
var actionWg sync.WaitGroup
var actionResults []string
var actionMutex sync.Mutex
actions := []struct {
name string
action string
delay time.Duration
}{
{"failover-1", "failover", 0},
{"migrate-1", "migrate", 5 * time.Second},
{"failover-2", "failover", 10 * time.Second},
}
p("Starting %d concurrent fault injection actions under extreme stress...", len(actions))
for _, action := range actions {
actionWg.Add(1)
go func(actionName, actionType string, delay time.Duration) {
defer actionWg.Done()
if delay > 0 {
time.Sleep(delay)
}
p("Triggering %s action under extreme stress...", actionName)
var resp *ActionResponse
var err error
switch actionType {
case "failover":
resp, err = faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
case "migrate":
resp, err = faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
}
if err != nil {
e("Failed to trigger %s action: %v", actionName, err)
return
}
// Wait for action to complete
status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
WithMaxWaitTime(360*time.Second), // Longer wait time for stress
WithPollInterval(2*time.Second),
)
if err != nil {
e("[FI] %s action failed: %v", actionName, err)
return
}
actionMutex.Lock()
actionResults = append(actionResults, fmt.Sprintf("%s: %s %s", actionName, status.Status, actionOutputIfFailed(status)))
actionMutex.Unlock()
p("[FI] %s action completed: %s %s", actionName, status.Status, actionOutputIfFailed(status))
}(action.name, action.action, action.delay)
}
// Wait for all actions to complete
actionWg.Wait()
// Continue stress for a bit longer
p("All fault injection actions completed, continuing stress for 2 more minutes...")
time.Sleep(2 * time.Minute)
// Stop all command runners
for _, runner := range commandRunners {
runner.Stop()
}
trafficWg.Wait()
// Analyze stress test results
allLogsAnalysis := logCollector.GetAnalysis()
totalOperations := int64(0)
totalErrors := int64(0)
totalTimeoutErrors := int64(0)
for i, runner := range commandRunners {
stats := runner.GetStats()
p("Stress client %d stats: Operations: %d, Errors: %d, Timeout Errors: %d",
i, stats.Operations, stats.Errors, stats.TimeoutErrors)
totalOperations += stats.Operations
totalErrors += stats.Errors
totalTimeoutErrors += stats.TimeoutErrors
}
p("STRESS TEST RESULTS:")
p("Total operations across all clients: %d", totalOperations)
p("Total errors: %d (%.2f%%)", totalErrors, float64(totalErrors)/float64(totalOperations)*100)
p("Total timeout errors: %d (%.2f%%)", totalTimeoutErrors, float64(totalTimeoutErrors)/float64(totalOperations)*100)
p("Total connections used: %d", allLogsAnalysis.ConnectionCount)
// Print action results
actionMutex.Lock()
p("Fault injection action results:")
for _, result := range actionResults {
p(" %s", result)
}
actionMutex.Unlock()
// Validate stress test results
if totalOperations < 1000 {
e("Expected at least 1000 operations under stress, got %d", totalOperations)
}
// Allow higher error rates under extreme stress (up to 20%)
errorRate := float64(totalErrors) / float64(totalOperations) * 100
if errorRate > 20.0 {
e("Error rate too high under stress: %.2f%% (max allowed: 20%%)", errorRate)
}
// Validate connection limits weren't exceeded
expectedMaxConnections := int64(numClients * maxConnections)
if allLogsAnalysis.ConnectionCount > expectedMaxConnections {
e("Connection count exceeded limit: %d > %d", allLogsAnalysis.ConnectionCount, expectedMaxConnections)
}
// Validate notifications were processed
totalTrackerNotifications := int64(0)
totalProcessingErrors := int64(0)
for _, tracker := range trackers {
analysis := tracker.GetAnalysis()
totalTrackerNotifications += analysis.TotalNotifications
totalProcessingErrors += analysis.NotificationProcessingErrors
}
if totalProcessingErrors > totalTrackerNotifications/10 { // Allow up to 10% processing errors under stress
e("Too many notification processing errors under stress: %d/%d", totalProcessingErrors, totalTrackerNotifications)
}
if errorsDetected {
ef("Errors detected under stress")
logCollector.DumpLogs()
for i, tracker := range trackers {
p("=== Stress Client %d Analysis ===", i)
tracker.GetAnalysis().Print(t)
}
logCollector.Clear()
for _, tracker := range trackers {
tracker.Clear()
}
}
dump = false
p("[SUCCESS] Stress test completed successfully!")
p("Processed %d operations across %d clients with %d connections",
totalOperations, numClients, allLogsAnalysis.ConnectionCount)
p("Error rate: %.2f%%, Notification processing errors: %d/%d",
errorRate, totalProcessingErrors, totalTrackerNotifications)
// Print final analysis
allLogsAnalysis.Print(t)
for i, tracker := range trackers {
p("=== Stress Client %d Analysis ===", i)
tracker.GetAnalysis().Print(t)
}
}

View File

@@ -0,0 +1,245 @@
package e2e
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/hitless"
)
// TestScenarioTemplate is a template for writing scenario tests
// Copy this file and rename it to scenario_your_test_name.go
func TestScenarioTemplate(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Step 1: Create client factory from configuration
factory, err := CreateTestClientFactory("enterprise-cluster") // or "standalone0"
if err != nil {
t.Fatalf("Failed to create client factory: %v", err)
}
defer factory.DestroyAll()
// Step 2: Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("Failed to create fault injector: %v", err)
}
// Step 3: Create Redis client with hitless upgrades
client, err := factory.Create("scenario-client", &CreateClientOptions{
Protocol: 3,
HitlessUpgradeConfig: &hitless.Config{
Mode: hitless.MaintNotificationsEnabled,
HandoffTimeout: 30000, // 30 seconds
RelaxedTimeout: 10000, // 10 seconds
MaxWorkers: 20,
},
ClientName: "scenario-test-client",
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Step 4: Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
t.Fatalf("Failed to ping Redis: %v", err)
}
t.Log("Initial setup completed successfully")
// Step 5: Start background operations (optional)
stopCh := make(chan struct{})
defer close(stopCh)
go func() {
counter := 0
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
key := fmt.Sprintf("test-key-%d", counter)
value := fmt.Sprintf("test-value-%d", counter)
err := client.Set(ctx, key, value, time.Minute).Err()
if err != nil {
t.Logf("Background operation failed: %v", err)
}
counter++
}
}
}()
// Step 6: Wait for baseline operations
time.Sleep(5 * time.Second)
// Step 7: Trigger fault injection scenario
t.Log("Triggering fault injection scenario...")
// Example: Cluster failover
// resp, err := faultInjector.TriggerClusterFailover(ctx, "node-1", false)
// if err != nil {
// t.Fatalf("Failed to trigger failover: %v", err)
// }
// Example: Network latency
// nodes := []string{"localhost:7001", "localhost:7002"}
// resp, err := faultInjector.SimulateNetworkLatency(ctx, nodes, 100*time.Millisecond, 20*time.Millisecond)
// if err != nil {
// t.Fatalf("Failed to simulate latency: %v", err)
// }
// Example: Complex sequence
// sequence := []SequenceAction{
// {
// Type: ActionNetworkLatency,
// Parameters: map[string]interface{}{
// "nodes": []string{"localhost:7001"},
// "latency": "50ms",
// },
// },
// {
// Type: ActionClusterFailover,
// Parameters: map[string]interface{}{
// "node_id": "node-1",
// "force": false,
// },
// Delay: 10 * time.Second,
// },
// }
// resp, err := faultInjector.ExecuteSequence(ctx, sequence)
// if err != nil {
// t.Fatalf("Failed to execute sequence: %v", err)
// }
// Step 8: Wait for fault injection to complete
// status, err := faultInjector.WaitForAction(ctx, resp.ActionID,
// WithMaxWaitTime(240*time.Second),
// WithPollInterval(2*time.Second))
// if err != nil {
// t.Fatalf("Fault injection failed: %v", err)
// }
// t.Logf("Fault injection completed: %s", status.Status)
// Step 9: Verify client remains operational during and after fault injection
time.Sleep(10 * time.Second)
err = client.Ping(ctx).Err()
if err != nil {
t.Errorf("Client not responsive after fault injection: %v", err)
}
// Step 10: Perform additional validation
testKey := "validation-key"
testValue := "validation-value"
err = client.Set(ctx, testKey, testValue, time.Minute).Err()
if err != nil {
t.Errorf("Failed to set validation key: %v", err)
}
retrievedValue, err := client.Get(ctx, testKey).Result()
if err != nil {
t.Errorf("Failed to get validation key: %v", err)
} else if retrievedValue != testValue {
t.Errorf("Validation failed: expected %s, got %s", testValue, retrievedValue)
}
t.Log("Scenario test completed successfully")
}
// Helper functions for common scenario patterns
func performContinuousOperations(ctx context.Context, client redis.UniversalClient, workerID int, stopCh <-chan struct{}, errorCh chan<- error) {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
counter := 0
for {
select {
case <-stopCh:
return
case <-ticker.C:
key := fmt.Sprintf("worker_%d_key_%d", workerID, counter)
value := fmt.Sprintf("value_%d", counter)
// Perform operation with timeout
opCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := client.Set(opCtx, key, value, time.Minute).Err()
cancel()
if err != nil {
select {
case errorCh <- err:
default:
}
}
counter++
}
}
}
func validateClusterHealth(ctx context.Context, client redis.UniversalClient) error {
// Basic connectivity test
if err := client.Ping(ctx).Err(); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
// Test basic operations
testKey := "health-check-key"
testValue := "health-check-value"
if err := client.Set(ctx, testKey, testValue, time.Minute).Err(); err != nil {
return fmt.Errorf("set operation failed: %w", err)
}
retrievedValue, err := client.Get(ctx, testKey).Result()
if err != nil {
return fmt.Errorf("get operation failed: %w", err)
}
if retrievedValue != testValue {
return fmt.Errorf("value mismatch: expected %s, got %s", testValue, retrievedValue)
}
// Clean up
client.Del(ctx, testKey)
return nil
}
func waitForStableOperations(ctx context.Context, client redis.UniversalClient, duration time.Duration) error {
deadline := time.Now().Add(duration)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if err := validateClusterHealth(ctx, client); err != nil {
return fmt.Errorf("cluster health check failed: %w", err)
}
}
}
return nil
}

View File

@@ -0,0 +1,357 @@
package e2e
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TestTimeoutConfigurationsPushNotifications tests push notifications with different timeout configurations
func TestTimeoutConfigurationsPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
var dump = true
var errorsDetected = false
var p = func(format string, args ...interface{}) {
printLog("TIMEOUT-CONFIGS", false, format, args...)
}
var e = func(format string, args ...interface{}) {
errorsDetected = true
printLog("TIMEOUT-CONFIGS", true, format, args...)
}
// Test different timeout configurations
timeoutConfigs := []struct {
name string
handoffTimeout time.Duration
relaxedTimeout time.Duration
postHandoffRelaxedDuration time.Duration
description string
expectedBehavior string
}{
{
name: "Conservative",
handoffTimeout: 60 * time.Second,
relaxedTimeout: 30 * time.Second,
postHandoffRelaxedDuration: 2 * time.Minute,
description: "Conservative timeouts for stable environments",
expectedBehavior: "Longer timeouts, fewer timeout errors",
},
{
name: "Aggressive",
handoffTimeout: 5 * time.Second,
relaxedTimeout: 3 * time.Second,
postHandoffRelaxedDuration: 1 * time.Second,
description: "Aggressive timeouts for fast failover",
expectedBehavior: "Shorter timeouts, faster recovery",
},
{
name: "HighLatency",
handoffTimeout: 90 * time.Second,
relaxedTimeout: 30 * time.Second,
postHandoffRelaxedDuration: 10 * time.Minute,
description: "High latency environment timeouts",
expectedBehavior: "Very long timeouts for high latency networks",
},
}
logCollector.ClearLogs()
defer func() {
logCollector.Clear()
}()
// Test each timeout configuration with its own fresh database
for _, timeoutTest := range timeoutConfigs {
t.Run(timeoutTest.name, func(t *testing.T) {
// Setup: Create fresh database and client factory for THIS timeout config test
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
defer cleanup()
t.Logf("[TIMEOUT-CONFIGS-%s] Created test database with bdb_id: %d", timeoutTest.name, bdbID)
// Get endpoint config from factory (now connected to new database)
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("[ERROR] Failed to create fault injector: %v", err)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
}()
errorsDetected = false
var ef = func(format string, args ...interface{}) {
printLog("TIMEOUT-CONFIGS", true, format, args...)
t.FailNow()
}
p("Testing timeout configuration: %s - %s", timeoutTest.name, timeoutTest.description)
p("Expected behavior: %s", timeoutTest.expectedBehavior)
p("Handoff timeout: %v, Relaxed timeout: %v, Post-handoff duration: %v",
timeoutTest.handoffTimeout, timeoutTest.relaxedTimeout, timeoutTest.postHandoffRelaxedDuration)
minIdleConns := 4
poolSize := 10
maxConnections := 15
// Create Redis client with specific timeout configuration
client, err := factory.Create(fmt.Sprintf("timeout-test-%s", timeoutTest.name), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: timeoutTest.handoffTimeout,
RelaxedTimeout: timeoutTest.relaxedTimeout,
PostHandoffRelaxedDuration: timeoutTest.postHandoffRelaxedDuration,
MaxWorkers: 20,
EndpointType: maintnotifications.EndpointTypeExternalIP,
},
ClientName: fmt.Sprintf("timeout-test-%s", timeoutTest.name),
})
if err != nil {
ef("Failed to create client for %s: %v", timeoutTest.name, err)
}
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
ef("Failed to ping Redis with %s timeout config: %v", timeoutTest.name, err)
}
p("Client connected successfully with %s timeout configuration", timeoutTest.name)
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
stats := commandsRunner.GetStats()
p("%s timeout config stats: Operations: %d, Errors: %d, Timeout Errors: %d",
timeoutTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
}
commandsRunner.Stop()
}()
// Start command traffic
go func() {
commandsRunner.FireCommandsUntilStop(ctx)
}()
// Record start time for timeout analysis
testStartTime := time.Now()
// Test failover with this timeout configuration
p("Testing failover with %s timeout configuration...", timeoutTest.name)
failoverResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "failover",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger failover action for %s: %v", timeoutTest.name, err)
}
// Wait for FAILING_OVER notification
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "FAILING_OVER")
}, 3*time.Minute)
if !found {
ef("FAILING_OVER notification was not received for %s timeout config", timeoutTest.name)
}
failingOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILING_OVER notification received for %s. %v", timeoutTest.name, failingOverData)
// Wait for FAILED_OVER notification
seqIDToObserve := int64(failingOverData["seqID"].(float64))
connIDToObserve := uint64(failingOverData["connID"].(float64))
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return notificationType(s, "FAILED_OVER") && connID(s, connIDToObserve) && seqID(s, seqIDToObserve+1)
}, 3*time.Minute)
if !found {
ef("FAILED_OVER notification was not received for %s timeout config", timeoutTest.name)
}
failedOverData := logs2.ExtractDataFromLogMessage(match)
p("FAILED_OVER notification received for %s. %v", timeoutTest.name, failedOverData)
// Wait for failover to complete
status, err := faultInjector.WaitForAction(ctx, failoverResp.ActionID,
WithMaxWaitTime(180*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Failover action failed for %s: %v", timeoutTest.name, err)
}
p("[FI] Failover action completed for %s: %s %s", timeoutTest.name, status.Status, actionOutputIfFailed(status))
// Continue traffic to observe timeout behavior
p("Continuing traffic for %v to observe timeout behavior...", timeoutTest.relaxedTimeout*2)
time.Sleep(timeoutTest.relaxedTimeout * 2)
// Test migration to trigger more timeout scenarios
p("Testing migration with %s timeout configuration...", timeoutTest.name)
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger migrate action for %s: %v", timeoutTest.name, err)
}
// Wait for migration to complete
status, err = faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Migrate action failed for %s: %v", timeoutTest.name, err)
}
p("[FI] Migrate action completed for %s: %s %s", timeoutTest.name, status.Status, actionOutputIfFailed(status))
// Wait for MIGRATING notification
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 60*time.Second)
if !found {
ef("MIGRATING notification was not received for %s timeout config", timeoutTest.name)
}
migrateData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATING notification received for %s: %v", timeoutTest.name, migrateData)
// do a bind action
bindResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "bind",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger bind action for %s: %v", timeoutTest.name, err)
}
status, err = faultInjector.WaitForAction(ctx, bindResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Bind action failed for %s: %v", timeoutTest.name, err)
}
p("[FI] Bind action completed for %s: %s %s", timeoutTest.name, status.Status, actionOutputIfFailed(status))
// waiting for moving notification
match, found = logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && notificationType(s, "MOVING")
}, 3*time.Minute)
if !found {
ef("MOVING notification was not received for %s timeout config", timeoutTest.name)
}
movingData := logs2.ExtractDataFromLogMessage(match)
p("MOVING notification received for %s. %v", timeoutTest.name, movingData)
// Continue traffic for post-handoff timeout observation
p("Continuing traffic for %v to observe post-handoff timeout behavior...", 1*time.Minute)
time.Sleep(1 * time.Minute)
commandsRunner.Stop()
testDuration := time.Since(testStartTime)
// Analyze timeout behavior
trackerAnalysis := tracker.GetAnalysis()
logAnalysis := logCollector.GetAnalysis()
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors with %s timeout config: %d", timeoutTest.name, trackerAnalysis.NotificationProcessingErrors)
}
// Validate timeout-specific behavior
switch timeoutTest.name {
case "Conservative":
if trackerAnalysis.UnrelaxedTimeoutCount > trackerAnalysis.RelaxedTimeoutCount {
e("Conservative config should have more relaxed than unrelaxed timeouts, got relaxed=%d, unrelaxed=%d",
trackerAnalysis.RelaxedTimeoutCount, trackerAnalysis.UnrelaxedTimeoutCount)
}
case "Aggressive":
// Aggressive timeouts should complete faster
if testDuration > 5*time.Minute {
e("Aggressive config took too long: %v", testDuration)
}
if logAnalysis.TotalHandoffRetries > logAnalysis.TotalHandoffCount {
e("Expect handoff retries since aggressive timeouts are shorter, got %d retries for %d handoffs",
logAnalysis.TotalHandoffRetries, logAnalysis.TotalHandoffCount)
}
case "HighLatency":
// High latency config should have very few unrelaxed after moving
if logAnalysis.UnrelaxedAfterMoving > 2 {
e("High latency config should have minimal unrelaxed timeouts after moving, got %d", logAnalysis.UnrelaxedAfterMoving)
}
}
// Validate we received expected notifications
if trackerAnalysis.FailingOverCount == 0 {
e("Expected FAILING_OVER notifications with %s timeout config, got none", timeoutTest.name)
}
if trackerAnalysis.FailedOverCount == 0 {
e("Expected FAILED_OVER notifications with %s timeout config, got none", timeoutTest.name)
}
if trackerAnalysis.MigratingCount == 0 {
e("Expected MIGRATING notifications with %s timeout config, got none", timeoutTest.name)
}
// Validate timeout counts are reasonable
if trackerAnalysis.RelaxedTimeoutCount == 0 {
e("Expected relaxed timeouts with %s config, got none", timeoutTest.name)
}
if logAnalysis.SucceededHandoffCount == 0 {
e("Expected successful handoffs with %s config, got none", timeoutTest.name)
}
if errorsDetected {
logCollector.DumpLogs()
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
ef("[FAIL] Errors detected with %s timeout config", timeoutTest.name)
}
p("Timeout configuration %s test completed successfully in %v", timeoutTest.name, testDuration)
p("Command runner stats:")
p("Operations: %d, Errors: %d, Timeout Errors: %d",
commandsRunner.GetStats().Operations, commandsRunner.GetStats().Errors, commandsRunner.GetStats().TimeoutErrors)
p("Relaxed timeouts: %d, Unrelaxed timeouts: %d", trackerAnalysis.RelaxedTimeoutCount, trackerAnalysis.UnrelaxedTimeoutCount)
})
// Clear logs between timeout configuration tests
logCollector.ClearLogs()
}
p("All timeout configurations tested successfully")
}

View File

@@ -0,0 +1,261 @@
package e2e
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
logs2 "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/logging"
"github.com/redis/go-redis/v9/maintnotifications"
)
// TODO ADD TLS CONFIGS
// TestTLSConfigurationsPushNotifications tests push notifications with different TLS configurations
func ТestTLSConfigurationsPushNotifications(t *testing.T) {
if os.Getenv("E2E_SCENARIO_TESTS") != "true" {
t.Skip("Scenario tests require E2E_SCENARIO_TESTS=true")
}
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Minute)
defer cancel()
var dump = true
var errorsDetected = false
var p = func(format string, args ...interface{}) {
printLog("TLS-CONFIGS", false, format, args...)
}
var e = func(format string, args ...interface{}) {
errorsDetected = true
printLog("TLS-CONFIGS", true, format, args...)
}
// Test different TLS configurations
// Note: TLS configuration is typically handled at the Redis connection config level
// This scenario demonstrates the testing pattern for different TLS setups
tlsConfigs := []struct {
name string
description string
skipReason string
}{
{
name: "NoTLS",
description: "No TLS encryption (plain text)",
},
{
name: "TLSInsecure",
description: "TLS with insecure skip verify (testing only)",
},
{
name: "TLSSecure",
description: "Secure TLS with certificate verification",
skipReason: "Requires valid certificates in test environment",
},
{
name: "TLSMinimal",
description: "TLS with minimal version requirements",
},
{
name: "TLSStrict",
description: "Strict TLS with TLS 1.3 and specific cipher suites",
},
}
logCollector.ClearLogs()
defer func() {
logCollector.Clear()
}()
// Test each TLS configuration with its own fresh database
for _, tlsTest := range tlsConfigs {
t.Run(tlsTest.name, func(t *testing.T) {
// Setup: Create fresh database and client factory for THIS TLS config test
bdbID, factory, cleanup := SetupTestDatabaseAndFactory(t, ctx, "standalone")
defer cleanup()
t.Logf("[TLS-CONFIGS-%s] Created test database with bdb_id: %d", tlsTest.name, bdbID)
// Get endpoint config from factory (now connected to new database)
endpointConfig := factory.GetConfig()
// Create fault injector
faultInjector, err := CreateTestFaultInjector()
if err != nil {
t.Fatalf("[ERROR] Failed to create fault injector: %v", err)
}
defer func() {
if dump {
p("Pool stats:")
factory.PrintPoolStats(t)
}
}()
errorsDetected = false
var ef = func(format string, args ...interface{}) {
printLog("TLS-CONFIGS", true, format, args...)
t.FailNow()
}
if tlsTest.skipReason != "" {
t.Skipf("Skipping %s: %s", tlsTest.name, tlsTest.skipReason)
}
p("Testing TLS configuration: %s - %s", tlsTest.name, tlsTest.description)
minIdleConns := 3
poolSize := 8
maxConnections := 12
// Create Redis client with specific TLS configuration
// Note: TLS configuration is handled at the factory/connection level
client, err := factory.Create(fmt.Sprintf("tls-test-%s", tlsTest.name), &CreateClientOptions{
Protocol: 3, // RESP3 required for push notifications
PoolSize: poolSize,
MinIdleConns: minIdleConns,
MaxActiveConns: maxConnections,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
HandoffTimeout: 30 * time.Second,
RelaxedTimeout: 10 * time.Second,
PostHandoffRelaxedDuration: 2 * time.Second,
MaxWorkers: 15,
EndpointType: maintnotifications.EndpointTypeExternalIP,
},
ClientName: fmt.Sprintf("tls-test-%s", tlsTest.name),
})
if err != nil {
// Some TLS configurations might fail in test environments
if tlsTest.name == "TLSSecure" || tlsTest.name == "TLSStrict" {
t.Skipf("TLS configuration %s failed (expected in test environment): %v", tlsTest.name, err)
}
ef("Failed to create client for %s: %v", tlsTest.name, err)
}
// Create timeout tracker
tracker := NewTrackingNotificationsHook()
logger := maintnotifications.NewLoggingHook(int(logging.LogLevelDebug))
setupNotificationHooks(client, tracker, logger)
defer func() {
tracker.Clear()
}()
// Verify initial connectivity
err = client.Ping(ctx).Err()
if err != nil {
if tlsTest.name == "TLSSecure" || tlsTest.name == "TLSStrict" {
t.Skipf("TLS configuration %s ping failed (expected in test environment): %v", tlsTest.name, err)
}
ef("Failed to ping Redis with %s TLS config: %v", tlsTest.name, err)
}
p("Client connected successfully with %s TLS configuration", tlsTest.name)
commandsRunner, _ := NewCommandRunner(client)
defer func() {
if dump {
stats := commandsRunner.GetStats()
p("%s TLS config stats: Operations: %d, Errors: %d, Timeout Errors: %d",
tlsTest.name, stats.Operations, stats.Errors, stats.TimeoutErrors)
}
commandsRunner.Stop()
}()
// Start command traffic
go func() {
commandsRunner.FireCommandsUntilStop(ctx)
}()
// Test migration with this TLS configuration
p("Testing migration with %s TLS configuration...", tlsTest.name)
migrateResp, err := faultInjector.TriggerAction(ctx, ActionRequest{
Type: "migrate",
Parameters: map[string]interface{}{
"bdb_id": endpointConfig.BdbID,
},
})
if err != nil {
ef("Failed to trigger migrate action for %s: %v", tlsTest.name, err)
}
// Wait for MIGRATING notification
match, found := logCollector.MatchOrWaitForLogMatchFunc(func(s string) bool {
return strings.Contains(s, logs2.ProcessingNotificationMessage) && strings.Contains(s, "MIGRATING")
}, 60*time.Second)
if !found {
ef("MIGRATING notification was not received for %s TLS config", tlsTest.name)
}
migrateData := logs2.ExtractDataFromLogMessage(match)
p("MIGRATING notification received for %s: %v", tlsTest.name, migrateData)
// Wait for migration to complete
status, err := faultInjector.WaitForAction(ctx, migrateResp.ActionID,
WithMaxWaitTime(240*time.Second),
WithPollInterval(2*time.Second),
)
if err != nil {
ef("[FI] Migrate action failed for %s: %v", tlsTest.name, err)
}
p("[FI] Migrate action completed for %s: %s %s", tlsTest.name, status.Status, actionOutputIfFailed(status))
// Continue traffic for a bit to observe TLS behavior
time.Sleep(5 * time.Second)
commandsRunner.Stop()
// Analyze results for this TLS configuration
trackerAnalysis := tracker.GetAnalysis()
if trackerAnalysis.NotificationProcessingErrors > 0 {
e("Notification processing errors with %s TLS config: %d", tlsTest.name, trackerAnalysis.NotificationProcessingErrors)
}
if trackerAnalysis.UnexpectedNotificationCount > 0 {
e("Unexpected notifications with %s TLS config: %d", tlsTest.name, trackerAnalysis.UnexpectedNotificationCount)
}
// Validate we received expected notifications
if trackerAnalysis.FailingOverCount == 0 {
e("Expected FAILING_OVER notifications with %s TLS config, got none", tlsTest.name)
}
if trackerAnalysis.FailedOverCount == 0 {
e("Expected FAILED_OVER notifications with %s TLS config, got none", tlsTest.name)
}
if trackerAnalysis.MigratingCount == 0 {
e("Expected MIGRATING notifications with %s TLS config, got none", tlsTest.name)
}
if errorsDetected {
logCollector.DumpLogs()
trackerAnalysis.Print(t)
logCollector.Clear()
tracker.Clear()
ef("[FAIL] Errors detected with %s TLS config", tlsTest.name)
}
// TLS-specific validations
stats := commandsRunner.GetStats()
switch tlsTest.name {
case "NoTLS":
// Plain text should work fine
p("Plain text connection processed %d operations", stats.Operations)
case "TLSInsecure", "TLSMinimal":
// Insecure TLS should work in test environments
p("Insecure TLS connection processed %d operations", stats.Operations)
if stats.Operations == 0 {
e("Expected operations with %s TLS config, got none", tlsTest.name)
}
case "TLSStrict":
// Strict TLS might have different performance characteristics
p("Strict TLS connection processed %d operations", stats.Operations)
}
p("TLS configuration %s test completed successfully", tlsTest.name)
})
// Clear logs between TLS configuration tests
logCollector.ClearLogs()
}
p("All TLS configurations tested successfully")
}

View File

@@ -0,0 +1,213 @@
#!/bin/bash
# Maintenance Notifications E2E Tests Runner
# This script sets up the environment and runs the maintnotifications upgrade E2E tests
set -euo pipefail
# Script directory and repository root
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
E2E_DIR="${REPO_ROOT}/maintnotifications/e2e"
# Configuration
FAULT_INJECTOR_URL="http://127.0.0.1:20324"
CONFIG_PATH="${REPO_ROOT}/maintnotifications/e2e/infra/cae-client-testing/endpoints.json"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Logging functions
log_info() {
echo -e "${BLUE}[INFO]${NC} $1" >&2
}
log_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1" >&2
}
log_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1" >&2
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1" >&2
}
# Help function
show_help() {
cat << EOF
Maintenance Notifications E2E Tests Runner
Usage: $0 [OPTIONS]
OPTIONS:
-h, --help Show this help message
-v, --verbose Enable verbose test output
-t, --timeout DURATION Test timeout (default: 30m)
-r, --run PATTERN Run only tests matching pattern
--dry-run Show what would be executed without running
--list List available tests
--config PATH Override config path (default: infra/cae-client-testing/endpoints.json)
--fault-injector URL Override fault injector URL (default: http://127.0.0.1:20324)
EXAMPLES:
$0 # Run all E2E tests
$0 -v # Run with verbose output
$0 -r TestPushNotifications # Run only push notification tests
$0 -t 45m # Run with 45 minute timeout
$0 --dry-run # Show what would be executed
$0 --list # List available tests
ENVIRONMENT:
The script automatically sets up the required environment variables:
- REDIS_ENDPOINTS_CONFIG_PATH: Path to Redis endpoints configuration
- FAULT_INJECTION_API_URL: URL of the fault injector server
- E2E_SCENARIO_TESTS: Enables scenario tests
EOF
}
# Parse command line arguments
VERBOSE=""
TIMEOUT="30m"
RUN_PATTERN=""
DRY_RUN=false
LIST_TESTS=false
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
-v|--verbose)
VERBOSE="-v"
shift
;;
-t|--timeout)
TIMEOUT="$2"
shift 2
;;
-r|--run)
RUN_PATTERN="$2"
shift 2
;;
--dry-run)
DRY_RUN=true
shift
;;
--list)
LIST_TESTS=true
shift
;;
--config)
CONFIG_PATH="$2"
shift 2
;;
--fault-injector)
FAULT_INJECTOR_URL="$2"
shift 2
;;
*)
log_error "Unknown option: $1"
show_help
exit 1
;;
esac
done
# Validate configuration file exists
if [[ ! -f "$CONFIG_PATH" ]]; then
log_error "Configuration file not found: $CONFIG_PATH"
log_info "Please ensure the endpoints.json file exists at the specified path"
exit 1
fi
# Set up environment variables
export REDIS_ENDPOINTS_CONFIG_PATH="$CONFIG_PATH"
export FAULT_INJECTION_API_URL="$FAULT_INJECTOR_URL"
export E2E_SCENARIO_TESTS="true"
# Build test command
TEST_CMD="go test -json -tags=e2e"
if [[ -n "$TIMEOUT" ]]; then
TEST_CMD="$TEST_CMD -timeout=$TIMEOUT"
fi
# Note: -v flag is not compatible with -json output format
# The -json format already provides verbose test information
if [[ -n "$RUN_PATTERN" ]]; then
TEST_CMD="$TEST_CMD -run $RUN_PATTERN"
fi
TEST_CMD="$TEST_CMD ./maintnotifications/e2e/ "
# List tests if requested
if [[ "$LIST_TESTS" == true ]]; then
log_info "Available E2E tests:"
cd "$REPO_ROOT"
go test -tags=e2e ./maintnotifications/e2e/ -list=. | grep -E "^Test" | sort
exit 0
fi
# Show configuration
log_info "Maintenance notifications E2E Tests Configuration:"
echo " Repository Root: $REPO_ROOT" >&2
echo " E2E Directory: $E2E_DIR" >&2
echo " Config Path: $CONFIG_PATH" >&2
echo " Fault Injector URL: $FAULT_INJECTOR_URL" >&2
echo " Test Timeout: $TIMEOUT" >&2
if [[ -n "$RUN_PATTERN" ]]; then
echo " Test Pattern: $RUN_PATTERN" >&2
fi
echo "" >&2
# Validate fault injector connectivity
log_info "Checking fault injector connectivity..."
if command -v curl >/dev/null 2>&1; then
if curl -s --connect-timeout 5 "$FAULT_INJECTOR_URL/health" >/dev/null 2>&1; then
log_success "Fault injector is accessible at $FAULT_INJECTOR_URL"
else
log_warning "Cannot connect to fault injector at $FAULT_INJECTOR_URL"
log_warning "Tests may fail if fault injection is required"
fi
else
log_warning "curl not available, skipping fault injector connectivity check"
fi
# Show what would be executed in dry-run mode
if [[ "$DRY_RUN" == true ]]; then
log_info "Dry run mode - would execute:"
echo " cd $REPO_ROOT" >&2
echo " export REDIS_ENDPOINTS_CONFIG_PATH=\"$CONFIG_PATH\"" >&2
echo " export FAULT_INJECTION_API_URL=\"$FAULT_INJECTOR_URL\"" >&2
echo " export E2E_SCENARIO_TESTS=\"true\"" >&2
echo " $TEST_CMD" >&2
exit 0
fi
# Change to repository root
cd "$REPO_ROOT"
# Run the tests
log_info "Starting E2E tests..."
log_info "Command: $TEST_CMD"
echo "" >&2
if eval "$TEST_CMD"; then
echo "" >&2
log_success "All E2E tests completed successfully!"
exit 0
else
echo "" >&2
log_error "E2E tests failed!"
log_info "Check the test output above for details"
exit 1
fi

View File

@@ -0,0 +1,76 @@
package e2e
import (
"fmt"
"path/filepath"
"runtime"
"time"
)
func isTimeout(errMsg string) bool {
return contains(errMsg, "i/o timeout") ||
contains(errMsg, "deadline exceeded") ||
contains(errMsg, "context deadline exceeded")
}
// isTimeoutError checks if an error is a timeout error
func isTimeoutError(err error) bool {
if err == nil {
return false
}
// Check for various timeout error types
errStr := err.Error()
return isTimeout(errStr)
}
// contains checks if a string contains a substring (case-insensitive)
func contains(s, substr string) bool {
return len(s) >= len(substr) &&
(s == substr ||
(len(s) > len(substr) &&
(s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
containsSubstring(s, substr))))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func printLog(group string, isError bool, format string, args ...interface{}) {
_, filename, line, _ := runtime.Caller(2)
filename = filepath.Base(filename)
finalFormat := "%s:%d [%s][%s] " + format + "\n"
if isError {
finalFormat = "%s:%d [%s][%s][ERROR] " + format + "\n"
}
ts := time.Now().Format("15:04:05.000")
args = append([]interface{}{filename, line, ts, group}, args...)
fmt.Printf(finalFormat, args...)
}
func actionOutputIfFailed(status *ActionStatusResponse) string {
if status.Status != StatusFailed {
return ""
}
if status.Error != nil {
return fmt.Sprintf("%v", status.Error)
}
if status.Output == nil {
return ""
}
return fmt.Sprintf("%+v", status.Output)
}

View File

@@ -0,0 +1,63 @@
package maintnotifications
import (
"errors"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError())
ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError())
ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError())
ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError())
ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError())
ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError())
ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError())
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
// Configuration validation errors
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
)
// Integration errors
var (
ErrInvalidClient = errors.New(logs.InvalidClientError())
)
// Handoff errors
var (
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
)
// Notification errors
var (
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
)
// connection handoff errors
var (
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoff = errors.New("" + logs.ConnectionMarkedForHandoffErrorMessage)
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
ErrConnectionInvalidHandoffState = errors.New("" + logs.ConnectionInvalidHandoffStateErrorMessage)
)
// general errors
var (
ErrShutdown = errors.New(logs.ShutdownError())
)
// circuit breaker errors
var (
ErrCircuitBreakerOpen = errors.New("" + logs.CircuitBreakerOpenErrorMessage)
)
// circuit breaker configuration errors
var (
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
)

View File

@@ -0,0 +1,101 @@
package maintnotifications
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
startTimeKey contextKey = "maint_notif_start_time"
)
// MetricsHook collects metrics about notification processing.
type MetricsHook struct {
NotificationCounts map[string]int64
ProcessingTimes map[string]time.Duration
ErrorCounts map[string]int64
HandoffCounts int64 // Total handoffs initiated
HandoffSuccesses int64 // Successful handoffs
HandoffFailures int64 // Failed handoffs
}
// NewMetricsHook creates a new metrics collection hook.
func NewMetricsHook() *MetricsHook {
return &MetricsHook{
NotificationCounts: make(map[string]int64),
ProcessingTimes: make(map[string]time.Duration),
ErrorCounts: make(map[string]int64),
}
}
// PreHook records the start time for processing metrics.
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
mh.NotificationCounts[notificationType]++
// Log connection information if available
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID()))
}
// Store start time in context for duration calculation
startTime := time.Now()
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
return notification, true
}
// PostHook records processing completion and any errors.
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
// Calculate processing duration
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
duration := time.Since(startTime)
mh.ProcessingTimes[notificationType] = duration
}
// Record errors
if result != nil {
mh.ErrorCounts[notificationType]++
// Log error details with connection information
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result))
}
}
}
// GetMetrics returns a summary of collected metrics.
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"notification_counts": mh.NotificationCounts,
"processing_times": mh.ProcessingTimes,
"error_counts": mh.ErrorCounts,
}
}
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
// Get circuit breaker statistics
stats := poolHook.GetCircuitBreakerStats()
for _, stat := range stats {
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
fmt.Printf(" State: %s\n", stat.State)
fmt.Printf(" Failures: %d\n", stat.Failures)
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
// Alert if circuit breaker is open
if stat.State.String() == "open" {
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
}
}
}

View File

@@ -0,0 +1,480 @@
package maintnotifications
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
)
// handoffWorkerManager manages background workers and queue for connection handoffs
type handoffWorkerManager struct {
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers atomic.Int32
workerTimeout time.Duration // How long workers wait for work before exiting
workersScaling atomic.Bool
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the maintenance notifications
config *Config
// Pool hook reference for handoff processing
poolHook *PoolHook
// Circuit breaker manager for endpoint failure handling
circuitBreakerManager *CircuitBreakerManager
}
// newHandoffWorkerManager creates a new handoff worker manager
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
return &handoffWorkerManager{
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
circuitBreakerManager: newCircuitBreakerManager(config),
}
}
// getCurrentWorkers returns the current number of active workers (for testing)
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
return int(hwm.activeWorkers.Load())
}
// getPendingMap returns the pending map for testing purposes
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
return &hwm.pending
}
// getMaxWorkers returns the max workers for testing purposes
func (hwm *handoffWorkerManager) getMaxWorkers() int {
return hwm.maxWorkers
}
// getHandoffQueue returns the handoff queue for testing purposes
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
return hwm.handoffQueue
}
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
return hwm.circuitBreakerManager.GetAllStats()
}
// resetCircuitBreakers resets all circuit breakers (useful for testing)
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
hwm.circuitBreakerManager.Reset()
}
// isHandoffPending returns true if the given connection has a pending handoff
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
_, pending := hwm.pending.Load(conn.GetID())
return pending
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
select {
case <-hwm.shutdown:
return
default:
if hwm.workersScaling.CompareAndSwap(false, true) {
defer hwm.workersScaling.Store(false)
// Check if we need a new worker
currentWorkers := hwm.activeWorkers.Load()
workersWas := currentWorkers
for currentWorkers < int32(hwm.maxWorkers) {
hwm.workerWg.Add(1)
go hwm.onDemandWorker()
currentWorkers++
}
// workersWas is always <= currentWorkers
// currentWorkers will be maxWorkers, but if we have a worker that was closed
// while we were creating new workers, just add the difference between
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
hwm.activeWorkers.Add(currentWorkers - workersWas)
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (hwm *handoffWorkerManager) onDemandWorker() {
defer func() {
// Handle panics to ensure proper cleanup
if r := recover(); r != nil {
internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
}
// Decrement active worker count when exiting
hwm.activeWorkers.Add(-1)
hwm.workerWg.Done()
}()
// Create reusable timer to prevent timer leaks
timer := time.NewTimer(hwm.workerTimeout)
defer timer.Stop()
for {
// Reset timer for next iteration
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(hwm.workerTimeout)
select {
case <-hwm.shutdown:
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
}
return
case <-timer.C:
// Worker has been idle for too long, exit to save resources
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
}
return
case request := <-hwm.handoffQueue:
// Check for shutdown before processing
select {
case <-hwm.shutdown:
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
}
// Clean up the request before exiting
hwm.pending.Delete(request.ConnID)
return
default:
// Process the request
hwm.processHandoffRequest(request)
}
}
}
}
// processHandoffRequest processes a single handoff request
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
// Remove from pending map
defer hwm.pending.Delete(request.Conn.GetID())
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
}
// Create a context with handoff timeout from config
handoffTimeout := 15 * time.Second // Default timeout
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
handoffTimeout = hwm.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-hwm.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
minRetryBackoff := 500 * time.Millisecond
if err != nil {
if shouldRetry {
now := time.Now()
deadline, ok := shutdownCtx.Deadline()
thirdOfTimeout := handoffTimeout / 3
if !ok || deadline.Before(now) {
// wait half the timeout before retrying if no deadline or deadline has passed
deadline = now.Add(thirdOfTimeout)
}
afterTime := deadline.Sub(now)
if afterTime < minRetryBackoff {
afterTime = minRetryBackoff
}
if internal.LogLevel.InfoOrAbove() {
// Get current retry count for better logging
currentRetries := request.Conn.HandoffRetries()
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
}
time.AfterFunc(afterTime, func() {
if err := hwm.queueHandoff(request.Conn); err != nil {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
}
hwm.closeConnFromRequest(context.Background(), request, err)
}
})
return
} else {
go hwm.closeConnFromRequest(ctx, request, err)
}
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if hwm.poolHook.operationsManager != nil {
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
}
}
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Get handoff info atomically to prevent race conditions
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
// on retries the connection will not be marked for handoff, but it will have retries > 0
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
if !shouldHandoff && conn.HandoffRetries() == 0 {
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
}
return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
}
// Create handoff request with atomically retrieved data
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: endpoint,
SeqID: seqID,
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-hwm.shutdown:
return ErrShutdown
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
case <-time.After(100 * time.Millisecond): // give workers a chance to process
// Queue is full - log and attempt scaling
queueLen := len(hwm.handoffQueue)
queueCap := cap(hwm.handoffQueue)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
}
}
}
}
// Ensure we have workers available to handle the load
hwm.ensureWorkerAvailable()
return ErrHandoffQueueFull
}
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
hwm.shutdownOnce.Do(func() {
close(hwm.shutdown)
// workers will exit when they finish their current request
// Shutdown circuit breaker manager cleanup goroutine
if hwm.circuitBreakerManager != nil {
hwm.circuitBreakerManager.Shutdown()
}
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
hwm.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// performConnectionHandoff performs the actual connection handoff
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
// Clear handoff state after successful handoff
connID := conn.GetID()
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
return false, ErrConnectionInvalidHandoffState
}
// Use circuit breaker to protect against failing endpoints
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
// Check if circuit breaker is open before attempting handoff
if circuitBreaker.IsOpen() {
internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
}
// Perform the handoff
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
// Update circuit breaker based on result
if err != nil {
// Only track dial/network errors in circuit breaker, not initialization errors
if shouldRetry {
circuitBreaker.recordFailure()
}
return shouldRetry, err
}
// Success - record in circuit breaker
circuitBreaker.recordSuccess()
return false, nil
}
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
if retries > maxRetries {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
}
// Create endpoint-specific dialer
endpointDialer := hwm.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
// will retry
// Maybe a network error - retry after a delay
return true, err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
// Setting this here (before initing the connection) ensures that the connection is going
// to use the relaxed timeout for the first operation (auth/ACL select)
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := hwm.config.RelaxedTimeout
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// won't retry
// Initialization failed - remove the connection
return false, err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
return false, nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
}
}
// closeConnFromRequest closes the connection and logs the reason
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
pooler := request.Pool
conn := request.Conn
if pooler != nil {
pooler.Remove(ctx, conn, err)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
}
} else {
conn.Close()
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
}
}
}

View File

@@ -0,0 +1,60 @@
package maintnotifications
import (
"context"
"slices"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel >= 2 { // Info level
// Log the notification type and content
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
seqID := int64(0)
if slices.Contains(maintenanceNotificationTypes, notificationType) {
// seqID is the second element in the notification array
if len(notification) > 1 {
if parsedSeqID, ok := notification[1].(int64); !ok {
seqID = 0
} else {
seqID = parsedSeqID
}
}
}
internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification))
}
return notification, true // Continue processing with unmodified notification
}
// PostHook logs the result after processing.
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
if result != nil && lh.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification))
} else if lh.LogLevel >= 3 { // Debug level
internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType))
}
}
// NewLoggingHook creates a new logging hook with the specified log level.
// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug
func NewLoggingHook(logLevel int) *LoggingHook {
return &LoggingHook{LogLevel: logLevel}
}

View File

@@ -0,0 +1,320 @@
package maintnotifications
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// Push notification type constants for maintenance
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
NotificationMigrated = "MIGRATED"
NotificationFailingOver = "FAILING_OVER"
NotificationFailedOver = "FAILED_OVER"
)
// maintenanceNotificationTypes contains all notification types that maintenance handles
var maintenanceNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
// NotificationHook is called before and after notification processing
// PreHook can modify the notification and return false to skip processing
// PostHook is called after successful processing
type NotificationHook interface {
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
}
// MovingOperationKey provides a unique key for tracking MOVING operations
// that combines sequence ID with connection identifier to handle duplicate
// sequence IDs across multiple connections to the same node.
type MovingOperationKey struct {
SeqID int64 // Sequence ID from MOVING notification
ConnID uint64 // Unique connection identifier
}
// String returns a string representation of the key for debugging
func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// Manager provides a simplified upgrade functionality with hooks and atomic state.
type Manager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
pool pool.Pooler
// MOVING operation tracking - using sync.Map for better concurrent performance
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
// Atomic state tracking - no locks needed for state queries
activeOperationCount atomic.Int64 // Number of active operations
closed atomic.Bool // Manager closed state
// Notification hooks for extensibility
hooks []NotificationHook
hooksMu sync.RWMutex // Protects hooks slice
poolHooksRef *PoolHook
}
// MovingOperation tracks an active MOVING operation.
type MovingOperation struct {
SeqID int64
NewEndpoint string
StartTime time.Time
Deadline time.Time
}
// NewManager creates a new simplified manager.
func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &Manager{
client: client,
pool: pool,
options: client.GetOptions(),
config: config.Clone(),
hooks: make([]NotificationHook, 0),
}
// Set up push notification handling
if err := hm.setupPushNotifications(); err != nil {
return nil, err
}
return hm, nil
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *Manager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm, operationsManager: hm}
// Register handlers for all upgrade notifications with the client's processor
for _, notificationType := range maintenanceNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return errors.New(logs.FailedToRegisterHandler(notificationType, err))
}
}
return nil
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Create MOVING operation record
movingOp := &MovingOperation{
SeqID: seqID,
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Deadline: deadline,
}
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
}
return nil
}
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
}
// Increment active operation count atomically
hm.activeOperationCount.Add(1)
return nil
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
}
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
} else {
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
}
}
}
// GetActiveMovingOperations returns active operations with composite keys.
// WARNING: This method creates a new map and copies all operations on every call.
// Use sparingly, especially in hot paths or high-frequency logging.
func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
hm.activeMovingOps.Range(func(key, value interface{}) bool {
k := key.(MovingOperationKey)
op := value.(*MovingOperation)
// Create a copy to avoid sharing references
result[k] = &MovingOperation{
SeqID: op.SeqID,
NewEndpoint: op.NewEndpoint,
StartTime: op.StartTime,
Deadline: op.Deadline,
}
return true // Continue iteration
})
return result
}
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *Manager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *Manager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the manager.
func (hm *Manager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
// Shutdown the pool hook if it exists
if hm.poolHooksRef != nil {
// Use a timeout to prevent hanging indefinitely
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := hm.poolHooksRef.Shutdown(shutdownCtx)
if err != nil {
// was not able to close pool hook, keep closed state false
hm.closed.Store(false)
return err
}
// Remove the pool hook from the pool
if hm.pool != nil {
hm.pool.RemovePoolHook(hm.poolHooksRef)
}
}
// Clear all active operations
hm.activeMovingOps.Range(func(key, value interface{}) bool {
hm.activeMovingOps.Delete(key)
return true
})
// Reset counter
hm.activeOperationCount.Store(0)
return nil
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *Manager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
return StateIdle
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
currentNotification := notification
for _, hook := range hm.hooks {
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
if !shouldContinue {
return modifiedNotification, false
}
currentNotification = modifiedNotification
}
return currentNotification, true
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
for _, hook := range hm.hooks {
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
}
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
// Get pool size from client options for better worker defaults
poolSize := 0
if hm.options != nil {
poolSize = hm.options.GetPoolSize()
}
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
hm.poolHooksRef.SetPool(hm.pool)
return hm.poolHooksRef
}
func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
hm.hooksMu.Lock()
defer hm.hooksMu.Unlock()
hm.hooks = append(hm.hooks, notificationHook)
}

View File

@@ -0,0 +1,260 @@
package maintnotifications
import (
"context"
"net"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
)
// MockClient implements interfaces.ClientInterface for testing
type MockClient struct {
options interfaces.OptionsInterface
}
func (mc *MockClient) GetOptions() interfaces.OptionsInterface {
return mc.options
}
func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor {
return &MockPushProcessor{}
}
// MockPushProcessor implements interfaces.NotificationProcessor for testing
type MockPushProcessor struct{}
func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error {
return nil
}
func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error {
return nil
}
func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} {
return nil
}
// MockOptions implements interfaces.OptionsInterface for testing
type MockOptions struct{}
func (mo *MockOptions) GetReadTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetWriteTimeout() time.Duration {
return 5 * time.Second
}
func (mo *MockOptions) GetAddr() string {
return "localhost:6379"
}
func (mo *MockOptions) IsTLSEnabled() bool {
return false
}
func (mo *MockOptions) GetProtocol() int {
return 3 // RESP3
}
func (mo *MockOptions) GetPoolSize() int {
return 10
}
func (mo *MockOptions) GetNetwork() string {
return "tcp"
}
func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
return nil, nil
}
}
func TestManagerRefactoring(t *testing.T) {
t.Run("AtomicStateTracking", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create maintnotifications manager: %v", err)
}
defer manager.Close()
// Test initial state
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress initially")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
// Add an operation
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Test state after adding operation
if !manager.IsHandoffInProgress() {
t.Error("Expected handoff in progress after adding operation")
}
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateMoving {
t.Errorf("Expected StateMoving, got %v", manager.GetState())
}
// Remove the operation
manager.UntrackOperationWithConnID(12345, 1)
// Test state after removing operation
if manager.IsHandoffInProgress() {
t.Error("Expected no handoff in progress after removing operation")
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount())
}
if manager.GetState() != StateIdle {
t.Errorf("Expected StateIdle, got %v", manager.GetState())
}
})
t.Run("SyncMapPerformance", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create maintnotifications manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Test concurrent operations
const numOps = 100
for i := 0; i < numOps; i++ {
err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i))
if err != nil {
t.Fatalf("Failed to track operation %d: %v", i, err)
}
}
if manager.GetActiveOperationCount() != numOps {
t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount())
}
// Test GetActiveMovingOperations
operations := manager.GetActiveMovingOperations()
if len(operations) != numOps {
t.Errorf("Expected %d operations in map, got %d", numOps, len(operations))
}
// Remove all operations
for i := 0; i < numOps; i++ {
manager.UntrackOperationWithConnID(int64(i), uint64(i))
}
if manager.GetActiveOperationCount() != 0 {
t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount())
}
})
t.Run("DuplicateOperationHandling", func(t *testing.T) {
config := DefaultConfig()
client := &MockClient{options: &MockOptions{}}
manager, err := NewManager(client, nil, config)
if err != nil {
t.Fatalf("Failed to create maintnotifications manager: %v", err)
}
defer manager.Close()
ctx := context.Background()
deadline := time.Now().Add(30 * time.Second)
// Add operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Failed to track operation: %v", err)
}
// Try to add duplicate operation
err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1)
if err != nil {
t.Fatalf("Duplicate operation should not return error: %v", err)
}
// Should still have only 1 operation
if manager.GetActiveOperationCount() != 1 {
t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount())
}
})
t.Run("NotificationTypeConstants", func(t *testing.T) {
// Test that constants are properly defined
expectedTypes := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
if len(maintenanceNotificationTypes) != len(expectedTypes) {
t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(maintenanceNotificationTypes))
}
// Test that all expected types are present
typeMap := make(map[string]bool)
for _, t := range maintenanceNotificationTypes {
typeMap[t] = true
}
for _, expected := range expectedTypes {
if !typeMap[expected] {
t.Errorf("Expected notification type %s not found in maintenanceNotificationTypes", expected)
}
}
// Test that maintenanceNotificationTypes contains all expected constants
expectedConstants := []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
for _, expected := range expectedConstants {
found := false
for _, actual := range maintenanceNotificationTypes {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected constant %s not found in maintenanceNotificationTypes", expected)
}
}
})
}

View File

@@ -0,0 +1,180 @@
package maintnotifications
import (
"context"
"net"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
)
// OperationsManagerInterface defines the interface for completing handoff operations
type OperationsManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
// HandoffRequest represents a request to handoff a connection to a new endpoint
type HandoffRequest struct {
Conn *pool.Conn
ConnID uint64 // Unique connection identifier
Endpoint string
SeqID int64
Pool pool.Pooler // Pool to remove connection from on failure
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with maintenance notifications support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
baseDialer func(context.Context, string, string) (net.Conn, error)
// Network type (e.g., "tcp", "unix")
network string
// Worker manager for background handoff processing
workerManager *handoffWorkerManager
// Configuration for the maintenance notifications
config *Config
// Operations manager interface for operation completion tracking
operationsManager OperationsManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook {
// Apply defaults if config is nil or has zero values
if config == nil {
config = config.ApplyDefaultsWithPoolSize(poolSize)
}
ph := &PoolHook{
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
config: config,
operationsManager: operationsManager,
}
// Create worker manager
ph.workerManager = newHandoffWorkerManager(config, ph)
return ph
}
// SetPool sets the pool interface for removing connections on handoff failure
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
ph.pool = pooler
}
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return ph.workerManager.getCurrentWorkers()
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
return ph.workerManager.isHandoffPending(conn)
}
// GetPendingMap returns the pending map for testing purposes
func (ph *PoolHook) GetPendingMap() *sync.Map {
return ph.workerManager.getPendingMap()
}
// GetMaxWorkers returns the max workers for testing purposes
func (ph *PoolHook) GetMaxWorkers() int {
return ph.workerManager.getMaxWorkers()
}
// GetHandoffQueue returns the handoff queue for testing purposes
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
return ph.workerManager.getHandoffQueue()
}
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
return ph.workerManager.getCircuitBreakerStats()
}
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
func (ph *PoolHook) ResetCircuitBreakers() {
ph.workerManager.resetCircuitBreakers()
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is
// in a handoff state at the moment.
// Check if connection is usable (not in a handoff state)
// Should not happen since the pool will not return a connection that is not usable.
if !conn.IsUsable() {
return ErrConnectionMarkedForHandoff
}
// Check if connection is marked for handoff, which means it will be queued for handoff on put.
if conn.ShouldHandoff() {
return ErrConnectionMarkedForHandoff
}
return nil
}
// OnPut is called when a connection is returned to the pool
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
// first check if we should handoff for faster rejection
if !conn.ShouldHandoff() {
// Default behavior (no handoff): pool the connection
return true, false, nil
}
// check pending handoff to not queue the same connection twice
if ph.workerManager.isHandoffPending(conn) {
// Default behavior (pending handoff): pool the connection
return true, false, nil
}
if err := ph.workerManager.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err))
// Don't pool, remove connection, no error to caller
return false, true, nil
}
// Check if handoff was already processed by a worker before we can mark it as queued
if !conn.ShouldHandoff() {
// Handoff was already processed - this is normal and the connection should be pooled
return true, false, nil
}
if err := conn.MarkQueuedForHandoff(); err != nil {
// If marking fails, check if handoff was processed in the meantime
if !conn.ShouldHandoff() {
// Handoff was processed - this is normal, pool the connection
return true, false, nil
}
// Other error - remove the connection
return false, true, nil
}
internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID()))
return true, false, nil
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
return ph.workerManager.shutdownWorkers(ctx)
}

View File

@@ -0,0 +1,954 @@
package maintnotifications
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
// mockNetConn implements net.Conn for testing
type mockNetConn struct {
addr string
shouldFailInit bool
}
func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil }
func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (m *mockNetConn) Close() error { return nil }
func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} }
func (m *mockNetConn) SetDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil }
type mockAddr struct {
addr string
}
func (m *mockAddr) Network() string { return "tcp" }
func (m *mockAddr) String() string { return m.addr }
// createMockPoolConnection creates a mock pool connection for testing
func createMockPoolConnection() *pool.Conn {
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
conn.SetUsable(true) // Make connection usable for testing
return conn
}
// mockPool implements pool.Pooler for testing
type mockPool struct {
removedConnections map[uint64]bool
mu sync.Mutex
}
func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) CloseConn(conn *pool.Conn) error {
return nil
}
func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
return nil, errors.New("not implemented")
}
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
// Not implemented for testing
}
func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Use pool.Conn directly - no adapter needed
mp.removedConnections[conn.GetID()] = true
}
// WasRemoved safely checks if a connection was removed from the pool
func (mp *mockPool) WasRemoved(connID uint64) bool {
mp.mu.Lock()
defer mp.mu.Unlock()
return mp.removedConnections[connID]
}
func (mp *mockPool) Len() int {
return 0
}
func (mp *mockPool) IdleLen() int {
return 0
}
func (mp *mockPool) Stats() *pool.Stats {
return &pool.Stats{}
}
func (mp *mockPool) AddPoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) {
// Mock implementation - do nothing
}
func (mp *mockPool) Close() error {
return nil
}
// TestConnectionHook tests the Redis connection processor functionality
func TestConnectionHook(t *testing.T) {
// Create a base dialer for testing
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
config := &Config{
Mode: ModeAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 1, // Use only 1 worker to ensure synchronization
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
MaxHandoffRetries: 3,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Verify connection is marked for handoff
if !conn.ShouldHandoff() {
t.Fatal("Connection should be marked for handoff")
}
// Set a mock initialization function with synchronization
initConnCalled := make(chan bool, 1)
proceedWithInit := make(chan bool, 1)
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
select {
case initConnCalled <- true:
default:
}
// Wait for test to proceed
<-proceedWithInit
return nil
}
conn.SetInitConnFunc(initConnFunc)
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection immediately (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled immediately with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for initialization to be called (indicates handoff started)
select {
case <-initConnCalled:
// Good, initialization was called
case <-time.After(1 * time.Second):
t.Fatal("Timeout waiting for initialization function to be called")
}
// Connection should be in pending map while initialization is blocked
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending handoffs map")
}
// Allow initialization to proceed
proceedWithInit <- true
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(2 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify handoff completed (removed from pending map)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map after handoff")
}
// Verify connection is usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after successful handoff")
}
// Verify handoff state is cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after completion")
}
})
t.Run("HandoffNotNeeded", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// Don't mark for handoff
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error when handoff not needed: %v", err)
}
// Should pool the connection normally
if !shouldPool {
t.Error("Connection should be pooled when no handoff needed")
}
if shouldRemove {
t.Error("Connection should not be removed when no handoff needed")
}
})
t.Run("EmptyEndpoint", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error with empty endpoint: %v", err)
}
// Should pool the connection (empty endpoint clears state)
if !shouldPool {
t.Error("Connection should be pooled after clearing empty endpoint")
}
if shouldRemove {
t.Error("Connection should not be removed after clearing empty endpoint")
}
// State should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after clearing empty endpoint")
}
})
t.Run("EventDrivenHandoffDialerError", func(t *testing.T) {
// Create a failing base dialer
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("dial failed")
}
config := &Config{
Mode: ModeAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 2, // Reduced retries for faster test
HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not return error to caller: %v", err)
}
// Should pool the connection initially (handoff queued)
if !shouldPool {
t.Error("Connection should be pooled initially with event-driven handoff")
}
if shouldRemove {
t.Error("Connection should not be removed when queuing handoff")
}
// Wait for handoff to complete and fail with proper timeout and polling
timeout := time.After(3 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
// wait for handoff to start
time.Sleep(50 * time.Millisecond)
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for failed handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
handoffCompleted = true
}
}
}
// Connection should be removed from pending map after failed handoff
if _, pending := processor.GetPendingMap().Load(conn.GetID()); pending {
t.Error("Connection should be removed from pending map after failed handoff")
}
// Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up)
// Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete
time.Sleep(800 * time.Millisecond)
// After max retries are reached, the connection should be removed from pool
// and handoff state should be cleared
if conn.ShouldHandoff() {
t.Error("Connection should not be marked for handoff after max retries reached")
}
t.Logf("EventDrivenHandoffDialerError test completed successfully")
})
t.Run("BufferedDataRESP2", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
// For this test, we'll just verify the logic works for connections without buffered data
// The actual buffered data detection is handled by the pool's connection health check
// which is outside the scope of the Redis connection processor
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
// Should pool the connection normally (no buffered data in mock)
if !shouldPool {
t.Error("Connection should be pooled when no buffered data")
}
if shouldRemove {
t.Error("Connection should not be removed when no buffered data")
}
})
t.Run("OnGet", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
conn := createMockPoolConnection()
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should not error for normal connection: %v", err)
}
})
t.Run("OnGetWithPendingHandoff", func(t *testing.T) {
config := &Config{
Mode: ModeAuto,
EndpointType: EndpointTypeAuto,
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Simulate a pending handoff by marking for handoff and queuing
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Clean up
processor.GetPendingMap().Delete(conn)
})
t.Run("EventDrivenStateManagement", func(t *testing.T) {
processor := NewPoolHook(baseDialer, "tcp", nil, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
// Test initial state - no pending handoffs
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("New connection should not have pending handoffs")
}
// Test adding to pending map
conn.MarkForHandoff("new-endpoint:6379", 12345)
processor.GetPendingMap().Store(conn.GetID(), int64(12345)) // Store connID -> seqID
conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false)
if _, pending := processor.GetPendingMap().Load(conn.GetID()); !pending {
t.Error("Connection should be in pending map")
}
// Test OnGet with pending handoff
ctx := context.Background()
err := processor.OnGet(ctx, conn, false)
if err != ErrConnectionMarkedForHandoff {
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
}
// Test removing from pending map and clearing handoff state
processor.GetPendingMap().Delete(conn)
if _, pending := processor.GetPendingMap().Load(conn); pending {
t.Error("Connection should be removed from pending map")
}
// Clear handoff state to simulate completed handoff
conn.ClearHandoffState()
conn.SetUsable(true) // Make connection usable again
// Test OnGet without pending handoff
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("Should not return error for non-pending connection: %v", err)
}
})
t.Run("EventDrivenQueueOptimization", func(t *testing.T) {
// Create processor with small queue to test optimization features
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 2,
MaxHandoffRetries: 3, // Small queue to trigger optimizations
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Add small delay to simulate network latency
time.Sleep(10 * time.Millisecond)
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create multiple connections that need handoff to fill the queue
connections := make([]*pool.Conn, 5)
for i := 0; i < 5; i++ {
connections[i] = createMockPoolConnection()
if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil {
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
}
// Set a mock initialization function
connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
}
ctx := context.Background()
successCount := 0
// Process connections - should trigger scaling and timeout logic
for _, conn := range connections {
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Logf("OnPut returned error (expected with timeout): %v", err)
}
if shouldPool && !shouldRemove {
successCount++
}
}
// With timeout and scaling, most handoffs should eventually succeed
if successCount == 0 {
t.Error("Should have queued some handoffs with timeout and scaling")
}
t.Logf("Successfully queued %d handoffs with optimization features", successCount)
// Give time for workers to process and scaling to occur
time.Sleep(100 * time.Millisecond)
})
t.Run("WorkerScalingBehavior", func(t *testing.T) {
// Create processor with small queue to test scaling behavior
config := &Config{
MaxWorkers: 15, // Set to >= 10 to test explicit value preservation
HandoffQueueSize: 1,
MaxHandoffRetries: 3, // Very small queue to force scaling
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Verify initial worker count (should be 0 with on-demand workers)
if processor.GetCurrentWorkers() != 0 {
t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers())
}
if processor.GetMaxWorkers() != 15 {
t.Errorf("Expected maxWorkers=15, got %d", processor.GetMaxWorkers())
}
// The on-demand worker behavior creates workers only when needed
// This test just verifies the basic configuration is correct
t.Logf("On-demand worker configuration verified - Max: %d, Current: %d",
processor.GetMaxWorkers(), processor.GetCurrentWorkers())
})
t.Run("PassiveTimeoutRestoration", func(t *testing.T) {
// Create processor with fast post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Allow retries for successful handoff
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing
RelaxedTimeout: 5 * time.Second,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a connection and trigger handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Handoff should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete with proper timeout and polling
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify relaxed timeout is set with deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should have relaxed timeout after handoff")
}
// Test that timeout is still active before deadline
// We'll use HasRelaxedTimeout which internally checks the deadline
if !conn.HasRelaxedTimeout() {
t.Error("Connection should still have active relaxed timeout before deadline")
}
// Wait for deadline to pass
time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer
// Test that timeout is automatically restored after deadline
// HasRelaxedTimeout should return false after deadline passes
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have active relaxed timeout after deadline")
}
// Additional verification: calling HasRelaxedTimeout again should still return false
// and should have cleared the internal timeout values
if conn.HasRelaxedTimeout() {
t.Error("Connection should not have relaxed timeout after deadline (second check)")
}
t.Logf("Passive timeout restoration test completed successfully")
})
t.Run("UsableFlagBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
ctx := context.Background()
// Create a new connection without setting it usable
mockNetConn := &mockNetConn{addr: "test:6379"}
conn := pool.NewConn(mockNetConn)
// Initially, connection should not be usable (not initialized)
if conn.IsUsable() {
t.Error("New connection should not be usable before initialization")
}
// Simulate initialization by setting usable to true
conn.SetUsable(true)
if !conn.IsUsable() {
t.Error("Connection should be usable after initialization")
}
// OnGet should succeed for usable connection
err := processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed for usable connection: %v", err)
}
// Mark connection for handoff
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
// Connection should still be usable until queued, but marked for handoff
if !conn.IsUsable() {
t.Error("Connection should still be usable after being marked for handoff (until queued)")
}
if !conn.ShouldHandoff() {
t.Error("Connection should be marked for handoff")
}
// OnGet should fail for connection marked for handoff
err = processor.OnGet(ctx, conn, false)
if err == nil {
t.Error("OnGet should fail for connection marked for handoff")
}
if err != ErrConnectionMarkedForHandoff {
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
}
// Process the connection to trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should succeed: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after handoff")
}
// Wait for handoff to complete
time.Sleep(50 * time.Millisecond)
// After handoff completion, connection should be usable again
if !conn.IsUsable() {
t.Error("Connection should be usable after handoff completion")
}
// OnGet should succeed again
err = processor.OnGet(ctx, conn, false)
if err != nil {
t.Errorf("OnGet should succeed after handoff completion: %v", err)
}
t.Logf("Usable flag behavior test completed successfully")
})
t.Run("StaticQueueBehavior", func(t *testing.T) {
config := &Config{
MaxWorkers: 3,
HandoffQueueSize: 50,
MaxHandoffRetries: 3, // Explicit static queue size
}
processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100
defer processor.Shutdown(context.Background())
// Verify queue capacity matches configured size
queueCapacity := cap(processor.GetHandoffQueue())
if queueCapacity != 50 {
t.Errorf("Expected queue capacity 50, got %d", queueCapacity)
}
// Test that queue size is static regardless of pool size
// (No dynamic resizing should occur)
ctx := context.Background()
// Fill part of the queue
for i := 0; i < 10; i++ {
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil {
t.Fatalf("Failed to mark conn[%d] for handoff: %v", i, err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("Failed to queue handoff %d: %v", i, err)
}
if !shouldPool || shouldRemove {
t.Errorf("conn[%d] should be pooled after handoff (shouldPool=%v, shouldRemove=%v)",
i, shouldPool, shouldRemove)
}
}
// Verify queue capacity remains static (the main purpose of this test)
finalCapacity := cap(processor.GetHandoffQueue())
if finalCapacity != 50 {
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
}
// Note: We don't check queue size here because workers process items quickly
// The important thing is that the capacity remains static regardless of pool size
})
t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) {
// Create a failing dialer that will cause handoff initialization to fail
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Return a connection that will fail during initialization
return &mockNetConn{addr: addr, shouldFailInit: true}, nil
}
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3,
}
processor := NewPoolHook(failingDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a mock pool that tracks removals
mockPool := &mockPool{removedConnections: make(map[uint64]bool)}
processor.SetPool(mockPool)
ctx := context.Background()
// Create a connection and mark it for handoff
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a failing initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return fmt.Errorf("initialization failed")
})
// Process the connection - handoff should fail and connection should be removed
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Errorf("OnPut should not error: %v", err)
}
if !shouldPool || shouldRemove {
t.Error("Connection should be pooled after failed handoff attempt")
}
// Wait for handoff to be attempted and fail
time.Sleep(100 * time.Millisecond)
// Verify that the connection was removed from the pool
if !mockPool.WasRemoved(conn.GetID()) {
t.Errorf("conn[%d] should have been removed from pool after handoff failure", conn.GetID())
}
t.Logf("Connection removal on handoff failure test completed successfully")
})
t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) {
// Create config with short post-handoff duration for testing
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
MaxHandoffRetries: 3, // Allow retries for successful handoff
RelaxedTimeout: 5 * time.Second,
PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing
}
baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
return &mockNetConn{addr: addr}, nil
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a mock initialization function
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
return nil
})
ctx := context.Background()
shouldPool, shouldRemove, err := processor.OnPut(ctx, conn)
if err != nil {
t.Fatalf("OnPut failed: %v", err)
}
if !shouldPool {
t.Error("Connection should be pooled after successful handoff")
}
if shouldRemove {
t.Error("Connection should not be removed after successful handoff")
}
// Wait for the handoff to complete (it happens asynchronously)
timeout := time.After(1 * time.Second)
ticker := time.NewTicker(5 * time.Millisecond)
defer ticker.Stop()
handoffCompleted := false
for !handoffCompleted {
select {
case <-timeout:
t.Fatal("Timeout waiting for handoff to complete")
case <-ticker.C:
if _, pending := processor.GetPendingMap().Load(conn); !pending {
handoffCompleted = true
}
}
}
// Verify that relaxed timeout was applied to the new connection
if !conn.HasRelaxedTimeout() {
t.Error("New connection should have relaxed timeout applied after handoff")
}
// Wait for the post-handoff duration to expire
time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration
// Verify that relaxed timeout was automatically cleared
if conn.HasRelaxedTimeout() {
t.Error("Relaxed timeout should be automatically cleared after post-handoff duration")
}
})
t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) {
conn := createMockPoolConnection()
// First mark should succeed
if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil {
t.Fatalf("First MarkForHandoff should succeed: %v", err)
}
// Second mark should fail
if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil {
t.Fatal("Second MarkForHandoff should return error")
} else if err.Error() != "connection is already marked for handoff" {
t.Fatalf("Expected specific error message, got: %v", err)
}
// Verify original handoff data is preserved
if !conn.ShouldHandoff() {
t.Fatal("Connection should still be marked for handoff")
}
if conn.GetHandoffEndpoint() != "new-endpoint:6379" {
t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint())
}
if conn.GetMovingSeqID() != 1 {
t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID())
}
})
t.Run("HandoffTimeoutConfiguration", func(t *testing.T) {
// Test that HandoffTimeout from config is actually used
customTimeout := 2 * time.Second
config := &Config{
MaxWorkers: 2,
HandoffQueueSize: 10,
HandoffTimeout: customTimeout, // Custom timeout
MaxHandoffRetries: 1, // Single retry to speed up test
}
processor := NewPoolHook(baseDialer, "tcp", config, nil)
defer processor.Shutdown(context.Background())
// Create a connection that will test the timeout
conn := createMockPoolConnection()
if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil {
t.Fatalf("Failed to mark connection for handoff: %v", err)
}
// Set a dialer that will check the context timeout
var timeoutVerified int32 // Use atomic for thread safety
conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error {
// Check that the context has the expected timeout
deadline, ok := ctx.Deadline()
if !ok {
t.Error("Context should have a deadline")
return errors.New("no deadline")
}
// The deadline should be approximately customTimeout from now
expectedDeadline := time.Now().Add(customTimeout)
timeDiff := deadline.Sub(expectedDeadline)
if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond {
t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)",
expectedDeadline, deadline, timeDiff)
} else {
atomic.StoreInt32(&timeoutVerified, 1)
}
return nil // Successful handoff
})
// Trigger handoff
shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn)
if err != nil {
t.Errorf("OnPut should not return error: %v", err)
}
// Connection should be queued for handoff
if !shouldPool || shouldRemove {
t.Errorf("Connection should be pooled for handoff processing")
}
// Wait for handoff to complete
time.Sleep(500 * time.Millisecond)
if atomic.LoadInt32(&timeoutVerified) == 0 {
t.Error("HandoffTimeout was not properly applied to context")
}
t.Logf("HandoffTimeout configuration test completed successfully")
})
}

View File

@@ -0,0 +1,282 @@
package maintnotifications
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *Manager
operationsManager OperationsManagerInterface
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification))
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0]))
return ErrInvalidNotification
}
// Process pre-hooks - they can modify the notification or skip processing
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
if !shouldContinue {
return nil // Hooks decided to skip processing
}
var err error
switch notificationType {
case NotificationMoving:
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
case NotificationMigrating:
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
case NotificationMigrated:
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
case NotificationFailingOver:
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
case NotificationFailedOver:
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
default:
// Ignore other notification types (e.g., pub/sub messages)
err = nil
}
// Process post-hooks with the result
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
return err
}
// handleMoving processes MOVING notifications.
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification))
return ErrInvalidNotification
}
seqID, ok := notification[1].(int64)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1]))
return ErrInvalidNotification
}
// Extract timeS
timeS, ok := notification[2].(int64)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2]))
return ErrInvalidNotification
}
newEndpoint := ""
if len(notification) > 3 {
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
stringified := fmt.Sprintf("%v", notification[3])
// this could be <nil> which is valid
if notification[3] == nil || stringified == internal.RedisNull {
newEndpoint = ""
} else {
internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3]))
return ErrInvalidNotification
}
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING"))
return ErrInvalidNotification
}
// Type assert to get the underlying pool connection
var poolConn *pool.Conn
if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx))
return ErrInvalidNotification
}
// If the connection is closed or not pooled, we can ignore the notification
// this connection won't be remembered by the pool and will be garbage collected
// Keep pubsub connections around since they are not pooled but are long-lived
// and should be allowed to handoff (the pubsub instance will reconnect and change
// the underlying *pool.Conn)
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
return nil
}
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
if internal.LogLevel.DebugOrAbove() {
internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2))
}
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
// delay the handoff for timeS/2 seconds to the same endpoint
// do this in a goroutine to avoid blocking the notification handler
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
// and there should be no possibility of a race condition or double handoff.
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
if poolConn == nil || poolConn.IsClosed() {
return
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine - use background context since original may be cancelled
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err))
}
})
return nil
}
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
}
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err))
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in m
if snh.operationsManager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return errors.New(logs.ManagerNotInitialized())
}
return nil
}
// handleMigrating processes MIGRATING notifications.
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout))
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleMigrated processes MIGRATED notifications.
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
}
conn.ClearRelaxedTimeout()
return nil
}
// handleFailingOver processes FAILING_OVER notifications.
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout))
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleFailedOver processes FAILED_OVER notifications.
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
}
conn.ClearRelaxedTimeout()
return nil
}

View File

@@ -0,0 +1,24 @@
package maintnotifications
// State represents the current state of a maintenance operation
type State int
const (
// StateIdle indicates no upgrade is in progress
StateIdle State = iota
// StateHandoff indicates a connection handoff is in progress
StateMoving
)
// String returns a string representation of the state.
func (s State) String() string {
switch s {
case StateIdle:
return "idle"
case StateMoving:
return "moving"
default:
return "unknown"
}
}

View File

@@ -16,6 +16,9 @@ import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
@@ -109,6 +112,16 @@ type Options struct {
// default: 5 seconds
DialTimeout time.Duration
// DialerRetries is the maximum number of retry attempts when dialing fails.
//
// default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
//
// default: 100 milliseconds
DialerRetryTimeout time.Duration
// ReadTimeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
//
@@ -152,6 +165,7 @@ type Options struct {
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
// default: false
PoolFIFO bool
// PoolSize is the base number of socket connections.
@@ -232,10 +246,24 @@ type Options struct {
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
UnstableResp3 bool
// Push notifications are always enabled for RESP3 connections (Protocol: 3)
// and are not available for RESP2 connections. No configuration option is needed.
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// MaintNotificationsConfig provides custom configuration for maintnotifications.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it.
MaintNotificationsConfig *maintnotifications.Config
}
func (opt *Options) init() {
@@ -255,6 +283,12 @@ func (opt *Options) init() {
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
}
if opt.DialerRetries == 0 {
opt.DialerRetries = 5
}
if opt.DialerRetryTimeout == 0 {
opt.DialerRetryTimeout = 100 * time.Millisecond
}
if opt.Dialer == nil {
opt.Dialer = NewDialer(opt)
}
@@ -312,13 +346,36 @@ func (opt *Options) init() {
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
// auto-detect endpoint type if not specified
endpointType := opt.MaintNotificationsConfig.EndpointType
if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.MaintNotificationsConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone MaintNotificationsConfig to avoid sharing between clients
if opt.MaintNotificationsConfig != nil {
configClone := *opt.MaintNotificationsConfig
clone.MaintNotificationsConfig = &configClone
}
return &clone
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
return NewDialer(opt)
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
@@ -666,21 +723,84 @@ func getUserPassword(u *url.URL) (string, string) {
func newConnPool(
opt *Options,
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) *pool.ConnPool {
) (*pool.ConnPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
})
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PushNotificationsEnabled: opt.Protocol == 3,
}), nil
}
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) (*pool.PubSubPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewPubSubPool(&pool.Options{
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
PushNotificationsEnabled: opt.Protocol == 3,
}, dialer), nil
}

View File

@@ -20,6 +20,8 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
const (
@@ -38,6 +40,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -125,10 +128,22 @@ type ClusterOptions struct {
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
UnstableResp3 bool
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications.
MaintNotificationsConfig *maintnotifications.Config
}
func (opt *ClusterOptions) init() {
@@ -385,6 +400,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone MaintNotificationsConfig to avoid sharing between cluster node clients
var maintNotificationsConfig *maintnotifications.Config
if opt.MaintNotificationsConfig != nil {
configClone := *opt.MaintNotificationsConfig
maintNotificationsConfig = &configClone
}
return &Options{
ClientName: opt.ClientName,
Dialer: opt.Dialer,
@@ -426,8 +448,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
// much use for ClusterSlots config). This means we cannot execute the
// READONLY command against that node -- setting readOnly to false in such
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
MaintNotificationsConfig: maintNotificationsConfig,
PushNotificationProcessor: opt.PushNotificationProcessor,
}
}
@@ -1730,7 +1754,7 @@ func (c *ClusterClient) processTxPipelineNode(
}
func (c *ClusterClient) processTxPipelineNodeConn(
ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
@@ -1748,7 +1772,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
trimmedCmds := cmds[1 : len(cmds)-1]
if err := c.txPipelineReadQueued(
ctx, rd, statusCmd, trimmedCmds, failedCmds,
ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds,
); err != nil {
setCmdsErr(cmds, err)
@@ -1760,30 +1784,56 @@ func (c *ClusterClient) processTxPipelineNodeConn(
return err
}
return pipelineReadCmds(rd, trimmedCmds)
return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
})
}
func (c *ClusterClient) txPipelineReadQueued(
ctx context.Context,
node *clusterNode,
cn *pool.Conn,
rd *proto.Reader,
statusCmd *StatusCmd,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
// Parse queued replies.
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
if err := statusCmd.readReply(rd); err != nil {
return err
}
for _, cmd := range cmds {
err := statusCmd.readReply(rd)
if err == nil || c.checkMovedErr(ctx, cmd, err, failedCmds) || isRedisError(err) {
continue
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
err := statusCmd.readReply(rd)
if err != nil {
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
// will be processed later
continue
}
cmd.SetErr(err)
if !isRedisError(err) {
return err
}
}
return err
}
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse number of replies.
line, err := rd.ReadLine()
if err != nil {
@@ -1889,12 +1939,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// maintenance notifications won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@@ -1928,18 +1978,25 @@ func (c *ClusterClient) pubSub() *PubSub {
return nil, err
}
}
cn, err := node.Client.newConn(context.TODO())
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
if err != nil {
node = nil
return nil, err
}
// will return nil if already initialized
err = node.Client.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
node = nil
return nil, err
}
node.Client.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
err := node.Client.connPool.CloseConn(cn)
// Untrack connection from PubSubPool
node.Client.pubSubPool.UntrackConn(cn)
err := cn.Close()
node = nil
return err
},

View File

@@ -30,9 +30,12 @@ type Pipeliner interface {
// If a certain Redis command is not yet supported, you can use Do to execute it.
Do(ctx context.Context, args ...interface{}) *Cmd
// Process puts the commands to be executed into the pipeline buffer.
// Process queues the cmd for later execution.
Process(ctx context.Context, cmd Cmder) error
// BatchProcess adds multiple commands to be executed into the pipeline buffer.
BatchProcess(ctx context.Context, cmd ...Cmder) error
// Discard discards all commands in the pipeline buffer that have not yet been executed.
Discard()
@@ -79,7 +82,12 @@ func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
// Process queues the cmd for later execution.
func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error {
c.cmds = append(c.cmds, cmd)
return c.BatchProcess(ctx, cmd)
}
// BatchProcess queues multiple cmds for later execution.
func (c *Pipeline) BatchProcess(ctx context.Context, cmd ...Cmder) error {
c.cmds = append(c.cmds, cmd...)
return nil
}

View File

@@ -60,6 +60,39 @@ var _ = Describe("pipelining", func() {
Expect(cmds).To(BeEmpty())
})
It("pipeline: basic exec", func() {
p := client.Pipeline()
p.Get(ctx, "key")
p.Set(ctx, "key", "value", 0)
p.Get(ctx, "key")
cmds, err := p.Exec(ctx)
Expect(err).To(Equal(redis.Nil))
Expect(cmds).To(HaveLen(3))
Expect(cmds[0].Err()).To(Equal(redis.Nil))
Expect(cmds[1].(*redis.StatusCmd).Val()).To(Equal("OK"))
Expect(cmds[1].Err()).NotTo(HaveOccurred())
Expect(cmds[2].(*redis.StringCmd).Val()).To(Equal("value"))
Expect(cmds[2].Err()).NotTo(HaveOccurred())
})
It("pipeline: exec pipeline when get conn failed", func() {
p := client.Pipeline()
p.Get(ctx, "key")
p.Set(ctx, "key", "value", 0)
p.Get(ctx, "key")
client.Close()
cmds, err := p.Exec(ctx)
Expect(err).To(Equal(redis.ErrClosed))
Expect(cmds).To(HaveLen(3))
for _, cmd := range cmds {
Expect(cmd.Err()).To(Equal(redis.ErrClosed))
}
client = redis.NewClient(redisOptions())
})
assertPipeline := func() {
It("returns no errors when there are no commands", func() {
_, err := pipe.Exec(ctx)
@@ -114,6 +147,25 @@ var _ = Describe("pipelining", func() {
err := pipe.Do(ctx).Err()
Expect(err).To(Equal(errors.New("redis: please enter the command to be executed")))
})
It("should process", func() {
err := pipe.Process(ctx, redis.NewCmd(ctx, "asking"))
Expect(err).To(BeNil())
Expect(pipe.Cmds()).To(HaveLen(1))
})
It("should batchProcess", func() {
err := pipe.BatchProcess(ctx, redis.NewCmd(ctx, "asking"))
Expect(err).To(BeNil())
Expect(pipe.Cmds()).To(HaveLen(1))
pipe.Discard()
Expect(pipe.Cmds()).To(HaveLen(0))
err = pipe.BatchProcess(ctx, redis.NewCmd(ctx, "asking"), redis.NewCmd(ctx, "set", "key", "value"))
Expect(err).To(BeNil())
Expect(pipe.Cmds()).To(HaveLen(2))
})
}
Describe("Pipeline", func() {

375
pool_pubsub_bench_test.go Normal file
View File

@@ -0,0 +1,375 @@
// Pool and PubSub Benchmark Suite
//
// This file contains comprehensive benchmarks for both pool operations and PubSub initialization.
// It's designed to be run against different branches to compare performance.
//
// Usage Examples:
// # Run all benchmarks
// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go
//
// # Run only pool benchmarks
// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go
//
// # Run only PubSub benchmarks
// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go
//
// # Compare between branches
// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt
// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt
// benchcmp branch1.txt branch2.txt
//
// # Run with memory profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go
//
// # Run with CPU profiling
// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go
package redis_test
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/internal/pool"
)
// dummyDialer creates a mock connection for benchmarking
func dummyDialer(ctx context.Context) (net.Conn, error) {
return &dummyConn{}, nil
}
// dummyConn implements net.Conn for benchmarking
type dummyConn struct{}
func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Close() error { return nil }
func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} }
func (c *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379}
}
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
// =============================================================================
// POOL BENCHMARKS
// =============================================================================
// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations
func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(poolSize),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0), // Start with no idle connections
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns
func BenchmarkPoolGetPutWithMinIdle(b *testing.B) {
ctx := context.Background()
configs := []struct {
poolSize int
minIdleConns int
}{
{8, 2},
{16, 4},
{32, 8},
{64, 16},
}
for _, config := range configs {
b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(config.poolSize),
MinIdleConns: int32(config.minIdleConns),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency
func BenchmarkPoolConcurrentGetPut(b *testing.B) {
ctx := context.Background()
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(32),
PoolTimeout: time.Second,
DialTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
MinIdleConns: int32(0),
})
defer connPool.Close()
b.ResetTimer()
b.ReportAllocs()
// Test with different levels of concurrency
concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.SetParallelism(concurrency)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
// =============================================================================
// PUBSUB BENCHMARKS
// =============================================================================
// benchmarkClient creates a Redis client for benchmarking with mock dialer
func benchmarkClient(poolSize int) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Mock address
DialTimeout: time.Second,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
PoolSize: poolSize,
MinIdleConns: 0, // Start with no idle connections for consistent benchmarks
})
}
// BenchmarkPubSubCreation benchmarks PubSub creation and subscription
func BenchmarkPubSubCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription
func BenchmarkPubSubPatternCreation(b *testing.B) {
ctx := context.Background()
poolSizes := []int{1, 4, 8, 16, 32}
for _, poolSize := range poolSizes {
b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) {
client := benchmarkClient(poolSize)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.PSubscribe(ctx, "test-*")
pubsub.Close()
}
})
}
}
// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation
func BenchmarkPubSubConcurrentCreation(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
concurrencyLevels := []int{1, 2, 4, 8, 16}
for _, concurrency := range concurrencyLevels {
b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for i := 0; i < b.N; i++ {
wg.Add(1)
semaphore <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-semaphore }()
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}()
}
wg.Wait()
})
}
}
// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels
func BenchmarkPubSubMultipleChannels(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
channelCounts := []int{1, 5, 10, 25, 50, 100}
for _, channelCount := range channelCounts {
b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) {
// Prepare channel names
channels := make([]string, channelCount)
for i := 0; i < channelCount; i++ {
channels[i] = fmt.Sprintf("channel-%d", i)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
pubsub := client.Subscribe(ctx, channels...)
pubsub.Close()
}
})
}
}
// BenchmarkPubSubReuse benchmarks reusing PubSub connections
func BenchmarkPubSubReuse(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Benchmark just the creation and closing of PubSub connections
// This simulates reuse patterns without requiring actual Redis operations
pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i))
pubsub.Close()
}
}
// =============================================================================
// COMBINED BENCHMARKS
// =============================================================================
// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations
func BenchmarkPoolAndPubSubMixed(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// Mix of pool stats collection and PubSub creation
if pb.Next() {
// Pool stats operation
stats := client.PoolStats()
_ = stats.Hits + stats.Misses // Use the stats to prevent optimization
}
if pb.Next() {
// PubSub operation
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
}
})
}
// BenchmarkPoolStatsCollection benchmarks pool statistics collection
func BenchmarkPoolStatsCollection(b *testing.B) {
client := benchmarkClient(16)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
stats := client.PoolStats()
_ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization
}
}
// BenchmarkPoolHighContention tests pool performance under high contention
func BenchmarkPoolHighContention(b *testing.B) {
ctx := context.Background()
client := benchmarkClient(32)
defer client.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// High contention Get/Put operations
pubsub := client.Subscribe(ctx, "test-channel")
pubsub.Close()
}
})
}

Some files were not shown because too many files have changed in this diff Show More